1use std::future::Future;
2use std::io;
3use std::marker::PhantomData;
4use std::net::UdpSocket;
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::sync::broadcast;
9use tokio::task::JoinHandle;
10
11use ombrac_macros::{error, info, warn};
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::{
14 self, Connection as QuicConnection, TransportConfig as QuicTransportConfig,
15 server::{Config as QuicConfig, Server as QuicServer},
16};
17use ombrac_transport::{Acceptor, Connection};
18
19use crate::config::ServiceConfig;
20use crate::connection::HandshakeValidator;
21use crate::server::Server;
22
23#[cfg(feature = "transport-quic")]
24use crate::config::TlsMode;
25
26#[derive(thiserror::Error, Debug)]
27pub enum Error {
28 #[error("Configuration error: {0}")]
29 Config(String),
30
31 #[error("{0}")]
32 Io(#[from] io::Error),
33
34 #[error("Transport layer error: {0}")]
35 Transport(String),
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
39
40macro_rules! require_config {
41 ($config_opt:expr, $field_name:expr) => {
42 $config_opt.ok_or_else(|| {
43 Error::Config(format!(
44 "'{}' is required but was not provided",
45 $field_name
46 ))
47 })
48 };
49}
50
51pub trait ServiceBuilder {
52 type Acceptor: Acceptor<Connection = Self::Connection>;
53 type Connection: Connection;
54 type Validator: HandshakeValidator + 'static;
55
56 fn build(
57 config: &Arc<ServiceConfig>,
58 ) -> impl Future<Output = Result<Arc<Server<Self::Acceptor, Self::Validator>>>> + Send;
59}
60
61#[cfg(feature = "transport-quic")]
62pub struct QuicServiceBuilder;
63
64#[cfg(feature = "transport-quic")]
65impl ServiceBuilder for QuicServiceBuilder {
66 type Acceptor = QuicServer;
67 type Connection = QuicConnection;
68 type Validator = ombrac::protocol::Secret;
69
70 async fn build(
71 config: &Arc<ServiceConfig>,
72 ) -> Result<Arc<Server<Self::Acceptor, Self::Validator>>> {
73 let acceptor = quic_server_from_config(config).await?;
74 let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
75 let server = Arc::new(Server::new(acceptor, secret));
76 Ok(server)
77 }
78}
79
80pub struct Service<T, C>
81where
82 T: Acceptor<Connection = C>,
83 C: Connection,
84{
85 handle: JoinHandle<Result<()>>,
86 shutdown_tx: broadcast::Sender<()>,
87 _acceptor: PhantomData<T>,
88 _connection: PhantomData<C>,
89}
90
91impl<T, C> Service<T, C>
92where
93 T: Acceptor<Connection = C> + Send + Sync + 'static,
94 C: Connection + Send + Sync + 'static,
95{
96 pub async fn build<Builder, V>(config: Arc<ServiceConfig>) -> Result<Self>
97 where
98 Builder: ServiceBuilder<Acceptor = T, Connection = C, Validator = V>,
99 V: HandshakeValidator + Send + Sync + 'static,
100 {
101 let server = Builder::build(&config).await?;
102 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
103
104 let handle =
105 tokio::spawn(async move { server.accept_loop(shutdown_rx).await.map_err(Error::Io) });
106
107 Ok(Service {
108 handle,
109 shutdown_tx,
110 _acceptor: PhantomData,
111 _connection: PhantomData,
112 })
113 }
114
115 pub async fn shutdown(self) {
116 let _ = self.shutdown_tx.send(());
117 match self.handle.await {
118 Ok(Ok(_)) => {}
119 Ok(Err(_e)) => {
120 error!("The main server task exited with an error: {_e}");
121 }
122 Err(_e) => {
123 error!("The main server task failed to shut down cleanly: {_e}");
124 }
125 }
126 warn!("Service shutdown complete");
127 }
128}
129
130#[cfg(feature = "transport-quic")]
131async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
132 let transport_cfg = &config.transport;
133 let mut quic_config = QuicConfig::new();
134
135 quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
136 if let Some(protocols) = &transport_cfg.alpn_protocols {
137 quic_config.alpn_protocols = protocols.clone();
138 }
139
140 match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
141 TlsMode::Tls => {
142 let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
143 let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
144 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
145 }
146 TlsMode::MTls => {
147 let cert_path = require_config!(
148 transport_cfg.tls_cert.clone(),
149 "transport.tls_cert for mTLS"
150 )?;
151 let key_path =
152 require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
153 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
154 quic_config.root_ca_path = Some(require_config!(
155 transport_cfg.ca_cert.clone(),
156 "transport.ca_cert for mTLS"
157 )?);
158 }
159 TlsMode::Insecure => {
160 warn!("TLS is running in insecure mode. Self-signed certificates will be generated.");
161 quic_config.enable_self_signed = true;
162 }
163 }
164
165 let mut transport_config = QuicTransportConfig::default();
166 let map_transport_err = |e: quic::error::Error| Error::Transport(e.to_string());
167 if let Some(timeout) = transport_cfg.idle_timeout {
168 transport_config
169 .max_idle_timeout(Duration::from_millis(timeout))
170 .map_err(map_transport_err)?;
171 }
172 if let Some(interval) = transport_cfg.keep_alive {
173 transport_config
174 .keep_alive_period(Duration::from_millis(interval))
175 .map_err(map_transport_err)?;
176 }
177 if let Some(max_streams) = transport_cfg.max_streams {
178 transport_config
179 .max_open_bidirectional_streams(max_streams)
180 .map_err(map_transport_err)?;
181 }
182 if let Some(congestion) = transport_cfg.congestion {
183 transport_config
184 .congestion(congestion, transport_cfg.cwnd_init)
185 .map_err(map_transport_err)?;
186 }
187 quic_config.transport_config(transport_config);
188
189 info!("Binding UDP socket to {}", config.listen);
190 let socket = UdpSocket::bind(config.listen)?;
191
192 QuicServer::new(socket, quic_config)
193 .await
194 .map_err(|e| Error::Transport(e.to_string()))
195}