1use std::io;
2use std::net::UdpSocket;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac::metrics::Metrics;
10use ombrac::protocol::Secret;
11use ombrac_macros::{error, info, warn};
12use ombrac_transport::quic::TransportConfig as QuicTransportConfig;
13use ombrac_transport::quic::error::Error as QuicError;
14use ombrac_transport::quic::server::Config as QuicConfig;
15use ombrac_transport::quic::server::Server as QuicServer;
16
17use crate::config::{ServiceConfig, TlsMode};
18use crate::connection::ConnectionAcceptor;
19
20type BuiltAcceptor = ConnectionAcceptor<QuicServer, Secret>;
21
22#[derive(thiserror::Error, Debug)]
23pub enum Error {
24 #[error(transparent)]
25 Io(#[from] io::Error),
26
27 #[error("{0}")]
28 Config(String),
29
30 #[error(transparent)]
31 Quic(#[from] QuicError),
32}
33
34pub type Result<T> = std::result::Result<T, Error>;
35
36macro_rules! require_config {
37 ($config_opt:expr, $field_name:expr) => {
38 $config_opt.ok_or_else(|| {
39 Error::Config(format!(
40 "'{}' is required but was not provided",
41 $field_name
42 ))
43 })
44 };
45}
46
47pub struct OmbracServer {
75 handle: JoinHandle<Result<()>>,
76 shutdown_tx: broadcast::Sender<()>,
77 metrics: Metrics,
78 _acceptor_keepalive: Arc<BuiltAcceptor>,
82}
83
84impl OmbracServer {
85 pub async fn build(config: Arc<ServiceConfig>) -> Result<Self> {
102 let acceptor = quic_server_from_config(&config).await?;
104
105 let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
107
108 let connection_config = Arc::new(config.connection.clone());
110 let acceptor = Arc::new(ConnectionAcceptor::with_config(
111 acceptor,
112 secret,
113 connection_config,
114 ));
115 let metrics = acceptor.metrics();
116
117 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
119
120 let acceptor_for_task = Arc::clone(&acceptor);
125 let handle = tokio::spawn(async move {
126 acceptor_for_task
127 .accept_loop(shutdown_rx)
128 .await
129 .map_err(Error::Io)
130 });
131
132 Ok(OmbracServer {
133 handle,
134 shutdown_tx,
135 metrics,
136 _acceptor_keepalive: acceptor,
137 })
138 }
139
140 pub fn metrics(&self) -> Metrics {
145 self.metrics.clone()
146 }
147
148 pub async fn shutdown(self) {
174 let _ = self.shutdown_tx.send(());
175 match self.handle.await {
176 Ok(Ok(_)) => {}
177 Ok(Err(e)) => {
178 error!("the main server task exited with an error: {e}");
179 }
180 Err(e) => {
181 error!("the main server task failed to shut down cleanly: {e}");
182 }
183 }
184 }
185
186 pub async fn shutdown_with_drain(self, drain_timeout: Duration) -> bool {
198 let _ = self.shutdown_tx.send(());
199
200 let deadline = tokio::time::Instant::now() + drain_timeout;
202 let mut interval = tokio::time::interval(Duration::from_millis(50));
203 let drained = loop {
204 interval.tick().await;
205 let snap = self.metrics.snapshot();
206 let active_streams = snap.streams_opened.saturating_sub(snap.streams_closed);
207 if active_streams == 0 {
208 break true;
209 }
210 if tokio::time::Instant::now() >= deadline {
211 warn!(
212 active_streams = active_streams,
213 timeout_ms = drain_timeout.as_millis() as u64,
214 "drain timed out with streams still active"
215 );
216 break false;
217 }
218 };
219
220 match self.handle.await {
222 Ok(Ok(_)) => {}
223 Ok(Err(e)) => error!("server task exited with error: {e}"),
224 Err(e) => error!("server task failed to shut down cleanly: {e}"),
225 }
226
227 drained
228 }
229}
230
231async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
232 let transport_cfg = &config.transport;
233 let mut quic_config = QuicConfig::new();
234
235 quic_config.enable_zero_rtt = transport_cfg.zero_rtt();
236 quic_config.alpn_protocols = transport_cfg.alpn_protocols();
237
238 match transport_cfg.tls_mode() {
239 TlsMode::Tls => {
240 let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
241 let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
242 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
243 }
244 TlsMode::MTls => {
245 let cert_path = require_config!(
246 transport_cfg.tls_cert.clone(),
247 "transport.tls_cert for mTLS"
248 )?;
249 let key_path =
250 require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
251 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
252 quic_config.root_ca_path = Some(require_config!(
253 transport_cfg.ca_cert.clone(),
254 "transport.ca_cert for mTLS"
255 )?);
256 }
257 TlsMode::Insecure => {
258 warn!(
259 "================================================================"
260 );
261 warn!(
262 "TLS IS IN INSECURE MODE — self-signed certificates will be used."
263 );
264 warn!(
265 "This bypasses any meaningful authentication of the SERVER and"
266 );
267 warn!(
268 "is intended for local development and CI only. DO NOT use this"
269 );
270 warn!(
271 "mode on a network you don't control."
272 );
273 warn!(
274 "================================================================"
275 );
276 quic_config.enable_self_signed = true;
277 }
278 }
279
280 let mut transport_config = QuicTransportConfig::default();
281 let map_transport_err = |e: QuicError| Error::Quic(e);
282 transport_config
283 .max_idle_timeout(Duration::from_millis(transport_cfg.idle_timeout()))
284 .map_err(map_transport_err)?;
285 transport_config
286 .keep_alive_period(Duration::from_millis(transport_cfg.keep_alive()))
287 .map_err(map_transport_err)?;
288 transport_config
289 .max_open_bidirectional_streams(transport_cfg.max_streams())
290 .map_err(map_transport_err)?;
291 transport_config
292 .congestion(transport_cfg.congestion(), transport_cfg.cwnd_init)
293 .map_err(map_transport_err)?;
294 quic_config.transport_config(transport_config);
295
296 info!("binding udp socket to {}", config.listen);
297 let socket = UdpSocket::bind(config.listen)?;
298
299 QuicServer::new(socket, quic_config)
300 .await
301 .map_err(Error::Quic)
302}