1use std::io;
2use std::net::UdpSocket;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac_macros::{error, info, warn};
10use ombrac_transport::quic::TransportConfig as QuicTransportConfig;
11use ombrac_transport::quic::error::Error as QuicError;
12use ombrac_transport::quic::server::Config as QuicConfig;
13use ombrac_transport::quic::server::Server as QuicServer;
14
15use crate::config::{ServiceConfig, TlsMode};
16use crate::connection::ConnectionAcceptor;
17
18#[derive(thiserror::Error, Debug)]
19pub enum Error {
20 #[error(transparent)]
21 Io(#[from] io::Error),
22
23 #[error("{0}")]
24 Config(String),
25
26 #[error(transparent)]
27 Quic(#[from] QuicError),
28}
29
30pub type Result<T> = std::result::Result<T, Error>;
31
32macro_rules! require_config {
33 ($config_opt:expr, $field_name:expr) => {
34 $config_opt.ok_or_else(|| {
35 Error::Config(format!(
36 "'{}' is required but was not provided",
37 $field_name
38 ))
39 })
40 };
41}
42
43pub struct OmbracServer {
71 handle: JoinHandle<Result<()>>,
72 shutdown_tx: broadcast::Sender<()>,
73}
74
75impl OmbracServer {
76 pub async fn build(config: Arc<ServiceConfig>) -> Result<Self> {
93 let acceptor = quic_server_from_config(&config).await?;
95
96 let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
98
99 let connection_config = Arc::new(config.connection.clone());
101 let acceptor = Arc::new(ConnectionAcceptor::with_config(
102 acceptor,
103 secret,
104 connection_config,
105 ));
106
107 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
109
110 let handle =
112 tokio::spawn(async move { acceptor.accept_loop(shutdown_rx).await.map_err(Error::Io) });
113
114 Ok(OmbracServer {
115 handle,
116 shutdown_tx,
117 })
118 }
119
120 pub async fn shutdown(self) {
146 let _ = self.shutdown_tx.send(());
147 match self.handle.await {
148 Ok(Ok(_)) => {}
149 Ok(Err(e)) => {
150 error!("the main server task exited with an error: {e}");
151 }
152 Err(e) => {
153 error!("the main server task failed to shut down cleanly: {e}");
154 }
155 }
156 }
157}
158
159async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
160 let transport_cfg = &config.transport;
161 let mut quic_config = QuicConfig::new();
162
163 quic_config.enable_zero_rtt = transport_cfg.zero_rtt();
164 quic_config.alpn_protocols = transport_cfg.alpn_protocols();
165
166 match transport_cfg.tls_mode() {
167 TlsMode::Tls => {
168 let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
169 let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
170 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
171 }
172 TlsMode::MTls => {
173 let cert_path = require_config!(
174 transport_cfg.tls_cert.clone(),
175 "transport.tls_cert for mTLS"
176 )?;
177 let key_path =
178 require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
179 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
180 quic_config.root_ca_path = Some(require_config!(
181 transport_cfg.ca_cert.clone(),
182 "transport.ca_cert for mTLS"
183 )?);
184 }
185 TlsMode::Insecure => {
186 warn!("tls is in insecure mode; generating self-signed certificates for local/dev use");
187 quic_config.enable_self_signed = true;
188 }
189 }
190
191 let mut transport_config = QuicTransportConfig::default();
192 let map_transport_err = |e: QuicError| Error::Quic(e);
193 transport_config
194 .max_idle_timeout(Duration::from_millis(transport_cfg.idle_timeout()))
195 .map_err(map_transport_err)?;
196 transport_config
197 .keep_alive_period(Duration::from_millis(transport_cfg.keep_alive()))
198 .map_err(map_transport_err)?;
199 transport_config
200 .max_open_bidirectional_streams(transport_cfg.max_streams())
201 .map_err(map_transport_err)?;
202 transport_config
203 .congestion(transport_cfg.congestion(), transport_cfg.cwnd_init)
204 .map_err(map_transport_err)?;
205 quic_config.transport_config(transport_config);
206
207 info!("binding udp socket to {}", config.listen);
208 let socket = UdpSocket::bind(config.listen)?;
209
210 QuicServer::new(socket, quic_config)
211 .await
212 .map_err(Error::Quic)
213}