Skip to main content

ombrac_server/
service.rs

1use std::io;
2use std::net::UdpSocket;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac::metrics::Metrics;
10use ombrac::protocol::Secret;
11use ombrac_macros::{error, info, warn};
12use ombrac_transport::quic::TransportConfig as QuicTransportConfig;
13use ombrac_transport::quic::error::Error as QuicError;
14use ombrac_transport::quic::server::Config as QuicConfig;
15use ombrac_transport::quic::server::Server as QuicServer;
16
17use crate::config::{ServiceConfig, TlsMode};
18use crate::connection::ConnectionAcceptor;
19
20type BuiltAcceptor = ConnectionAcceptor<QuicServer, Secret>;
21
22#[derive(thiserror::Error, Debug)]
23pub enum Error {
24    #[error(transparent)]
25    Io(#[from] io::Error),
26
27    #[error("{0}")]
28    Config(String),
29
30    #[error(transparent)]
31    Quic(#[from] QuicError),
32}
33
34pub type Result<T> = std::result::Result<T, Error>;
35
36macro_rules! require_config {
37    ($config_opt:expr, $field_name:expr) => {
38        $config_opt.ok_or_else(|| {
39            Error::Config(format!(
40                "'{}' is required but was not provided",
41                $field_name
42            ))
43        })
44    };
45}
46
47/// OmbracServer provides a simple, easy-to-use API for starting and managing
48/// the ombrac server using QUIC transport.
49///
50/// This struct hides all transport-specific implementation details and provides
51/// a clean interface for external users.
52///
53/// # Example
54///
55/// ```no_run
56/// use ombrac_server::{OmbracServer, ServiceConfig};
57/// use std::sync::Arc;
58///
59/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
60/// let config = Arc::new(ServiceConfig {
61///     secret: "my-secret".to_string(),
62///     listen: "0.0.0.0:8080".parse()?,
63///     transport: Default::default(),
64///     connection: Default::default(),
65///     logging: Default::default(),
66/// });
67///
68/// let server = OmbracServer::build(config).await?;
69/// // ... use server ...
70/// server.shutdown().await;
71/// # Ok(())
72/// # }
73/// ```
74pub struct OmbracServer {
75    handle: JoinHandle<Result<()>>,
76    shutdown_tx: broadcast::Sender<()>,
77    metrics: Metrics,
78    // Held to keep the QUIC endpoint alive after the accept loop exits, so
79    // `shutdown_with_drain` can wait for in-flight streams without the
80    // underlying transport being torn down.
81    _acceptor_keepalive: Arc<BuiltAcceptor>,
82}
83
84impl OmbracServer {
85    /// Builds a new server instance from the configuration.
86    ///
87    /// This method:
88    /// 1. Creates a QUIC server from the transport configuration
89    /// 2. Sets up connection validation using the secret
90    /// 3. Spawns the accept loop in a background task
91    /// 4. Returns an OmbracServer handle for lifecycle management
92    ///
93    /// # Arguments
94    ///
95    /// * `config` - The service configuration containing transport, connection, and secret settings
96    ///
97    /// # Returns
98    ///
99    /// A configured `OmbracServer` instance ready to accept connections, or an error
100    /// if configuration is invalid or server setup fails.
101    pub async fn build(config: Arc<ServiceConfig>) -> Result<Self> {
102        // Build QUIC server from config
103        let acceptor = quic_server_from_config(&config).await?;
104
105        // Create secret authenticator from config
106        let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
107
108        // Create connection acceptor with connection config
109        let connection_config = Arc::new(config.connection.clone());
110        let acceptor = Arc::new(ConnectionAcceptor::with_config(
111            acceptor,
112            secret,
113            connection_config,
114        ));
115        let metrics = acceptor.metrics();
116
117        // Set up shutdown channel
118        let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
119
120        // Spawn accept loop. Hold a strong reference to the acceptor (and thus
121        // the underlying QUIC server) outside the task as well, so the
122        // transport stays alive after the accept loop exits — needed for
123        // shutdown_with_drain to let in-flight streams finish naturally.
124        let acceptor_for_task = Arc::clone(&acceptor);
125        let handle = tokio::spawn(async move {
126            acceptor_for_task
127                .accept_loop(shutdown_rx)
128                .await
129                .map_err(Error::Io)
130        });
131
132        Ok(OmbracServer {
133            handle,
134            shutdown_tx,
135            metrics,
136            _acceptor_keepalive: acceptor,
137        })
138    }
139
140    /// Returns a clone-able handle to runtime metrics for this server.
141    ///
142    /// Callers can snapshot or read individual counters at any time:
143    /// `server.metrics().snapshot()`.
144    pub fn metrics(&self) -> Metrics {
145        self.metrics.clone()
146    }
147
148    /// Gracefully shuts down the server.
149    ///
150    /// This method will:
151    /// 1. Send a shutdown signal to stop accepting new connections
152    /// 2. Wait for the accept loop to finish gracefully
153    /// 3. Wait for existing connections to close
154    ///
155    /// # Example
156    ///
157    /// ```no_run
158    /// # use ombrac_server::OmbracServer;
159    /// # use std::sync::Arc;
160    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
161    /// # let config = Arc::new(ombrac_server::ServiceConfig {
162    /// #     secret: "test".to_string(),
163    /// #     listen: "0.0.0.0:0".parse()?,
164    /// #     transport: Default::default(),
165    /// #     connection: Default::default(),
166    /// #     logging: Default::default(),
167    /// # });
168    /// # let server = OmbracServer::build(config).await?;
169    /// server.shutdown().await;
170    /// # Ok(())
171    /// # }
172    /// ```
173    pub async fn shutdown(self) {
174        let _ = self.shutdown_tx.send(());
175        match self.handle.await {
176            Ok(Ok(_)) => {}
177            Ok(Err(e)) => {
178                error!("the main server task exited with an error: {e}");
179            }
180            Err(e) => {
181                error!("the main server task failed to shut down cleanly: {e}");
182            }
183        }
184    }
185
186    /// Stops accepting new connections, then waits up to `drain_timeout` for
187    /// in-flight streams to close naturally before returning.
188    ///
189    /// Returns `true` if all streams drained within the timeout, `false` if
190    /// the timeout elapsed with active streams still in flight. In the latter
191    /// case, those streams' tasks will keep running until their underlying
192    /// QUIC connection closes (typically via idle timeout) — they are not
193    /// hard-cancelled by this call.
194    ///
195    /// Use this for rolling restarts where you want existing client requests
196    /// to complete before the process exits.
197    pub async fn shutdown_with_drain(self, drain_timeout: Duration) -> bool {
198        let _ = self.shutdown_tx.send(());
199
200        // Poll the stream counter until balanced (no active streams) or timeout.
201        let deadline = tokio::time::Instant::now() + drain_timeout;
202        let mut interval = tokio::time::interval(Duration::from_millis(50));
203        let drained = loop {
204            interval.tick().await;
205            let snap = self.metrics.snapshot();
206            let active_streams = snap.streams_opened.saturating_sub(snap.streams_closed);
207            if active_streams == 0 {
208                break true;
209            }
210            if tokio::time::Instant::now() >= deadline {
211                warn!(
212                    active_streams = active_streams,
213                    timeout_ms = drain_timeout.as_millis() as u64,
214                    "drain timed out with streams still active"
215                );
216                break false;
217            }
218        };
219
220        // Either way, wait for the accept-loop task to clean up.
221        match self.handle.await {
222            Ok(Ok(_)) => {}
223            Ok(Err(e)) => error!("server task exited with error: {e}"),
224            Err(e) => error!("server task failed to shut down cleanly: {e}"),
225        }
226
227        drained
228    }
229}
230
231async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
232    let transport_cfg = &config.transport;
233    let mut quic_config = QuicConfig::new();
234
235    quic_config.enable_zero_rtt = transport_cfg.zero_rtt();
236    quic_config.alpn_protocols = transport_cfg.alpn_protocols();
237
238    match transport_cfg.tls_mode() {
239        TlsMode::Tls => {
240            let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
241            let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
242            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
243        }
244        TlsMode::MTls => {
245            let cert_path = require_config!(
246                transport_cfg.tls_cert.clone(),
247                "transport.tls_cert for mTLS"
248            )?;
249            let key_path =
250                require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
251            quic_config.tls_cert_key_paths = Some((cert_path, key_path));
252            quic_config.root_ca_path = Some(require_config!(
253                transport_cfg.ca_cert.clone(),
254                "transport.ca_cert for mTLS"
255            )?);
256        }
257        TlsMode::Insecure => {
258            warn!(
259                "================================================================"
260            );
261            warn!(
262                "TLS IS IN INSECURE MODE — self-signed certificates will be used."
263            );
264            warn!(
265                "This bypasses any meaningful authentication of the SERVER and"
266            );
267            warn!(
268                "is intended for local development and CI only. DO NOT use this"
269            );
270            warn!(
271                "mode on a network you don't control."
272            );
273            warn!(
274                "================================================================"
275            );
276            quic_config.enable_self_signed = true;
277        }
278    }
279
280    let mut transport_config = QuicTransportConfig::default();
281    let map_transport_err = |e: QuicError| Error::Quic(e);
282    transport_config
283        .max_idle_timeout(Duration::from_millis(transport_cfg.idle_timeout()))
284        .map_err(map_transport_err)?;
285    transport_config
286        .keep_alive_period(Duration::from_millis(transport_cfg.keep_alive()))
287        .map_err(map_transport_err)?;
288    transport_config
289        .max_open_bidirectional_streams(transport_cfg.max_streams())
290        .map_err(map_transport_err)?;
291    transport_config
292        .congestion(transport_cfg.congestion(), transport_cfg.cwnd_init)
293        .map_err(map_transport_err)?;
294    quic_config.transport_config(transport_config);
295
296    info!("binding udp socket to {}", config.listen);
297    let socket = UdpSocket::bind(config.listen)?;
298
299    QuicServer::new(socket, quic_config)
300        .await
301        .map_err(Error::Quic)
302}