pub mod codec;
pub mod endpoints;
pub mod error;
pub mod identity;
pub mod models;
pub mod response;
use crate::api::identity::AccountIdentityResolver;
use anyhow::Result;
use arc_swap::ArcSwap;
use axum::http::{HeaderMap, HeaderName, Request, Response};
use axum::middleware::{self, Next};
use axum::response::IntoResponse;
use axum::response::Response as AxumResponse;
use axum::routing::{get, post};
use axum::Json;
use axum::Router;
use parking_lot::Mutex;
use std::collections::{BTreeMap, BTreeSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use tokn_accounts::registry::Registry as ProviderRegistry;
use tokn_accounts::routing::RouteResolver;
use tokn_accounts::{AccountInventory, AccountPool, AccountPoolRuleset};
use tokn_config::ProxyProviderMode;
use tokn_config::RouteMode;
use tokn_config::{AgentId, Config, ModelFamily, ProfileConfig};
use tokn_core::account::AccountConfig;
use tokn_core::event::EventBus;
const PIPELINE_RETRY_POLICY: tokn_requests::RetryPolicy =
tokn_requests::RetryPolicy::new(2, Duration::from_millis(100));
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::trace::TraceLayer;
use tracing::{Level, Span};
type ReloadFuture = Pin<Box<dyn Future<Output = std::result::Result<ReloadReport, String>> + Send>>;
#[derive(Clone)]
pub struct AdminReloader {
reload: Arc<dyn Fn() -> ReloadFuture + Send + Sync>,
}
impl AdminReloader {
pub fn new<F, Fut>(reload: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = std::result::Result<ReloadReport, String>> + Send + 'static,
{
Self {
reload: Arc::new(move || Box::pin(reload())),
}
}
async fn reload(&self) -> std::result::Result<ReloadReport, String> {
(self.reload)().await
}
}
#[derive(Clone, Debug, serde::Serialize)]
pub struct ReloadReport {
pub status: &'static str,
pub generation: u64,
pub accounts: usize,
pub route_mode: &'static str,
}
#[derive(Clone)]
pub struct LiveAppState {
current: Arc<ArcSwap<AppState>>,
admin_reloader: Arc<OnceLock<AdminReloader>>,
}
impl LiveAppState {
pub fn new(state: AppState) -> Self {
Self {
current: Arc::new(ArcSwap::from_pointee(state)),
admin_reloader: Arc::new(OnceLock::new()),
}
}
pub fn current(&self) -> AppState {
self.current.load_full().as_ref().clone()
}
pub fn swap(&self, state: AppState) {
self.current.store(Arc::new(state));
}
pub fn set_admin_reloader(&self, reloader: AdminReloader) -> std::result::Result<(), AdminReloader> {
self.admin_reloader.set(reloader)
}
async fn reload(&self) -> std::result::Result<ReloadReport, String> {
let Some(reloader) = self.admin_reloader.get() else {
return Err("admin config reload is not configured".into());
};
reloader.reload().await
}
}
#[derive(Clone)]
pub struct AppState {
pub inventory: Arc<AccountInventory>,
pub pool: Arc<AccountPool>,
pub provider_registry: Arc<ProviderRegistry>,
pub identity: Arc<AccountIdentityResolver>,
pub route: Arc<RouteResolver>,
pub default_policy: Arc<RequestPolicyRuntime>,
pub profiles: Arc<BTreeMap<String, Arc<RequestPolicyRuntime>>>,
pub http: reqwest::Client,
pub events: Arc<EventBus>,
pub body_max_bytes: usize,
pub proxy_provider_modes: Arc<std::collections::BTreeMap<String, ProxyProviderMode>>,
pub request_pipeline: Arc<tokn_requests::Pipeline>,
pub passthrough_pipeline: Arc<tokn_requests::Pipeline>,
pub proxy_passthrough_pipeline: Arc<tokn_requests::Pipeline>,
pub proxy_switch_pipeline: Arc<tokn_requests::Pipeline>,
}
#[derive(Clone)]
pub struct RequestPolicyRuntime {
pub mode: RouteMode,
pub agent_id: Option<AgentId>,
pub ruleset: AccountPoolRuleset,
pub pool: Arc<AccountPool>,
pub route: Arc<RouteResolver>,
pub model_families: Vec<ModelFamily>,
pub request_pipeline: Arc<tokn_requests::Pipeline>,
pub passthrough_pipeline: Arc<tokn_requests::Pipeline>,
}
#[derive(Clone)]
struct PolicyBuildDeps {
cfg: Config,
inventory: Arc<AccountInventory>,
http: reqwest::Client,
events: Arc<EventBus>,
}
pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const SESSION_ID_HEADER: &str = "x-session-id";
pub(crate) fn is_router_owned_header(name: &axum::http::HeaderName) -> bool {
let name = name.as_str();
name.starts_with("x-tokn-router-") || name == "x-route-mode" || name == "x-behave-as"
}
pub(crate) fn first_header<'a>(headers: &'a HeaderMap, names: &[&str]) -> Option<&'a str> {
names.iter().find_map(|name| {
headers
.get(*name)
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|s| !s.is_empty())
})
}
tokio::task_local! {
static REQUEST_TRACKING: Mutex<RequestTracking>;
}
#[derive(Default)]
struct RequestTracking {
account: Option<Arc<str>>,
upstream_url: Option<Arc<str>>,
}
#[allow(dead_code)]
pub(crate) fn record_upstream_url(url: &str) {
let _ = REQUEST_TRACKING.try_with(|state| {
state.lock().upstream_url = Some(Arc::from(url));
});
}
fn tracking_snapshot() -> (String, String) {
REQUEST_TRACKING
.try_with(|state| {
let g = state.lock();
(
g.account.as_deref().unwrap_or("-").to_string(),
g.upstream_url.as_deref().unwrap_or("-").to_string(),
)
})
.unwrap_or_else(|_| ("-".into(), "-".into()))
}
async fn track_request(req: Request<axum::body::Body>, next: Next) -> Response<axum::body::Body> {
REQUEST_TRACKING
.scope(Mutex::new(RequestTracking::default()), next.run(req))
.await
}
pub fn router(state: AppState) -> Router {
router_live(LiveAppState::new(state))
}
pub fn router_live(state: LiveAppState) -> Router {
let request_id_header = HeaderName::from_static(REQUEST_ID_HEADER);
let trace = TraceLayer::new_for_http()
.make_span_with(|req: &Request<_>| {
let request_id = req
.headers()
.get(REQUEST_ID_HEADER)
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
tracing::info_span!(
"http",
method = %req.method(),
uri = %req.uri(),
request_id = %request_id,
account = tracing::field::Empty,
upstream_url = tracing::field::Empty,
status = tracing::field::Empty,
latency_ms = tracing::field::Empty,
)
})
.on_request(|req: &Request<_>, _span: &Span| {
let len = req
.headers()
.get(axum::http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
tracing::debug!(content_length = %len, "request started");
})
.on_response(|resp: &Response<_>, latency: Duration, span: &Span| {
let status = resp.status();
let ms = latency.as_millis() as u64;
let (account, upstream_url) = tracking_snapshot();
span.record("status", status.as_u16());
span.record("latency_ms", ms);
span.record("account", account.as_str());
span.record("upstream_url", upstream_url.as_str());
if status.is_server_error() || status.is_client_error() {
tracing::event!(Level::WARN, status = %status, latency_ms = ms, account = %account, upstream_url = %upstream_url, "request finished with error");
} else {
tracing::event!(Level::INFO, status = %status, latency_ms = ms, account = %account, upstream_url = %upstream_url, "request finished");
}
})
.on_failure(
|err: tower_http::classify::ServerErrorsFailureClass, latency: Duration, _span: &Span| {
let (account, upstream_url) = tracking_snapshot();
tracing::warn!(error = %err, latency_ms = latency.as_millis() as u64, account = %account, upstream_url = %upstream_url, "request failed");
},
);
let profile_routes = Router::new()
.route("/{profile}/v1/models", get(models::list_models_with_profile))
.route(
"/{profile}/v1/chat/completions",
post(endpoints::chat_completions_with_profile),
)
.route("/{profile}/v1/responses", post(endpoints::responses_with_profile))
.route("/{profile}/v1/messages", post(endpoints::messages_with_profile));
Router::new()
.route("/v1/models", get(models::list_models))
.route("/v1/chat/completions", post(endpoints::chat_completions))
.route("/v1/responses", post(endpoints::responses))
.route("/v1/messages", post(endpoints::messages))
.merge(profile_routes)
.route("/admin/config/reload", post(admin_config_reload))
.route("/healthz", get(health))
.with_state(state)
.layer(PropagateRequestIdLayer::new(request_id_header.clone()))
.layer(trace)
.layer(SetRequestIdLayer::new(request_id_header, MakeRequestUuid))
.layer(middleware::from_fn(track_request))
}
async fn health() -> &'static str {
"ok"
}
async fn admin_config_reload(axum::extract::State(state): axum::extract::State<LiveAppState>) -> AxumResponse {
match state.reload().await {
Ok(report) => Json(report).into_response(),
Err(message) if message == "admin config reload is not configured" => (
axum::http::StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": {
"message": message,
"type": "not_found",
"code": 404,
"request_id": serde_json::Value::Null,
}
})),
)
.into_response(),
Err(message) => (
axum::http::StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!({
"error": {
"message": message,
"type": "reload_failed",
"code": 422,
"request_id": serde_json::Value::Null,
}
})),
)
.into_response(),
}
}
pub fn build_state(cfg: &Config, accounts: &[AccountConfig], events: Arc<EventBus>) -> Result<AppState> {
cfg.validate()?;
let provider_registry = Arc::new(ProviderRegistry::builtin());
let _ = validate_proxy_provider_modes(cfg, provider_registry.as_ref());
validate_policy_providers(cfg, provider_registry.as_ref())?;
validate_policy_accounts(cfg, accounts)?;
let identity = Arc::new(AccountIdentityResolver::from_accounts(accounts));
let default_mode = effective_default_mode(cfg);
let inventory = if accounts.is_empty() && matches!(default_mode, RouteMode::Passthrough) {
AccountInventory::empty()
} else {
let registry = provider_registry.clone();
AccountInventory::from_accounts_with(accounts, move |account| registry.build(account))?
};
let pool = AccountPool::from_inventory(&inventory, cfg, &AccountPoolRuleset::all())?;
let default_families = effective_default_families(cfg);
let route = Arc::new(RouteResolver::new(default_mode, &default_families));
let http = tokn_core::util::http::build_client(&cfg.proxy.to_http_options())?;
let body_max_bytes = if cfg.db.enabled { cfg.db.body_max_bytes } else { 0 };
let default_policy = build_policy_runtime(
"default",
default_mode,
cfg.defaults.agent_id.clone(),
cfg.defaults.providers.clone(),
cfg.defaults.accounts.clone(),
default_families,
PolicyBuildDeps {
cfg: cfg.clone(),
inventory: inventory.clone(),
http: http.clone(),
events: events.clone(),
},
);
let mut profiles = BTreeMap::new();
for (name, profile) in &cfg.profiles {
let runtime = build_profile_runtime(
name,
profile,
default_policy.as_ref(),
PolicyBuildDeps {
cfg: cfg.clone(),
inventory: inventory.clone(),
http: http.clone(),
events: events.clone(),
},
);
profiles.insert(name.clone(), runtime);
}
let proxy_passthrough_pipeline = build_proxy_passthrough_pipeline(http.clone(), events.clone());
let proxy_switch_pipeline = build_proxy_switch_pipeline(pool.clone(), http.clone(), events.clone());
Ok(AppState {
inventory,
pool,
provider_registry,
identity,
route,
request_pipeline: default_policy.request_pipeline.clone(),
passthrough_pipeline: default_policy.passthrough_pipeline.clone(),
default_policy,
profiles: Arc::new(profiles),
http,
events,
body_max_bytes,
proxy_provider_modes: Arc::new(cfg.proxy_mode.provider_modes.clone()),
proxy_passthrough_pipeline,
proxy_switch_pipeline,
})
}
fn effective_default_mode(cfg: &Config) -> RouteMode {
if cfg.defaults.mode == RouteMode::Route && cfg.server.route_mode != RouteMode::Route {
cfg.server.route_mode
} else {
cfg.defaults.mode
}
}
fn effective_default_families(cfg: &Config) -> Vec<ModelFamily> {
if cfg.defaults.model_families.is_empty() {
cfg.model_families.clone()
} else {
cfg.defaults.model_families.clone()
}
}
fn build_profile_runtime(
name: &str,
profile: &ProfileConfig,
default_policy: &RequestPolicyRuntime,
deps: PolicyBuildDeps,
) -> Arc<RequestPolicyRuntime> {
let mode = profile.mode.unwrap_or(default_policy.mode);
let agent_id = profile.agent_id.clone().or_else(|| default_policy.agent_id.clone());
let providers = profile.providers.clone().or_else(|| {
default_policy
.ruleset
.providers
.as_ref()
.map(|providers| providers.iter().cloned().collect())
});
let accounts = profile.accounts.clone().or_else(|| {
default_policy
.ruleset
.accounts
.as_ref()
.map(|accounts| accounts.iter().cloned().collect())
});
let families = profile
.model_families
.clone()
.unwrap_or_else(|| default_policy.model_families.clone());
build_policy_runtime(name, mode, agent_id, providers, accounts, families, deps)
}
fn build_policy_runtime(
name: &str,
mode: RouteMode,
agent_id: Option<AgentId>,
providers: Option<Vec<String>>,
accounts: Option<Vec<String>>,
families: Vec<ModelFamily>,
deps: PolicyBuildDeps,
) -> Arc<RequestPolicyRuntime> {
let route = Arc::new(RouteResolver::new(mode, &families));
let ruleset = AccountPoolRuleset::from_filters(providers, accounts);
let pool = AccountPool::from_inventory(&deps.inventory, &deps.cfg, &ruleset)
.expect("building account pool from validated inventory must not fail");
Arc::new(RequestPolicyRuntime {
mode,
agent_id,
ruleset,
pool: pool.clone(),
route: route.clone(),
model_families: families,
request_pipeline: build_request_pipeline(
format!("router-{name}"),
pool.clone(),
route.clone(),
deps.http.clone(),
deps.events.clone(),
),
passthrough_pipeline: build_passthrough_pipeline(
format!("router-{name}-passthrough"),
pool,
route,
deps.http,
deps.events,
),
})
}
fn validate_proxy_provider_modes(cfg: &Config, provider_registry: &ProviderRegistry) -> Result<()> {
let mut unresolved = Vec::new();
for provider_id in cfg.proxy_mode.provider_modes.keys() {
if provider_registry.resolve(provider_id).is_none() {
tracing::warn!(
provider_id = %provider_id,
"ignoring unresolved [proxy_mode].provider_modes entry"
);
unresolved.push(provider_id.clone());
}
}
if unresolved.is_empty() {
return Ok(());
}
anyhow::bail!(
"[proxy_mode].provider_modes contains unresolved provider ids: {}",
unresolved.join(", ")
);
}
fn validate_policy_providers(cfg: &Config, provider_registry: &ProviderRegistry) -> Result<()> {
let mut unresolved = Vec::new();
for provider_id in cfg
.defaults
.providers
.iter()
.flat_map(|providers| providers.iter().map(String::as_str))
{
if provider_registry.resolve(provider_id).is_none() {
unresolved.push(format!("defaults.providers:{provider_id}"));
}
}
for (profile_name, profile) in &cfg.profiles {
for provider_id in profile
.providers
.iter()
.flat_map(|providers| providers.iter().map(String::as_str))
{
if provider_registry.resolve(provider_id).is_none() {
unresolved.push(format!("profiles.{profile_name}.providers:{provider_id}"));
}
}
}
if unresolved.is_empty() {
return Ok(());
}
anyhow::bail!(
"profile/default provider filters contain unknown provider ids: {}",
unresolved.join(", ")
)
}
fn validate_policy_accounts(cfg: &Config, accounts: &[AccountConfig]) -> Result<()> {
let known = accounts
.iter()
.map(|account| account.id.as_str())
.collect::<BTreeSet<_>>();
let mut unresolved = Vec::new();
for account_id in cfg
.defaults
.accounts
.iter()
.flat_map(|accounts| accounts.iter().map(String::as_str))
{
if !known.contains(account_id) {
unresolved.push(format!("defaults.accounts:{account_id}"));
}
}
for (profile_name, profile) in &cfg.profiles {
for account_id in profile
.accounts
.iter()
.flat_map(|accounts| accounts.iter().map(String::as_str))
{
if !known.contains(account_id) {
unresolved.push(format!("profiles.{profile_name}.accounts:{account_id}"));
}
}
}
if unresolved.is_empty() {
return Ok(());
}
anyhow::bail!(
"profile/default account filters contain unknown account ids: {}",
unresolved.join(", ")
)
}
fn build_request_pipeline(
name: impl Into<smol_str::SmolStr>,
pool: Arc<AccountPool>,
route: Arc<RouteResolver>,
http: reqwest::Client,
events: Arc<EventBus>,
) -> Arc<tokn_requests::Pipeline> {
use tokn_requests::stages::{
DefaultBuildHeaders, DefaultConvertRequest, DefaultConvertResponse, DefaultExtract, DefaultSend,
PoolAccountSelector, PoolResolve,
};
let selector = Arc::new(PoolAccountSelector::new(pool, route));
let profile = tokn_requests::Profile::full(
name,
Arc::new(DefaultExtract),
Arc::new(PoolResolve::new(selector)),
Arc::new(DefaultBuildHeaders::with_provider_defaults()),
Arc::new(DefaultConvertRequest),
Arc::new(DefaultSend::new(http)),
Arc::new(DefaultConvertResponse::new()),
);
Arc::new(tokn_requests::Pipeline::new_with_retry(
Arc::new(profile),
events,
PIPELINE_RETRY_POLICY,
))
}
fn build_passthrough_pipeline(
name: impl Into<smol_str::SmolStr>,
pool: Arc<AccountPool>,
route: Arc<RouteResolver>,
http: reqwest::Client,
events: Arc<EventBus>,
) -> Arc<tokn_requests::Pipeline> {
use tokn_requests::stages::{
DefaultSend, PassthroughBuildHeaders, PassthroughConvertRequest, PassthroughConvertResponse, PassthroughExtract,
PoolAccountSelector, PoolResolve,
};
let selector = Arc::new(PoolAccountSelector::new(pool, route));
let profile = tokn_requests::Profile::full(
name,
Arc::new(PassthroughExtract),
Arc::new(PoolResolve::new(selector)),
Arc::new(PassthroughBuildHeaders::new()),
Arc::new(PassthroughConvertRequest),
Arc::new(DefaultSend::new(http)),
Arc::new(PassthroughConvertResponse::new()),
);
Arc::new(tokn_requests::Pipeline::new_with_retry(
Arc::new(profile),
events,
PIPELINE_RETRY_POLICY,
))
}
fn build_proxy_passthrough_pipeline(http: reqwest::Client, events: Arc<EventBus>) -> Arc<tokn_requests::Pipeline> {
use tokn_requests::stages::{
PassthroughBuildHeaders, PassthroughConvertRequest, PassthroughConvertResponse, PassthroughExtract, ProxyResolve,
ProxySend,
};
let profile = tokn_requests::Profile::full(
"router-proxy-passthrough",
Arc::new(PassthroughExtract),
Arc::new(ProxyResolve),
Arc::new(PassthroughBuildHeaders::preserve_host()),
Arc::new(PassthroughConvertRequest),
Arc::new(ProxySend::new(http)),
Arc::new(PassthroughConvertResponse::new()),
);
Arc::new(tokn_requests::Pipeline::new_with_retry(
Arc::new(profile),
events,
PIPELINE_RETRY_POLICY,
))
}
fn build_proxy_switch_pipeline(
pool: Arc<AccountPool>,
http: reqwest::Client,
events: Arc<EventBus>,
) -> Arc<tokn_requests::Pipeline> {
use tokn_requests::stages::{
PassthroughBuildHeaders, PassthroughConvertRequest, PassthroughConvertResponse, PassthroughExtract,
ProxyProviderResolve, ProxySend,
};
let profile = tokn_requests::Profile::full(
"router-proxy-switch",
Arc::new(PassthroughExtract),
Arc::new(ProxyProviderResolve::new(pool)),
Arc::new(PassthroughBuildHeaders::preserve_host_with_router_auth()),
Arc::new(PassthroughConvertRequest),
Arc::new(ProxySend::new(http)),
Arc::new(PassthroughConvertResponse::new()),
);
Arc::new(tokn_requests::Pipeline::new(Arc::new(profile), events))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Account as AccountCfg, Config, ProfileConfig};
use crate::util::secret::Secret;
use axum::body::{to_bytes, Body};
use axum::http::{Method, Request, StatusCode};
use axum::routing::get;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokn_headers::inbound::{PROJECT_ID_HEADERS, REQUEST_ID_HEADERS, SESSION_ID_HEADERS};
use tower::ServiceExt;
fn zai_account() -> AccountCfg {
zai_account_with_id("acct")
}
fn zai_account_with_id(id: &str) -> AccountCfg {
zai_account_with_id_and_base(id, None)
}
fn zai_account_with_id_and_base(id: &str, base_url: Option<String>) -> AccountCfg {
AccountCfg {
id: id.into(),
provider: "zai-coding-plan".into(),
enabled: true,
tier: tokn_core::account::AccountTier::Active,
tags: Vec::new(),
label: None,
base_url,
headers: Default::default(),
auth_type: None,
username: None,
api_key: Some(Secret::new("sk-test".into())),
api_key_expires_at: None,
access_token: None,
access_token_expires_at: None,
id_token: None,
refresh_token: None,
provider_account_id: None,
extra: Default::default(),
refresh_url: None,
last_refresh: None,
settings: toml::Table::new(),
}
}
fn core_account(cfg: AccountCfg) -> AccountConfig {
let s = toml::to_string(&cfg).expect("serialize account");
toml::from_str(&s).expect("parse core account")
}
struct ControlledUpstream {
base_url: String,
arrived: tokio::sync::oneshot::Receiver<()>,
release: tokio::sync::oneshot::Sender<()>,
request: tokio::sync::oneshot::Receiver<Vec<u8>>,
task: tokio::task::JoinHandle<()>,
}
async fn controlled_chat_upstream(label: &'static str) -> ControlledUpstream {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let base_url = format!("http://{addr}");
let (arrived_tx, arrived) = tokio::sync::oneshot::channel();
let (release, release_rx) = tokio::sync::oneshot::channel();
let (request_tx, request) = tokio::sync::oneshot::channel();
let task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let req = read_http_request(&mut stream).await;
let _ = arrived_tx.send(());
let _ = release_rx.await;
let body = chat_completion_body(label);
let resp = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n",
body.len()
);
stream.write_all(resp.as_bytes()).await.unwrap();
stream.write_all(body.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
let _ = request_tx.send(req);
});
ControlledUpstream {
base_url,
arrived,
release,
request,
task,
}
}
async fn read_http_request(stream: &mut tokio::net::TcpStream) -> Vec<u8> {
let mut buf = Vec::new();
let mut chunk = [0_u8; 4096];
loop {
let n = stream.read(&mut chunk).await.unwrap();
assert!(n > 0, "connection closed before request was complete");
buf.extend_from_slice(&chunk[..n]);
let Some(header_end) = find_header_end(&buf) else {
continue;
};
let headers = String::from_utf8_lossy(&buf[..header_end]);
let content_len = headers
.lines()
.find_map(|line| {
line
.strip_prefix("Content-Length:")
.or_else(|| line.strip_prefix("content-length:"))
.and_then(|v| v.trim().parse::<usize>().ok())
})
.unwrap_or(0);
if buf.len() >= header_end + 4 + content_len {
return buf;
}
}
}
fn find_header_end(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|window| window == b"\r\n\r\n")
}
fn chat_completion_body(label: &str) -> String {
serde_json::json!({
"id": format!("chatcmpl-{label}"),
"object": "chat.completion",
"created": 0,
"model": "glm-4.7",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": label,
},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
})
.to_string()
}
fn routed_state_for_upstream(base_url: String) -> AppState {
let mut cfg = Config::default();
cfg.defaults.mode = RouteMode::Route;
build_state(
&cfg,
&[core_account(zai_account_with_id_and_base("routed", Some(base_url)))],
Arc::new(EventBus::noop()),
)
.expect("routed test state should build")
}
fn chat_request(request_id: &str) -> Request<Body> {
Request::builder()
.method(Method::POST)
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("x-request-id", request_id)
.body(Body::from(Bytes::from_static(
br#"{"model":"glm-4.7","messages":[{"role":"user","content":"hi"}],"stream":false}"#,
)))
.unwrap()
}
fn test_router() -> Router {
let header = HeaderName::from_static(REQUEST_ID_HEADER);
Router::new()
.route("/probe", get(|| async { "ok" }))
.layer(PropagateRequestIdLayer::new(header.clone()))
.layer(SetRequestIdLayer::new(header, MakeRequestUuid))
}
#[tokio::test]
async fn inbound_request_id_passes_through() {
let app = test_router();
let req = Request::builder()
.uri("/probe")
.header(REQUEST_ID_HEADER, "client-supplied-123")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let echoed = resp
.headers()
.get(REQUEST_ID_HEADER)
.expect("response missing x-request-id")
.to_str()
.unwrap();
assert_eq!(echoed, "client-supplied-123");
}
#[tokio::test]
async fn missing_request_id_is_generated() {
let app = test_router();
let req = Request::builder().uri("/probe").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let id = resp
.headers()
.get(REQUEST_ID_HEADER)
.expect("response missing generated x-request-id")
.to_str()
.unwrap();
assert!(uuid::Uuid::parse_str(id).is_ok(), "not a uuid: {id}");
}
#[test]
fn first_header_uses_priority_and_ignores_empty_values() {
let mut headers = HeaderMap::new();
headers.insert("x-session-id", " ".parse().unwrap());
headers.insert("x-client-session-id", " client-session ".parse().unwrap());
headers.insert("x-opencode-session", "opencode-session".parse().unwrap());
assert_eq!(first_header(&headers, SESSION_ID_HEADERS), Some("client-session"));
}
#[test]
fn build_state_allows_empty_accounts_in_passthrough_mode() {
let mut cfg = Config::default();
cfg.server.route_mode = RouteMode::Passthrough;
let bus = EventBus::new(16);
let state = build_state(&cfg, &[], Arc::new(bus)).expect("passthrough mode should allow empty accounts");
assert_eq!(state.pool.len(), 0);
}
#[test]
fn build_state_rejects_empty_accounts_in_non_passthrough_mode() {
let mut cfg = Config::default();
cfg.server.route_mode = RouteMode::Route;
let bus = EventBus::new(16);
let res = build_state(&cfg, &[], Arc::new(bus));
assert!(res.is_err(), "non-passthrough mode should require accounts");
let err = res.err().expect("checked above");
assert!(err.to_string().contains("no accounts configured"));
}
fn passthrough_state(
body_max_bytes: usize,
default_mode: RouteMode,
proxy_provider_mode: ProxyProviderMode,
profile_name: &str,
) -> AppState {
let mut cfg = Config::default();
cfg.server.route_mode = RouteMode::Passthrough;
cfg.defaults.mode = default_mode;
cfg.db.enabled = true;
cfg.db.body_max_bytes = body_max_bytes;
cfg
.proxy_mode
.provider_modes
.insert("openai".into(), proxy_provider_mode);
cfg.profiles.insert(profile_name.into(), ProfileConfig::default());
build_state(&cfg, &[zai_account()], Arc::new(EventBus::noop())).expect("test state should build")
}
#[tokio::test]
async fn admin_config_reload_swaps_live_state() {
let initial = passthrough_state(1, RouteMode::Passthrough, ProxyProviderMode::Passthrough, "old-profile");
let replacement = passthrough_state(2, RouteMode::Fuzzy, ProxyProviderMode::Switch, "new-profile");
let live = LiveAppState::new(initial);
let live_for_reload = live.clone();
assert!(live
.set_admin_reloader(AdminReloader::new(move || {
let live = live_for_reload.clone();
let replacement = replacement.clone();
async move {
live.swap(replacement);
Ok(ReloadReport {
status: "reloaded",
generation: 1,
accounts: 0,
route_mode: "passthrough",
})
}
}))
.is_ok());
let app = router_live(live.clone());
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/admin/config/reload")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let current = live.current();
assert_eq!(current.body_max_bytes, 2);
assert_eq!(current.default_policy.mode, RouteMode::Fuzzy);
assert_eq!(current.route.resolve_mode(None).unwrap(), RouteMode::Fuzzy);
assert_eq!(
current.proxy_provider_modes.get("openai"),
Some(&ProxyProviderMode::Switch)
);
assert!(current.profiles.contains_key("new-profile"));
assert!(!current.profiles.contains_key("old-profile"));
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "reloaded");
assert_eq!(json["generation"], 1);
assert_eq!(json["accounts"], 0);
assert_eq!(json["route_mode"], "passthrough");
}
#[tokio::test]
async fn admin_config_reload_failure_keeps_live_state() {
let live = LiveAppState::new(passthrough_state(
1,
RouteMode::Passthrough,
ProxyProviderMode::Passthrough,
"old-profile",
));
assert!(live
.set_admin_reloader(AdminReloader::new(|| async { Err("invalid config".into()) }))
.is_ok());
let app = router_live(live.clone());
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/admin/config/reload")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY);
let current = live.current();
assert_eq!(current.body_max_bytes, 1);
assert_eq!(current.default_policy.mode, RouteMode::Passthrough);
assert_eq!(current.route.resolve_mode(None).unwrap(), RouteMode::Passthrough);
assert_eq!(
current.proxy_provider_modes.get("openai"),
Some(&ProxyProviderMode::Passthrough)
);
assert!(current.profiles.contains_key("old-profile"));
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["type"], "reload_failed");
assert_eq!(json["error"]["message"], "invalid config");
}
#[tokio::test]
async fn routed_requests_keep_in_flight_policy_and_new_requests_use_reloaded_policy() {
let old_upstream = controlled_chat_upstream("old").await;
let new_upstream = controlled_chat_upstream("new").await;
let live = LiveAppState::new(routed_state_for_upstream(old_upstream.base_url.clone()));
let replacement = routed_state_for_upstream(new_upstream.base_url.clone());
let live_for_reload = live.clone();
assert!(live
.set_admin_reloader(AdminReloader::new(move || {
let live = live_for_reload.clone();
let replacement = replacement.clone();
async move {
live.swap(replacement);
Ok(ReloadReport {
status: "reloaded",
generation: 1,
accounts: 1,
route_mode: "route",
})
}
}))
.is_ok());
let app = router_live(live);
let old_request = {
let app = app.clone();
tokio::spawn(async move { app.oneshot(chat_request("old-request")).await.unwrap() })
};
tokio::time::timeout(std::time::Duration::from_secs(2), old_upstream.arrived)
.await
.expect("old request should reach old upstream before reload")
.unwrap();
let reload = app
.clone()
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/admin/config/reload")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(reload.status(), StatusCode::OK);
let _ = to_bytes(reload.into_body(), usize::MAX).await.unwrap();
old_upstream.release.send(()).unwrap();
let old_response = old_request.await.unwrap();
assert_eq!(old_response.status(), StatusCode::OK);
let old_body = to_bytes(old_response.into_body(), usize::MAX).await.unwrap();
assert!(
String::from_utf8_lossy(&old_body).contains("chatcmpl-old"),
"old in-flight request should finish with old upstream response"
);
old_upstream.task.await.unwrap();
let old_raw_request = String::from_utf8_lossy(&old_upstream.request.await.unwrap()).to_string();
assert!(
old_raw_request.contains(r#""model":"glm-4.7""#),
"old upstream should receive the pre-reload request"
);
let new_request = {
let app = app.clone();
tokio::spawn(async move { app.oneshot(chat_request("new-request")).await.unwrap() })
};
tokio::time::timeout(std::time::Duration::from_secs(2), new_upstream.arrived)
.await
.expect("new request should reach reloaded upstream")
.unwrap();
new_upstream.release.send(()).unwrap();
let new_response = new_request.await.unwrap();
assert_eq!(new_response.status(), StatusCode::OK);
let new_body = to_bytes(new_response.into_body(), usize::MAX).await.unwrap();
assert!(
String::from_utf8_lossy(&new_body).contains("chatcmpl-new"),
"new request should finish with reloaded upstream response"
);
new_upstream.task.await.unwrap();
let new_raw_request = String::from_utf8_lossy(&new_upstream.request.await.unwrap()).to_string();
assert!(
new_raw_request.contains(r#""model":"glm-4.7""#),
"new upstream should receive the post-reload request"
);
}
#[test]
fn build_state_allows_unknown_proxy_provider_mode_provider() {
let mut cfg = Config::default();
cfg.server.route_mode = RouteMode::Passthrough;
cfg
.proxy_mode
.provider_modes
.insert("made-up-provider".into(), ProxyProviderMode::Switch);
let bus = EventBus::new(16);
let res = build_state(&cfg, &[], Arc::new(bus));
assert!(
res.is_ok(),
"unknown provider ids should only warn and not fail state construction"
);
}
#[test]
fn validate_proxy_provider_modes_returns_error_for_unknown_provider() {
let mut cfg = Config::default();
cfg
.proxy_mode
.provider_modes
.insert("made-up-provider".into(), ProxyProviderMode::Switch);
let registry = ProviderRegistry::builtin();
let res = validate_proxy_provider_modes(&cfg, ®istry);
let err = res.expect_err("helper should still return an error for outside callers");
assert!(err.to_string().contains("unresolved provider ids"));
assert!(err.to_string().contains("made-up-provider"));
}
#[tokio::test]
async fn route_mode_not_implemented_returns_json_error_body() {
let cfg = Config::default();
let accounts = vec![zai_account()];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
let app = router(state);
let req = Request::builder()
.method("POST")
.uri("/v1/responses")
.header("content-type", "application/json")
.header("x-route-mode", "route")
.body(Body::from(Bytes::from_static(br#"{"model":"unknown","input":"hi"}"#)))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
assert_eq!(
resp
.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok()),
Some("application/json")
);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let message = json["error"]["message"].as_str().unwrap();
assert!(!message.is_empty());
assert!(message.contains("responses"));
assert!(message.contains("unknown"));
assert_eq!(json["error"]["type"], "not_implemented_error");
assert_eq!(json["error"]["code"], 501);
}
#[tokio::test]
async fn profile_provider_filter_excludes_other_provider_accounts() {
let mut cfg = Config::default();
cfg.profiles.insert(
"openai-only".into(),
ProfileConfig {
providers: Some(vec!["openai".into()]),
..Default::default()
},
);
let accounts = vec![zai_account()];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
assert_eq!(state.profiles.get("openai-only").unwrap().pool.len(), 0);
let app = router(state);
let req = Request::builder()
.method("POST")
.uri("/openai-only/v1/responses")
.header("content-type", "application/json")
.body(Body::from(Bytes::from_static(br#"{"model":"glm-4.6","input":"hi"}"#)))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
}
#[tokio::test]
async fn profile_models_uses_prefiltered_policy_pool() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let upstream_url = format!("http://{addr}");
let (req_tx, req_rx) = tokio::sync::oneshot::channel::<Vec<u8>>();
let server = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = vec![0_u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
buf.truncate(n);
let body = br#"{"object":"list","data":[{"id":"policy-only-model","object":"model"}]}"#;
let resp = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n",
body.len()
);
stream.write_all(resp.as_bytes()).await.unwrap();
stream.write_all(body).await.unwrap();
stream.flush().await.unwrap();
let _ = req_tx.send(buf);
});
let mut cfg = Config::default();
cfg.profiles.insert(
"work".into(),
ProfileConfig {
accounts: Some(vec!["local".into()]),
..Default::default()
},
);
let accounts = vec![
zai_account_with_id_and_base("local", Some(upstream_url)),
zai_account_with_id("excluded"),
];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
assert_eq!(state.pool.len(), 2);
assert_eq!(state.profiles.get("work").unwrap().pool.len(), 1);
let app = router(state);
let req = Request::builder()
.method("GET")
.uri("/work/v1/models")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let ids = json["data"]
.as_array()
.unwrap()
.iter()
.filter_map(|model| model["id"].as_str())
.collect::<Vec<_>>();
assert_eq!(ids, vec!["policy-only-model"]);
let upstream_req = req_rx.await.unwrap();
let upstream_req = String::from_utf8_lossy(&upstream_req);
assert!(upstream_req.starts_with("GET /models "));
server.await.unwrap();
}
#[test]
fn profile_without_providers_inherits_unrestricted_defaults() {
let mut cfg = Config::default();
cfg.profiles.insert("work".into(), ProfileConfig::default());
let accounts = vec![zai_account()];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
let work = state.profiles.get("work").expect("work profile");
assert!(state.default_policy.ruleset.providers.is_none());
assert!(work.ruleset.providers.is_none());
assert_eq!(work.pool.len(), 1);
}
#[test]
fn profile_providers_replace_default_providers() {
let mut cfg = Config::default();
cfg.defaults.providers = Some(vec!["openai".into()]);
cfg.profiles.insert(
"zai-only".into(),
ProfileConfig {
providers: Some(vec!["zai-coding-plan".into()]),
..Default::default()
},
);
let accounts = vec![zai_account()];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
let providers = state
.profiles
.get("zai-only")
.and_then(|profile| profile.ruleset.providers.as_ref())
.expect("profile should have provider filter");
assert!(providers.contains("zai-coding-plan"));
assert!(!providers.contains("openai"));
assert_eq!(state.profiles.get("zai-only").unwrap().pool.len(), 1);
}
#[test]
fn profile_accounts_replace_default_accounts() {
let mut cfg = Config::default();
cfg.defaults.accounts = Some(vec!["other".into()]);
cfg.profiles.insert(
"agent".into(),
ProfileConfig {
accounts: Some(vec!["acct".into()]),
..Default::default()
},
);
let accounts = vec![zai_account(), zai_account_with_id("other")];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
let profile_accounts = state
.profiles
.get("agent")
.and_then(|profile| profile.ruleset.accounts.as_ref())
.expect("profile should have account filter");
assert!(profile_accounts.contains("acct"));
assert!(!profile_accounts.contains("other"));
let account_ids = state
.profiles
.get("agent")
.unwrap()
.pool
.all()
.iter()
.map(|account| account.id())
.collect::<Vec<_>>();
assert_eq!(account_ids, vec!["acct".to_string()]);
}
#[test]
fn profile_model_families_replace_default_families() {
let mut cfg = Config::default();
cfg.defaults.model_families = vec![ModelFamily {
name: "glm-family".into(),
members: vec!["glm-4.7".into()],
}];
cfg.profiles.insert(
"work".into(),
ProfileConfig {
mode: Some(RouteMode::Fuzzy),
model_families: Some(vec![ModelFamily {
name: "glm-family".into(),
members: vec!["glm-5.1".into()],
}]),
..Default::default()
},
);
let accounts = vec![zai_account()];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
let work = state.profiles.get("work").expect("work profile");
assert_eq!(work.model_families.len(), 1);
assert_eq!(work.model_families[0].members, vec!["glm-5.1".to_string()]);
let resolved = work.route.resolve("glm-family", None).unwrap();
assert_eq!(
resolved.selector,
tokn_accounts::routing::RouteSelector::Fuzzy {
candidates: vec!["glm-5.1".into()]
}
);
}
#[test]
fn build_state_rejects_unknown_policy_provider_filters() {
let mut cfg = Config::default();
cfg.defaults.providers = Some(vec!["made-up-provider".into()]);
let err = match build_state(&cfg, &[], Arc::new(EventBus::noop())) {
Ok(_) => panic!("unknown default provider must fail"),
Err(err) => err,
};
assert!(err.to_string().contains("defaults.providers:made-up-provider"));
let mut cfg = Config::default();
cfg.profiles.insert(
"work".into(),
ProfileConfig {
providers: Some(vec!["made-up-provider".into()]),
..Default::default()
},
);
let err = match build_state(&cfg, &[], Arc::new(EventBus::noop())) {
Ok(_) => panic!("unknown profile provider must fail"),
Err(err) => err,
};
assert!(err.to_string().contains("profiles.work.providers:made-up-provider"));
}
#[test]
fn build_state_rejects_unknown_policy_account_filters() {
let accounts = vec![zai_account()];
let mut cfg = Config::default();
cfg.defaults.accounts = Some(vec!["missing".into()]);
let err = match build_state(&cfg, &accounts, Arc::new(EventBus::noop())) {
Ok(_) => panic!("unknown default account must fail"),
Err(err) => err,
};
assert!(err.to_string().contains("defaults.accounts:missing"));
let mut cfg = Config::default();
cfg.profiles.insert(
"work".into(),
ProfileConfig {
accounts: Some(vec!["missing".into()]),
..Default::default()
},
);
let err = match build_state(&cfg, &accounts, Arc::new(EventBus::noop())) {
Ok(_) => panic!("unknown profile account must fail"),
Err(err) => err,
};
assert!(err.to_string().contains("profiles.work.accounts:missing"));
}
#[tokio::test]
async fn unknown_profile_path_returns_bad_request() {
let cfg = Config::default();
let accounts = vec![zai_account()];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
let app = router(state);
let req = Request::builder()
.method("POST")
.uri("/missing/v1/responses")
.header("content-type", "application/json")
.body(Body::from(Bytes::from_static(br#"{"model":"glm-4.6","input":"hi"}"#)))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn old_mode_path_is_not_magic_without_matching_profile() {
let cfg = Config::default();
let accounts = vec![zai_account()];
let state = build_state(&cfg, &accounts, Arc::new(EventBus::noop())).unwrap();
let app = router(state);
let req = Request::builder()
.method("POST")
.uri("/fuzzy/v1/responses")
.header("content-type", "application/json")
.body(Body::from(Bytes::from_static(br#"{"model":"glm-4.6","input":"hi"}"#)))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn is_router_owned_header_does_not_include_request_session_project_id_headers() {
use axum::http::HeaderName;
for header in REQUEST_ID_HEADERS.iter() {
let name = HeaderName::try_from(*header).unwrap();
assert!(!is_router_owned_header(&name), "{header} should NOT be router-owned");
}
for header in SESSION_ID_HEADERS.iter() {
let name = HeaderName::try_from(*header).unwrap();
assert!(!is_router_owned_header(&name), "{header} should NOT be router-owned");
}
for header in PROJECT_ID_HEADERS.iter() {
let name = HeaderName::try_from(*header).unwrap();
assert!(!is_router_owned_header(&name), "{header} should NOT be router-owned");
}
}
}