Skip to main content

oxibonsai_runtime/
server.rs

1//! OpenAI-compatible chat completions server.
2//!
3//! Provides an Axum-based HTTP server with the following endpoints:
4//!
5//! | Method | Path | Description |
6//! |--------|------|-------------|
7//! | POST | `/v1/chat/completions` | Chat completion (streaming and non-streaming) |
8//! | GET | `/v1/models` | List available models |
9//! | GET | `/health` | Liveness probe |
10//! | GET | `/metrics` | Prometheus text exposition |
11//!
12//! Use [`create_router`] or [`create_router_with_metrics`] to build
13//! the Axum router, then serve it with `axum::serve`.
14
15use axum::extract::State;
16use axum::http::{HeaderMap, HeaderValue, StatusCode};
17use axum::response::{
18    sse::{Event, Sse},
19    IntoResponse, Json, Response,
20};
21use axum::Router;
22use serde::{Deserialize, Serialize};
23use std::convert::Infallible;
24use std::sync::Arc;
25use tokio::sync::Mutex;
26use tokio_stream::StreamExt;
27
28use crate::engine::InferenceEngine;
29use crate::metrics::InferenceMetrics;
30use crate::request_id::RequestId;
31use crate::tokenizer_bridge::TokenizerBridge;
32
33/// Header name used for end-to-end request correlation. Request handlers
34/// echo whatever the client supplied in the response, or generate a fresh
35/// UUIDv4-style id when the header is absent.
36pub const REQUEST_ID_HEADER: &str = "x-request-id";
37
38/// Resolve a [`RequestId`] from an incoming request header, falling back to
39/// a freshly generated id when none is supplied or when the supplied value
40/// is malformed (in either case we still want a usable id to thread through
41/// tracing spans and the response).
42///
43/// Accepts both the 32-hex form (no dashes) and the 36-char UUID form
44/// (`8-4-4-4-12`).
45pub fn resolve_request_id(headers: &HeaderMap) -> RequestId {
46    if let Some(v) = headers.get(REQUEST_ID_HEADER) {
47        if let Ok(s) = v.to_str() {
48            if let Some(id) = RequestId::from_uuid(s).or_else(|| RequestId::from_hex(s)) {
49                return id;
50            }
51        }
52    }
53    RequestId::new()
54}
55
56/// Build response headers for a [`RequestId`]. Returns a `HeaderMap` with the
57/// `X-Request-ID` set to the canonical 36-char UUID form.
58pub fn request_id_header_map(id: RequestId) -> HeaderMap {
59    let mut headers = HeaderMap::new();
60    if let Ok(value) = HeaderValue::from_str(&id.as_uuid()) {
61        headers.insert(REQUEST_ID_HEADER, value);
62    }
63    headers
64}
65
66/// Server state.
67pub struct AppState {
68    engine: Mutex<InferenceEngine<'static>>,
69    tokenizer: Option<TokenizerBridge>,
70    metrics: Arc<InferenceMetrics>,
71}
72
73impl AppState {
74    /// Acquire a mutable guard over the inference engine.
75    pub async fn engine_lock(&self) -> tokio::sync::MutexGuard<'_, InferenceEngine<'static>> {
76        self.engine.lock().await
77    }
78
79    /// Access the optional tokenizer.
80    pub fn tokenizer(&self) -> Option<&TokenizerBridge> {
81        self.tokenizer.as_ref()
82    }
83
84    /// Access the shared metrics instance.
85    pub fn metrics(&self) -> &Arc<InferenceMetrics> {
86        &self.metrics
87    }
88}
89
90/// Chat message (OpenAI-compatible).
91///
92/// `content` is `Option<String>` so that it can be `null` when `tool_calls`
93/// is set (the model produced a tool call instead of a text reply).
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ChatMessage {
96    /// Role of the message sender: `"system"`, `"user"`, `"assistant"`, `"tool"`.
97    pub role: String,
98    /// Text content of the message.  `null` when the assistant returns tool calls.
99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub content: Option<String>,
101    /// Tool calls produced by the model (assistant role only).
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub tool_calls: Option<Vec<crate::api_types::ToolCallResult>>,
104    /// ID of the tool call being responded to (tool role only).
105    #[serde(default, skip_serializing_if = "Option::is_none")]
106    pub tool_call_id: Option<String>,
107}
108
109impl ChatMessage {
110    /// Construct a plain text assistant or user message.
111    pub fn text(role: impl Into<String>, content: impl Into<String>) -> Self {
112        Self {
113            role: role.into(),
114            content: Some(content.into()),
115            tool_calls: None,
116            tool_call_id: None,
117        }
118    }
119}
120
121/// Chat completion request.
122#[derive(Debug, Deserialize)]
123pub struct ChatCompletionRequest {
124    /// Conversation history.
125    pub messages: Vec<ChatMessage>,
126    /// Maximum tokens to generate.
127    #[serde(default = "default_max_tokens")]
128    pub max_tokens: usize,
129    /// Sampling temperature.
130    #[serde(default = "default_temperature")]
131    pub temperature: f32,
132    /// Whether to stream the response as SSE.
133    #[serde(default)]
134    pub stream: bool,
135    /// Tools available to the model.
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub tools: Option<Vec<crate::api_types::ToolDefinition>>,
138    /// Tool choice: `"auto"`, `"none"`, or a specific function selector.
139    #[serde(default, skip_serializing_if = "Option::is_none")]
140    pub tool_choice: Option<serde_json::Value>,
141}
142
143fn default_max_tokens() -> usize {
144    256
145}
146fn default_temperature() -> f32 {
147    0.7
148}
149
150/// Chat completion response.
151#[derive(Debug, Serialize)]
152pub struct ChatCompletionResponse {
153    pub id: String,
154    pub object: String,
155    pub choices: Vec<ChatChoice>,
156    pub usage: Usage,
157}
158
159/// Token usage info.
160#[derive(Debug, Serialize)]
161pub struct Usage {
162    pub prompt_tokens: usize,
163    pub completion_tokens: usize,
164    pub total_tokens: usize,
165}
166
167/// A choice in the completion response.
168#[derive(Debug, Serialize)]
169pub struct ChatChoice {
170    pub index: usize,
171    pub message: ChatMessage,
172    pub finish_reason: String,
173}
174
175/// SSE streaming chunk (OpenAI-compatible).
176#[derive(Serialize)]
177struct ChatCompletionChunk {
178    id: String,
179    object: String,
180    created: u64,
181    model: String,
182    choices: Vec<ChunkChoice>,
183}
184
185/// A choice in the SSE streaming chunk.
186#[derive(Serialize)]
187struct ChunkChoice {
188    index: usize,
189    delta: ChunkDelta,
190    finish_reason: Option<String>,
191}
192
193/// Delta content in a streaming chunk.
194#[derive(Serialize)]
195struct ChunkDelta {
196    #[serde(skip_serializing_if = "Option::is_none")]
197    role: Option<String>,
198    #[serde(skip_serializing_if = "Option::is_none")]
199    content: Option<String>,
200}
201
202/// Create the Axum router.
203pub fn create_router(
204    engine: InferenceEngine<'static>,
205    tokenizer: Option<TokenizerBridge>,
206) -> Router {
207    create_router_with_metrics(engine, tokenizer, Arc::new(InferenceMetrics::new()))
208}
209
210/// Create the Axum router with a shared metrics instance.
211pub fn create_router_with_metrics(
212    engine: InferenceEngine<'static>,
213    tokenizer: Option<TokenizerBridge>,
214    metrics: Arc<InferenceMetrics>,
215) -> Router {
216    let state = Arc::new(AppState {
217        engine: Mutex::new(engine),
218        tokenizer,
219        metrics,
220    });
221
222    // The embeddings router carries its own Arc<EmbeddingAppState>; merge it
223    // before attaching the main AppState so the states don't conflict.
224    let embeddings_router = crate::embeddings::create_embeddings_router(512);
225
226    Router::new()
227        .route(
228            "/v1/chat/completions",
229            axum::routing::post(chat_completions),
230        )
231        .route(
232            "/v1/chat/completions/extended",
233            axum::routing::post(crate::api_extensions::extended_chat_completions),
234        )
235        .route(
236            "/v1/completions",
237            axum::routing::post(crate::completions::create_completion),
238        )
239        .route("/v1/models", axum::routing::get(list_models))
240        .route("/health", axum::routing::get(health))
241        .route("/metrics", axum::routing::get(prometheus_metrics))
242        .with_state(state)
243        .merge(embeddings_router)
244}
245
246async fn health() -> &'static str {
247    "ok"
248}
249
250/// Prometheus metrics endpoint.
251async fn prometheus_metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
252    let body = state.metrics.render_prometheus();
253    (
254        StatusCode::OK,
255        [("content-type", "text/plain; version=0.0.4; charset=utf-8")],
256        body,
257    )
258}
259
260async fn list_models() -> Json<serde_json::Value> {
261    Json(serde_json::json!({
262        "object": "list",
263        "data": [{
264            "id": "bonsai-8b",
265            "object": "model",
266            "owned_by": "oxibonsai"
267        }]
268    }))
269}
270
271#[tracing::instrument(skip(state, headers, body), fields(request_id))]
272async fn chat_completions(
273    State(state): State<Arc<AppState>>,
274    headers: HeaderMap,
275    Json(body): Json<ChatCompletionRequest>,
276) -> Result<Response, StatusCode> {
277    let request_id = resolve_request_id(&headers);
278    tracing::Span::current().record("request_id", tracing::field::display(&request_id));
279
280    let request_start = std::time::Instant::now();
281    state.metrics.requests_total.inc();
282    state.metrics.active_requests.inc();
283
284    // Build prompt from messages
285    let prompt_text = build_prompt(&body.messages);
286
287    // Tokenize
288    let prompt_tokens = if let Some(tok) = &state.tokenizer {
289        tok.encode(&prompt_text).map_err(|_| {
290            state.metrics.errors_total.inc();
291            state.metrics.active_requests.dec();
292            StatusCode::INTERNAL_SERVER_ERROR
293        })?
294    } else {
295        // Fallback: single start token
296        vec![151644]
297    };
298
299    state
300        .metrics
301        .prompt_tokens_total
302        .inc_by(prompt_tokens.len() as u64);
303
304    let result = if body.stream {
305        // ── SSE streaming mode ──
306        chat_completions_stream(
307            Arc::clone(&state),
308            prompt_tokens,
309            body.max_tokens,
310            request_id,
311        )
312        .await
313    } else {
314        // ── Non-streaming mode ──
315        chat_completions_non_stream(
316            Arc::clone(&state),
317            prompt_tokens,
318            body.max_tokens,
319            request_id,
320        )
321        .await
322    };
323
324    let elapsed = request_start.elapsed().as_secs_f64();
325    state.metrics.request_duration_seconds.observe(elapsed);
326    state.metrics.active_requests.dec();
327
328    if result.is_err() {
329        state.metrics.errors_total.inc();
330    }
331
332    result
333}
334
335/// Non-streaming chat completion handler.
336async fn chat_completions_non_stream(
337    state: Arc<AppState>,
338    prompt_tokens: Vec<u32>,
339    max_tokens: usize,
340    request_id: RequestId,
341) -> Result<Response, StatusCode> {
342    let prompt_len = prompt_tokens.len();
343
344    let mut engine = state.engine.lock().await;
345    let output_tokens = engine.generate(&prompt_tokens, max_tokens).map_err(|e| {
346        tracing::error!(error = %e, "generation failed");
347        StatusCode::INTERNAL_SERVER_ERROR
348    })?;
349
350    let completion_len = output_tokens.len();
351
352    // Record token metrics
353    state
354        .metrics
355        .tokens_generated_total
356        .inc_by(completion_len as u64);
357
358    // Decode
359    let content = if let Some(tok) = &state.tokenizer {
360        tok.decode(&output_tokens)
361            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
362    } else {
363        format!("{output_tokens:?}")
364    };
365
366    let response = ChatCompletionResponse {
367        id: format!("chatcmpl-{}", rand_id()),
368        object: "chat.completion".to_string(),
369        choices: vec![ChatChoice {
370            index: 0,
371            message: ChatMessage {
372                role: "assistant".to_string(),
373                content: Some(content),
374                tool_calls: None,
375                tool_call_id: None,
376            },
377            finish_reason: "stop".to_string(),
378        }],
379        usage: Usage {
380            prompt_tokens: prompt_len,
381            completion_tokens: completion_len,
382            total_tokens: prompt_len + completion_len,
383        },
384    };
385
386    let headers = request_id_header_map(request_id);
387    Ok((headers, Json(response)).into_response())
388}
389
390/// SSE streaming chat completion handler.
391async fn chat_completions_stream(
392    state: Arc<AppState>,
393    prompt_tokens: Vec<u32>,
394    max_tokens: usize,
395    request_id: RequestId,
396) -> Result<Response, StatusCode> {
397    let completion_id = format!("chatcmpl-{}", rand_id());
398    let created = std::time::SystemTime::now()
399        .duration_since(std::time::UNIX_EPOCH)
400        .unwrap_or_default()
401        .as_secs();
402
403    let (token_tx, token_rx) = tokio::sync::mpsc::unbounded_channel::<u32>();
404
405    // Spawn generation task that locks the engine and streams tokens
406    let gen_state = Arc::clone(&state);
407    tokio::task::spawn_blocking(move || {
408        let rt = tokio::runtime::Handle::current();
409        let mut engine = rt.block_on(gen_state.engine.lock());
410        let _result = engine.generate_streaming(&prompt_tokens, max_tokens, &token_tx);
411        // token_tx is dropped here, closing the channel
412    });
413
414    // Build SSE stream from the token receiver
415    let id_for_stream = completion_id;
416    let state_for_stream = Arc::clone(&state);
417
418    // First, send a role delta
419    let role_chunk = ChatCompletionChunk {
420        id: id_for_stream.clone(),
421        object: "chat.completion.chunk".to_string(),
422        created,
423        model: "bonsai-8b".to_string(),
424        choices: vec![ChunkChoice {
425            index: 0,
426            delta: ChunkDelta {
427                role: Some("assistant".to_string()),
428                content: None,
429            },
430            finish_reason: None,
431        }],
432    };
433
434    let role_event = match serde_json::to_string(&role_chunk) {
435        Ok(json) => json,
436        Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
437    };
438
439    let id_clone = id_for_stream.clone();
440
441    // Convert token receiver into a stream of SSE events
442    let token_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(token_rx);
443
444    // Per-request streaming-decode state.  BPE tokens may straddle UTF-8
445    // codepoint boundaries (CJK, emoji), so we buffer through HF's
446    // step_decode_stream and only emit a chunk when a complete UTF-8 piece is
447    // ready.  Mid-codepoint tokens yield `Ok(None)` and are filtered out.
448    let mut stream_state = state_for_stream
449        .tokenizer
450        .as_ref()
451        .map(|t| t.new_decode_stream(true));
452
453    let content_stream = token_stream.filter_map(move |token_id| {
454        let text = match (&state_for_stream.tokenizer, stream_state.as_mut()) {
455            (Some(tok), Some(state)) => match tok.step_decode(state, token_id) {
456                Ok(Some(txt)) => txt,
457                Ok(None) => return None,
458                Err(_) => format!("[{token_id}]"),
459            },
460            _ => format!("[{token_id}]"),
461        };
462
463        let chunk = ChatCompletionChunk {
464            id: id_clone.clone(),
465            object: "chat.completion.chunk".to_string(),
466            created,
467            model: "bonsai-8b".to_string(),
468            choices: vec![ChunkChoice {
469                index: 0,
470                delta: ChunkDelta {
471                    role: None,
472                    content: Some(text),
473                },
474                finish_reason: None,
475            }],
476        };
477
478        Some(serde_json::to_string(&chunk).unwrap_or_default())
479    });
480
481    // Build finish chunk
482    let finish_chunk = ChatCompletionChunk {
483        id: id_for_stream,
484        object: "chat.completion.chunk".to_string(),
485        created,
486        model: "bonsai-8b".to_string(),
487        choices: vec![ChunkChoice {
488            index: 0,
489            delta: ChunkDelta {
490                role: None,
491                content: None,
492            },
493            finish_reason: Some("stop".to_string()),
494        }],
495    };
496    let finish_json = serde_json::to_string(&finish_chunk).unwrap_or_default();
497
498    // Prepend role event, append finish event and [DONE]
499    let role_stream = tokio_stream::once(role_event);
500
501    let full_stream = role_stream
502        .chain(content_stream)
503        .chain(tokio_stream::once(finish_json))
504        .map(|json_str| -> Result<Event, Infallible> { Ok(Event::default().data(json_str)) })
505        .chain(tokio_stream::once(Ok(Event::default().data("[DONE]"))));
506
507    let headers = request_id_header_map(request_id);
508    Ok((headers, Sse::new(full_stream)).into_response())
509}
510
511/// Build a simple prompt from chat messages.
512///
513/// Messages with `content = None` (e.g. tool-call turns) are skipped.
514fn build_prompt(messages: &[ChatMessage]) -> String {
515    let mut prompt = String::new();
516    for msg in messages {
517        let text = match msg.content.as_deref() {
518            Some(t) => t,
519            None => continue,
520        };
521        match msg.role.as_str() {
522            "system" => {
523                prompt.push_str("<|im_start|>system\n");
524                prompt.push_str(text);
525                prompt.push_str("<|im_end|>\n");
526            }
527            "user" => {
528                prompt.push_str("<|im_start|>user\n");
529                prompt.push_str(text);
530                prompt.push_str("<|im_end|>\n");
531            }
532            "assistant" => {
533                prompt.push_str("<|im_start|>assistant\n");
534                prompt.push_str(text);
535                prompt.push_str("<|im_end|>\n");
536            }
537            _ => {
538                prompt.push_str(text);
539                prompt.push('\n');
540            }
541        }
542    }
543    // Signal model to respond as assistant
544    prompt.push_str("<|im_start|>assistant\n");
545    prompt
546}
547
548/// Generate a short random-ish ID for completion responses.
549fn rand_id() -> String {
550    let ts = std::time::SystemTime::now()
551        .duration_since(std::time::UNIX_EPOCH)
552        .unwrap_or_default()
553        .as_nanos();
554    format!("{ts:x}")
555}
556
557// ─── Graceful shutdown ─────────────────────────────────────────────────
558
559/// Start server with graceful shutdown support.
560///
561/// Binds to `addr`, serves `router`, and shuts down cleanly when
562/// `shutdown_signal` completes. In-flight requests are given time
563/// to finish before the server exits.
564pub async fn serve_with_shutdown(
565    router: Router,
566    addr: std::net::SocketAddr,
567    shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
568) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
569    let listener = tokio::net::TcpListener::bind(addr).await?;
570    tracing::info!(%addr, "server listening");
571
572    axum::serve(listener, router)
573        .with_graceful_shutdown(shutdown_signal)
574        .await?;
575
576    tracing::info!("server shut down gracefully");
577    Ok(())
578}
579
580/// Create a shutdown signal that responds to SIGTERM and SIGINT (Ctrl+C).
581///
582/// Completes when either signal is received, allowing the server to
583/// begin its graceful shutdown procedure.
584pub async fn shutdown_signal() {
585    let ctrl_c = async {
586        tokio::signal::ctrl_c()
587            .await
588            .expect("failed to install Ctrl+C handler");
589    };
590
591    #[cfg(unix)]
592    let terminate = async {
593        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
594            .expect("failed to install SIGTERM handler")
595            .recv()
596            .await;
597    };
598
599    #[cfg(not(unix))]
600    let terminate = std::future::pending::<()>();
601
602    tokio::select! {
603        () = ctrl_c => {
604            tracing::info!("received Ctrl+C, initiating shutdown");
605        }
606        () = terminate => {
607            tracing::info!("received SIGTERM, initiating shutdown");
608        }
609    }
610}
611
612/// Create the full server setup: router + graceful shutdown future.
613///
614/// Returns a future that runs the server until a shutdown signal is received.
615pub async fn create_server(
616    engine: InferenceEngine<'static>,
617    tokenizer: Option<TokenizerBridge>,
618    addr: std::net::SocketAddr,
619) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
620    let metrics = Arc::new(InferenceMetrics::new());
621    let router = create_router_with_metrics(engine, tokenizer, metrics);
622    serve_with_shutdown(router, addr, shutdown_signal()).await
623}
624
625// ─── Request queue depth tracking ──────────────────────────────────────
626
627/// Server configuration with request management.
628#[derive(Debug, Clone)]
629pub struct ServerConfig {
630    /// Maximum number of queued requests before rejecting new ones.
631    pub max_queue_depth: usize,
632    /// Request timeout in seconds.
633    pub request_timeout_seconds: u64,
634    /// Address to bind to.
635    pub bind_addr: std::net::SocketAddr,
636}
637
638impl Default for ServerConfig {
639    fn default() -> Self {
640        Self {
641            max_queue_depth: 128,
642            request_timeout_seconds: 60,
643            bind_addr: std::net::SocketAddr::from(([127, 0, 0, 1], 8080)),
644        }
645    }
646}
647
648/// Request queue depth tracker.
649///
650/// Thread-safe counter for tracking how many requests are currently
651/// queued or in-flight. Used to implement backpressure.
652pub struct QueueDepthTracker {
653    current: std::sync::atomic::AtomicUsize,
654    max_depth: usize,
655}
656
657impl QueueDepthTracker {
658    /// Create a new tracker with the given maximum depth.
659    pub fn new(max_depth: usize) -> Self {
660        Self {
661            current: std::sync::atomic::AtomicUsize::new(0),
662            max_depth: max_depth.max(1),
663        }
664    }
665
666    /// Try to acquire a slot. Returns `true` if successful, `false` if queue is full.
667    pub fn try_acquire(&self) -> bool {
668        let current = self.current.load(std::sync::atomic::Ordering::Relaxed);
669        if current >= self.max_depth {
670            return false;
671        }
672        // CAS loop for correctness under contention
673        self.current
674            .compare_exchange(
675                current,
676                current + 1,
677                std::sync::atomic::Ordering::AcqRel,
678                std::sync::atomic::Ordering::Relaxed,
679            )
680            .is_ok()
681    }
682
683    /// Release a slot.
684    pub fn release(&self) {
685        self.current
686            .fetch_sub(1, std::sync::atomic::Ordering::Release);
687    }
688
689    /// Current queue depth.
690    pub fn depth(&self) -> usize {
691        self.current.load(std::sync::atomic::Ordering::Relaxed)
692    }
693
694    /// Maximum allowed depth.
695    pub fn max_depth(&self) -> usize {
696        self.max_depth
697    }
698
699    /// Whether the queue has capacity for more requests.
700    pub fn has_capacity(&self) -> bool {
701        self.depth() < self.max_depth
702    }
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708
709    #[test]
710    fn build_prompt_simple() {
711        let msgs = vec![ChatMessage {
712            role: "user".to_string(),
713            content: Some("Hello".to_string()),
714            tool_calls: None,
715            tool_call_id: None,
716        }];
717        let p = build_prompt(&msgs);
718        assert!(p.contains("<|im_start|>user\nHello<|im_end|>"));
719        assert!(p.ends_with("<|im_start|>assistant\n"));
720    }
721
722    #[test]
723    fn build_prompt_system_and_user() {
724        let msgs = vec![
725            ChatMessage {
726                role: "system".to_string(),
727                content: Some("You are a helpful assistant.".to_string()),
728                tool_calls: None,
729                tool_call_id: None,
730            },
731            ChatMessage {
732                role: "user".to_string(),
733                content: Some("Hi".to_string()),
734                tool_calls: None,
735                tool_call_id: None,
736            },
737        ];
738        let p = build_prompt(&msgs);
739        assert!(p.contains("<|im_start|>system\nYou are a helpful assistant.<|im_end|>"));
740        assert!(p.contains("<|im_start|>user\nHi<|im_end|>"));
741    }
742
743    #[test]
744    fn build_prompt_multi_turn() {
745        let msgs = vec![
746            ChatMessage {
747                role: "user".to_string(),
748                content: Some("What is 2+2?".to_string()),
749                tool_calls: None,
750                tool_call_id: None,
751            },
752            ChatMessage {
753                role: "assistant".to_string(),
754                content: Some("4".to_string()),
755                tool_calls: None,
756                tool_call_id: None,
757            },
758            ChatMessage {
759                role: "user".to_string(),
760                content: Some("And 3+3?".to_string()),
761                tool_calls: None,
762                tool_call_id: None,
763            },
764        ];
765        let p = build_prompt(&msgs);
766        assert!(p.contains("<|im_start|>assistant\n4<|im_end|>"));
767        assert!(p.contains("And 3+3?"));
768    }
769
770    #[test]
771    fn rand_id_is_nonempty() {
772        let id = rand_id();
773        assert!(!id.is_empty());
774    }
775
776    #[test]
777    fn default_max_tokens_value() {
778        assert_eq!(default_max_tokens(), 256);
779    }
780
781    #[test]
782    fn default_temperature_value() {
783        assert!((default_temperature() - 0.7).abs() < f32::EPSILON);
784    }
785
786    #[test]
787    fn create_router_builds_without_tokenizer() {
788        let config = oxibonsai_core::config::Qwen3Config::bonsai_8b();
789        let params = crate::sampling::SamplingParams::default();
790        let engine = InferenceEngine::new(config, params, 42);
791        let _router = create_router(engine, None);
792    }
793
794    #[test]
795    fn create_router_with_shared_metrics() {
796        let config = oxibonsai_core::config::Qwen3Config::bonsai_8b();
797        let params = crate::sampling::SamplingParams::default();
798        let engine = InferenceEngine::new(config, params, 42);
799        let metrics = Arc::new(InferenceMetrics::new());
800        let _router = create_router_with_metrics(engine, None, Arc::clone(&metrics));
801        // Metrics should be accessible from outside
802        assert_eq!(metrics.requests_total.get(), 0);
803    }
804
805    // ── ServerConfig tests ──
806
807    #[test]
808    fn server_config_default() {
809        let config = ServerConfig::default();
810        assert_eq!(config.max_queue_depth, 128);
811        assert_eq!(config.request_timeout_seconds, 60);
812        assert_eq!(
813            config.bind_addr,
814            std::net::SocketAddr::from(([127, 0, 0, 1], 8080))
815        );
816    }
817
818    // ── QueueDepthTracker tests ──
819
820    #[test]
821    fn queue_depth_tracker_basic() {
822        let tracker = QueueDepthTracker::new(3);
823        assert_eq!(tracker.depth(), 0);
824        assert_eq!(tracker.max_depth(), 3);
825        assert!(tracker.has_capacity());
826
827        assert!(tracker.try_acquire());
828        assert_eq!(tracker.depth(), 1);
829        assert!(tracker.try_acquire());
830        assert_eq!(tracker.depth(), 2);
831        assert!(tracker.try_acquire());
832        assert_eq!(tracker.depth(), 3);
833        assert!(!tracker.has_capacity());
834
835        // Should fail when full
836        assert!(!tracker.try_acquire());
837
838        tracker.release();
839        assert_eq!(tracker.depth(), 2);
840        assert!(tracker.has_capacity());
841        assert!(tracker.try_acquire());
842    }
843
844    #[test]
845    fn queue_depth_tracker_min_capacity() {
846        let tracker = QueueDepthTracker::new(0);
847        assert_eq!(tracker.max_depth(), 1);
848        assert!(tracker.try_acquire());
849        assert!(!tracker.try_acquire());
850    }
851}