1use axum::extract::State;
16use axum::http::{HeaderMap, HeaderValue, StatusCode};
17use axum::response::{
18 sse::{Event, Sse},
19 IntoResponse, Json, Response,
20};
21use axum::Router;
22use serde::{Deserialize, Serialize};
23use std::convert::Infallible;
24use std::sync::Arc;
25use tokio::sync::Mutex;
26use tokio_stream::StreamExt;
27
28use crate::engine::InferenceEngine;
29use crate::metrics::InferenceMetrics;
30use crate::request_id::RequestId;
31use crate::tokenizer_bridge::TokenizerBridge;
32
33pub const REQUEST_ID_HEADER: &str = "x-request-id";
37
38pub fn resolve_request_id(headers: &HeaderMap) -> RequestId {
46 if let Some(v) = headers.get(REQUEST_ID_HEADER) {
47 if let Ok(s) = v.to_str() {
48 if let Some(id) = RequestId::from_uuid(s).or_else(|| RequestId::from_hex(s)) {
49 return id;
50 }
51 }
52 }
53 RequestId::new()
54}
55
56pub fn request_id_header_map(id: RequestId) -> HeaderMap {
59 let mut headers = HeaderMap::new();
60 if let Ok(value) = HeaderValue::from_str(&id.as_uuid()) {
61 headers.insert(REQUEST_ID_HEADER, value);
62 }
63 headers
64}
65
66pub struct AppState {
68 engine: Mutex<InferenceEngine<'static>>,
69 tokenizer: Option<TokenizerBridge>,
70 metrics: Arc<InferenceMetrics>,
71}
72
73impl AppState {
74 pub async fn engine_lock(&self) -> tokio::sync::MutexGuard<'_, InferenceEngine<'static>> {
76 self.engine.lock().await
77 }
78
79 pub fn tokenizer(&self) -> Option<&TokenizerBridge> {
81 self.tokenizer.as_ref()
82 }
83
84 pub fn metrics(&self) -> &Arc<InferenceMetrics> {
86 &self.metrics
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ChatMessage {
96 pub role: String,
98 #[serde(default, skip_serializing_if = "Option::is_none")]
100 pub content: Option<String>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub tool_calls: Option<Vec<crate::api_types::ToolCallResult>>,
104 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub tool_call_id: Option<String>,
107}
108
109impl ChatMessage {
110 pub fn text(role: impl Into<String>, content: impl Into<String>) -> Self {
112 Self {
113 role: role.into(),
114 content: Some(content.into()),
115 tool_calls: None,
116 tool_call_id: None,
117 }
118 }
119}
120
121#[derive(Debug, Deserialize)]
123pub struct ChatCompletionRequest {
124 pub messages: Vec<ChatMessage>,
126 #[serde(default = "default_max_tokens")]
128 pub max_tokens: usize,
129 #[serde(default = "default_temperature")]
131 pub temperature: f32,
132 #[serde(default)]
134 pub stream: bool,
135 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub tools: Option<Vec<crate::api_types::ToolDefinition>>,
138 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub tool_choice: Option<serde_json::Value>,
141}
142
143fn default_max_tokens() -> usize {
144 256
145}
146fn default_temperature() -> f32 {
147 0.7
148}
149
150#[derive(Debug, Serialize)]
152pub struct ChatCompletionResponse {
153 pub id: String,
154 pub object: String,
155 pub choices: Vec<ChatChoice>,
156 pub usage: Usage,
157}
158
159#[derive(Debug, Serialize)]
161pub struct Usage {
162 pub prompt_tokens: usize,
163 pub completion_tokens: usize,
164 pub total_tokens: usize,
165}
166
167#[derive(Debug, Serialize)]
169pub struct ChatChoice {
170 pub index: usize,
171 pub message: ChatMessage,
172 pub finish_reason: String,
173}
174
175#[derive(Serialize)]
177struct ChatCompletionChunk {
178 id: String,
179 object: String,
180 created: u64,
181 model: String,
182 choices: Vec<ChunkChoice>,
183}
184
185#[derive(Serialize)]
187struct ChunkChoice {
188 index: usize,
189 delta: ChunkDelta,
190 finish_reason: Option<String>,
191}
192
193#[derive(Serialize)]
195struct ChunkDelta {
196 #[serde(skip_serializing_if = "Option::is_none")]
197 role: Option<String>,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 content: Option<String>,
200}
201
202pub fn create_router(
204 engine: InferenceEngine<'static>,
205 tokenizer: Option<TokenizerBridge>,
206) -> Router {
207 create_router_with_metrics(engine, tokenizer, Arc::new(InferenceMetrics::new()))
208}
209
210pub fn create_router_with_metrics(
212 engine: InferenceEngine<'static>,
213 tokenizer: Option<TokenizerBridge>,
214 metrics: Arc<InferenceMetrics>,
215) -> Router {
216 let state = Arc::new(AppState {
217 engine: Mutex::new(engine),
218 tokenizer,
219 metrics,
220 });
221
222 let embeddings_router = crate::embeddings::create_embeddings_router(512);
225
226 Router::new()
227 .route(
228 "/v1/chat/completions",
229 axum::routing::post(chat_completions),
230 )
231 .route(
232 "/v1/chat/completions/extended",
233 axum::routing::post(crate::api_extensions::extended_chat_completions),
234 )
235 .route(
236 "/v1/completions",
237 axum::routing::post(crate::completions::create_completion),
238 )
239 .route("/v1/models", axum::routing::get(list_models))
240 .route("/health", axum::routing::get(health))
241 .route("/metrics", axum::routing::get(prometheus_metrics))
242 .with_state(state)
243 .merge(embeddings_router)
244}
245
246async fn health() -> &'static str {
247 "ok"
248}
249
250async fn prometheus_metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
252 let body = state.metrics.render_prometheus();
253 (
254 StatusCode::OK,
255 [("content-type", "text/plain; version=0.0.4; charset=utf-8")],
256 body,
257 )
258}
259
260async fn list_models() -> Json<serde_json::Value> {
261 Json(serde_json::json!({
262 "object": "list",
263 "data": [{
264 "id": "bonsai-8b",
265 "object": "model",
266 "owned_by": "oxibonsai"
267 }]
268 }))
269}
270
271#[tracing::instrument(skip(state, headers, body), fields(request_id))]
272async fn chat_completions(
273 State(state): State<Arc<AppState>>,
274 headers: HeaderMap,
275 Json(body): Json<ChatCompletionRequest>,
276) -> Result<Response, StatusCode> {
277 let request_id = resolve_request_id(&headers);
278 tracing::Span::current().record("request_id", tracing::field::display(&request_id));
279
280 let request_start = std::time::Instant::now();
281 state.metrics.requests_total.inc();
282 state.metrics.active_requests.inc();
283
284 let prompt_text = build_prompt(&body.messages);
286
287 let prompt_tokens = if let Some(tok) = &state.tokenizer {
289 tok.encode(&prompt_text).map_err(|_| {
290 state.metrics.errors_total.inc();
291 state.metrics.active_requests.dec();
292 StatusCode::INTERNAL_SERVER_ERROR
293 })?
294 } else {
295 vec![151644]
297 };
298
299 state
300 .metrics
301 .prompt_tokens_total
302 .inc_by(prompt_tokens.len() as u64);
303
304 let result = if body.stream {
305 chat_completions_stream(
307 Arc::clone(&state),
308 prompt_tokens,
309 body.max_tokens,
310 request_id,
311 )
312 .await
313 } else {
314 chat_completions_non_stream(
316 Arc::clone(&state),
317 prompt_tokens,
318 body.max_tokens,
319 request_id,
320 )
321 .await
322 };
323
324 let elapsed = request_start.elapsed().as_secs_f64();
325 state.metrics.request_duration_seconds.observe(elapsed);
326 state.metrics.active_requests.dec();
327
328 if result.is_err() {
329 state.metrics.errors_total.inc();
330 }
331
332 result
333}
334
335async fn chat_completions_non_stream(
337 state: Arc<AppState>,
338 prompt_tokens: Vec<u32>,
339 max_tokens: usize,
340 request_id: RequestId,
341) -> Result<Response, StatusCode> {
342 let prompt_len = prompt_tokens.len();
343
344 let mut engine = state.engine.lock().await;
345 let output_tokens = engine.generate(&prompt_tokens, max_tokens).map_err(|e| {
346 tracing::error!(error = %e, "generation failed");
347 StatusCode::INTERNAL_SERVER_ERROR
348 })?;
349
350 let completion_len = output_tokens.len();
351
352 state
354 .metrics
355 .tokens_generated_total
356 .inc_by(completion_len as u64);
357
358 let content = if let Some(tok) = &state.tokenizer {
360 tok.decode(&output_tokens)
361 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
362 } else {
363 format!("{output_tokens:?}")
364 };
365
366 let response = ChatCompletionResponse {
367 id: format!("chatcmpl-{}", rand_id()),
368 object: "chat.completion".to_string(),
369 choices: vec![ChatChoice {
370 index: 0,
371 message: ChatMessage {
372 role: "assistant".to_string(),
373 content: Some(content),
374 tool_calls: None,
375 tool_call_id: None,
376 },
377 finish_reason: "stop".to_string(),
378 }],
379 usage: Usage {
380 prompt_tokens: prompt_len,
381 completion_tokens: completion_len,
382 total_tokens: prompt_len + completion_len,
383 },
384 };
385
386 let headers = request_id_header_map(request_id);
387 Ok((headers, Json(response)).into_response())
388}
389
390async fn chat_completions_stream(
392 state: Arc<AppState>,
393 prompt_tokens: Vec<u32>,
394 max_tokens: usize,
395 request_id: RequestId,
396) -> Result<Response, StatusCode> {
397 let completion_id = format!("chatcmpl-{}", rand_id());
398 let created = std::time::SystemTime::now()
399 .duration_since(std::time::UNIX_EPOCH)
400 .unwrap_or_default()
401 .as_secs();
402
403 let (token_tx, token_rx) = tokio::sync::mpsc::unbounded_channel::<u32>();
404
405 let gen_state = Arc::clone(&state);
407 tokio::task::spawn_blocking(move || {
408 let rt = tokio::runtime::Handle::current();
409 let mut engine = rt.block_on(gen_state.engine.lock());
410 let _result = engine.generate_streaming(&prompt_tokens, max_tokens, &token_tx);
411 });
413
414 let id_for_stream = completion_id;
416 let state_for_stream = Arc::clone(&state);
417
418 let role_chunk = ChatCompletionChunk {
420 id: id_for_stream.clone(),
421 object: "chat.completion.chunk".to_string(),
422 created,
423 model: "bonsai-8b".to_string(),
424 choices: vec![ChunkChoice {
425 index: 0,
426 delta: ChunkDelta {
427 role: Some("assistant".to_string()),
428 content: None,
429 },
430 finish_reason: None,
431 }],
432 };
433
434 let role_event = match serde_json::to_string(&role_chunk) {
435 Ok(json) => json,
436 Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
437 };
438
439 let id_clone = id_for_stream.clone();
440
441 let token_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(token_rx);
443
444 let mut stream_state = state_for_stream
449 .tokenizer
450 .as_ref()
451 .map(|t| t.new_decode_stream(true));
452
453 let content_stream = token_stream.filter_map(move |token_id| {
454 let text = match (&state_for_stream.tokenizer, stream_state.as_mut()) {
455 (Some(tok), Some(state)) => match tok.step_decode(state, token_id) {
456 Ok(Some(txt)) => txt,
457 Ok(None) => return None,
458 Err(_) => format!("[{token_id}]"),
459 },
460 _ => format!("[{token_id}]"),
461 };
462
463 let chunk = ChatCompletionChunk {
464 id: id_clone.clone(),
465 object: "chat.completion.chunk".to_string(),
466 created,
467 model: "bonsai-8b".to_string(),
468 choices: vec![ChunkChoice {
469 index: 0,
470 delta: ChunkDelta {
471 role: None,
472 content: Some(text),
473 },
474 finish_reason: None,
475 }],
476 };
477
478 Some(serde_json::to_string(&chunk).unwrap_or_default())
479 });
480
481 let finish_chunk = ChatCompletionChunk {
483 id: id_for_stream,
484 object: "chat.completion.chunk".to_string(),
485 created,
486 model: "bonsai-8b".to_string(),
487 choices: vec![ChunkChoice {
488 index: 0,
489 delta: ChunkDelta {
490 role: None,
491 content: None,
492 },
493 finish_reason: Some("stop".to_string()),
494 }],
495 };
496 let finish_json = serde_json::to_string(&finish_chunk).unwrap_or_default();
497
498 let role_stream = tokio_stream::once(role_event);
500
501 let full_stream = role_stream
502 .chain(content_stream)
503 .chain(tokio_stream::once(finish_json))
504 .map(|json_str| -> Result<Event, Infallible> { Ok(Event::default().data(json_str)) })
505 .chain(tokio_stream::once(Ok(Event::default().data("[DONE]"))));
506
507 let headers = request_id_header_map(request_id);
508 Ok((headers, Sse::new(full_stream)).into_response())
509}
510
511fn build_prompt(messages: &[ChatMessage]) -> String {
515 let mut prompt = String::new();
516 for msg in messages {
517 let text = match msg.content.as_deref() {
518 Some(t) => t,
519 None => continue,
520 };
521 match msg.role.as_str() {
522 "system" => {
523 prompt.push_str("<|im_start|>system\n");
524 prompt.push_str(text);
525 prompt.push_str("<|im_end|>\n");
526 }
527 "user" => {
528 prompt.push_str("<|im_start|>user\n");
529 prompt.push_str(text);
530 prompt.push_str("<|im_end|>\n");
531 }
532 "assistant" => {
533 prompt.push_str("<|im_start|>assistant\n");
534 prompt.push_str(text);
535 prompt.push_str("<|im_end|>\n");
536 }
537 _ => {
538 prompt.push_str(text);
539 prompt.push('\n');
540 }
541 }
542 }
543 prompt.push_str("<|im_start|>assistant\n");
545 prompt
546}
547
548fn rand_id() -> String {
550 let ts = std::time::SystemTime::now()
551 .duration_since(std::time::UNIX_EPOCH)
552 .unwrap_or_default()
553 .as_nanos();
554 format!("{ts:x}")
555}
556
557pub async fn serve_with_shutdown(
565 router: Router,
566 addr: std::net::SocketAddr,
567 shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
568) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
569 let listener = tokio::net::TcpListener::bind(addr).await?;
570 tracing::info!(%addr, "server listening");
571
572 axum::serve(listener, router)
573 .with_graceful_shutdown(shutdown_signal)
574 .await?;
575
576 tracing::info!("server shut down gracefully");
577 Ok(())
578}
579
580pub async fn shutdown_signal() {
585 let ctrl_c = async {
586 tokio::signal::ctrl_c()
587 .await
588 .expect("failed to install Ctrl+C handler");
589 };
590
591 #[cfg(unix)]
592 let terminate = async {
593 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
594 .expect("failed to install SIGTERM handler")
595 .recv()
596 .await;
597 };
598
599 #[cfg(not(unix))]
600 let terminate = std::future::pending::<()>();
601
602 tokio::select! {
603 () = ctrl_c => {
604 tracing::info!("received Ctrl+C, initiating shutdown");
605 }
606 () = terminate => {
607 tracing::info!("received SIGTERM, initiating shutdown");
608 }
609 }
610}
611
612pub async fn create_server(
616 engine: InferenceEngine<'static>,
617 tokenizer: Option<TokenizerBridge>,
618 addr: std::net::SocketAddr,
619) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
620 let metrics = Arc::new(InferenceMetrics::new());
621 let router = create_router_with_metrics(engine, tokenizer, metrics);
622 serve_with_shutdown(router, addr, shutdown_signal()).await
623}
624
625#[derive(Debug, Clone)]
629pub struct ServerConfig {
630 pub max_queue_depth: usize,
632 pub request_timeout_seconds: u64,
634 pub bind_addr: std::net::SocketAddr,
636}
637
638impl Default for ServerConfig {
639 fn default() -> Self {
640 Self {
641 max_queue_depth: 128,
642 request_timeout_seconds: 60,
643 bind_addr: std::net::SocketAddr::from(([127, 0, 0, 1], 8080)),
644 }
645 }
646}
647
648pub struct QueueDepthTracker {
653 current: std::sync::atomic::AtomicUsize,
654 max_depth: usize,
655}
656
657impl QueueDepthTracker {
658 pub fn new(max_depth: usize) -> Self {
660 Self {
661 current: std::sync::atomic::AtomicUsize::new(0),
662 max_depth: max_depth.max(1),
663 }
664 }
665
666 pub fn try_acquire(&self) -> bool {
668 let current = self.current.load(std::sync::atomic::Ordering::Relaxed);
669 if current >= self.max_depth {
670 return false;
671 }
672 self.current
674 .compare_exchange(
675 current,
676 current + 1,
677 std::sync::atomic::Ordering::AcqRel,
678 std::sync::atomic::Ordering::Relaxed,
679 )
680 .is_ok()
681 }
682
683 pub fn release(&self) {
685 self.current
686 .fetch_sub(1, std::sync::atomic::Ordering::Release);
687 }
688
689 pub fn depth(&self) -> usize {
691 self.current.load(std::sync::atomic::Ordering::Relaxed)
692 }
693
694 pub fn max_depth(&self) -> usize {
696 self.max_depth
697 }
698
699 pub fn has_capacity(&self) -> bool {
701 self.depth() < self.max_depth
702 }
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708
709 #[test]
710 fn build_prompt_simple() {
711 let msgs = vec![ChatMessage {
712 role: "user".to_string(),
713 content: Some("Hello".to_string()),
714 tool_calls: None,
715 tool_call_id: None,
716 }];
717 let p = build_prompt(&msgs);
718 assert!(p.contains("<|im_start|>user\nHello<|im_end|>"));
719 assert!(p.ends_with("<|im_start|>assistant\n"));
720 }
721
722 #[test]
723 fn build_prompt_system_and_user() {
724 let msgs = vec![
725 ChatMessage {
726 role: "system".to_string(),
727 content: Some("You are a helpful assistant.".to_string()),
728 tool_calls: None,
729 tool_call_id: None,
730 },
731 ChatMessage {
732 role: "user".to_string(),
733 content: Some("Hi".to_string()),
734 tool_calls: None,
735 tool_call_id: None,
736 },
737 ];
738 let p = build_prompt(&msgs);
739 assert!(p.contains("<|im_start|>system\nYou are a helpful assistant.<|im_end|>"));
740 assert!(p.contains("<|im_start|>user\nHi<|im_end|>"));
741 }
742
743 #[test]
744 fn build_prompt_multi_turn() {
745 let msgs = vec![
746 ChatMessage {
747 role: "user".to_string(),
748 content: Some("What is 2+2?".to_string()),
749 tool_calls: None,
750 tool_call_id: None,
751 },
752 ChatMessage {
753 role: "assistant".to_string(),
754 content: Some("4".to_string()),
755 tool_calls: None,
756 tool_call_id: None,
757 },
758 ChatMessage {
759 role: "user".to_string(),
760 content: Some("And 3+3?".to_string()),
761 tool_calls: None,
762 tool_call_id: None,
763 },
764 ];
765 let p = build_prompt(&msgs);
766 assert!(p.contains("<|im_start|>assistant\n4<|im_end|>"));
767 assert!(p.contains("And 3+3?"));
768 }
769
770 #[test]
771 fn rand_id_is_nonempty() {
772 let id = rand_id();
773 assert!(!id.is_empty());
774 }
775
776 #[test]
777 fn default_max_tokens_value() {
778 assert_eq!(default_max_tokens(), 256);
779 }
780
781 #[test]
782 fn default_temperature_value() {
783 assert!((default_temperature() - 0.7).abs() < f32::EPSILON);
784 }
785
786 #[test]
787 fn create_router_builds_without_tokenizer() {
788 let config = oxibonsai_core::config::Qwen3Config::bonsai_8b();
789 let params = crate::sampling::SamplingParams::default();
790 let engine = InferenceEngine::new(config, params, 42);
791 let _router = create_router(engine, None);
792 }
793
794 #[test]
795 fn create_router_with_shared_metrics() {
796 let config = oxibonsai_core::config::Qwen3Config::bonsai_8b();
797 let params = crate::sampling::SamplingParams::default();
798 let engine = InferenceEngine::new(config, params, 42);
799 let metrics = Arc::new(InferenceMetrics::new());
800 let _router = create_router_with_metrics(engine, None, Arc::clone(&metrics));
801 assert_eq!(metrics.requests_total.get(), 0);
803 }
804
805 #[test]
808 fn server_config_default() {
809 let config = ServerConfig::default();
810 assert_eq!(config.max_queue_depth, 128);
811 assert_eq!(config.request_timeout_seconds, 60);
812 assert_eq!(
813 config.bind_addr,
814 std::net::SocketAddr::from(([127, 0, 0, 1], 8080))
815 );
816 }
817
818 #[test]
821 fn queue_depth_tracker_basic() {
822 let tracker = QueueDepthTracker::new(3);
823 assert_eq!(tracker.depth(), 0);
824 assert_eq!(tracker.max_depth(), 3);
825 assert!(tracker.has_capacity());
826
827 assert!(tracker.try_acquire());
828 assert_eq!(tracker.depth(), 1);
829 assert!(tracker.try_acquire());
830 assert_eq!(tracker.depth(), 2);
831 assert!(tracker.try_acquire());
832 assert_eq!(tracker.depth(), 3);
833 assert!(!tracker.has_capacity());
834
835 assert!(!tracker.try_acquire());
837
838 tracker.release();
839 assert_eq!(tracker.depth(), 2);
840 assert!(tracker.has_capacity());
841 assert!(tracker.try_acquire());
842 }
843
844 #[test]
845 fn queue_depth_tracker_min_capacity() {
846 let tracker = QueueDepthTracker::new(0);
847 assert_eq!(tracker.max_depth(), 1);
848 assert!(tracker.try_acquire());
849 assert!(!tracker.try_acquire());
850 }
851}