ombrac_server/
service.rs

1use std::future::Future;
2use std::io;
3use std::marker::PhantomData;
4use std::net::UdpSocket;
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::sync::broadcast;
9use tokio::task::JoinHandle;
10
11use ombrac_macros::{error, info, warn};
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::{
14    self, Connection as QuicConnection, TransportConfig as QuicTransportConfig,
15    server::{Config as QuicConfig, Server as QuicServer},
16};
17use ombrac_transport::{Acceptor, Connection};
18
19use crate::config::ServiceConfig;
20use crate::connection::ConnectionHandler;
21use crate::server::Server;
22
23#[cfg(feature = "transport-quic")]
24use crate::config::TlsMode;
25
26#[derive(thiserror::Error, Debug)]
27pub enum Error {
28    #[error("Configuration error: {0}")]
29    Config(String),
30
31    #[error("{0}")]
32    Io(#[from] io::Error),
33
34    #[error("Transport layer error: {0}")]
35    Transport(String),
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
39
40macro_rules! require_config {
41    ($config_opt:expr, $field_name:expr) => {
42        $config_opt.ok_or_else(|| {
43            Error::Config(format!(
44                "'{}' is required but was not provided",
45                $field_name
46            ))
47        })
48    };
49}
50
51pub trait ServiceBuilder {
52    type Acceptor: Acceptor<Connection = Self::Connection>;
53    type Connection: Connection;
54    type Validator: ConnectionHandler<Self::Connection> + 'static;
55
56    fn build(
57        config: &Arc<ServiceConfig>,
58    ) -> impl Future<Output = Result<Arc<Server<Self::Acceptor, Self::Validator>>>> + Send;
59}
60
61#[cfg(feature = "transport-quic")]
62pub struct QuicServiceBuilder;
63
64#[cfg(feature = "transport-quic")]
65impl ServiceBuilder for QuicServiceBuilder {
66    type Acceptor = QuicServer;
67    type Connection = QuicConnection;
68    type Validator = ombrac::protocol::Secret;
69
70    async fn build(
71        config: &Arc<ServiceConfig>,
72    ) -> Result<Arc<Server<Self::Acceptor, Self::Validator>>> {
73        let acceptor = quic_server_from_config(config).await?;
74        let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
75        let server = Arc::new(Server::new(acceptor, secret));
76        Ok(server)
77    }
78}
79
80pub struct Service<T, C>
81where
82    T: Acceptor<Connection = C>,
83    C: Connection,
84{
85    handle: JoinHandle<Result<()>>,
86    shutdown_tx: broadcast::Sender<()>,
87    _acceptor: PhantomData<T>,
88    _connection: PhantomData<C>,
89}
90
91impl<T, C> Service<T, C>
92where
93    T: Acceptor<Connection = C> + Send + Sync + 'static,
94    C: Connection + Send + Sync + 'static,
95{
96    pub async fn build<Builder, V>(config: Arc<ServiceConfig>) -> Result<Self>
97    where
98        Builder: ServiceBuilder<Acceptor = T, Connection = C, Validator = V>,
99        V: ConnectionHandler<C> + Send + Sync + 'static,
100    {
101        let server = Builder::build(&config).await?;
102        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
103
104        let handle =
105            tokio::spawn(async move { server.accept_loop(shutdown_rx).await.map_err(Error::Io) });
106
107        Ok(Service {
108            handle,
109            shutdown_tx,
110            _acceptor: PhantomData,
111            _connection: PhantomData,
112        })
113    }
114
115    pub async fn shutdown(self) {
116        let _ = self.shutdown_tx.send(());
117        match self.handle.await {
118            Ok(Ok(_)) => {}
119            Ok(Err(_e)) => {
120                error!("The main server task exited with an error: {_e}");
121            }
122            Err(_e) => {
123                error!("The main server task failed to shut down cleanly: {_e}");
124            }
125        }
126        warn!("Service shutdown complete");
127    }
128}
129
130#[cfg(feature = "transport-quic")]
131async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
132    let transport_cfg = &config.transport;
133    let mut quic_config = QuicConfig::new();
134
135    quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
136    if let Some(protocols) = &transport_cfg.alpn_protocols {
137        quic_config.alpn_protocols = protocols.clone();
138    }
139
140    match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
141        TlsMode::Tls => {
142            let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
143            let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
144            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
145        }
146        TlsMode::MTls => {
147            let cert_path = require_config!(
148                transport_cfg.tls_cert.clone(),
149                "transport.tls_cert for mTLS"
150            )?;
151            let key_path =
152                require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
153            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
154            quic_config.root_ca_path = Some(require_config!(
155                transport_cfg.ca_cert.clone(),
156                "transport.ca_cert for mTLS"
157            )?);
158        }
159        TlsMode::Insecure => {
160            warn!("TLS is running in insecure mode. Self-signed certificates will be generated.");
161            quic_config.enable_self_signed = true;
162        }
163    }
164
165    let mut transport_config = QuicTransportConfig::default();
166    let map_transport_err = |e: quic::error::Error| Error::Transport(e.to_string());
167    if let Some(timeout) = transport_cfg.idle_timeout {
168        transport_config
169            .max_idle_timeout(Duration::from_millis(timeout))
170            .map_err(map_transport_err)?;
171    }
172    if let Some(interval) = transport_cfg.keep_alive {
173        transport_config
174            .keep_alive_period(Duration::from_millis(interval))
175            .map_err(map_transport_err)?;
176    }
177    if let Some(max_streams) = transport_cfg.max_streams {
178        transport_config
179            .max_open_bidirectional_streams(max_streams)
180            .map_err(map_transport_err)?;
181    }
182    if let Some(congestion) = transport_cfg.congestion {
183        transport_config
184            .congestion(congestion, transport_cfg.cwnd_init)
185            .map_err(map_transport_err)?;
186    }
187    quic_config.transport_config(transport_config);
188
189    info!("Binding UDP socket to {}", config.listen);
190    let socket = UdpSocket::bind(config.listen)?;
191
192    QuicServer::new(socket, quic_config)
193        .await
194        .map_err(|e| Error::Transport(e.to_string()))
195}