1use std::net::SocketAddr;
7use std::sync::Arc;
8use std::time::Instant;
9
10use axum::extract::State;
11use axum::http::StatusCode;
12use axum::response::{IntoResponse, Response, Sse};
13use axum::routing::{get, post};
14use axum::{Json, Router};
15use futures::stream::StreamExt;
16use serde::{Deserialize, Serialize};
17use tokio::sync::RwLock;
18use tower_http::cors::CorsLayer;
19use tower_http::trace::TraceLayer;
20
21use abaddon::{Engine, EngineConfig, InferenceEngine};
22use infernum_core::{GenerateRequest, Result, SamplingParams};
23
24use crate::openai::{
25 ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, CompletionChoice,
26 CompletionRequest, CompletionResponse, EmbeddingData, EmbeddingInput, EmbeddingRequest,
27 EmbeddingResponse, EmbeddingUsage, ModelObject, ModelsResponse, Usage,
28};
29
30#[derive(Debug, Clone)]
32pub struct ServerConfig {
33 pub addr: SocketAddr,
35 pub cors: bool,
37 pub model: Option<String>,
39 pub max_concurrent_requests: usize,
41}
42
43impl Default for ServerConfig {
44 fn default() -> Self {
45 Self {
46 addr: "0.0.0.0:8080".parse().unwrap(),
47 cors: true,
48 model: None,
49 max_concurrent_requests: 64,
50 }
51 }
52}
53
54impl ServerConfig {
55 pub fn builder() -> ServerConfigBuilder {
57 ServerConfigBuilder::default()
58 }
59}
60
61#[derive(Debug, Default)]
63pub struct ServerConfigBuilder {
64 addr: Option<SocketAddr>,
65 cors: Option<bool>,
66 model: Option<String>,
67 max_concurrent_requests: Option<usize>,
68}
69
70impl ServerConfigBuilder {
71 pub fn addr(mut self, addr: SocketAddr) -> Self {
73 self.addr = Some(addr);
74 self
75 }
76
77 pub fn cors(mut self, enabled: bool) -> Self {
79 self.cors = Some(enabled);
80 self
81 }
82
83 pub fn model(mut self, model: impl Into<String>) -> Self {
85 self.model = Some(model.into());
86 self
87 }
88
89 pub fn max_concurrent_requests(mut self, max: usize) -> Self {
91 self.max_concurrent_requests = Some(max);
92 self
93 }
94
95 pub fn build(self) -> ServerConfig {
97 ServerConfig {
98 addr: self.addr.unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap()),
99 cors: self.cors.unwrap_or(true),
100 model: self.model,
101 max_concurrent_requests: self.max_concurrent_requests.unwrap_or(64),
102 }
103 }
104}
105
106pub struct AppState {
108 pub engine: RwLock<Option<Arc<Engine>>>,
110 pub config: ServerConfig,
112 pub start_time: Instant,
114}
115
116impl AppState {
117 pub fn new(config: ServerConfig) -> Self {
119 Self {
120 engine: RwLock::new(None),
121 config,
122 start_time: Instant::now(),
123 }
124 }
125
126 pub fn with_engine(config: ServerConfig, engine: Engine) -> Self {
128 Self {
129 engine: RwLock::new(Some(Arc::new(engine))),
130 config,
131 start_time: Instant::now(),
132 }
133 }
134}
135
136pub struct Server {
138 config: ServerConfig,
139 state: Arc<AppState>,
140}
141
142impl Server {
143 pub fn new(config: ServerConfig) -> Self {
145 let state = Arc::new(AppState::new(config.clone()));
146 Self { config, state }
147 }
148
149 pub fn with_engine(config: ServerConfig, engine: Engine) -> Self {
151 let state = Arc::new(AppState::with_engine(config.clone(), engine));
152 Self { config, state }
153 }
154
155 fn router(&self) -> Router {
157 let mut router = Router::new()
158 .route("/health", get(health))
160 .route("/ready", get(ready))
161 .route("/v1/models", get(list_models))
163 .route("/v1/chat/completions", post(chat_completions))
164 .route("/v1/completions", post(completions))
165 .route("/api/models/load", post(load_model))
169 .route("/api/models/unload", post(unload_model))
170 .route("/api/status", get(server_status))
171 .with_state(self.state.clone());
172
173 router = router.layer(TraceLayer::new_for_http());
175
176 if self.config.cors {
177 router = router.layer(CorsLayer::permissive());
178 }
179
180 router
181 }
182
183 pub async fn load_model(&self, model_source: &str) -> Result<()> {
185 tracing::info!(model = %model_source, "Loading model");
186
187 let engine_config = EngineConfig::builder()
188 .model(model_source)
189 .build()
190 .map_err(|e| infernum_core::Error::Internal { message: e })?;
191
192 let engine = Engine::new(engine_config).await?;
193 let mut engine_guard = self.state.engine.write().await;
194 *engine_guard = Some(Arc::new(engine));
195
196 tracing::info!(model = %model_source, "Model loaded successfully");
197 Ok(())
198 }
199
200 pub async fn run(self) -> Result<()> {
206 if let Some(model) = &self.config.model {
208 self.load_model(model).await?;
209 tracing::info!(model = %model, "Model loaded and ready for inference");
210 } else {
211 tracing::warn!("=======================================================");
212 tracing::warn!(" SERVER STARTED WITHOUT A MODEL");
213 tracing::warn!(" All inference requests will fail until a model is loaded.");
214 tracing::warn!(" ");
215 tracing::warn!(" To load a model, either:");
216 tracing::warn!(" 1. Restart with: infernum serve --model <model>");
217 tracing::warn!(" 2. POST to /api/models/load with {{\"model\": \"<model>\"}}");
218 tracing::warn!("=======================================================");
219 }
220
221 let router = self.router();
222
223 tracing::info!(addr = %self.config.addr, "Starting Infernum server");
224 eprintln!(
225 "\n\x1b[32m✓\x1b[0m Server listening on http://{}",
226 self.config.addr
227 );
228 eprintln!(" Press Ctrl+C to stop\n");
229
230 let listener = tokio::net::TcpListener::bind(self.config.addr)
231 .await
232 .map_err(infernum_core::Error::Io)?;
233
234 let shutdown_signal = async {
236 let ctrl_c = async {
237 tokio::signal::ctrl_c()
238 .await
239 .expect("Failed to install Ctrl+C handler");
240 };
241
242 #[cfg(unix)]
243 let terminate = async {
244 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
245 .expect("Failed to install signal handler")
246 .recv()
247 .await;
248 };
249
250 #[cfg(not(unix))]
251 let terminate = std::future::pending::<()>();
252
253 tokio::select! {
254 () = ctrl_c => {
255 eprintln!("\n\x1b[33m⚡\x1b[0m Received Ctrl+C, shutting down gracefully...");
256 },
257 () = terminate => {
258 eprintln!("\n\x1b[33m⚡\x1b[0m Received SIGTERM, shutting down gracefully...");
259 },
260 }
261 };
262
263 axum::serve(listener, router)
264 .with_graceful_shutdown(shutdown_signal)
265 .await
266 .map_err(|e| infernum_core::Error::Internal {
267 message: e.to_string(),
268 })?;
269
270 tracing::info!("Server shutdown complete");
271 eprintln!("\x1b[32m✓\x1b[0m Server stopped");
272
273 Ok(())
274 }
275}
276
277#[derive(Debug, Serialize)]
280struct ErrorResponse {
281 error: ErrorDetail,
282}
283
284#[derive(Debug, Serialize)]
285struct ErrorDetail {
286 message: String,
287 #[serde(rename = "type")]
288 error_type: String,
289 code: Option<String>,
290}
291
292impl ErrorResponse {
293 fn new(message: impl Into<String>, error_type: impl Into<String>) -> Self {
294 Self {
295 error: ErrorDetail {
296 message: message.into(),
297 error_type: error_type.into(),
298 code: None,
299 },
300 }
301 }
302
303 #[allow(dead_code)] fn with_code(mut self, code: impl Into<String>) -> Self {
305 self.error.code = Some(code.into());
306 self
307 }
308}
309
310fn error_response(status: StatusCode, message: &str, error_type: &str) -> Response {
311 let body = Json(ErrorResponse::new(message, error_type));
312 (status, body).into_response()
313}
314
315async fn health() -> &'static str {
318 "OK"
319}
320
321async fn ready(State(state): State<Arc<AppState>>) -> Response {
322 let engine = state.engine.read().await;
323 if engine.is_some() {
324 (StatusCode::OK, "Ready").into_response()
325 } else {
326 (StatusCode::SERVICE_UNAVAILABLE, "No model loaded").into_response()
327 }
328}
329
330#[derive(Debug, Serialize)]
331struct ServerStatus {
332 status: String,
333 uptime_seconds: u64,
334 model_loaded: bool,
335 model_id: Option<String>,
336}
337
338async fn server_status(State(state): State<Arc<AppState>>) -> Json<ServerStatus> {
339 let engine = state.engine.read().await;
340 let model_id = engine.as_ref().map(|e| e.model_info().id.to_string());
341
342 Json(ServerStatus {
343 status: "running".to_string(),
344 uptime_seconds: state.start_time.elapsed().as_secs(),
345 model_loaded: engine.is_some(),
346 model_id,
347 })
348}
349
350#[derive(Debug, Deserialize)]
353struct LoadModelRequest {
354 model: String,
355}
356
357async fn load_model(
358 State(state): State<Arc<AppState>>,
359 Json(req): Json<LoadModelRequest>,
360) -> Response {
361 tracing::info!(model = %req.model, "Loading model via API");
362
363 let engine_config = match EngineConfig::builder().model(&req.model).build() {
364 Ok(config) => config,
365 Err(e) => {
366 return error_response(
367 StatusCode::BAD_REQUEST,
368 &format!("Invalid model configuration: {}", e),
369 "invalid_request_error",
370 );
371 },
372 };
373
374 let engine = match Engine::new(engine_config).await {
375 Ok(engine) => engine,
376 Err(e) => {
377 return error_response(
378 StatusCode::INTERNAL_SERVER_ERROR,
379 &format!("Failed to load model: {}", e),
380 "model_load_error",
381 );
382 },
383 };
384
385 let mut engine_guard = state.engine.write().await;
386 *engine_guard = Some(Arc::new(engine));
387
388 (
389 StatusCode::OK,
390 Json(serde_json::json!({"status": "loaded", "model": req.model})),
391 )
392 .into_response()
393}
394
395async fn unload_model(State(state): State<Arc<AppState>>) -> Response {
396 let mut engine_guard = state.engine.write().await;
397 *engine_guard = None;
398 tracing::info!("Model unloaded");
399 (
400 StatusCode::OK,
401 Json(serde_json::json!({"status": "unloaded"})),
402 )
403 .into_response()
404}
405
406async fn list_models(State(state): State<Arc<AppState>>) -> Json<ModelsResponse> {
409 let engine = state.engine.read().await;
410
411 let models = match engine.as_ref() {
412 Some(engine) => {
413 let info = engine.model_info();
414 vec![ModelObject {
415 id: info.id.to_string(),
416 object: "model".to_string(),
417 created: chrono::Utc::now().timestamp(),
418 owned_by: "infernum".to_string(),
419 }]
420 },
421 None => vec![],
422 };
423
424 Json(ModelsResponse {
425 object: "list".to_string(),
426 data: models,
427 })
428}
429
430async fn chat_completions(
431 State(state): State<Arc<AppState>>,
432 Json(req): Json<ChatCompletionRequest>,
433) -> Response {
434 let start = Instant::now();
435 let request_id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
436
437 tracing::debug!(request_id = %request_id, model = %req.model, "Chat completion request");
438
439 let engine_guard = state.engine.read().await;
441 let engine = match engine_guard.as_ref() {
442 Some(engine) => Arc::clone(engine),
443 None => {
444 return error_response(
445 StatusCode::SERVICE_UNAVAILABLE,
446 "No model loaded",
447 "model_not_loaded",
448 );
449 },
450 };
451 drop(engine_guard); let stream = req.stream.unwrap_or(false);
455
456 let messages: Vec<infernum_core::Message> = req
458 .messages
459 .iter()
460 .map(|m| {
461 let role = match m.role.as_str() {
462 "system" => infernum_core::Role::System,
463 "user" => infernum_core::Role::User,
464 "assistant" => infernum_core::Role::Assistant,
465 _ => infernum_core::Role::User,
466 };
467 infernum_core::Message {
468 role,
469 content: m.content.clone(),
470 name: None,
471 tool_call_id: None,
472 }
473 })
474 .collect();
475
476 let mut sampling = SamplingParams::default();
478 if let Some(temp) = req.temperature {
479 sampling = sampling.with_temperature(temp);
480 }
481 if let Some(top_p) = req.top_p {
482 sampling = sampling.with_top_p(top_p);
483 }
484 if let Some(max_tokens) = req.max_tokens {
485 sampling = sampling.with_max_tokens(max_tokens);
486 }
487 if let Some(stop) = &req.stop {
488 for s in stop {
489 sampling = sampling.with_stop(s.clone());
490 }
491 }
492
493 let gen_request = GenerateRequest::new(infernum_core::request::PromptInput::Messages(messages))
495 .with_sampling(sampling);
496
497 if stream {
498 match engine.generate_stream(gen_request).await {
500 Ok(token_stream) => {
501 let model_name = engine.model_info().id.to_string();
502 let sse_stream = token_stream.map(move |chunk_result| {
503 match chunk_result {
504 Ok(chunk) => {
505 let data = serde_json::json!({
506 "id": request_id,
507 "object": "chat.completion.chunk",
508 "created": chrono::Utc::now().timestamp(),
509 "model": model_name,
510 "choices": [{
511 "index": 0,
512 "delta": {
513 "content": chunk.choices.first().map(|c| c.delta.content.as_deref().unwrap_or("")).unwrap_or("")
514 },
515 "finish_reason": chunk.choices.first().and_then(|c| c.finish_reason.as_ref().map(|r| format!("{:?}", r).to_lowercase()))
516 }]
517 });
518 Ok::<_, std::convert::Infallible>(axum::response::sse::Event::default().data(serde_json::to_string(&data).unwrap()))
519 }
520 Err(e) => {
521 let data = serde_json::json!({
522 "error": {
523 "message": e.to_string(),
524 "type": "server_error"
525 }
526 });
527 Ok(axum::response::sse::Event::default().data(serde_json::to_string(&data).unwrap()))
528 }
529 }
530 });
531
532 Sse::new(sse_stream)
533 .keep_alive(axum::response::sse::KeepAlive::default())
534 .into_response()
535 },
536 Err(e) => error_response(
537 StatusCode::INTERNAL_SERVER_ERROR,
538 &e.to_string(),
539 "generation_error",
540 ),
541 }
542 } else {
543 match engine.generate(gen_request).await {
545 Ok(response) => {
546 let choice = response.choices.first();
547 let content = choice.map(|c| c.text.clone()).unwrap_or_default();
548 let finish_reason = choice
549 .and_then(|c| c.finish_reason.as_ref())
550 .map(|r| format!("{:?}", r).to_lowercase())
551 .unwrap_or_else(|| "stop".to_string());
552
553 let chat_response = ChatCompletionResponse {
554 id: request_id,
555 object: "chat.completion".to_string(),
556 created: chrono::Utc::now().timestamp(),
557 model: engine.model_info().id.to_string(),
558 choices: vec![ChatChoice {
559 index: 0,
560 message: ChatMessage {
561 role: "assistant".to_string(),
562 content,
563 name: None,
564 },
565 finish_reason,
566 }],
567 usage: Usage {
568 prompt_tokens: response.usage.prompt_tokens,
569 completion_tokens: response.usage.completion_tokens,
570 total_tokens: response.usage.total_tokens,
571 },
572 };
573
574 tracing::debug!(
575 request_id = %chat_response.id,
576 prompt_tokens = response.usage.prompt_tokens,
577 completion_tokens = response.usage.completion_tokens,
578 latency_ms = start.elapsed().as_millis() as u64,
579 "Chat completion finished"
580 );
581
582 Json(chat_response).into_response()
583 },
584 Err(e) => error_response(
585 StatusCode::INTERNAL_SERVER_ERROR,
586 &e.to_string(),
587 "generation_error",
588 ),
589 }
590 }
591}
592
593async fn completions(
594 State(state): State<Arc<AppState>>,
595 Json(req): Json<CompletionRequest>,
596) -> Response {
597 let start = Instant::now();
598 let request_id = format!("cmpl-{}", uuid::Uuid::new_v4());
599
600 tracing::debug!(request_id = %request_id, model = %req.model, "Completion request");
601
602 let engine_guard = state.engine.read().await;
604 let engine = match engine_guard.as_ref() {
605 Some(engine) => Arc::clone(engine),
606 None => {
607 return error_response(
608 StatusCode::SERVICE_UNAVAILABLE,
609 "No model loaded",
610 "model_not_loaded",
611 );
612 },
613 };
614 drop(engine_guard);
615
616 let mut sampling = SamplingParams::default();
618 if let Some(temp) = req.temperature {
619 sampling = sampling.with_temperature(temp);
620 }
621 if let Some(top_p) = req.top_p {
622 sampling = sampling.with_top_p(top_p);
623 }
624 if let Some(max_tokens) = req.max_tokens {
625 sampling = sampling.with_max_tokens(max_tokens);
626 }
627 if let Some(stop) = &req.stop {
628 for s in stop {
629 sampling = sampling.with_stop(s.clone());
630 }
631 }
632
633 let gen_request = GenerateRequest::new(infernum_core::request::PromptInput::Text(req.prompt))
635 .with_sampling(sampling);
636
637 match engine.generate(gen_request).await {
638 Ok(response) => {
639 let choice = response.choices.first();
640 let text = choice.map(|c| c.text.clone()).unwrap_or_default();
641 let finish_reason = choice
642 .and_then(|c| c.finish_reason.as_ref())
643 .map(|r| format!("{:?}", r).to_lowercase())
644 .unwrap_or_else(|| "stop".to_string());
645
646 let completion_response = CompletionResponse {
647 id: request_id.clone(),
648 object: "text_completion".to_string(),
649 created: chrono::Utc::now().timestamp(),
650 model: engine.model_info().id.to_string(),
651 choices: vec![CompletionChoice {
652 text,
653 index: 0,
654 finish_reason,
655 logprobs: None,
656 }],
657 usage: Usage {
658 prompt_tokens: response.usage.prompt_tokens,
659 completion_tokens: response.usage.completion_tokens,
660 total_tokens: response.usage.total_tokens,
661 },
662 };
663
664 tracing::debug!(
665 request_id = %request_id,
666 prompt_tokens = response.usage.prompt_tokens,
667 completion_tokens = response.usage.completion_tokens,
668 latency_ms = start.elapsed().as_millis() as u64,
669 "Completion finished"
670 );
671
672 Json(completion_response).into_response()
673 },
674 Err(e) => error_response(
675 StatusCode::INTERNAL_SERVER_ERROR,
676 &e.to_string(),
677 "generation_error",
678 ),
679 }
680}
681
682#[allow(dead_code)]
684async fn embeddings(
685 State(state): State<Arc<AppState>>,
686 Json(req): Json<EmbeddingRequest>,
687) -> Response {
688 let request_id = format!("emb-{}", uuid::Uuid::new_v4());
689
690 tracing::debug!(request_id = %request_id, model = %req.model, "Embedding request");
691
692 let engine_guard = state.engine.read().await;
694 let engine = match engine_guard.as_ref() {
695 Some(engine) => Arc::clone(engine),
696 None => {
697 return error_response(
698 StatusCode::SERVICE_UNAVAILABLE,
699 "No model loaded",
700 "model_not_loaded",
701 );
702 },
703 };
704 drop(engine_guard);
705
706 let texts: Vec<String> = match &req.input {
708 EmbeddingInput::Single(s) => vec![s.clone()],
709 EmbeddingInput::Multiple(v) => v.clone(),
710 };
711
712 let mut embeddings = Vec::new();
714 let mut total_tokens = 0u32;
715
716 for (idx, text) in texts.iter().enumerate() {
717 let embed_request = infernum_core::EmbedRequest::new(text.clone());
718
719 match engine.embed(embed_request).await {
720 Ok(response) => {
721 let embedding_vec = response
723 .data
724 .first()
725 .and_then(|e| e.embedding.as_floats().ok())
726 .unwrap_or_default();
727
728 embeddings.push(EmbeddingData {
729 object: "embedding".to_string(),
730 index: idx as u32,
731 embedding: embedding_vec,
732 });
733 total_tokens += response.usage.total_tokens;
734 },
735 Err(e) => {
736 return error_response(
737 StatusCode::INTERNAL_SERVER_ERROR,
738 &e.to_string(),
739 "embedding_error",
740 );
741 },
742 }
743 }
744
745 let response = EmbeddingResponse {
746 object: "list".to_string(),
747 data: embeddings,
748 model: engine.model_info().id.to_string(),
749 usage: EmbeddingUsage {
750 prompt_tokens: total_tokens,
751 total_tokens,
752 },
753 };
754
755 Json(response).into_response()
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761
762 #[test]
763 fn test_server_config_builder() {
764 let config = ServerConfig::builder()
765 .addr("127.0.0.1:3000".parse().unwrap())
766 .cors(false)
767 .model("test-model")
768 .max_concurrent_requests(32)
769 .build();
770
771 assert_eq!(config.addr, "127.0.0.1:3000".parse().unwrap());
772 assert!(!config.cors);
773 assert_eq!(config.model, Some("test-model".to_string()));
774 assert_eq!(config.max_concurrent_requests, 32);
775 }
776
777 #[test]
778 fn test_error_response() {
779 let err = ErrorResponse::new("Test error", "test_error").with_code("TEST_CODE");
780
781 assert_eq!(err.error.message, "Test error");
782 assert_eq!(err.error.error_type, "test_error");
783 assert_eq!(err.error.code, Some("TEST_CODE".to_string()));
784 }
785}