1use std::future::Future;
2use std::io;
3use std::marker::PhantomData;
4use std::net::UdpSocket;
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::sync::broadcast;
9use tokio::task::JoinHandle;
10
11use ombrac_macros::{error, info, warn};
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::{
14 self, Connection as QuicConnection, TransportConfig as QuicTransportConfig,
15 server::{Config as QuicConfig, Server as QuicServer},
16};
17use ombrac_transport::{Acceptor, Connection};
18
19use crate::config::ServiceConfig;
20use crate::server::Server;
21
22#[cfg(feature = "transport-quic")]
23use crate::config::TlsMode;
24
25#[derive(thiserror::Error, Debug)]
26pub enum Error {
27 #[error("Configuration error: {0}")]
28 Config(String),
29
30 #[error("{0}")]
31 Io(#[from] io::Error),
32
33 #[error("Transport layer error: {0}")]
34 Transport(String),
35}
36
37pub type Result<T> = std::result::Result<T, Error>;
38
39macro_rules! require_config {
40 ($config_opt:expr, $field_name:expr) => {
41 $config_opt.ok_or_else(|| {
42 Error::Config(format!(
43 "'{}' is required but was not provided",
44 $field_name
45 ))
46 })
47 };
48}
49
50pub trait ServiceBuilder {
51 type Acceptor: Acceptor<Connection = Self::Connection>;
52 type Connection: Connection;
53
54 fn build(
55 config: &Arc<ServiceConfig>,
56 ) -> impl Future<Output = Result<Arc<Server<Self::Acceptor>>>> + Send;
57}
58
59#[cfg(feature = "transport-quic")]
60pub struct QuicServiceBuilder;
61
62#[cfg(feature = "transport-quic")]
63impl ServiceBuilder for QuicServiceBuilder {
64 type Acceptor = QuicServer;
65 type Connection = QuicConnection;
66
67 async fn build(config: &Arc<ServiceConfig>) -> Result<Arc<Server<Self::Acceptor>>> {
68 let acceptor = quic_server_from_config(config).await?;
69 let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
70 let server = Arc::new(Server::new(acceptor, secret));
71 Ok(server)
72 }
73}
74
75pub struct Service<T, C>
76where
77 T: Acceptor<Connection = C>,
78 C: Connection,
79{
80 handle: JoinHandle<Result<()>>,
81 shutdown_tx: broadcast::Sender<()>,
82 _acceptor: PhantomData<T>,
83 _connection: PhantomData<C>,
84}
85
86impl<T, C> Service<T, C>
87where
88 T: Acceptor<Connection = C> + Send + Sync + 'static,
89 C: Connection + Send + Sync + 'static,
90{
91 pub async fn build<Builder>(config: Arc<ServiceConfig>) -> Result<Self>
92 where
93 Builder: ServiceBuilder<Acceptor = T, Connection = C>,
94 {
95 let server = Builder::build(&config).await?;
96 let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
97
98 let handle =
99 tokio::spawn(async move { server.accept_loop(shutdown_rx).await.map_err(Error::Io) });
100
101 Ok(Service {
102 handle,
103 shutdown_tx,
104 _acceptor: PhantomData,
105 _connection: PhantomData,
106 })
107 }
108
109 pub async fn shutdown(self) {
110 let _ = self.shutdown_tx.send(());
111 match self.handle.await {
112 Ok(Ok(_)) => {}
113 Ok(Err(_e)) => {
114 error!("The main server task exited with an error: {_e}");
115 }
116 Err(_e) => {
117 error!("The main server task failed to shut down cleanly: {_e}");
118 }
119 }
120 warn!("Service shutdown complete");
121 }
122}
123
124#[cfg(feature = "transport-quic")]
125async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
126 let transport_cfg = &config.transport;
127 let mut quic_config = QuicConfig::new();
128
129 quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
130 if let Some(protocols) = &transport_cfg.alpn_protocols {
131 quic_config.alpn_protocols = protocols.clone();
132 }
133
134 match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
135 TlsMode::Tls => {
136 let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
137 let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
138 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
139 }
140 TlsMode::MTls => {
141 let cert_path = require_config!(
142 transport_cfg.tls_cert.clone(),
143 "transport.tls_cert for mTLS"
144 )?;
145 let key_path =
146 require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
147 quic_config.tls_cert_key_paths = Some((cert_path, key_path));
148 quic_config.root_ca_path = Some(require_config!(
149 transport_cfg.ca_cert.clone(),
150 "transport.ca_cert for mTLS"
151 )?);
152 }
153 TlsMode::Insecure => {
154 warn!("TLS is running in insecure mode. Self-signed certificates will be generated.");
155 quic_config.enable_self_signed = true;
156 }
157 }
158
159 let mut transport_config = QuicTransportConfig::default();
160 let map_transport_err = |e: quic::error::Error| Error::Transport(e.to_string());
161 if let Some(timeout) = transport_cfg.idle_timeout {
162 transport_config
163 .max_idle_timeout(Duration::from_millis(timeout))
164 .map_err(map_transport_err)?;
165 }
166 if let Some(interval) = transport_cfg.keep_alive {
167 transport_config
168 .keep_alive_period(Duration::from_millis(interval))
169 .map_err(map_transport_err)?;
170 }
171 if let Some(max_streams) = transport_cfg.max_streams {
172 transport_config
173 .max_open_bidirectional_streams(max_streams)
174 .map_err(map_transport_err)?;
175 }
176 if let Some(congestion) = transport_cfg.congestion {
177 transport_config
178 .congestion(congestion, transport_cfg.cwnd_init)
179 .map_err(map_transport_err)?;
180 }
181 quic_config.transport_config(transport_config);
182
183 info!("Binding UDP socket to {}", config.listen);
184 let socket = UdpSocket::bind(config.listen)?;
185
186 QuicServer::new(socket, quic_config)
187 .await
188 .map_err(|e| Error::Transport(e.to_string()))
189}