use std::net::SocketAddr;
use std::path::PathBuf;
use std::process::ExitCode;
use oxibonsai_core::config::Qwen3Config;
use oxibonsai_runtime::engine::InferenceEngine;
use oxibonsai_runtime::sampling::SamplingParams;
use oxibonsai_runtime::server::{create_router, serve_with_shutdown, shutdown_signal};
use oxibonsai_runtime::tokenizer_bridge::TokenizerBridge;
use oxibonsai_serve::{
args::parse_args_from,
banner,
config::{PartialServerConfig, ServerConfig},
env::parse_process_env,
};
use tracing::{error, info, warn};
#[tokio::main]
async fn main() -> ExitCode {
match run().await {
Ok(()) => ExitCode::SUCCESS,
Err(err) => {
error!(%err, "oxibonsai-serve startup failed");
eprintln!("error: {err}");
ExitCode::FAILURE
}
}
}
async fn run() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let argv: Vec<String> = std::env::args().collect();
let cli_args = match parse_args_from(&argv)? {
Some(a) => a,
None => return Ok(()),
};
let bootstrap_filter = tracing_subscriber::EnvFilter::try_new(&cli_args.log_level)
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
let _ = tracing_subscriber::fmt()
.with_env_filter(bootstrap_filter)
.with_target(false)
.compact()
.try_init();
let toml_path: Option<PathBuf> = cli_args.config_path.as_ref().map(PathBuf::from);
let env_partial = parse_process_env()?;
let cli_partial: PartialServerConfig = cli_args.to_partial();
let config = ServerConfig::load(toml_path.as_deref(), Some(env_partial), Some(cli_partial))?;
banner::print_banner();
info!(
"{}",
banner::startup_message(&config.bind.host, config.bind.port)
);
let sampling = SamplingParams {
temperature: config.sampling.default_temperature,
top_p: config.sampling.default_top_p,
..SamplingParams::default()
};
let engine: InferenceEngine<'static> = match config.model.path.as_ref() {
Some(path) => {
info!(path = %path.display(), "loading GGUF model");
match InferenceEngine::from_gguf_path(
path,
sampling.clone(),
config.seed,
config.limits.max_input_tokens,
) {
Ok(e) => {
info!("GGUF model loaded");
e
}
Err(err) => {
error!(
path = %path.display(),
%err,
"failed to load GGUF model"
);
return Err(format!("failed to load GGUF model: {err}").into());
}
}
}
None => {
warn!("no --model path supplied; falling back to tiny_test engine");
let tiny = Qwen3Config::tiny_test();
InferenceEngine::new(tiny, sampling, config.seed)
}
};
let tokenizer = match config.tokenizer.path.as_ref() {
Some(path) => match TokenizerBridge::from_file(&path.display().to_string()) {
Ok(t) => {
info!(path = %path.display(), "tokenizer loaded");
Some(t)
}
Err(err) => {
error!(path = %path.display(), %err, "failed to load tokenizer");
return Err(format!("failed to load tokenizer: {err}").into());
}
},
None => {
let lookup = match config.model.path.as_ref() {
Some(model_path) => tokenizer_lookup::resolve_tokenizer_for_model(model_path),
None => tokenizer_lookup::TokenizerLookup::default(),
};
match lookup.found {
Some(found) => match TokenizerBridge::from_file(&found) {
Ok(t) => {
info!(path = %found, "auto-detected tokenizer alongside model");
Some(t)
}
Err(err) => {
error!(path = %found, %err, "failed to load auto-detected tokenizer");
return Err(format!("failed to load tokenizer: {err}").into());
}
},
None => {
warn!(
"{}",
tokenizer_lookup::missing_tokenizer_warning(&lookup.searched)
);
None
}
}
}
};
let base_router = create_router(engine, tokenizer);
let router = if let Some(ref token) = config.auth.bearer_token {
let state = middleware::BearerAuthState {
token: token.clone(),
};
info!("bearer-token authentication enabled");
base_router.layer(axum::middleware::from_fn_with_state(
state,
middleware::bearer_auth,
))
} else {
base_router
};
let addr_str = format!("{}:{}", config.bind.host, config.bind.port);
let addr: SocketAddr = addr_str
.parse()
.map_err(|e| format!("invalid bind address '{addr_str}': {e}"))?;
info!(%addr, "starting listener");
serve_with_shutdown(router, addr, shutdown_signal()).await?;
info!("oxibonsai-serve exited cleanly");
Ok(())
}
mod tokenizer_lookup {
use std::path::{Path, PathBuf};
#[derive(Debug, Default)]
pub struct TokenizerLookup {
pub found: Option<String>,
pub searched: Vec<PathBuf>,
}
fn strip_quant_suffix(basename: &str) -> &str {
let Some(dash_pos) = basename.rfind('-') else {
return basename;
};
let suffix = &basename[dash_pos + 1..];
if suffix.is_empty() {
return basename;
}
let is_float = matches!(suffix, "F16" | "BF16" | "F32");
let is_quant = {
let mut chars = suffix.chars();
match chars.next() {
Some('Q') => {
let rest: String = chars.collect();
if rest.is_empty() {
false
} else {
let mut parts = rest.split('_');
let first = parts.next().unwrap_or("");
if first.is_empty() || !first.chars().all(|c| c.is_ascii_digit()) {
false
} else {
parts.all(|p| {
!p.is_empty() && p.chars().all(|c| c.is_ascii_alphanumeric())
})
}
}
}
_ => false,
}
};
if is_float || is_quant {
&basename[..dash_pos]
} else {
basename
}
}
fn tokenizer_candidates(model_path: &Path) -> Vec<PathBuf> {
let parent = model_path
.parent()
.filter(|p| !p.as_os_str().is_empty())
.map(Path::to_path_buf)
.unwrap_or_else(|| PathBuf::from("."));
let mut out: Vec<PathBuf> = Vec::new();
let push_unique = |p: PathBuf, out: &mut Vec<PathBuf>| {
if !out.iter().any(|existing| existing == &p) {
out.push(p);
}
};
push_unique(parent.join("tokenizer.json"), &mut out);
push_unique(parent.join("..").join("tokenizer.json"), &mut out);
if let Some(stem) = model_path.file_stem().and_then(|s| s.to_str()) {
let base = strip_quant_suffix(stem);
for variant in [
base.to_string(),
format!("{base}-unpacked"),
format!("{base}-ONNX"),
] {
push_unique(parent.join(&variant).join("tokenizer.json"), &mut out);
}
}
for ancestor in model_path.ancestors().skip(1) {
if ancestor.file_name().and_then(|n| n.to_str()) == Some("models") {
push_unique(ancestor.join("tokenizer.json"), &mut out);
break;
}
}
out
}
pub fn resolve_tokenizer_for_model(model_path: &Path) -> TokenizerLookup {
let candidates = tokenizer_candidates(model_path);
for candidate in &candidates {
if candidate.exists() {
return TokenizerLookup {
found: Some(candidate.to_string_lossy().into_owned()),
searched: candidates,
};
}
}
TokenizerLookup {
found: None,
searched: candidates,
}
}
pub fn missing_tokenizer_warning(searched: &[PathBuf]) -> String {
let mut msg = String::from("no tokenizer found. Searched:\n");
if searched.is_empty() {
msg.push_str(" (no candidate paths — model path was not provided)\n");
} else {
for path in searched {
msg.push_str(&format!(" - {}\n", path.display()));
}
}
msg.push_str("To fix:\n");
msg.push_str(" - Pass --tokenizer <path/to/tokenizer.json>, OR\n");
msg.push_str(
" - Run ./scripts/download_tokenizer.sh to fetch the Qwen3 tokenizer to models/tokenizer.json\n",
);
msg.push_str("Continuing with raw token IDs in output.");
msg
}
}
mod middleware {
use axum::body::Body;
use axum::extract::State;
use axum::http::{header, Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::Json;
#[derive(Debug, Clone)]
pub struct BearerAuthState {
pub token: String,
}
pub async fn bearer_auth(
State(state): State<BearerAuthState>,
req: Request<Body>,
next: Next,
) -> Response {
let path = req.uri().path();
if path == "/health" || path == "/metrics" {
return next.run(req).await;
}
let header_value = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
let presented = match header_value.and_then(|h| h.strip_prefix("Bearer ")) {
Some(tok) => tok.trim(),
None => {
return unauthorized("missing or malformed Authorization header").into_response();
}
};
if presented != state.token {
return unauthorized("invalid bearer token").into_response();
}
next.run(req).await
}
fn unauthorized(msg: &str) -> (StatusCode, Json<serde_json::Value>) {
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": {
"message": msg,
"type": "auth_error",
"param": null,
"code": null,
}
})),
)
}
}