use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{Context, Result};
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use smooth_operator::auth::{AuthVerifier, NoAuthVerifier};
use smooth_operator::tool_provider::ToolProvider;
use crate::config::{ServerConfig, StorageBackend};
use crate::server::{build_state, router};
use crate::state::AppState;
pub const DEFAULT_LOCAL_ADDR: &str = "127.0.0.1:8787";
#[derive(Clone)]
pub struct LocalServerBuilder {
addr: SocketAddr,
seed_kb: bool,
config: Option<ServerConfig>,
auth: Option<Arc<dyn AuthVerifier>>,
tool_provider: Option<Arc<dyn ToolProvider>>,
serve_widget: bool,
widget_token: Option<String>,
}
impl std::fmt::Debug for LocalServerBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalServerBuilder")
.field("addr", &self.addr)
.field("seed_kb", &self.seed_kb)
.field("config", &self.config)
.field("auth", &self.auth.as_ref().map(|a| a.mode()))
.field("tool_provider", &self.tool_provider.is_some())
.field("serve_widget", &self.serve_widget)
.finish()
}
}
impl Default for LocalServerBuilder {
fn default() -> Self {
Self {
addr: DEFAULT_LOCAL_ADDR
.parse()
.expect("DEFAULT_LOCAL_ADDR is a valid SocketAddr"),
seed_kb: false,
config: None,
auth: None,
tool_provider: None,
serve_widget: false,
widget_token: None,
}
}
}
impl LocalServerBuilder {
#[must_use]
pub fn addr(mut self, addr: SocketAddr) -> Self {
self.addr = addr;
self
}
#[must_use]
pub fn seed_kb(mut self, seed: bool) -> Self {
self.seed_kb = seed;
self
}
#[must_use]
pub fn auth(mut self, auth: Arc<dyn AuthVerifier>) -> Self {
self.auth = Some(auth);
self
}
#[must_use]
pub fn tools(mut self, provider: Arc<dyn ToolProvider>) -> Self {
self.tool_provider = Some(provider);
self
}
#[must_use]
pub fn serve_widget(mut self, token: Option<String>) -> Self {
self.serve_widget = true;
self.widget_token = token;
self
}
#[must_use]
pub fn config(mut self, config: ServerConfig) -> Self {
self.config = Some(config);
self
}
fn build(&self) -> AppState {
let mut config = self.config.clone().unwrap_or_else(local_config);
config.storage = StorageBackend::Memory;
config.bind = self.addr.ip().to_string();
config.port = self.addr.port();
config.seed_kb = self.seed_kb;
let auth = self
.auth
.clone()
.unwrap_or_else(|| Arc::new(NoAuthVerifier::default()) as Arc<dyn AuthVerifier>);
let mut state = build_state(config).with_auth(auth);
if let Some(provider) = &self.tool_provider {
state = state.with_tools(Arc::clone(provider));
}
if self.serve_widget {
state = state.with_widget(self.widget_token.clone());
}
state
}
pub async fn spawn(self) -> Result<LocalServer> {
let listener = TcpListener::bind(self.addr)
.await
.with_context(|| format!("binding local smooth-operator server on {}", self.addr))?;
let addr = listener.local_addr().context("local addr")?;
let app = router(self.build());
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let join = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await
.context("serving local smooth-operator connections")
});
Ok(LocalServer {
addr,
shutdown_tx: Some(shutdown_tx),
join,
})
}
}
#[must_use = "the server stops when the handle is dropped; hold it for the server's lifetime"]
pub struct LocalServer {
addr: SocketAddr,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
join: JoinHandle<Result<()>>,
}
impl LocalServer {
pub fn builder() -> LocalServerBuilder {
LocalServerBuilder::default()
}
#[must_use]
pub fn addr(&self) -> SocketAddr {
self.addr
}
#[must_use]
pub fn ws_url(&self) -> String {
format!("ws://{}/ws", self.addr)
}
pub async fn shutdown(mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
match (&mut self.join).await {
Ok(result) => result,
Err(join_err) => Err(anyhow::anyhow!("local server task failed: {join_err}")),
}
}
}
impl Drop for LocalServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
#[must_use]
pub fn local_config() -> ServerConfig {
let mut config = ServerConfig::from_env();
config.storage = StorageBackend::Memory;
config
}
pub async fn serve_local(addr: &str) -> Result<()> {
let addr: SocketAddr = addr
.parse()
.with_context(|| format!("parsing local bind address '{addr}'"))?;
let server = LocalServer::builder().addr(addr).spawn().await?;
let local = server.addr();
println!("smooth-operator-server (local flavor) listening on ws://{local}/ws");
tracing::info!(%local, endpoint = "/ws", "smooth-operator-server (local flavor) listening");
server.run_to_completion().await
}
impl LocalServer {
async fn run_to_completion(mut self) -> Result<()> {
match (&mut self.join).await {
Ok(result) => result,
Err(join_err) => Err(anyhow::anyhow!("local server task failed: {join_err}")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn spawn_binds_ephemeral_and_reports_real_addr() {
let server = LocalServer::builder()
.addr("127.0.0.1:0".parse().unwrap())
.spawn()
.await
.expect("spawn local server");
let addr = server.addr();
assert_ne!(addr.port(), 0, "ephemeral port must be resolved: {addr}");
assert!(server.ws_url().starts_with("ws://127.0.0.1:"));
server.shutdown().await.expect("clean shutdown");
}
#[tokio::test]
async fn build_uses_in_memory_storage_and_no_auth() {
let state = LocalServerBuilder::default()
.config(ServerConfig {
storage: StorageBackend::Postgres,
..local_config()
})
.build();
assert_eq!(state.config.storage, StorageBackend::Memory);
assert_eq!(state.auth.mode(), "none");
}
#[test]
fn auth_seam_installs_a_custom_verifier() {
use smooth_operator::auth::LocalTokenVerifier;
let state = LocalServerBuilder::default()
.auth(Arc::new(LocalTokenVerifier::new("s3cret")))
.build();
assert_eq!(
state.auth.mode(),
"local-token",
"custom verifier overrides the default"
);
}
#[test]
fn tools_seam_installs_a_provider() {
use async_trait::async_trait;
use smooth_operator::tool_provider::{ToolProvider, ToolProviderContext};
use smooth_operator_core::Tool;
struct EmptyProvider;
#[async_trait]
impl ToolProvider for EmptyProvider {
async fn tools_for(&self, _ctx: &ToolProviderContext) -> Vec<Arc<dyn Tool>> {
Vec::new()
}
}
let state = LocalServerBuilder::default()
.tools(Arc::new(EmptyProvider))
.build();
assert!(state.tool_provider.is_some(), "host ToolProvider installed");
}
#[test]
fn serve_widget_opts_into_the_widget_routes_with_token() {
let state = LocalServerBuilder::default()
.serve_widget(Some("tok-123".into()))
.build();
assert!(state.serve_widget, "widget routes opted in");
assert_eq!(state.widget_token.as_deref(), Some("tok-123"));
let _ = crate::server::router(state);
}
#[test]
fn no_widget_by_default() {
let state = LocalServerBuilder::default().build();
assert!(
!state.serve_widget,
"widget off by default (K8s/Lambda never serve it)"
);
assert_eq!(state.widget_token, None);
}
}