use axum::extract::State;
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::response::{
sse::{Event, Sse},
IntoResponse, Json, Response,
};
use axum::Router;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
use crate::engine::InferenceEngine;
use crate::metrics::InferenceMetrics;
use crate::request_id::RequestId;
use crate::tokenizer_bridge::TokenizerBridge;
pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub fn resolve_request_id(headers: &HeaderMap) -> RequestId {
if let Some(v) = headers.get(REQUEST_ID_HEADER) {
if let Ok(s) = v.to_str() {
if let Some(id) = RequestId::from_uuid(s).or_else(|| RequestId::from_hex(s)) {
return id;
}
}
}
RequestId::new()
}
pub fn request_id_header_map(id: RequestId) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Ok(value) = HeaderValue::from_str(&id.as_uuid()) {
headers.insert(REQUEST_ID_HEADER, value);
}
headers
}
pub struct AppState {
engine: Mutex<InferenceEngine<'static>>,
tokenizer: Option<TokenizerBridge>,
metrics: Arc<InferenceMetrics>,
}
impl AppState {
pub async fn engine_lock(&self) -> tokio::sync::MutexGuard<'_, InferenceEngine<'static>> {
self.engine.lock().await
}
pub fn tokenizer(&self) -> Option<&TokenizerBridge> {
self.tokenizer.as_ref()
}
pub fn metrics(&self) -> &Arc<InferenceMetrics> {
&self.metrics
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<crate::api_types::ToolCallResult>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl ChatMessage {
pub fn text(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}
}
#[derive(Debug, Deserialize)]
pub struct ChatCompletionRequest {
pub messages: Vec<ChatMessage>,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default)]
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<crate::api_types::ToolDefinition>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<serde_json::Value>,
}
fn default_max_tokens() -> usize {
256
}
fn default_temperature() -> f32 {
0.7
}
#[derive(Debug, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub choices: Vec<ChatChoice>,
pub usage: Usage,
}
#[derive(Debug, Serialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Serialize)]
pub struct ChatChoice {
pub index: usize,
pub message: ChatMessage,
pub finish_reason: String,
}
#[derive(Serialize)]
struct ChatCompletionChunk {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<ChunkChoice>,
}
#[derive(Serialize)]
struct ChunkChoice {
index: usize,
delta: ChunkDelta,
finish_reason: Option<String>,
}
#[derive(Serialize)]
struct ChunkDelta {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
}
pub fn create_router(
engine: InferenceEngine<'static>,
tokenizer: Option<TokenizerBridge>,
) -> Router {
create_router_with_metrics(engine, tokenizer, Arc::new(InferenceMetrics::new()))
}
pub fn create_router_with_metrics(
engine: InferenceEngine<'static>,
tokenizer: Option<TokenizerBridge>,
metrics: Arc<InferenceMetrics>,
) -> Router {
let state = Arc::new(AppState {
engine: Mutex::new(engine),
tokenizer,
metrics,
});
let embeddings_router = crate::embeddings::create_embeddings_router(512);
Router::new()
.route(
"/v1/chat/completions",
axum::routing::post(chat_completions),
)
.route(
"/v1/chat/completions/extended",
axum::routing::post(crate::api_extensions::extended_chat_completions),
)
.route(
"/v1/completions",
axum::routing::post(crate::completions::create_completion),
)
.route("/v1/models", axum::routing::get(list_models))
.route("/health", axum::routing::get(health))
.route("/metrics", axum::routing::get(prometheus_metrics))
.with_state(state)
.merge(embeddings_router)
}
async fn health() -> &'static str {
"ok"
}
async fn prometheus_metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let body = state.metrics.render_prometheus();
(
StatusCode::OK,
[("content-type", "text/plain; version=0.0.4; charset=utf-8")],
body,
)
}
async fn list_models() -> Json<serde_json::Value> {
Json(serde_json::json!({
"object": "list",
"data": [{
"id": "bonsai-8b",
"object": "model",
"owned_by": "oxibonsai"
}]
}))
}
#[tracing::instrument(skip(state, headers, body), fields(request_id))]
async fn chat_completions(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(body): Json<ChatCompletionRequest>,
) -> Result<Response, StatusCode> {
let request_id = resolve_request_id(&headers);
tracing::Span::current().record("request_id", tracing::field::display(&request_id));
let request_start = std::time::Instant::now();
state.metrics.requests_total.inc();
state.metrics.active_requests.inc();
let prompt_text = build_prompt(&body.messages);
let prompt_tokens = if let Some(tok) = &state.tokenizer {
tok.encode(&prompt_text).map_err(|_| {
state.metrics.errors_total.inc();
state.metrics.active_requests.dec();
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
vec![151644]
};
state
.metrics
.prompt_tokens_total
.inc_by(prompt_tokens.len() as u64);
let result = if body.stream {
chat_completions_stream(
Arc::clone(&state),
prompt_tokens,
body.max_tokens,
request_id,
)
.await
} else {
chat_completions_non_stream(
Arc::clone(&state),
prompt_tokens,
body.max_tokens,
request_id,
)
.await
};
let elapsed = request_start.elapsed().as_secs_f64();
state.metrics.request_duration_seconds.observe(elapsed);
state.metrics.active_requests.dec();
if result.is_err() {
state.metrics.errors_total.inc();
}
result
}
async fn chat_completions_non_stream(
state: Arc<AppState>,
prompt_tokens: Vec<u32>,
max_tokens: usize,
request_id: RequestId,
) -> Result<Response, StatusCode> {
let prompt_len = prompt_tokens.len();
let mut engine = state.engine.lock().await;
let output_tokens = engine.generate(&prompt_tokens, max_tokens).map_err(|e| {
tracing::error!(error = %e, "generation failed");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let completion_len = output_tokens.len();
state
.metrics
.tokens_generated_total
.inc_by(completion_len as u64);
let content = if let Some(tok) = &state.tokenizer {
tok.decode(&output_tokens)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
} else {
format!("{output_tokens:?}")
};
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", rand_id()),
object: "chat.completion".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: Some(content),
tool_calls: None,
tool_call_id: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt_len,
completion_tokens: completion_len,
total_tokens: prompt_len + completion_len,
},
};
let headers = request_id_header_map(request_id);
Ok((headers, Json(response)).into_response())
}
async fn chat_completions_stream(
state: Arc<AppState>,
prompt_tokens: Vec<u32>,
max_tokens: usize,
request_id: RequestId,
) -> Result<Response, StatusCode> {
let completion_id = format!("chatcmpl-{}", rand_id());
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let (token_tx, token_rx) = tokio::sync::mpsc::unbounded_channel::<u32>();
let gen_state = Arc::clone(&state);
tokio::task::spawn_blocking(move || {
let rt = tokio::runtime::Handle::current();
let mut engine = rt.block_on(gen_state.engine.lock());
let _result = engine.generate_streaming(&prompt_tokens, max_tokens, &token_tx);
});
let id_for_stream = completion_id;
let state_for_stream = Arc::clone(&state);
let role_chunk = ChatCompletionChunk {
id: id_for_stream.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: "bonsai-8b".to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: ChunkDelta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
};
let role_event = match serde_json::to_string(&role_chunk) {
Ok(json) => json,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let id_clone = id_for_stream.clone();
let token_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(token_rx);
let mut stream_state = state_for_stream
.tokenizer
.as_ref()
.map(|t| t.new_decode_stream(true));
let content_stream = token_stream.filter_map(move |token_id| {
let text = match (&state_for_stream.tokenizer, stream_state.as_mut()) {
(Some(tok), Some(state)) => match tok.step_decode(state, token_id) {
Ok(Some(txt)) => txt,
Ok(None) => return None,
Err(_) => format!("[{token_id}]"),
},
_ => format!("[{token_id}]"),
};
let chunk = ChatCompletionChunk {
id: id_clone.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: "bonsai-8b".to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: ChunkDelta {
role: None,
content: Some(text),
},
finish_reason: None,
}],
};
Some(serde_json::to_string(&chunk).unwrap_or_default())
});
let finish_chunk = ChatCompletionChunk {
id: id_for_stream,
object: "chat.completion.chunk".to_string(),
created,
model: "bonsai-8b".to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: ChunkDelta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
let finish_json = serde_json::to_string(&finish_chunk).unwrap_or_default();
let role_stream = tokio_stream::once(role_event);
let full_stream = role_stream
.chain(content_stream)
.chain(tokio_stream::once(finish_json))
.map(|json_str| -> Result<Event, Infallible> { Ok(Event::default().data(json_str)) })
.chain(tokio_stream::once(Ok(Event::default().data("[DONE]"))));
let headers = request_id_header_map(request_id);
Ok((headers, Sse::new(full_stream)).into_response())
}
fn build_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for msg in messages {
let text = match msg.content.as_deref() {
Some(t) => t,
None => continue,
};
match msg.role.as_str() {
"system" => {
prompt.push_str("<|im_start|>system\n");
prompt.push_str(text);
prompt.push_str("<|im_end|>\n");
}
"user" => {
prompt.push_str("<|im_start|>user\n");
prompt.push_str(text);
prompt.push_str("<|im_end|>\n");
}
"assistant" => {
prompt.push_str("<|im_start|>assistant\n");
prompt.push_str(text);
prompt.push_str("<|im_end|>\n");
}
_ => {
prompt.push_str(text);
prompt.push('\n');
}
}
}
prompt.push_str("<|im_start|>assistant\n");
prompt
}
fn rand_id() -> String {
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
format!("{ts:x}")
}
pub async fn serve_with_shutdown(
router: Router,
addr: std::net::SocketAddr,
shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!(%addr, "server listening");
axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal)
.await?;
tracing::info!("server shut down gracefully");
Ok(())
}
pub async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {
tracing::info!("received Ctrl+C, initiating shutdown");
}
() = terminate => {
tracing::info!("received SIGTERM, initiating shutdown");
}
}
}
pub async fn create_server(
engine: InferenceEngine<'static>,
tokenizer: Option<TokenizerBridge>,
addr: std::net::SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let metrics = Arc::new(InferenceMetrics::new());
let router = create_router_with_metrics(engine, tokenizer, metrics);
serve_with_shutdown(router, addr, shutdown_signal()).await
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub max_queue_depth: usize,
pub request_timeout_seconds: u64,
pub bind_addr: std::net::SocketAddr,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
max_queue_depth: 128,
request_timeout_seconds: 60,
bind_addr: std::net::SocketAddr::from(([127, 0, 0, 1], 8080)),
}
}
}
pub struct QueueDepthTracker {
current: std::sync::atomic::AtomicUsize,
max_depth: usize,
}
impl QueueDepthTracker {
pub fn new(max_depth: usize) -> Self {
Self {
current: std::sync::atomic::AtomicUsize::new(0),
max_depth: max_depth.max(1),
}
}
pub fn try_acquire(&self) -> bool {
let current = self.current.load(std::sync::atomic::Ordering::Relaxed);
if current >= self.max_depth {
return false;
}
self.current
.compare_exchange(
current,
current + 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
)
.is_ok()
}
pub fn release(&self) {
self.current
.fetch_sub(1, std::sync::atomic::Ordering::Release);
}
pub fn depth(&self) -> usize {
self.current.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn max_depth(&self) -> usize {
self.max_depth
}
pub fn has_capacity(&self) -> bool {
self.depth() < self.max_depth
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_prompt_simple() {
let msgs = vec![ChatMessage {
role: "user".to_string(),
content: Some("Hello".to_string()),
tool_calls: None,
tool_call_id: None,
}];
let p = build_prompt(&msgs);
assert!(p.contains("<|im_start|>user\nHello<|im_end|>"));
assert!(p.ends_with("<|im_start|>assistant\n"));
}
#[test]
fn build_prompt_system_and_user() {
let msgs = vec![
ChatMessage {
role: "system".to_string(),
content: Some("You are a helpful assistant.".to_string()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".to_string(),
content: Some("Hi".to_string()),
tool_calls: None,
tool_call_id: None,
},
];
let p = build_prompt(&msgs);
assert!(p.contains("<|im_start|>system\nYou are a helpful assistant.<|im_end|>"));
assert!(p.contains("<|im_start|>user\nHi<|im_end|>"));
}
#[test]
fn build_prompt_multi_turn() {
let msgs = vec![
ChatMessage {
role: "user".to_string(),
content: Some("What is 2+2?".to_string()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "assistant".to_string(),
content: Some("4".to_string()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".to_string(),
content: Some("And 3+3?".to_string()),
tool_calls: None,
tool_call_id: None,
},
];
let p = build_prompt(&msgs);
assert!(p.contains("<|im_start|>assistant\n4<|im_end|>"));
assert!(p.contains("And 3+3?"));
}
#[test]
fn rand_id_is_nonempty() {
let id = rand_id();
assert!(!id.is_empty());
}
#[test]
fn default_max_tokens_value() {
assert_eq!(default_max_tokens(), 256);
}
#[test]
fn default_temperature_value() {
assert!((default_temperature() - 0.7).abs() < f32::EPSILON);
}
#[test]
fn create_router_builds_without_tokenizer() {
let config = oxibonsai_core::config::Qwen3Config::bonsai_8b();
let params = crate::sampling::SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
let _router = create_router(engine, None);
}
#[test]
fn create_router_with_shared_metrics() {
let config = oxibonsai_core::config::Qwen3Config::bonsai_8b();
let params = crate::sampling::SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
let metrics = Arc::new(InferenceMetrics::new());
let _router = create_router_with_metrics(engine, None, Arc::clone(&metrics));
assert_eq!(metrics.requests_total.get(), 0);
}
#[test]
fn server_config_default() {
let config = ServerConfig::default();
assert_eq!(config.max_queue_depth, 128);
assert_eq!(config.request_timeout_seconds, 60);
assert_eq!(
config.bind_addr,
std::net::SocketAddr::from(([127, 0, 0, 1], 8080))
);
}
#[test]
fn queue_depth_tracker_basic() {
let tracker = QueueDepthTracker::new(3);
assert_eq!(tracker.depth(), 0);
assert_eq!(tracker.max_depth(), 3);
assert!(tracker.has_capacity());
assert!(tracker.try_acquire());
assert_eq!(tracker.depth(), 1);
assert!(tracker.try_acquire());
assert_eq!(tracker.depth(), 2);
assert!(tracker.try_acquire());
assert_eq!(tracker.depth(), 3);
assert!(!tracker.has_capacity());
assert!(!tracker.try_acquire());
tracker.release();
assert_eq!(tracker.depth(), 2);
assert!(tracker.has_capacity());
assert!(tracker.try_acquire());
}
#[test]
fn queue_depth_tracker_min_capacity() {
let tracker = QueueDepthTracker::new(0);
assert_eq!(tracker.max_depth(), 1);
assert!(tracker.try_acquire());
assert!(!tracker.try_acquire());
}
}