Skip to main content

ombrac_server/
service.rs

1use std::io;
2use std::net::UdpSocket;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac_macros::{error, info, warn};
10use ombrac_transport::quic::TransportConfig as QuicTransportConfig;
11use ombrac_transport::quic::error::Error as QuicError;
12use ombrac_transport::quic::server::Config as QuicConfig;
13use ombrac_transport::quic::server::Server as QuicServer;
14
15use crate::config::{ServiceConfig, TlsMode};
16use crate::connection::ConnectionAcceptor;
17
18#[derive(thiserror::Error, Debug)]
19pub enum Error {
20    #[error(transparent)]
21    Io(#[from] io::Error),
22
23    #[error("{0}")]
24    Config(String),
25
26    #[error(transparent)]
27    Quic(#[from] QuicError),
28}
29
30pub type Result<T> = std::result::Result<T, Error>;
31
32macro_rules! require_config {
33    ($config_opt:expr, $field_name:expr) => {
34        $config_opt.ok_or_else(|| {
35            Error::Config(format!(
36                "'{}' is required but was not provided",
37                $field_name
38            ))
39        })
40    };
41}
42
43/// OmbracServer provides a simple, easy-to-use API for starting and managing
44/// the ombrac server using QUIC transport.
45///
46/// This struct hides all transport-specific implementation details and provides
47/// a clean interface for external users.
48///
49/// # Example
50///
51/// ```no_run
52/// use ombrac_server::{OmbracServer, ServiceConfig};
53/// use std::sync::Arc;
54///
55/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
56/// let config = Arc::new(ServiceConfig {
57///     secret: "my-secret".to_string(),
58///     listen: "0.0.0.0:8080".parse()?,
59///     transport: Default::default(),
60///     connection: Default::default(),
61///     logging: Default::default(),
62/// });
63///
64/// let server = OmbracServer::build(config).await?;
65/// // ... use server ...
66/// server.shutdown().await;
67/// # Ok(())
68/// # }
69/// ```
70pub struct OmbracServer {
71    handle: JoinHandle<Result<()>>,
72    shutdown_tx: broadcast::Sender<()>,
73}
74
75impl OmbracServer {
76    /// Builds a new server instance from the configuration.
77    ///
78    /// This method:
79    /// 1. Creates a QUIC server from the transport configuration
80    /// 2. Sets up connection validation using the secret
81    /// 3. Spawns the accept loop in a background task
82    /// 4. Returns an OmbracServer handle for lifecycle management
83    ///
84    /// # Arguments
85    ///
86    /// * `config` - The service configuration containing transport, connection, and secret settings
87    ///
88    /// # Returns
89    ///
90    /// A configured `OmbracServer` instance ready to accept connections, or an error
91    /// if configuration is invalid or server setup fails.
92    pub async fn build(config: Arc<ServiceConfig>) -> Result<Self> {
93        // Build QUIC server from config
94        let acceptor = quic_server_from_config(&config).await?;
95
96        // Create secret authenticator from config
97        let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
98
99        // Create connection acceptor with connection config
100        let connection_config = Arc::new(config.connection.clone());
101        let acceptor = Arc::new(ConnectionAcceptor::with_config(
102            acceptor,
103            secret,
104            connection_config,
105        ));
106
107        // Set up shutdown channel
108        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
109
110        // Spawn accept loop
111        let handle =
112            tokio::spawn(async move { acceptor.accept_loop(shutdown_rx).await.map_err(Error::Io) });
113
114        Ok(OmbracServer {
115            handle,
116            shutdown_tx,
117        })
118    }
119
120    /// Gracefully shuts down the server.
121    ///
122    /// This method will:
123    /// 1. Send a shutdown signal to stop accepting new connections
124    /// 2. Wait for the accept loop to finish gracefully
125    /// 3. Wait for existing connections to close
126    ///
127    /// # Example
128    ///
129    /// ```no_run
130    /// # use ombrac_server::OmbracServer;
131    /// # use std::sync::Arc;
132    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
133    /// # let config = Arc::new(ombrac_server::ServiceConfig {
134    /// #     secret: "test".to_string(),
135    /// #     listen: "0.0.0.0:0".parse()?,
136    /// #     transport: Default::default(),
137    /// #     connection: Default::default(),
138    /// #     logging: Default::default(),
139    /// # });
140    /// # let server = OmbracServer::build(config).await?;
141    /// server.shutdown().await;
142    /// # Ok(())
143    /// # }
144    /// ```
145    pub async fn shutdown(self) {
146        let _ = self.shutdown_tx.send(());
147        match self.handle.await {
148            Ok(Ok(_)) => {}
149            Ok(Err(e)) => {
150                error!("the main server task exited with an error: {e}");
151            }
152            Err(e) => {
153                error!("the main server task failed to shut down cleanly: {e}");
154            }
155        }
156    }
157}
158
159async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
160    let transport_cfg = &config.transport;
161    let mut quic_config = QuicConfig::new();
162
163    quic_config.enable_zero_rtt = transport_cfg.zero_rtt();
164    quic_config.alpn_protocols = transport_cfg.alpn_protocols();
165
166    match transport_cfg.tls_mode() {
167        TlsMode::Tls => {
168            let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
169            let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
170            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
171        }
172        TlsMode::MTls => {
173            let cert_path = require_config!(
174                transport_cfg.tls_cert.clone(),
175                "transport.tls_cert for mTLS"
176            )?;
177            let key_path =
178                require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
179            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
180            quic_config.root_ca_path = Some(require_config!(
181                transport_cfg.ca_cert.clone(),
182                "transport.ca_cert for mTLS"
183            )?);
184        }
185        TlsMode::Insecure => {
186            warn!("tls is in insecure mode; generating self-signed certificates for local/dev use");
187            quic_config.enable_self_signed = true;
188        }
189    }
190
191    let mut transport_config = QuicTransportConfig::default();
192    let map_transport_err = |e: QuicError| Error::Quic(e);
193    transport_config
194        .max_idle_timeout(Duration::from_millis(transport_cfg.idle_timeout()))
195        .map_err(map_transport_err)?;
196    transport_config
197        .keep_alive_period(Duration::from_millis(transport_cfg.keep_alive()))
198        .map_err(map_transport_err)?;
199    transport_config
200        .max_open_bidirectional_streams(transport_cfg.max_streams())
201        .map_err(map_transport_err)?;
202    transport_config
203        .congestion(transport_cfg.congestion(), transport_cfg.cwnd_init)
204        .map_err(map_transport_err)?;
205    quic_config.transport_config(transport_config);
206
207    info!("binding udp socket to {}", config.listen);
208    let socket = UdpSocket::bind(config.listen)?;
209
210    QuicServer::new(socket, quic_config)
211        .await
212        .map_err(Error::Quic)
213}