use std::path::PathBuf;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tonic::transport::server::Router;
use super::{
ContainerType, ServerConfig, ServerInfo, SocketCleanup, create_listener_stream,
init_panic_hook, shutdown_signal,
};
pub trait ServerExtras<T> {
fn transform_inner<F>(self, f: F) -> Self
where
F: FnOnce(Server<T>) -> Server<T>,
Self: Sized;
fn inner_ref(&self) -> &Server<T>;
fn with_socket_file(self, file: impl Into<PathBuf>) -> Self
where
Self: Sized,
{
self.transform_inner(|inner| inner.with_socket_file(file))
}
fn socket_file<'a>(&'a self) -> &'a std::path::Path
where
T: 'a,
{
self.inner_ref().socket_file()
}
fn with_max_message_size(self, message_size: usize) -> Self
where
Self: Sized,
{
self.transform_inner(|inner| inner.with_max_message_size(message_size))
}
fn max_message_size(&self) -> usize {
self.inner_ref().max_message_size()
}
fn with_server_info_file(self, file: impl Into<PathBuf>) -> Self
where
Self: Sized,
{
self.transform_inner(|inner| inner.with_server_info_file(file))
}
fn server_info_file<'a>(&'a self) -> &'a std::path::Path
where
T: 'a,
{
self.inner_ref().server_info_file()
}
}
#[derive(Debug)]
pub(crate) struct ServerStarter {
config: ServerConfig,
container_type: ContainerType,
_cleanup: SocketCleanup,
init_panic_hook: bool,
}
#[allow(dead_code)]
impl ServerStarter {
fn new(
container_type: ContainerType,
default_sock_addr: &str,
default_server_info_file: &str,
) -> Self {
let config = ServerConfig::new(default_sock_addr, default_server_info_file);
let cleanup = SocketCleanup::new(default_sock_addr.into(), default_server_info_file.into());
Self {
config,
container_type,
_cleanup: cleanup,
init_panic_hook: true,
}
}
fn with_panic_hook(mut self, init_panic_hook: bool) -> Self {
self.init_panic_hook = init_panic_hook;
self
}
fn with_socket_file(mut self, file: impl Into<PathBuf>) -> Self {
let file_path = file.into();
self.config = self.config.with_socket_file(&file_path);
self._cleanup = SocketCleanup::new(file_path, self.config.server_info_file().to_path_buf());
self
}
fn socket_file(&self) -> &std::path::Path {
self.config.socket_file()
}
fn with_max_message_size(mut self, message_size: usize) -> Self {
self.config = self.config.with_max_message_size(message_size);
self
}
fn max_message_size(&self) -> usize {
self.config.max_message_size()
}
fn with_server_info_file(mut self, file: impl Into<PathBuf>) -> Self {
let file_path = file.into();
self.config = self.config.with_server_info_file(&file_path);
self._cleanup = SocketCleanup::new(self.config.socket_file().to_path_buf(), file_path);
self
}
fn server_info_file(&self) -> &std::path::Path {
self.config.server_info_file()
}
async fn start_server<F>(
self,
shutdown_rx: Option<oneshot::Receiver<()>>,
service_builder: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
F: FnOnce(mpsc::Sender<()>, CancellationToken) -> Router,
{
if self.init_panic_hook {
init_panic_hook();
}
let info = ServerInfo::new(self.container_type);
let listener = create_listener_stream(
self.config.socket_file(),
self.config.server_info_file(),
info,
)?;
let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1);
let cln_token = CancellationToken::new();
let router = service_builder(internal_shutdown_tx, cln_token.clone());
let shutdown = shutdown_signal(internal_shutdown_rx, shutdown_rx, cln_token);
router
.serve_with_incoming_shutdown(listener, shutdown)
.await?;
Ok(())
}
}
#[derive(Debug)]
pub struct Server<T> {
starter: ServerStarter,
svc: T,
}
#[allow(dead_code)]
impl<T> Server<T> {
pub(crate) fn new(
svc: T,
container_type: ContainerType,
default_sock_addr: &str,
default_server_info_file: &str,
) -> Self {
let starter =
ServerStarter::new(container_type, default_sock_addr, default_server_info_file);
Self { starter, svc }
}
pub(crate) fn new_with_custom_paths(
svc: T,
container_type: ContainerType,
sock_addr: &str,
server_info_file: &str,
) -> Self {
let starter = ServerStarter::new(container_type, sock_addr, server_info_file);
Self { starter, svc }
}
pub(crate) fn with_socket_file(mut self, file: impl Into<PathBuf>) -> Self {
self.starter = self.starter.with_socket_file(file);
self
}
pub(crate) fn socket_file(&self) -> &std::path::Path {
self.starter.socket_file()
}
pub(crate) fn with_max_message_size(mut self, message_size: usize) -> Self {
self.starter = self.starter.with_max_message_size(message_size);
self
}
pub(crate) fn max_message_size(&self) -> usize {
self.starter.max_message_size()
}
pub(crate) fn with_server_info_file(mut self, file: impl Into<PathBuf>) -> Self {
self.starter = self.starter.with_server_info_file(file);
self
}
pub(crate) fn server_info_file(&self) -> &std::path::Path {
self.starter.server_info_file()
}
pub async fn start_with_shutdown<F>(
self,
shutdown_rx: oneshot::Receiver<()>,
service_builder: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
F: FnOnce(T, usize, mpsc::Sender<()>, CancellationToken) -> Router + Send + 'static,
T: Send + Sync + 'static,
{
let handler = self.svc;
let max_message_size = self.starter.max_message_size();
self.starter
.start_server(Some(shutdown_rx), |shutdown_tx, cln_token| {
service_builder(handler, max_message_size, shutdown_tx, cln_token)
})
.await
}
pub async fn start<F>(
self,
service_builder: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
F: FnOnce(T, usize, mpsc::Sender<()>, CancellationToken) -> Router + Send + 'static,
T: Send + Sync + 'static,
{
let handler = self.svc;
let max_message_size = self.starter.max_message_size();
self.starter
.start_server(None, |shutdown_tx, cln_token| {
service_builder(handler, max_message_size, shutdown_tx, cln_token)
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_server_starter_creation() {
let starter = ServerStarter::new(ContainerType::Map, "/tmp/test.sock", "/tmp/test-info");
assert_eq!(
starter.socket_file(),
std::path::Path::new("/tmp/test.sock")
);
assert_eq!(
starter.server_info_file(),
std::path::Path::new("/tmp/test-info")
);
assert_eq!(starter.max_message_size(), 64 * 1024 * 1024); }
#[test]
fn test_server_starter_configuration() {
let tmp_dir = TempDir::new().unwrap();
let sock_file = tmp_dir.path().join("custom.sock");
let info_file = tmp_dir.path().join("custom-info");
let starter = ServerStarter::new(ContainerType::Map, "/tmp/test.sock", "/tmp/test-info")
.with_socket_file(&sock_file)
.with_server_info_file(&info_file)
.with_max_message_size(1024)
.with_panic_hook(false);
assert_eq!(starter.socket_file(), sock_file);
assert_eq!(starter.server_info_file(), info_file);
assert_eq!(starter.max_message_size(), 1024);
assert!(!starter.init_panic_hook);
}
#[test]
fn test_create_server_config() {
let starter = ServerStarter::new(
ContainerType::Reduce,
"/var/run/numaflow/reduce.sock",
"/var/run/numaflow/reducer-server-info",
);
assert_eq!(
starter.socket_file(),
std::path::Path::new("/var/run/numaflow/reduce.sock")
);
assert_eq!(
starter.server_info_file(),
std::path::Path::new("/var/run/numaflow/reducer-server-info")
);
}
}