1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Instant;
5
6use tokio::sync::broadcast;
7#[cfg(feature = "tracing")]
8use tracing::Instrument;
9
10use ombrac_macros::{error, info};
11use ombrac_transport::{Acceptor, Connection};
12
13use crate::connection::{ConnectionDriver, ConnectionHandler};
14
15pub struct Server<T, V> {
16 acceptor: Arc<T>,
17 validator: Arc<V>,
18}
19
20impl<T: Acceptor, V: ConnectionHandler<T::Connection> + 'static> Server<T, V> {
21 pub fn new(acceptor: T, validator: V) -> Self {
22 Self {
23 acceptor: Arc::new(acceptor),
24 validator: Arc::new(validator),
25 }
26 }
27
28 pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
29 loop {
30 tokio::select! {
31 _ = shutdown_rx.recv() => break,
32 accepted = self.acceptor.accept() => {
33 match accepted {
34 Ok(connection) => {
35 let validator = Arc::clone(&self.validator);
36 #[cfg(not(feature = "tracing"))]
37 tokio::spawn(Self::handle_connection(connection, validator));
38 #[cfg(feature = "tracing")]
39 tokio::spawn(Self::handle_connection(connection, validator).in_current_span());
40 },
41 Err(_err) => error!("failed to accept connection: {}", _err)
42 }
43 },
44 }
45 }
46
47 Ok(())
48 }
49
50 #[cfg_attr(feature = "tracing",
51 tracing::instrument(
52 name = "connection",
53 skip_all,
54 fields(
55 id = connection.id(),
56 from = tracing::field::Empty,
57 secret = tracing::field::Empty
58 )
59 )
60 )]
61 pub async fn handle_connection(connection: <T as Acceptor>::Connection, validator: Arc<V>) {
62 #[cfg(feature = "tracing")]
63 let created_at = Instant::now();
64
65 let peer_addr = match connection.remote_address() {
66 Ok(addr) => addr,
67 Err(_err) => {
68 return error!("failed to get remote address for incoming connection {_err}");
69 }
70 };
71
72 #[cfg(feature = "tracing")]
73 tracing::Span::current().record("from", tracing::field::display(peer_addr));
74
75 let reason: std::borrow::Cow<'static, str> = {
76 match ConnectionDriver::handle(connection, validator.as_ref()).await {
77 Ok(_) => "ok".into(),
78 Err(e) => {
79 if matches!(
80 e.kind(),
81 io::ErrorKind::ConnectionReset
82 | io::ErrorKind::BrokenPipe
83 | io::ErrorKind::UnexpectedEof
84 ) {
85 format!("client disconnect: {}", e.kind()).into()
86 } else {
87 error!("connection handler failed: {e}");
88 format!("error: {e}").into()
89 }
90 }
91 }
92 };
93
94 info!(
95 duration = created_at.elapsed().as_millis(),
96 reason = %reason.as_ref(),
97 "connection closed"
98 );
99 }
100
101 pub fn local_addr(&self) -> io::Result<SocketAddr> {
102 self.acceptor.local_addr()
103 }
104}