use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{Context, Result};
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::{Query, State};
use axum::response::{Html, IntoResponse, Response};
use axum::routing::get;
use axum::Router;
use futures_util::{SinkExt, StreamExt};
use smooth_operator::access_control::AccessContext;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use smooth_operator_adapter_memory::InMemoryStorageAdapter;
use smooth_operator_core::{Document, DocumentType};
use crate::config::ServerConfig;
use crate::handler;
use crate::state::AppState;
pub fn router(state: AppState) -> Router {
let mut router = Router::new()
.route("/ws", get(ws_upgrade))
.route("/health", get(health))
.merge(crate::admin::router());
if state.serve_widget {
router = router
.route("/", get(widget_index))
.route("/chat-widget.iife.js", get(widget_bundle));
}
router.with_state(state)
}
async fn widget_index(State(state): State<AppState>) -> Html<String> {
let token_json = serde_json::to_string(state.widget_token.as_deref().unwrap_or(""))
.unwrap_or_else(|_| "\"\"".to_string());
let html =
include_str!("../assets/widget-index.html").replace("__SMOOTH_LOCAL_TOKEN__", &token_json);
Html(html)
}
async fn widget_bundle() -> impl IntoResponse {
(
[(
axum::http::header::CONTENT_TYPE,
"application/javascript; charset=utf-8",
)],
include_str!("../assets/chat-widget.iife.js"),
)
}
async fn health() -> &'static str {
"ok"
}
const SEED_DOCUMENT_SET: &str = "policies";
pub const SEED_ORG_ID: &str = "reference-org";
#[must_use]
pub fn build_state(config: ServerConfig) -> AppState {
let seed = config.seed_kb;
let storage = Arc::new(InMemoryStorageAdapter::new());
let state = AppState::new(storage.clone(), config);
if seed {
seed_knowledge(storage.as_ref());
state.record_document_set(SEED_ORG_ID, SEED_DOCUMENT_SET);
state.record_document_set(SEED_ORG_ID, SEED_DOCUMENT_SET);
}
state
}
pub fn build_state_from_env(config: ServerConfig) -> Result<AppState> {
let verifier = smooth_operator::auth::AuthConfig::from_env()
.map_err(|e| anyhow::anyhow!("auth configuration error: {e}"))?;
let state = install_widget_auth_from_env(build_state(config));
Ok(state.with_auth(Arc::from(verifier)))
}
fn install_widget_auth_from_env(state: AppState) -> AppState {
let Ok(url) = std::env::var("WIDGET_AUTH_URL") else {
return state;
};
let url = url.trim();
if url.is_empty() {
return state;
}
let mut provider = smooth_operator::widget_auth::HttpWidgetAuth::new(url);
if let Ok(bearer) = std::env::var("WIDGET_AUTH_BEARER") {
let bearer = bearer.trim();
if !bearer.is_empty() {
provider = provider.with_bearer(bearer);
}
}
if let Some(secs) = std::env::var("WIDGET_AUTH_TTL_SECS")
.ok()
.and_then(|s| s.trim().parse::<u64>().ok())
{
provider = provider.with_ttl(std::time::Duration::from_secs(secs));
}
state.with_widget_auth(Arc::new(provider))
}
pub async fn build_state_from_env_async(config: ServerConfig) -> Result<AppState> {
use crate::config::StorageBackend;
#[cfg(any(feature = "postgres", feature = "dynamodb"))]
use smooth_operator::adapter::StorageAdapter;
let verifier = smooth_operator::auth::AuthConfig::from_env()
.map_err(|e| anyhow::anyhow!("auth configuration error: {e}"))?;
let state = match config.storage {
StorageBackend::Memory => build_state(config),
#[cfg(feature = "postgres")]
StorageBackend::Postgres => {
use smooth_operator_adapter_postgres::PostgresAdapter;
let embedder = crate::embedder::build_embedder(
&crate::embedder::EmbedderConfig::from_server_config(&config),
);
let conn_str = std::env::var("SMOOTH_AGENT_DATABASE_URL")
.or_else(|_| std::env::var("DATABASE_URL"))
.map_err(|_| {
anyhow::anyhow!(
"Postgres backend selected but neither SMOOTH_AGENT_DATABASE_URL \
nor DATABASE_URL is set"
)
})?;
let adapter = Arc::new(
PostgresAdapter::connect_with_embedder(&conn_str, embedder)
.await
.map_err(|e| anyhow::anyhow!("connecting Postgres storage backend: {e}"))?,
);
let connectors = Arc::new(adapter.connector_config_store());
let settings = Arc::new(adapter.settings_store());
let indexing = Arc::new(adapter.indexing_store());
let storage: Arc<dyn StorageAdapter> = adapter;
AppState::new(storage, config)
.with_connector_configs(connectors)
.with_settings(settings)
.with_indexing(indexing)
}
#[cfg(feature = "dynamodb")]
StorageBackend::Dynamodb => {
use smooth_operator_adapter_dynamodb::DynamoDbAdapter;
let adapter = Arc::new(
DynamoDbAdapter::from_env(None)
.await
.map_err(|e| anyhow::anyhow!("connecting DynamoDB storage backend: {e}"))?,
);
adapter
.create_table()
.await
.map_err(|e| anyhow::anyhow!("creating DynamoDB table: {e}"))?;
let connectors = Arc::new(adapter.connector_config_store());
let settings = Arc::new(adapter.settings_store());
let indexing = Arc::new(adapter.indexing_store());
let storage: Arc<dyn StorageAdapter> = adapter;
AppState::new(storage, config)
.with_connector_configs(connectors)
.with_settings(settings)
.with_indexing(indexing)
}
#[cfg(not(feature = "postgres"))]
StorageBackend::Postgres => {
anyhow::bail!(
"SMOOTH_AGENT_STORAGE=postgres requires building with --features postgres \
(this is a lean/local build); use SMOOTH_AGENT_STORAGE=memory or rebuild \
with the 'cloud'/'postgres' feature"
);
}
#[cfg(not(feature = "dynamodb"))]
StorageBackend::Dynamodb => {
anyhow::bail!(
"SMOOTH_AGENT_STORAGE=dynamodb requires building with --features dynamodb \
(this is a lean/local build); use SMOOTH_AGENT_STORAGE=memory or rebuild \
with the 'cloud'/'dynamodb' feature"
);
}
};
let state = install_backplane_from_env(state).await?;
let state = install_widget_auth_from_env(state);
Ok(state.with_auth(Arc::from(verifier)))
}
async fn install_backplane_from_env(state: AppState) -> Result<AppState> {
let kind = std::env::var("SMOOTH_AGENT_BACKPLANE")
.unwrap_or_default()
.trim()
.to_lowercase();
let url = |specific: &str| -> Result<String> {
std::env::var("SMOOTH_AGENT_BACKPLANE_URL")
.or_else(|_| std::env::var(specific))
.map_err(|_| {
anyhow::anyhow!(
"{kind} backplane selected but neither SMOOTH_AGENT_BACKPLANE_URL nor {specific} is set"
)
})
};
match kind.as_str() {
"" | "memory" | "inmemory" => Ok(state), "redis" | "valkey" => {
#[cfg(feature = "redis")]
{
use smooth_operator_adapter_backplane_redis::RedisBackplane;
let backplane = RedisBackplane::connect(&url("SMOOTH_AGENT_REDIS_URL")?)
.await
.map_err(|e| anyhow::anyhow!("connecting Redis backplane: {e}"))?;
Ok(state.with_backplane(Arc::new(backplane)))
}
#[cfg(not(feature = "redis"))]
{
let _ = url; anyhow::bail!(
"SMOOTH_AGENT_BACKPLANE={kind} requires building with --features redis \
(this is a lean/local build); use SMOOTH_AGENT_BACKPLANE=memory or rebuild \
with the 'cloud'/'redis' feature"
)
}
}
"nats" => {
#[cfg(feature = "nats")]
{
use smooth_operator_adapter_backplane_nats::NatsBackplane;
let backplane = NatsBackplane::connect(&url("SMOOTH_AGENT_NATS_URL")?)
.await
.map_err(|e| anyhow::anyhow!("connecting NATS backplane: {e}"))?;
Ok(state.with_backplane(Arc::new(backplane)))
}
#[cfg(not(feature = "nats"))]
{
let _ = url; anyhow::bail!(
"SMOOTH_AGENT_BACKPLANE=nats requires building with --features nats \
(this is a lean/local build); use SMOOTH_AGENT_BACKPLANE=memory or rebuild \
with the 'cloud'/'nats' feature"
)
}
}
other => Err(anyhow::anyhow!(
"unknown SMOOTH_AGENT_BACKPLANE '{other}' (expected: memory | redis | valkey | nats)"
)),
}
}
pub fn seed_knowledge(storage: &InMemoryStorageAdapter) {
let kb = smooth_operator::adapter::StorageAdapter::knowledge(storage);
let _ = kb.ingest(smooth_operator::with_document_set(
Document::new(
"SmooAI's return window is exactly 17 days from delivery. Returns after 17 days are not accepted.",
"policies/returns.md",
DocumentType::Documentation,
),
[SEED_DOCUMENT_SET],
));
let _ = kb.ingest(smooth_operator::with_document_set(
Document::new(
"SmooAI standard shipping takes 5 to 7 business days. Expedited shipping takes 2 business days.",
"policies/shipping.md",
DocumentType::Documentation,
),
[SEED_DOCUMENT_SET],
));
}
pub async fn bind(config: ServerConfig) -> Result<(TcpListener, Router, CancellationToken)> {
let ip: std::net::IpAddr = config
.bind
.parse()
.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
let addr = SocketAddr::new(ip, config.port);
let state = build_state_from_env_async(config).await?;
let shutdown = state.shutdown.clone();
let app = router(state);
let listener = TcpListener::bind(addr)
.await
.with_context(|| format!("binding WebSocket server on {addr}"))?;
Ok((listener, app, shutdown))
}
pub async fn serve_state(state: AppState) -> Result<()> {
let ip: std::net::IpAddr = state
.config
.bind
.parse()
.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
let addr = SocketAddr::new(ip, state.config.port);
let listener = TcpListener::bind(addr)
.await
.with_context(|| format!("binding WebSocket server on {addr}"))?;
serve_state_on(state, listener).await
}
pub async fn serve_state_on(state: AppState, listener: TcpListener) -> Result<()> {
let has_llm = state.config.has_llm();
let model = state.config.model.clone();
let gateway = state.config.gateway_url.clone();
let local = listener.local_addr().context("local addr")?;
let app = router(state);
tracing::info!(
%local,
endpoint = "/ws",
%model,
%gateway,
llm_enabled = has_llm,
"smooth-operator-server listening"
);
println!(
"smooth-operator-server listening on ws://{local}/ws (model={model}, llm_enabled={has_llm})"
);
axum::serve(listener, app)
.await
.context("serving WebSocket connections")?;
Ok(())
}
pub async fn run(config: ServerConfig) -> Result<()> {
let has_llm = config.has_llm();
let model = config.model.clone();
let gateway = config.gateway_url.clone();
let (listener, app, shutdown) = bind(config).await?;
let local = listener.local_addr().context("local addr")?;
tracing::info!(
%local,
endpoint = "/ws",
%model,
%gateway,
llm_enabled = has_llm,
"smooth-operator-server listening"
);
println!(
"smooth-operator-server listening on ws://{local}/ws (model={model}, llm_enabled={has_llm})"
);
axum::serve(listener, app)
.with_graceful_shutdown(async move {
wait_for_shutdown_signal().await;
tracing::info!("shutdown signal received; draining in-flight WebSocket turns");
shutdown.cancel();
})
.await
.context("serving WebSocket connections")?;
Ok(())
}
async fn wait_for_shutdown_signal() {
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = match signal(SignalKind::terminate()) {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "failed to install SIGTERM handler; ctrl_c only");
let _ = tokio::signal::ctrl_c().await;
return;
}
};
tokio::select! {
_ = sigterm.recv() => {}
_ = tokio::signal::ctrl_c() => {}
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
}
#[derive(Debug, serde::Deserialize, Default)]
struct WsQuery {
#[serde(default)]
token: Option<String>,
}
struct ConnectionAuth {
access: AccessContext,
org_id: Option<String>,
}
fn resolve_ws_access(state: &AppState, query: &WsQuery) -> ConnectionAuth {
let Some(token) = query
.token
.as_deref()
.map(str::trim)
.filter(|t| !t.is_empty())
else {
return ConnectionAuth {
access: AccessContext::anonymous(),
org_id: None,
};
};
match state.auth.verify(token) {
Ok(principal) => ConnectionAuth {
access: principal.access_context(),
org_id: Some(principal.org_id),
},
Err(e) => {
tracing::warn!(
auth_mode = state.auth.mode(),
error = %e,
"ws token failed verification; serving org-public knowledge only (anonymous)"
);
ConnectionAuth {
access: AccessContext::anonymous(),
org_id: None,
}
}
}
}
async fn ws_upgrade(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(query): Query<WsQuery>,
headers: axum::http::HeaderMap,
) -> Response {
let ConnectionAuth { access, org_id } = resolve_ws_access(&state, &query);
let origin = headers
.get(axum::http::header::ORIGIN)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
ws.on_upgrade(move |socket| connection_loop(socket, state, access, org_id, origin))
}
async fn connection_loop(
socket: WebSocket,
state: AppState,
access: AccessContext,
auth_org: Option<String>,
origin: Option<String>,
) {
let (mut ws_tx, mut ws_rx) = socket.split();
let (sink_tx, mut sink_rx) = tokio::sync::mpsc::unbounded_channel::<serde_json::Value>();
let conn_id = uuid::Uuid::new_v4().to_string();
let sink_for_backplane = sink_tx.clone();
state
.backplane
.attach(
&conn_id,
std::sync::Arc::new(move |event| {
let _ = sink_for_backplane.send(event);
}),
)
.await;
let writer = tokio::spawn(async move {
while let Some(event) = sink_rx.recv().await {
let text = match serde_json::to_string(&event) {
Ok(t) => t,
Err(_) => continue,
};
if ws_tx.send(Message::Text(text.into())).await.is_err() {
break;
}
}
});
loop {
tokio::select! {
biased;
() = state.shutdown.cancelled() => {
break;
}
frame = ws_rx.next() => {
match frame {
Some(Ok(Message::Text(text))) => {
handler::handle_frame(
&state,
&access,
&conn_id,
origin.as_deref(),
auth_org.as_deref(),
text.as_str(),
&sink_tx,
)
.await;
}
Some(Ok(Message::Binary(_))) => {
let _ = sink_tx.send(crate::protocol::error(
None,
"VALIDATION_ERROR",
"binary frames are not supported; send JSON text frames",
));
}
Some(Ok(Message::Close(_))) => break,
Some(Ok(_)) => {}
Some(Err(_)) => break,
None => break,
}
}
}
}
state.backplane.detach(&conn_id).await;
drop(sink_tx);
let _ = writer.await;
}
#[cfg(test)]
mod tests {
use super::*;
use smooth_operator::adapter::StorageAdapter;
#[test]
fn seeded_kb_returns_17_day_fact() {
let storage = InMemoryStorageAdapter::new();
seed_knowledge(&storage);
let results = storage
.knowledge()
.query("return window policy", 3)
.expect("query");
assert!(
results.iter().any(|r| r.chunk.contains("17")),
"expected seeded 17-day fact, got: {results:?}"
);
}
#[tokio::test]
async fn build_state_without_key_has_no_llm() {
let cfg = ServerConfig {
bind: "127.0.0.1".into(),
port: 0,
gateway_url: "https://example.test/v1".into(),
gateway_key: None,
model: "m".into(),
seed_kb: true,
max_iterations: 4,
max_tokens: 128,
storage: crate::config::StorageBackend::Memory,
widget_auth_strict: false,
confirm_tools: Vec::new(),
};
let state = build_state(cfg);
assert!(!state.config.has_llm());
}
}