use std::io;
use std::net::UdpSocket;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use ombrac_macros::{error, info, warn};
use ombrac_transport::quic::TransportConfig as QuicTransportConfig;
use ombrac_transport::quic::error::Error as QuicError;
use ombrac_transport::quic::server::Config as QuicConfig;
use ombrac_transport::quic::server::Server as QuicServer;
use crate::config::{ServiceConfig, TlsMode};
use crate::connection::ConnectionAcceptor;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error(transparent)]
Io(#[from] io::Error),
#[error("{0}")]
Config(String),
#[error(transparent)]
Quic(#[from] QuicError),
}
pub type Result<T> = std::result::Result<T, Error>;
macro_rules! require_config {
($config_opt:expr, $field_name:expr) => {
$config_opt.ok_or_else(|| {
Error::Config(format!(
"'{}' is required but was not provided",
$field_name
))
})
};
}
pub struct OmbracServer {
handle: JoinHandle<Result<()>>,
shutdown_tx: broadcast::Sender<()>,
}
impl OmbracServer {
pub async fn build(config: Arc<ServiceConfig>) -> Result<Self> {
let acceptor = quic_server_from_config(&config).await?;
let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
let connection_config = Arc::new(config.connection.clone());
let acceptor = Arc::new(ConnectionAcceptor::with_config(
acceptor,
secret,
connection_config,
));
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let handle =
tokio::spawn(async move { acceptor.accept_loop(shutdown_rx).await.map_err(Error::Io) });
Ok(OmbracServer {
handle,
shutdown_tx,
})
}
pub async fn shutdown(self) {
let _ = self.shutdown_tx.send(());
match self.handle.await {
Ok(Ok(_)) => {}
Ok(Err(e)) => {
error!("the main server task exited with an error: {e}");
}
Err(e) => {
error!("the main server task failed to shut down cleanly: {e}");
}
}
}
}
async fn quic_server_from_config(config: &ServiceConfig) -> Result<QuicServer> {
let transport_cfg = &config.transport;
let mut quic_config = QuicConfig::new();
quic_config.enable_zero_rtt = transport_cfg.zero_rtt();
quic_config.alpn_protocols = transport_cfg.alpn_protocols();
match transport_cfg.tls_mode() {
TlsMode::Tls => {
let cert_path = require_config!(transport_cfg.tls_cert.clone(), "transport.tls_cert")?;
let key_path = require_config!(transport_cfg.tls_key.clone(), "transport.tls_key")?;
quic_config.tls_cert_key_paths = Some((cert_path, key_path));
}
TlsMode::MTls => {
let cert_path = require_config!(
transport_cfg.tls_cert.clone(),
"transport.tls_cert for mTLS"
)?;
let key_path =
require_config!(transport_cfg.tls_key.clone(), "transport.tls_key for mTLS")?;
quic_config.tls_cert_key_paths = Some((cert_path, key_path));
quic_config.root_ca_path = Some(require_config!(
transport_cfg.ca_cert.clone(),
"transport.ca_cert for mTLS"
)?);
}
TlsMode::Insecure => {
warn!("tls is in insecure mode; generating self-signed certificates for local/dev use");
quic_config.enable_self_signed = true;
}
}
let mut transport_config = QuicTransportConfig::default();
let map_transport_err = |e: QuicError| Error::Quic(e);
transport_config
.max_idle_timeout(Duration::from_millis(transport_cfg.idle_timeout()))
.map_err(map_transport_err)?;
transport_config
.keep_alive_period(Duration::from_millis(transport_cfg.keep_alive()))
.map_err(map_transport_err)?;
transport_config
.max_open_bidirectional_streams(transport_cfg.max_streams())
.map_err(map_transport_err)?;
transport_config
.congestion(transport_cfg.congestion(), transport_cfg.cwnd_init)
.map_err(map_transport_err)?;
quic_config.transport_config(transport_config);
info!("binding udp socket to {}", config.listen);
let socket = UdpSocket::bind(config.listen)?;
QuicServer::new(socket, quic_config)
.await
.map_err(Error::Quic)
}