infernum_server/
server.rs

1//! HTTP server implementation with OpenAI-compatible API endpoints.
2//!
3//! Provides a production-ready server that interfaces with the Abaddon inference engine
4//! for text generation, chat completions, and embeddings.
5
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::time::Instant;
9
10use axum::extract::State;
11use axum::http::StatusCode;
12use axum::response::{IntoResponse, Response, Sse};
13use axum::routing::{get, post};
14use axum::{Json, Router};
15use futures::stream::StreamExt;
16use serde::{Deserialize, Serialize};
17use tokio::sync::RwLock;
18use tower_http::cors::CorsLayer;
19use tower_http::trace::TraceLayer;
20
21use abaddon::{Engine, EngineConfig, InferenceEngine};
22use infernum_core::{GenerateRequest, Result, SamplingParams};
23
24use crate::openai::{
25    ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, CompletionChoice,
26    CompletionRequest, CompletionResponse, EmbeddingData, EmbeddingInput, EmbeddingRequest,
27    EmbeddingResponse, EmbeddingUsage, ModelObject, ModelsResponse, Usage,
28};
29
30/// Server configuration.
31#[derive(Debug, Clone)]
32pub struct ServerConfig {
33    /// Listen address.
34    pub addr: SocketAddr,
35    /// Enable CORS.
36    pub cors: bool,
37    /// Model to load (optional - server can start without a model).
38    pub model: Option<String>,
39    /// Maximum concurrent requests.
40    pub max_concurrent_requests: usize,
41}
42
43impl Default for ServerConfig {
44    fn default() -> Self {
45        Self {
46            addr: "0.0.0.0:8080".parse().unwrap(),
47            cors: true,
48            model: None,
49            max_concurrent_requests: 64,
50        }
51    }
52}
53
54impl ServerConfig {
55    /// Creates a new server config builder.
56    pub fn builder() -> ServerConfigBuilder {
57        ServerConfigBuilder::default()
58    }
59}
60
61/// Builder for ServerConfig.
62#[derive(Debug, Default)]
63pub struct ServerConfigBuilder {
64    addr: Option<SocketAddr>,
65    cors: Option<bool>,
66    model: Option<String>,
67    max_concurrent_requests: Option<usize>,
68}
69
70impl ServerConfigBuilder {
71    /// Sets the listen address.
72    pub fn addr(mut self, addr: SocketAddr) -> Self {
73        self.addr = Some(addr);
74        self
75    }
76
77    /// Sets whether CORS is enabled.
78    pub fn cors(mut self, enabled: bool) -> Self {
79        self.cors = Some(enabled);
80        self
81    }
82
83    /// Sets the model to load.
84    pub fn model(mut self, model: impl Into<String>) -> Self {
85        self.model = Some(model.into());
86        self
87    }
88
89    /// Sets the maximum concurrent requests.
90    pub fn max_concurrent_requests(mut self, max: usize) -> Self {
91        self.max_concurrent_requests = Some(max);
92        self
93    }
94
95    /// Builds the server config.
96    pub fn build(self) -> ServerConfig {
97        ServerConfig {
98            addr: self.addr.unwrap_or_else(|| "0.0.0.0:8080".parse().unwrap()),
99            cors: self.cors.unwrap_or(true),
100            model: self.model,
101            max_concurrent_requests: self.max_concurrent_requests.unwrap_or(64),
102        }
103    }
104}
105
106/// Shared application state.
107pub struct AppState {
108    /// The inference engine (None if no model is loaded).
109    pub engine: RwLock<Option<Arc<Engine>>>,
110    /// Server configuration.
111    pub config: ServerConfig,
112    /// Server start time.
113    pub start_time: Instant,
114}
115
116impl AppState {
117    /// Creates new app state with the given config.
118    pub fn new(config: ServerConfig) -> Self {
119        Self {
120            engine: RwLock::new(None),
121            config,
122            start_time: Instant::now(),
123        }
124    }
125
126    /// Creates new app state with a pre-loaded engine.
127    pub fn with_engine(config: ServerConfig, engine: Engine) -> Self {
128        Self {
129            engine: RwLock::new(Some(Arc::new(engine))),
130            config,
131            start_time: Instant::now(),
132        }
133    }
134}
135
136/// The HTTP server.
137pub struct Server {
138    config: ServerConfig,
139    state: Arc<AppState>,
140}
141
142impl Server {
143    /// Creates a new server with the given configuration.
144    pub fn new(config: ServerConfig) -> Self {
145        let state = Arc::new(AppState::new(config.clone()));
146        Self { config, state }
147    }
148
149    /// Creates a new server with a pre-loaded engine.
150    pub fn with_engine(config: ServerConfig, engine: Engine) -> Self {
151        let state = Arc::new(AppState::with_engine(config.clone(), engine));
152        Self { config, state }
153    }
154
155    /// Creates the router.
156    fn router(&self) -> Router {
157        let mut router = Router::new()
158            // Health endpoints
159            .route("/health", get(health))
160            .route("/ready", get(ready))
161            // OpenAI-compatible API endpoints
162            .route("/v1/models", get(list_models))
163            .route("/v1/chat/completions", post(chat_completions))
164            .route("/v1/completions", post(completions))
165            // NOTE: /v1/embeddings disabled until embedding models are supported
166            // .route("/v1/embeddings", post(embeddings))
167            // Internal management endpoints
168            .route("/api/models/load", post(load_model))
169            .route("/api/models/unload", post(unload_model))
170            .route("/api/status", get(server_status))
171            .with_state(self.state.clone());
172
173        // Add middleware
174        router = router.layer(TraceLayer::new_for_http());
175
176        if self.config.cors {
177            router = router.layer(CorsLayer::permissive());
178        }
179
180        router
181    }
182
183    /// Loads a model into the server.
184    pub async fn load_model(&self, model_source: &str) -> Result<()> {
185        tracing::info!(model = %model_source, "Loading model");
186
187        let engine_config = EngineConfig::builder()
188            .model(model_source)
189            .build()
190            .map_err(|e| infernum_core::Error::Internal { message: e })?;
191
192        let engine = Engine::new(engine_config).await?;
193        let mut engine_guard = self.state.engine.write().await;
194        *engine_guard = Some(Arc::new(engine));
195
196        tracing::info!(model = %model_source, "Model loaded successfully");
197        Ok(())
198    }
199
200    /// Runs the server.
201    ///
202    /// # Errors
203    ///
204    /// Returns an error if the server cannot start.
205    pub async fn run(self) -> Result<()> {
206        // Load model if specified
207        if let Some(model) = &self.config.model {
208            self.load_model(model).await?;
209            tracing::info!(model = %model, "Model loaded and ready for inference");
210        } else {
211            tracing::warn!("=======================================================");
212            tracing::warn!("  SERVER STARTED WITHOUT A MODEL");
213            tracing::warn!("  All inference requests will fail until a model is loaded.");
214            tracing::warn!("  ");
215            tracing::warn!("  To load a model, either:");
216            tracing::warn!("    1. Restart with: infernum serve --model <model>");
217            tracing::warn!("    2. POST to /api/models/load with {{\"model\": \"<model>\"}}");
218            tracing::warn!("=======================================================");
219        }
220
221        let router = self.router();
222
223        tracing::info!(addr = %self.config.addr, "Starting Infernum server");
224        eprintln!(
225            "\n\x1b[32m✓\x1b[0m Server listening on http://{}",
226            self.config.addr
227        );
228        eprintln!("  Press Ctrl+C to stop\n");
229
230        let listener = tokio::net::TcpListener::bind(self.config.addr)
231            .await
232            .map_err(infernum_core::Error::Io)?;
233
234        // Set up graceful shutdown
235        let shutdown_signal = async {
236            let ctrl_c = async {
237                tokio::signal::ctrl_c()
238                    .await
239                    .expect("Failed to install Ctrl+C handler");
240            };
241
242            #[cfg(unix)]
243            let terminate = async {
244                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
245                    .expect("Failed to install signal handler")
246                    .recv()
247                    .await;
248            };
249
250            #[cfg(not(unix))]
251            let terminate = std::future::pending::<()>();
252
253            tokio::select! {
254                () = ctrl_c => {
255                    eprintln!("\n\x1b[33m⚡\x1b[0m Received Ctrl+C, shutting down gracefully...");
256                },
257                () = terminate => {
258                    eprintln!("\n\x1b[33m⚡\x1b[0m Received SIGTERM, shutting down gracefully...");
259                },
260            }
261        };
262
263        axum::serve(listener, router)
264            .with_graceful_shutdown(shutdown_signal)
265            .await
266            .map_err(|e| infernum_core::Error::Internal {
267                message: e.to_string(),
268            })?;
269
270        tracing::info!("Server shutdown complete");
271        eprintln!("\x1b[32m✓\x1b[0m Server stopped");
272
273        Ok(())
274    }
275}
276
277// === Error Response ===
278
279#[derive(Debug, Serialize)]
280struct ErrorResponse {
281    error: ErrorDetail,
282}
283
284#[derive(Debug, Serialize)]
285struct ErrorDetail {
286    message: String,
287    #[serde(rename = "type")]
288    error_type: String,
289    code: Option<String>,
290}
291
292impl ErrorResponse {
293    fn new(message: impl Into<String>, error_type: impl Into<String>) -> Self {
294        Self {
295            error: ErrorDetail {
296                message: message.into(),
297                error_type: error_type.into(),
298                code: None,
299            },
300        }
301    }
302
303    #[allow(dead_code)] // Reserved for future use with specific error codes
304    fn with_code(mut self, code: impl Into<String>) -> Self {
305        self.error.code = Some(code.into());
306        self
307    }
308}
309
310fn error_response(status: StatusCode, message: &str, error_type: &str) -> Response {
311    let body = Json(ErrorResponse::new(message, error_type));
312    (status, body).into_response()
313}
314
315// === Health Endpoints ===
316
317async fn health() -> &'static str {
318    "OK"
319}
320
321async fn ready(State(state): State<Arc<AppState>>) -> Response {
322    let engine = state.engine.read().await;
323    if engine.is_some() {
324        (StatusCode::OK, "Ready").into_response()
325    } else {
326        (StatusCode::SERVICE_UNAVAILABLE, "No model loaded").into_response()
327    }
328}
329
330#[derive(Debug, Serialize)]
331struct ServerStatus {
332    status: String,
333    uptime_seconds: u64,
334    model_loaded: bool,
335    model_id: Option<String>,
336}
337
338async fn server_status(State(state): State<Arc<AppState>>) -> Json<ServerStatus> {
339    let engine = state.engine.read().await;
340    let model_id = engine.as_ref().map(|e| e.model_info().id.to_string());
341
342    Json(ServerStatus {
343        status: "running".to_string(),
344        uptime_seconds: state.start_time.elapsed().as_secs(),
345        model_loaded: engine.is_some(),
346        model_id,
347    })
348}
349
350// === Model Management ===
351
352#[derive(Debug, Deserialize)]
353struct LoadModelRequest {
354    model: String,
355}
356
357async fn load_model(
358    State(state): State<Arc<AppState>>,
359    Json(req): Json<LoadModelRequest>,
360) -> Response {
361    tracing::info!(model = %req.model, "Loading model via API");
362
363    let engine_config = match EngineConfig::builder().model(&req.model).build() {
364        Ok(config) => config,
365        Err(e) => {
366            return error_response(
367                StatusCode::BAD_REQUEST,
368                &format!("Invalid model configuration: {}", e),
369                "invalid_request_error",
370            );
371        },
372    };
373
374    let engine = match Engine::new(engine_config).await {
375        Ok(engine) => engine,
376        Err(e) => {
377            return error_response(
378                StatusCode::INTERNAL_SERVER_ERROR,
379                &format!("Failed to load model: {}", e),
380                "model_load_error",
381            );
382        },
383    };
384
385    let mut engine_guard = state.engine.write().await;
386    *engine_guard = Some(Arc::new(engine));
387
388    (
389        StatusCode::OK,
390        Json(serde_json::json!({"status": "loaded", "model": req.model})),
391    )
392        .into_response()
393}
394
395async fn unload_model(State(state): State<Arc<AppState>>) -> Response {
396    let mut engine_guard = state.engine.write().await;
397    *engine_guard = None;
398    tracing::info!("Model unloaded");
399    (
400        StatusCode::OK,
401        Json(serde_json::json!({"status": "unloaded"})),
402    )
403        .into_response()
404}
405
406// === OpenAI-Compatible Endpoints ===
407
408async fn list_models(State(state): State<Arc<AppState>>) -> Json<ModelsResponse> {
409    let engine = state.engine.read().await;
410
411    let models = match engine.as_ref() {
412        Some(engine) => {
413            let info = engine.model_info();
414            vec![ModelObject {
415                id: info.id.to_string(),
416                object: "model".to_string(),
417                created: chrono::Utc::now().timestamp(),
418                owned_by: "infernum".to_string(),
419            }]
420        },
421        None => vec![],
422    };
423
424    Json(ModelsResponse {
425        object: "list".to_string(),
426        data: models,
427    })
428}
429
430async fn chat_completions(
431    State(state): State<Arc<AppState>>,
432    Json(req): Json<ChatCompletionRequest>,
433) -> Response {
434    let start = Instant::now();
435    let request_id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
436
437    tracing::debug!(request_id = %request_id, model = %req.model, "Chat completion request");
438
439    // Get engine
440    let engine_guard = state.engine.read().await;
441    let engine = match engine_guard.as_ref() {
442        Some(engine) => Arc::clone(engine),
443        None => {
444            return error_response(
445                StatusCode::SERVICE_UNAVAILABLE,
446                "No model loaded",
447                "model_not_loaded",
448            );
449        },
450    };
451    drop(engine_guard); // Release lock early
452
453    // Check for streaming
454    let stream = req.stream.unwrap_or(false);
455
456    // Build messages into prompt
457    let messages: Vec<infernum_core::Message> = req
458        .messages
459        .iter()
460        .map(|m| {
461            let role = match m.role.as_str() {
462                "system" => infernum_core::Role::System,
463                "user" => infernum_core::Role::User,
464                "assistant" => infernum_core::Role::Assistant,
465                _ => infernum_core::Role::User,
466            };
467            infernum_core::Message {
468                role,
469                content: m.content.clone(),
470                name: None,
471                tool_call_id: None,
472            }
473        })
474        .collect();
475
476    // Build sampling params
477    let mut sampling = SamplingParams::default();
478    if let Some(temp) = req.temperature {
479        sampling = sampling.with_temperature(temp);
480    }
481    if let Some(top_p) = req.top_p {
482        sampling = sampling.with_top_p(top_p);
483    }
484    if let Some(max_tokens) = req.max_tokens {
485        sampling = sampling.with_max_tokens(max_tokens);
486    }
487    if let Some(stop) = &req.stop {
488        for s in stop {
489            sampling = sampling.with_stop(s.clone());
490        }
491    }
492
493    // Create inference request
494    let gen_request = GenerateRequest::new(infernum_core::request::PromptInput::Messages(messages))
495        .with_sampling(sampling);
496
497    if stream {
498        // Streaming response
499        match engine.generate_stream(gen_request).await {
500            Ok(token_stream) => {
501                let model_name = engine.model_info().id.to_string();
502                let sse_stream = token_stream.map(move |chunk_result| {
503                    match chunk_result {
504                        Ok(chunk) => {
505                            let data = serde_json::json!({
506                                "id": request_id,
507                                "object": "chat.completion.chunk",
508                                "created": chrono::Utc::now().timestamp(),
509                                "model": model_name,
510                                "choices": [{
511                                    "index": 0,
512                                    "delta": {
513                                        "content": chunk.choices.first().map(|c| c.delta.content.as_deref().unwrap_or("")).unwrap_or("")
514                                    },
515                                    "finish_reason": chunk.choices.first().and_then(|c| c.finish_reason.as_ref().map(|r| format!("{:?}", r).to_lowercase()))
516                                }]
517                            });
518                            Ok::<_, std::convert::Infallible>(axum::response::sse::Event::default().data(serde_json::to_string(&data).unwrap()))
519                        }
520                        Err(e) => {
521                            let data = serde_json::json!({
522                                "error": {
523                                    "message": e.to_string(),
524                                    "type": "server_error"
525                                }
526                            });
527                            Ok(axum::response::sse::Event::default().data(serde_json::to_string(&data).unwrap()))
528                        }
529                    }
530                });
531
532                Sse::new(sse_stream)
533                    .keep_alive(axum::response::sse::KeepAlive::default())
534                    .into_response()
535            },
536            Err(e) => error_response(
537                StatusCode::INTERNAL_SERVER_ERROR,
538                &e.to_string(),
539                "generation_error",
540            ),
541        }
542    } else {
543        // Non-streaming response
544        match engine.generate(gen_request).await {
545            Ok(response) => {
546                let choice = response.choices.first();
547                let content = choice.map(|c| c.text.clone()).unwrap_or_default();
548                let finish_reason = choice
549                    .and_then(|c| c.finish_reason.as_ref())
550                    .map(|r| format!("{:?}", r).to_lowercase())
551                    .unwrap_or_else(|| "stop".to_string());
552
553                let chat_response = ChatCompletionResponse {
554                    id: request_id,
555                    object: "chat.completion".to_string(),
556                    created: chrono::Utc::now().timestamp(),
557                    model: engine.model_info().id.to_string(),
558                    choices: vec![ChatChoice {
559                        index: 0,
560                        message: ChatMessage {
561                            role: "assistant".to_string(),
562                            content,
563                            name: None,
564                        },
565                        finish_reason,
566                    }],
567                    usage: Usage {
568                        prompt_tokens: response.usage.prompt_tokens,
569                        completion_tokens: response.usage.completion_tokens,
570                        total_tokens: response.usage.total_tokens,
571                    },
572                };
573
574                tracing::debug!(
575                    request_id = %chat_response.id,
576                    prompt_tokens = response.usage.prompt_tokens,
577                    completion_tokens = response.usage.completion_tokens,
578                    latency_ms = start.elapsed().as_millis() as u64,
579                    "Chat completion finished"
580                );
581
582                Json(chat_response).into_response()
583            },
584            Err(e) => error_response(
585                StatusCode::INTERNAL_SERVER_ERROR,
586                &e.to_string(),
587                "generation_error",
588            ),
589        }
590    }
591}
592
593async fn completions(
594    State(state): State<Arc<AppState>>,
595    Json(req): Json<CompletionRequest>,
596) -> Response {
597    let start = Instant::now();
598    let request_id = format!("cmpl-{}", uuid::Uuid::new_v4());
599
600    tracing::debug!(request_id = %request_id, model = %req.model, "Completion request");
601
602    // Get engine
603    let engine_guard = state.engine.read().await;
604    let engine = match engine_guard.as_ref() {
605        Some(engine) => Arc::clone(engine),
606        None => {
607            return error_response(
608                StatusCode::SERVICE_UNAVAILABLE,
609                "No model loaded",
610                "model_not_loaded",
611            );
612        },
613    };
614    drop(engine_guard);
615
616    // Build sampling params
617    let mut sampling = SamplingParams::default();
618    if let Some(temp) = req.temperature {
619        sampling = sampling.with_temperature(temp);
620    }
621    if let Some(top_p) = req.top_p {
622        sampling = sampling.with_top_p(top_p);
623    }
624    if let Some(max_tokens) = req.max_tokens {
625        sampling = sampling.with_max_tokens(max_tokens);
626    }
627    if let Some(stop) = &req.stop {
628        for s in stop {
629            sampling = sampling.with_stop(s.clone());
630        }
631    }
632
633    // Create inference request
634    let gen_request = GenerateRequest::new(infernum_core::request::PromptInput::Text(req.prompt))
635        .with_sampling(sampling);
636
637    match engine.generate(gen_request).await {
638        Ok(response) => {
639            let choice = response.choices.first();
640            let text = choice.map(|c| c.text.clone()).unwrap_or_default();
641            let finish_reason = choice
642                .and_then(|c| c.finish_reason.as_ref())
643                .map(|r| format!("{:?}", r).to_lowercase())
644                .unwrap_or_else(|| "stop".to_string());
645
646            let completion_response = CompletionResponse {
647                id: request_id.clone(),
648                object: "text_completion".to_string(),
649                created: chrono::Utc::now().timestamp(),
650                model: engine.model_info().id.to_string(),
651                choices: vec![CompletionChoice {
652                    text,
653                    index: 0,
654                    finish_reason,
655                    logprobs: None,
656                }],
657                usage: Usage {
658                    prompt_tokens: response.usage.prompt_tokens,
659                    completion_tokens: response.usage.completion_tokens,
660                    total_tokens: response.usage.total_tokens,
661                },
662            };
663
664            tracing::debug!(
665                request_id = %request_id,
666                prompt_tokens = response.usage.prompt_tokens,
667                completion_tokens = response.usage.completion_tokens,
668                latency_ms = start.elapsed().as_millis() as u64,
669                "Completion finished"
670            );
671
672            Json(completion_response).into_response()
673        },
674        Err(e) => error_response(
675            StatusCode::INTERNAL_SERVER_ERROR,
676            &e.to_string(),
677            "generation_error",
678        ),
679    }
680}
681
682// TODO: Re-enable when embedding models are supported
683#[allow(dead_code)]
684async fn embeddings(
685    State(state): State<Arc<AppState>>,
686    Json(req): Json<EmbeddingRequest>,
687) -> Response {
688    let request_id = format!("emb-{}", uuid::Uuid::new_v4());
689
690    tracing::debug!(request_id = %request_id, model = %req.model, "Embedding request");
691
692    // Get engine
693    let engine_guard = state.engine.read().await;
694    let engine = match engine_guard.as_ref() {
695        Some(engine) => Arc::clone(engine),
696        None => {
697            return error_response(
698                StatusCode::SERVICE_UNAVAILABLE,
699                "No model loaded",
700                "model_not_loaded",
701            );
702        },
703    };
704    drop(engine_guard);
705
706    // Get input texts
707    let texts: Vec<String> = match &req.input {
708        EmbeddingInput::Single(s) => vec![s.clone()],
709        EmbeddingInput::Multiple(v) => v.clone(),
710    };
711
712    // Generate embeddings for each input
713    let mut embeddings = Vec::new();
714    let mut total_tokens = 0u32;
715
716    for (idx, text) in texts.iter().enumerate() {
717        let embed_request = infernum_core::EmbedRequest::new(text.clone());
718
719        match engine.embed(embed_request).await {
720            Ok(response) => {
721                // Extract embedding vector from the response
722                let embedding_vec = response
723                    .data
724                    .first()
725                    .and_then(|e| e.embedding.as_floats().ok())
726                    .unwrap_or_default();
727
728                embeddings.push(EmbeddingData {
729                    object: "embedding".to_string(),
730                    index: idx as u32,
731                    embedding: embedding_vec,
732                });
733                total_tokens += response.usage.total_tokens;
734            },
735            Err(e) => {
736                return error_response(
737                    StatusCode::INTERNAL_SERVER_ERROR,
738                    &e.to_string(),
739                    "embedding_error",
740                );
741            },
742        }
743    }
744
745    let response = EmbeddingResponse {
746        object: "list".to_string(),
747        data: embeddings,
748        model: engine.model_info().id.to_string(),
749        usage: EmbeddingUsage {
750            prompt_tokens: total_tokens,
751            total_tokens,
752        },
753    };
754
755    Json(response).into_response()
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761
762    #[test]
763    fn test_server_config_builder() {
764        let config = ServerConfig::builder()
765            .addr("127.0.0.1:3000".parse().unwrap())
766            .cors(false)
767            .model("test-model")
768            .max_concurrent_requests(32)
769            .build();
770
771        assert_eq!(config.addr, "127.0.0.1:3000".parse().unwrap());
772        assert!(!config.cors);
773        assert_eq!(config.model, Some("test-model".to_string()));
774        assert_eq!(config.max_concurrent_requests, 32);
775    }
776
777    #[test]
778    fn test_error_response() {
779        let err = ErrorResponse::new("Test error", "test_error").with_code("TEST_CODE");
780
781        assert_eq!(err.error.message, "Test error");
782        assert_eq!(err.error.error_type, "test_error");
783        assert_eq!(err.error.code, Some("TEST_CODE".to_string()));
784    }
785}