use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{Context, Result};
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use smooth_operator::adapter::StorageAdapter;
use smooth_operator::auth::{AuthVerifier, NoAuthVerifier};
use smooth_operator::tool_provider::ToolProvider;
use crate::config::{ServerConfig, StorageBackend};
use crate::server::{build_state, router};
use crate::state::AppState;
pub const DEFAULT_LOCAL_ADDR: &str = "127.0.0.1:8787";
#[derive(Clone)]
pub struct LocalServerBuilder {
addr: SocketAddr,
seed_kb: bool,
config: Option<ServerConfig>,
auth: Option<Arc<dyn AuthVerifier>>,
tool_provider: Option<Arc<dyn ToolProvider>>,
serve_widget: bool,
widget_token: Option<String>,
strict_auth: bool,
storage: Option<Arc<dyn StorageAdapter>>,
persona: Option<String>,
spa_router: Option<axum::Router>,
extra_routes: Option<axum::Router>,
}
impl std::fmt::Debug for LocalServerBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalServerBuilder")
.field("addr", &self.addr)
.field("seed_kb", &self.seed_kb)
.field("config", &self.config)
.field("auth", &self.auth.as_ref().map(|a| a.mode()))
.field("tool_provider", &self.tool_provider.is_some())
.field("serve_widget", &self.serve_widget)
.finish()
}
}
impl Default for LocalServerBuilder {
fn default() -> Self {
Self {
addr: DEFAULT_LOCAL_ADDR
.parse()
.expect("DEFAULT_LOCAL_ADDR is a valid SocketAddr"),
seed_kb: false,
config: None,
auth: None,
tool_provider: None,
serve_widget: false,
widget_token: None,
strict_auth: false,
storage: None,
persona: None,
spa_router: None,
extra_routes: None,
}
}
}
impl LocalServerBuilder {
#[must_use]
pub fn addr(mut self, addr: SocketAddr) -> Self {
self.addr = addr;
self
}
#[must_use]
pub fn seed_kb(mut self, seed: bool) -> Self {
self.seed_kb = seed;
self
}
#[must_use]
pub fn auth(mut self, auth: Arc<dyn AuthVerifier>) -> Self {
self.auth = Some(auth);
self
}
#[must_use]
pub fn tools(mut self, provider: Arc<dyn ToolProvider>) -> Self {
self.tool_provider = Some(provider);
self
}
#[must_use]
pub fn serve_widget(mut self, token: Option<String>) -> Self {
self.serve_widget = true;
self.widget_token = token;
self
}
#[must_use]
pub fn persona(mut self, persona: impl Into<String>) -> Self {
self.persona = Some(persona.into());
self
}
#[must_use]
pub fn serve_spa(mut self, spa: axum::Router) -> Self {
self.spa_router = Some(spa);
self
}
#[must_use]
pub fn serve_routes(mut self, routes: axum::Router) -> Self {
self.extra_routes = Some(routes);
self
}
#[must_use]
pub fn strict_auth(mut self, strict: bool) -> Self {
self.strict_auth = strict;
self
}
#[must_use]
pub fn storage(mut self, storage: Arc<dyn StorageAdapter>) -> Self {
self.storage = Some(storage);
self
}
#[must_use]
pub fn config(mut self, config: ServerConfig) -> Self {
self.config = Some(config);
self
}
fn build(&self) -> AppState {
let mut config = self.config.clone().unwrap_or_else(local_config);
config.storage = StorageBackend::Memory;
config.bind = self.addr.ip().to_string();
config.port = self.addr.port();
config.seed_kb = self.seed_kb;
let auth = self
.auth
.clone()
.unwrap_or_else(|| Arc::new(NoAuthVerifier::default()) as Arc<dyn AuthVerifier>);
let mut state = build_state(config).with_auth(auth);
if let Some(storage) = &self.storage {
state = state.with_storage(Arc::clone(storage));
}
if let Some(provider) = &self.tool_provider {
state = state.with_tools(Arc::clone(provider));
}
if self.serve_widget {
state = state.with_widget(self.widget_token.clone());
}
if self.strict_auth {
state = state.with_strict_auth(true);
}
if let Some(persona) = &self.persona {
state = state.with_default_persona(persona.clone());
}
state
}
fn build_app(&self) -> axum::Router {
let mut app = router(self.build());
if let Some(routes) = self.extra_routes.clone() {
app = app.merge(routes.layer(crate::admin::admin_cors()));
}
if let Some(spa) = self.spa_router.clone() {
app = app.fallback_service(spa);
}
app
}
pub async fn spawn(self) -> Result<LocalServer> {
let listener = TcpListener::bind(self.addr)
.await
.with_context(|| format!("binding local smooth-operator server on {}", self.addr))?;
let addr = listener.local_addr().context("local addr")?;
let app = self.build_app();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let join = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await
.context("serving local smooth-operator connections")
});
Ok(LocalServer {
addr,
shutdown_tx: Some(shutdown_tx),
join,
})
}
}
#[must_use = "the server stops when the handle is dropped; hold it for the server's lifetime"]
pub struct LocalServer {
addr: SocketAddr,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
join: JoinHandle<Result<()>>,
}
impl LocalServer {
pub fn builder() -> LocalServerBuilder {
LocalServerBuilder::default()
}
#[must_use]
pub fn addr(&self) -> SocketAddr {
self.addr
}
#[must_use]
pub fn ws_url(&self) -> String {
format!("ws://{}/ws", self.addr)
}
pub async fn shutdown(mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
match (&mut self.join).await {
Ok(result) => result,
Err(join_err) => Err(anyhow::anyhow!("local server task failed: {join_err}")),
}
}
}
impl Drop for LocalServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
#[must_use]
pub fn local_config() -> ServerConfig {
let mut config = ServerConfig::from_env();
config.storage = StorageBackend::Memory;
config
}
pub async fn serve_local(addr: &str) -> Result<()> {
let addr: SocketAddr = addr
.parse()
.with_context(|| format!("parsing local bind address '{addr}'"))?;
let server = LocalServer::builder().addr(addr).spawn().await?;
let local = server.addr();
println!("smooth-operator-server (local flavor) listening on ws://{local}/ws");
tracing::info!(%local, endpoint = "/ws", "smooth-operator-server (local flavor) listening");
server.run_to_completion().await
}
impl LocalServer {
async fn run_to_completion(mut self) -> Result<()> {
match (&mut self.join).await {
Ok(result) => result,
Err(join_err) => Err(anyhow::anyhow!("local server task failed: {join_err}")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn spawn_binds_ephemeral_and_reports_real_addr() {
let server = LocalServer::builder()
.addr("127.0.0.1:0".parse().unwrap())
.spawn()
.await
.expect("spawn local server");
let addr = server.addr();
assert_ne!(addr.port(), 0, "ephemeral port must be resolved: {addr}");
assert!(server.ws_url().starts_with("ws://127.0.0.1:"));
server.shutdown().await.expect("clean shutdown");
}
#[tokio::test]
async fn build_uses_in_memory_storage_and_no_auth() {
let state = LocalServerBuilder::default()
.config(ServerConfig {
storage: StorageBackend::Postgres,
..local_config()
})
.build();
assert_eq!(state.config.storage, StorageBackend::Memory);
assert_eq!(state.auth.mode(), "none");
}
#[test]
fn storage_seam_installs_a_durable_adapter() {
use smooth_operator_adapter_memory::InMemoryStorageAdapter;
let injected: Arc<dyn StorageAdapter> = Arc::new(InMemoryStorageAdapter::new());
let state = LocalServerBuilder::default()
.storage(Arc::clone(&injected))
.build();
assert!(
Arc::ptr_eq(&state.storage, &injected),
"the injected storage adapter must be installed"
);
let default_state = LocalServerBuilder::default().build();
assert!(!Arc::ptr_eq(&default_state.storage, &injected));
}
#[test]
fn auth_seam_installs_a_custom_verifier() {
use smooth_operator::auth::LocalTokenVerifier;
let state = LocalServerBuilder::default()
.auth(Arc::new(LocalTokenVerifier::new("s3cret")))
.build();
assert_eq!(
state.auth.mode(),
"local-token",
"custom verifier overrides the default"
);
}
#[test]
fn tools_seam_installs_a_provider() {
use async_trait::async_trait;
use smooth_operator::tool_provider::{ToolProvider, ToolProviderContext};
use smooth_operator_core::Tool;
struct EmptyProvider;
#[async_trait]
impl ToolProvider for EmptyProvider {
async fn tools_for(&self, _ctx: &ToolProviderContext) -> Vec<Arc<dyn Tool>> {
Vec::new()
}
}
let state = LocalServerBuilder::default()
.tools(Arc::new(EmptyProvider))
.build();
assert!(state.tool_provider.is_some(), "host ToolProvider installed");
}
#[test]
fn serve_widget_opts_into_the_widget_routes_with_token() {
let state = LocalServerBuilder::default()
.serve_widget(Some("tok-123".into()))
.build();
assert!(state.serve_widget, "widget routes opted in");
assert_eq!(state.widget_token.as_deref(), Some("tok-123"));
let _ = crate::server::router(state);
}
#[test]
fn no_widget_by_default() {
let state = LocalServerBuilder::default().build();
assert!(
!state.serve_widget,
"widget off by default (K8s/Lambda never serve it)"
);
assert_eq!(state.widget_token, None);
}
#[test]
fn persona_seam_installs_default_persona() {
assert_eq!(
LocalServerBuilder::default().build().default_persona,
None,
"no default persona unless set"
);
let state = LocalServerBuilder::default()
.persona("You are Big Smooth.")
.build();
assert_eq!(
state.default_persona.as_deref(),
Some("You are Big Smooth.")
);
}
#[tokio::test]
async fn serve_spa_mounts_host_router_as_fallback() {
use http_body_util::BodyExt;
use tower::ServiceExt;
let spa = axum::Router::new().fallback(axum::routing::get(|| async { "SPA-ROOT" }));
let app = LocalServerBuilder::default().serve_spa(spa).build_app();
let res = app
.clone()
.oneshot(
axum::http::Request::builder()
.uri("/")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), axum::http::StatusCode::OK);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(
&body[..],
b"SPA-ROOT",
"the host SPA is served as the fallback"
);
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/health")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), axum::http::StatusCode::OK);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(
&body[..],
b"ok",
"explicit operator routes win over the SPA"
);
}
#[test]
fn no_spa_by_default() {
assert!(
LocalServerBuilder::default().spa_router.is_none(),
"no SPA mounted unless the host installs one"
);
}
#[tokio::test]
async fn serve_routes_merges_host_routes_alongside_operator_routes() {
use http_body_util::BodyExt;
use tower::ServiceExt;
let routes =
axum::Router::new().route("/search", axum::routing::get(|| async { "SEARCH-OK" }));
let app = LocalServerBuilder::default()
.serve_routes(routes)
.build_app();
let res = app
.clone()
.oneshot(
axum::http::Request::builder()
.uri("/search?q=foo")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), axum::http::StatusCode::OK);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"SEARCH-OK", "merged host route responds");
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/health")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), axum::http::StatusCode::OK);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"ok", "operator routes survive the merge");
}
#[test]
fn no_extra_routes_by_default() {
assert!(
LocalServerBuilder::default().extra_routes.is_none(),
"no host routes merged unless the host installs them"
);
}
#[test]
fn strict_auth_off_by_default_and_opt_in() {
assert!(
!LocalServerBuilder::default().build().strict_auth,
"lenient/anonymous by default"
);
assert!(
LocalServerBuilder::default()
.strict_auth(true)
.build()
.strict_auth,
"opt-in threads to AppState"
);
}
}