use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use axum::Router;
use axum::extract::State;
use axum::http::{HeaderValue, Request, StatusCode, header};
use axum::middleware::{self, Next};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use tracing::{info, warn};
use crate::runner::Runner;
use crate::server::routes;
#[derive(Clone)]
pub struct AppState {
pub runners: Arc<HashMap<String, Arc<Runner>>>,
pub auth_token: Option<Arc<String>>,
pub allow_origins: Arc<Vec<String>>,
}
impl AppState {
pub fn unauthenticated(runners: Arc<HashMap<String, Arc<Runner>>>) -> Self {
Self {
runners,
auth_token: None,
allow_origins: Arc::new(Vec::new()),
}
}
pub fn with_bearer_token(
runners: Arc<HashMap<String, Arc<Runner>>>,
token: impl Into<String>,
) -> Self {
Self {
runners,
auth_token: Some(Arc::new(token.into())),
allow_origins: Arc::new(Vec::new()),
}
}
#[must_use]
pub fn with_allow_origins(mut self, origins: impl IntoIterator<Item = String>) -> Self {
self.allow_origins = Arc::new(origins.into_iter().collect());
self
}
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AppState")
.field("agents", &self.runners.keys().collect::<Vec<_>>())
.field("auth_token", &self.auth_token.as_ref().map(|_| "<set>"))
.finish()
}
}
pub fn build_router(state: AppState) -> Router {
let mut inner = Router::new()
.route("/list-agents", get(routes::list_agents))
.merge(crate::server::adk_web::router());
if !state.allow_origins.is_empty() {
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
let origins: Vec<HeaderValue> = state
.allow_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
let allow_origin = if state.allow_origins.iter().any(|o| o == "*") {
AllowOrigin::any()
} else {
AllowOrigin::list(origins)
};
inner = inner.layer(
CorsLayer::new()
.allow_origin(allow_origin)
.allow_methods(Any)
.allow_headers(Any),
);
}
if state.auth_token.is_some() {
let token_state = state.clone();
inner
.route_layer(middleware::from_fn_with_state(token_state, require_bearer))
.with_state(state)
} else {
inner.with_state(state)
}
}
async fn require_bearer(
State(state): State<AppState>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
let Some(expected) = state.auth_token.as_ref() else {
return next.run(req).await;
};
let presented = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.or_else(|| {
req.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("bearer "))
});
let ok = presented
.map(|tok| constant_time_eq(expected.as_bytes(), tok.as_bytes()))
.unwrap_or(false);
if ok {
next.run(req).await
} else {
let mut resp = (StatusCode::UNAUTHORIZED, "unauthorized").into_response();
resp.headers_mut().insert(
header::WWW_AUTHENTICATE,
HeaderValue::from_static("Bearer realm=\"adk-rs\""),
);
resp
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[derive(Debug, Clone, Default)]
pub struct ServeOptions {
pub dangerously_allow_unauthenticated_remote: bool,
}
pub async fn serve(addr: SocketAddr, state: AppState) -> crate::error::Result<()> {
serve_with(addr, state, ServeOptions::default()).await
}
pub async fn serve_with(
addr: SocketAddr,
state: AppState,
opts: ServeOptions,
) -> crate::error::Result<()> {
validate_bind_policy(addr, state.auth_token.is_some(), &opts)?;
if !addr.ip().is_loopback() {
let has_auth = state.auth_token.is_some();
warn!(
"adk-server bound on non-loopback {addr}: anyone reachable on this network can drive your agents{} — proceed only if this is what you intended",
if has_auth {
" (bearer token required)"
} else {
" AND NO AUTHENTICATION IS ENFORCED"
}
);
}
let app = build_router(state);
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| crate::error::Error::other(format!("bind {addr}: {e}")))?;
info!("adk-server listening on http://{addr}");
axum::serve(listener, app)
.await
.map_err(|e| crate::error::Error::other(format!("serve: {e}")))
}
fn validate_bind_policy(
addr: SocketAddr,
has_auth: bool,
opts: &ServeOptions,
) -> crate::error::Result<()> {
if !addr.ip().is_loopback() && !has_auth && !opts.dangerously_allow_unauthenticated_remote {
return Err(crate::error::Error::config(format!(
"refusing to bind dev server on non-loopback address {addr} without auth — \
set an auth token via AppState::with_bearer_token(...) or pass \
ServeOptions::dangerously_allow_unauthenticated_remote=true to opt out"
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn empty_state(token: Option<&str>) -> AppState {
let runners: HashMap<String, Arc<Runner>> = HashMap::new();
match token {
Some(t) => AppState::with_bearer_token(Arc::new(runners), t),
None => AppState::unauthenticated(Arc::new(runners)),
}
}
#[tokio::test]
async fn serve_refuses_non_loopback_without_auth_or_override() {
let state = empty_state(None);
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
let err = serve(addr, state).await.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("non-loopback"),
"expected non-loopback error, got: {msg}"
);
}
#[tokio::test]
async fn serve_allows_non_loopback_when_auth_set() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
validate_bind_policy(addr, true, &ServeOptions::default()).unwrap();
}
#[tokio::test]
async fn bearer_required_when_token_set() {
use axum::body::{Body, to_bytes};
use axum::http::{Method, Request};
use tower::ServiceExt;
let state = empty_state(Some("topsecret"));
let app = build_router(state);
let resp = app
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/list-agents")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let www = resp
.headers()
.get(header::WWW_AUTHENTICATE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(www.contains("Bearer"));
let _ = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let resp = app
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/list-agents")
.header(header::AUTHORIZATION, "Bearer wrong")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let resp = app
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/list-agents")
.header(header::AUTHORIZATION, "Bearer topsecret")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn no_auth_required_when_token_absent() {
use axum::body::Body;
use axum::http::{Method, Request};
use tower::ServiceExt;
let state = empty_state(None);
let app = build_router(state);
let resp = app
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/list-agents")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}