ombrac_server/
service.rs

1use std::future::Future;
2use std::io;
3use std::marker::PhantomData;
4use std::net::UdpSocket;
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::sync::broadcast;
9use tokio::task::JoinHandle;
10
11use ombrac_macros::{error, info, warn};
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::{
14    self, Connection as QuicConnection, TransportConfig as QuicTransportConfig,
15    server::{Config as QuicConfig, Server as QuicServer},
16};
17use ombrac_transport::{Acceptor, Connection};
18
19use crate::config::ServiceConfig;
20use crate::server::Server;
21
22#[cfg(feature = "transport-quic")]
23use crate::config::TlsMode;
24
25#[derive(thiserror::Error, Debug)]
26pub enum Error {
27    #[error("Configuration error: {0}")]
28    Config(String),
29
30    #[error("{0}")]
31    Io(#[from] io::Error),
32
33    #[error("Transport layer error: {0}")]
34    Transport(String),
35}
36
37pub type Result<T> = std::result::Result<T, Error>;
38
39macro_rules! require_config {
40    ($config_opt:expr, $field_name:expr) => {
41        $config_opt.ok_or_else(|| {
42            Error::Config(format!(
43                "'{}' is required but was not provided",
44                $field_name
45            ))
46        })
47    };
48}
49
50pub trait ServiceBuilder {
51    type Acceptor: Acceptor<Connection = Self::Connection>;
52    type Connection: Connection;
53
54    fn build(
55        config: &Arc<ServiceConfig>,
56    ) -> impl Future<Output = Result<Arc<Server<Self::Acceptor>>>> + Send;
57}
58
59#[cfg(feature = "transport-quic")]
60pub struct QuicServiceBuilder;
61
62#[cfg(feature = "transport-quic")]
63impl ServiceBuilder for QuicServiceBuilder {
64    type Acceptor = QuicServer;
65    type Connection = QuicConnection;
66
67    async fn build(config: &Arc<ServiceConfig>) -> Result<Arc<Server<Self::Acceptor>>> {
68        let acceptor = quic_server_from_config(config).await?;
69        let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
70        let server = Arc::new(Server::new(acceptor, secret));
71        Ok(server)
72    }
73}
74
75pub struct Service<T, C>
76where
77    T: Acceptor<Connection = C>,
78    C: Connection,
79{
80    handle: JoinHandle<Result<()>>,
81    shutdown_tx: broadcast::Sender<()>,
82    _acceptor: PhantomData<T>,
83    _connection: PhantomData<C>,
84}
85
86impl<T, C> Service<T, C>
87where
88    T: Acceptor<Connection = C> + Send + Sync + 'static,
89    C: Connection + Send + Sync + 'static,
90{
91    pub async fn build<Builder>(config: Arc<ServiceConfig>) -> Result<Self>
92    where
93        Builder: ServiceBuilder<Acceptor = T, Connection = C>,
94    {
95        let server = Builder::build(&config).await?;
96        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
97
98        let handle =
99            tokio::spawn(async move { server.accept_loop(shutdown_rx).await.map_err(Error::Io) });
100
101        Ok(Service {
102            handle,
103            shutdown_tx,
104            _acceptor: PhantomData,
105            _connection: PhantomData,
106        })
107    }
108
109    pub async fn shutdown(self) {
110        let _ = self.shutdown_tx.send(());
111        match self.handle.await {
112            Ok(Ok(_)) => {}
113            Ok(Err(_e)) => {
114                error!("The main server task exited with an error: {_e}");
115            }
116            Err(_e) => {
117                error!("The main server task failed to shut down cleanly: {_e}");
118            }
119        }
120        warn!("Service shutdown complete");
121    }
122}
123
124#[cfg(feature = "transport-quic")]
125async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
126    let transport_cfg = &config.transport;
127    let mut quic_config = QuicConfig::new();
128
129    quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
130    if let Some(protocols) = &transport_cfg.alpn_protocols {
131        quic_config.alpn_protocols = protocols.clone();
132    }
133
134    match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
135        TlsMode::Tls => {
136            let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
137            let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
138            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
139        }
140        TlsMode::MTls => {
141            let cert_path = require_config!(
142                transport_cfg.tls_cert.clone(),
143                "transport.tls_cert for mTLS"
144            )?;
145            let key_path =
146                require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
147            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
148            quic_config.root_ca_path = Some(require_config!(
149                transport_cfg.ca_cert.clone(),
150                "transport.ca_cert for mTLS"
151            )?);
152        }
153        TlsMode::Insecure => {
154            warn!("TLS is running in insecure mode. Self-signed certificates will be generated.");
155            quic_config.enable_self_signed = true;
156        }
157    }
158
159    let mut transport_config = QuicTransportConfig::default();
160    let map_transport_err = |e: quic::error::Error| Error::Transport(e.to_string());
161    if let Some(timeout) = transport_cfg.idle_timeout {
162        transport_config
163            .max_idle_timeout(Duration::from_millis(timeout))
164            .map_err(map_transport_err)?;
165    }
166    if let Some(interval) = transport_cfg.keep_alive {
167        transport_config
168            .keep_alive_period(Duration::from_millis(interval))
169            .map_err(map_transport_err)?;
170    }
171    if let Some(max_streams) = transport_cfg.max_streams {
172        transport_config
173            .max_open_bidirectional_streams(max_streams)
174            .map_err(map_transport_err)?;
175    }
176    if let Some(congestion) = transport_cfg.congestion {
177        transport_config
178            .congestion(congestion, transport_cfg.cwnd_init)
179            .map_err(map_transport_err)?;
180    }
181    quic_config.transport_config(transport_config);
182
183    info!("Binding UDP socket to {}", config.listen);
184    let socket = UdpSocket::bind(config.listen)?;
185
186    QuicServer::new(socket, quic_config)
187        .await
188        .map_err(|e| Error::Transport(e.to_string()))
189}