use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{Context, Result};
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use smooth_operator::auth::NoAuthVerifier;
use crate::config::{ServerConfig, StorageBackend};
use crate::server::{build_state, router};
use crate::state::AppState;
pub const DEFAULT_LOCAL_ADDR: &str = "127.0.0.1:8787";
#[derive(Debug, Clone)]
pub struct LocalServerBuilder {
addr: SocketAddr,
seed_kb: bool,
config: Option<ServerConfig>,
}
impl Default for LocalServerBuilder {
fn default() -> Self {
Self {
addr: DEFAULT_LOCAL_ADDR
.parse()
.expect("DEFAULT_LOCAL_ADDR is a valid SocketAddr"),
seed_kb: false,
config: None,
}
}
}
impl LocalServerBuilder {
#[must_use]
pub fn addr(mut self, addr: SocketAddr) -> Self {
self.addr = addr;
self
}
#[must_use]
pub fn seed_kb(mut self, seed: bool) -> Self {
self.seed_kb = seed;
self
}
#[must_use]
pub fn config(mut self, config: ServerConfig) -> Self {
self.config = Some(config);
self
}
fn build(&self) -> AppState {
let mut config = self.config.clone().unwrap_or_else(local_config);
config.storage = StorageBackend::Memory;
config.bind = self.addr.ip().to_string();
config.port = self.addr.port();
config.seed_kb = self.seed_kb;
build_state(config).with_auth(Arc::new(NoAuthVerifier::default()))
}
pub async fn spawn(self) -> Result<LocalServer> {
let listener = TcpListener::bind(self.addr)
.await
.with_context(|| format!("binding local smooth-operator server on {}", self.addr))?;
let addr = listener.local_addr().context("local addr")?;
let app = router(self.build());
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let join = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await
.context("serving local smooth-operator connections")
});
Ok(LocalServer {
addr,
shutdown_tx: Some(shutdown_tx),
join,
})
}
}
#[must_use = "the server stops when the handle is dropped; hold it for the server's lifetime"]
pub struct LocalServer {
addr: SocketAddr,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
join: JoinHandle<Result<()>>,
}
impl LocalServer {
pub fn builder() -> LocalServerBuilder {
LocalServerBuilder::default()
}
#[must_use]
pub fn addr(&self) -> SocketAddr {
self.addr
}
#[must_use]
pub fn ws_url(&self) -> String {
format!("ws://{}/ws", self.addr)
}
pub async fn shutdown(mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
match (&mut self.join).await {
Ok(result) => result,
Err(join_err) => Err(anyhow::anyhow!("local server task failed: {join_err}")),
}
}
}
impl Drop for LocalServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
#[must_use]
pub fn local_config() -> ServerConfig {
let mut config = ServerConfig::from_env();
config.storage = StorageBackend::Memory;
config
}
pub async fn serve_local(addr: &str) -> Result<()> {
let addr: SocketAddr = addr
.parse()
.with_context(|| format!("parsing local bind address '{addr}'"))?;
let server = LocalServer::builder().addr(addr).spawn().await?;
let local = server.addr();
println!("smooth-operator-server (local flavor) listening on ws://{local}/ws");
tracing::info!(%local, endpoint = "/ws", "smooth-operator-server (local flavor) listening");
server.run_to_completion().await
}
impl LocalServer {
async fn run_to_completion(mut self) -> Result<()> {
match (&mut self.join).await {
Ok(result) => result,
Err(join_err) => Err(anyhow::anyhow!("local server task failed: {join_err}")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn spawn_binds_ephemeral_and_reports_real_addr() {
let server = LocalServer::builder()
.addr("127.0.0.1:0".parse().unwrap())
.spawn()
.await
.expect("spawn local server");
let addr = server.addr();
assert_ne!(addr.port(), 0, "ephemeral port must be resolved: {addr}");
assert!(server.ws_url().starts_with("ws://127.0.0.1:"));
server.shutdown().await.expect("clean shutdown");
}
#[tokio::test]
async fn build_uses_in_memory_storage_and_no_auth() {
let state = LocalServerBuilder::default()
.config(ServerConfig {
storage: StorageBackend::Postgres,
..local_config()
})
.build();
assert_eq!(state.config.storage, StorageBackend::Memory);
assert_eq!(state.auth.mode(), "none");
}
}