ombrac_server/
server.rs

1use std::io;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Instant;
5
6use tokio::sync::broadcast;
7#[cfg(feature = "tracing")]
8use tracing::Instrument;
9
10use ombrac_macros::{error, info};
11use ombrac_transport::{Acceptor, Connection};
12
13use crate::connection::{ConnectionDriver, ConnectionHandler};
14
15pub struct Server<T, V> {
16    acceptor: Arc<T>,
17    validator: Arc<V>,
18}
19
20impl<T: Acceptor, V: ConnectionHandler<T::Connection> + 'static> Server<T, V> {
21    pub fn new(acceptor: T, validator: V) -> Self {
22        Self {
23            acceptor: Arc::new(acceptor),
24            validator: Arc::new(validator),
25        }
26    }
27
28    pub async fn accept_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> io::Result<()> {
29        loop {
30            tokio::select! {
31                _ = shutdown_rx.recv() => break,
32                accepted = self.acceptor.accept() => {
33                    match accepted {
34                        Ok(connection) => {
35                            let validator = Arc::clone(&self.validator);
36                            #[cfg(not(feature = "tracing"))]
37                            tokio::spawn(Self::handle_connection(connection, validator));
38                            #[cfg(feature = "tracing")]
39                            tokio::spawn(Self::handle_connection(connection, validator).in_current_span());
40                        },
41                        Err(_err) => error!("failed to accept connection: {}", _err)
42                    }
43                },
44            }
45        }
46
47        Ok(())
48    }
49
50    #[cfg_attr(feature = "tracing",
51        tracing::instrument(
52            name = "connection",
53            skip_all,
54            fields(
55                id = connection.id(),
56                from = tracing::field::Empty,
57                secret = tracing::field::Empty
58            )
59        )
60    )]
61    pub async fn handle_connection(connection: <T as Acceptor>::Connection, validator: Arc<V>) {
62        #[cfg(feature = "tracing")]
63        let created_at = Instant::now();
64
65        let peer_addr = match connection.remote_address() {
66            Ok(addr) => addr,
67            Err(_err) => {
68                return error!("failed to get remote address for incoming connection {_err}");
69            }
70        };
71
72        #[cfg(feature = "tracing")]
73        tracing::Span::current().record("from", tracing::field::display(peer_addr));
74
75        let reason: std::borrow::Cow<'static, str> = {
76            match ConnectionDriver::handle(connection, validator.as_ref()).await {
77                Ok(_) => "ok".into(),
78                Err(e) => {
79                    if matches!(
80                        e.kind(),
81                        io::ErrorKind::ConnectionReset
82                            | io::ErrorKind::BrokenPipe
83                            | io::ErrorKind::UnexpectedEof
84                    ) {
85                        format!("client disconnect: {}", e.kind()).into()
86                    } else {
87                        error!("connection handler failed: {e}");
88                        format!("error: {e}").into()
89                    }
90                }
91            }
92        };
93
94        info!(
95            duration = created_at.elapsed().as_millis(),
96            reason = %reason.as_ref(),
97            "connection closed"
98        );
99    }
100
101    pub fn local_addr(&self) -> io::Result<SocketAddr> {
102        self.acceptor.local_addr()
103    }
104}