helios_engine/
serve.rs

1//! # Serve Module
2//!
3//! This module provides functionality to serve fully OpenAI-compatible API endpoints
4//! with real-time streaming and parameter control, allowing users to expose their
5//! agents or LLM clients via HTTP with full generation parameter support.
6//!
7//! ## Usage
8//!
9//! ### From CLI
10//! ```bash
11//! helios-engine serve --port 8000
12//! ```
13//!
14//! ### Programmatically
15//! ```no_run
16//! use helios_engine::{Config, serve};
17//!
18//! #[tokio::main]
19//! async fn main() -> helios_engine::Result<()> {
20//!     let config = Config::from_file("config.toml")?;
21//!     serve::start_server(config, "127.0.0.1:8000").await?;
22//!     Ok(())
23//! }
24//! ```
25
26use crate::agent::Agent;
27use crate::chat::{ChatMessage, Role};
28use crate::config::Config;
29use crate::error::{HeliosError, Result};
30use crate::llm::{LLMClient, LLMProviderType};
31use axum::{
32    extract::State,
33    http::StatusCode,
34    response::{
35        sse::{Event, Sse},
36        IntoResponse,
37    },
38    routing::{delete, get, patch, post, put},
39    Json, Router,
40};
41use futures::stream::Stream;
42use serde::{Deserialize, Serialize};
43use std::convert::Infallible;
44use std::sync::Arc;
45use tokio::sync::RwLock;
46use tokio_stream::wrappers::ReceiverStream;
47use tower_http::cors::CorsLayer;
48use tower_http::trace::TraceLayer;
49use tracing::{error, info};
50use uuid::Uuid;
51
52/// OpenAI-compatible chat completion request.
53#[derive(Debug, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub struct ChatCompletionRequest {
56    /// The model to use.
57    pub model: String,
58    /// The messages to send.
59    pub messages: Vec<OpenAIMessage>,
60    /// The temperature to use.
61    #[serde(default)]
62    pub temperature: Option<f32>,
63    /// The maximum number of tokens to generate.
64    #[serde(default)]
65    pub max_tokens: Option<u32>,
66    /// Whether to stream the response.
67    #[serde(default)]
68    pub stream: Option<bool>,
69    /// Stop sequences.
70    #[serde(default)]
71    pub stop: Option<Vec<String>>,
72}
73
74/// OpenAI-compatible message format.
75#[derive(Debug, Deserialize)]
76pub struct OpenAIMessage {
77    /// The role of the message sender.
78    pub role: String,
79    /// The content of the message.
80    pub content: String,
81    /// The name of the message sender (optional).
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub name: Option<String>,
84}
85
86/// OpenAI-compatible chat completion response.
87#[derive(Debug, Serialize)]
88pub struct ChatCompletionResponse {
89    /// The ID of the completion.
90    pub id: String,
91    /// The object type.
92    pub object: String,
93    /// The creation timestamp.
94    pub created: u64,
95    /// The model used.
96    pub model: String,
97    /// The choices in the response.
98    pub choices: Vec<CompletionChoice>,
99    /// Usage statistics.
100    pub usage: Usage,
101}
102
103/// A choice in a completion response.
104#[derive(Debug, Serialize)]
105pub struct CompletionChoice {
106    /// The index of the choice.
107    pub index: u32,
108    /// The message in the choice.
109    pub message: OpenAIMessageResponse,
110    /// The finish reason.
111    pub finish_reason: String,
112}
113
114/// A message in a completion response.
115#[derive(Debug, Serialize)]
116pub struct OpenAIMessageResponse {
117    /// The role of the message sender.
118    pub role: String,
119    /// The content of the message.
120    pub content: String,
121}
122
123/// Usage statistics for a completion.
124#[derive(Debug, Serialize)]
125pub struct Usage {
126    /// The number of prompt tokens.
127    pub prompt_tokens: u32,
128    /// The number of completion tokens.
129    pub completion_tokens: u32,
130    /// The total number of tokens.
131    pub total_tokens: u32,
132}
133
134/// Model information for the models endpoint.
135#[derive(Debug, Serialize)]
136pub struct ModelInfo {
137    /// The ID of the model.
138    pub id: String,
139    /// The object type.
140    pub object: String,
141    /// The creation timestamp.
142    pub created: u64,
143    /// The owner of the model.
144    pub owned_by: String,
145}
146
147/// Models list response.
148#[derive(Debug, Serialize)]
149pub struct ModelsResponse {
150    /// The object type.
151    pub object: String,
152    /// The list of models.
153    pub data: Vec<ModelInfo>,
154}
155
156/// Custom endpoint configuration.
157#[derive(Debug, Clone, Deserialize)]
158pub struct CustomEndpoint {
159    /// The HTTP method (GET, POST, PUT, DELETE, PATCH).
160    pub method: String,
161    /// The endpoint path.
162    pub path: String,
163    /// The response body as JSON.
164    pub response: serde_json::Value,
165    /// Optional status code (defaults to 200).
166    #[serde(default = "default_status_code")]
167    pub status_code: u16,
168}
169
170fn default_status_code() -> u16 {
171    200
172}
173
174/// Custom endpoints configuration.
175#[derive(Debug, Clone, Deserialize)]
176pub struct CustomEndpointsConfig {
177    /// List of custom endpoints.
178    pub endpoints: Vec<CustomEndpoint>,
179}
180
181/// Server state containing the LLM client and agent (if any).
182#[derive(Clone)]
183pub struct ServerState {
184    /// The LLM client for direct LLM calls.
185    pub llm_client: Option<Arc<LLMClient>>,
186    /// The agent (if serving an agent).
187    pub agent: Option<Arc<RwLock<Agent>>>,
188    /// The model name being served.
189    pub model_name: String,
190}
191
192impl ServerState {
193    /// Creates a new server state with an LLM client.
194    pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
195        Self {
196            llm_client: Some(Arc::new(llm_client)),
197            agent: None,
198            model_name,
199        }
200    }
201
202    /// Creates a new server state with an agent.
203    pub fn with_agent(agent: Agent, model_name: String) -> Self {
204        Self {
205            llm_client: None,
206            agent: Some(Arc::new(RwLock::new(agent))),
207            model_name,
208        }
209    }
210}
211
212/// Starts the HTTP server with the given configuration.
213///
214/// # Arguments
215///
216/// * `config` - The configuration to use for the LLM client.
217/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
218///
219/// # Returns
220///
221/// A `Result` that resolves when the server shuts down.
222pub async fn start_server(config: Config, address: &str) -> Result<()> {
223    let provider_type = if let Some(local_config) = config.local.clone() {
224        LLMProviderType::Local(local_config)
225    } else {
226        LLMProviderType::Remote(config.llm.clone())
227    };
228
229    let llm_client = LLMClient::new(provider_type).await?;
230    let model_name = config
231        .local
232        .as_ref()
233        .map(|_| "local-model".to_string())
234        .unwrap_or_else(|| config.llm.model_name.clone());
235
236    let state = ServerState::with_llm_client(llm_client, model_name);
237
238    let app = create_router(state);
239
240    info!("🚀 Starting Helios Engine server on http://{}", address);
241    info!("📡 OpenAI-compatible API endpoints:");
242    info!("   POST /v1/chat/completions");
243    info!("   GET  /v1/models");
244
245    let listener = tokio::net::TcpListener::bind(address)
246        .await
247        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
248
249    axum::serve(listener, app)
250        .await
251        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
252
253    Ok(())
254}
255
256/// Starts the HTTP server with an agent.
257///
258/// # Arguments
259///
260/// * `agent` - The agent to serve.
261/// * `model_name` - The model name to expose in the API.
262/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
263///
264/// # Returns
265///
266/// A `Result` that resolves when the server shuts down.
267pub async fn start_server_with_agent(
268    agent: Agent,
269    model_name: String,
270    address: &str,
271) -> Result<()> {
272    let state = ServerState::with_agent(agent, model_name);
273
274    let app = create_router(state);
275
276    info!(
277        "🚀 Starting Helios Engine server with agent on http://{}",
278        address
279    );
280    info!("📡 OpenAI-compatible API endpoints:");
281    info!("   POST /v1/chat/completions");
282    info!("   GET  /v1/models");
283
284    let listener = tokio::net::TcpListener::bind(address)
285        .await
286        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
287
288    axum::serve(listener, app)
289        .await
290        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
291
292    Ok(())
293}
294
295/// Starts the HTTP server with custom endpoints.
296///
297/// # Arguments
298///
299/// * `config` - The configuration to use for the LLM client.
300/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
301/// * `custom_endpoints` - Optional custom endpoints configuration.
302///
303/// # Returns
304///
305/// A `Result` that resolves when the server shuts down.
306pub async fn start_server_with_custom_endpoints(
307    config: Config,
308    address: &str,
309    custom_endpoints: Option<CustomEndpointsConfig>,
310) -> Result<()> {
311    let provider_type = if let Some(local_config) = config.local.clone() {
312        LLMProviderType::Local(local_config)
313    } else {
314        LLMProviderType::Remote(config.llm.clone())
315    };
316
317    let llm_client = LLMClient::new(provider_type).await?;
318    let model_name = config
319        .local
320        .as_ref()
321        .map(|_| "local-model".to_string())
322        .unwrap_or_else(|| config.llm.model_name.clone());
323
324    let state = ServerState::with_llm_client(llm_client, model_name);
325
326    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
327
328    info!("🚀 Starting Helios Engine server on http://{}", address);
329    info!("📡 OpenAI-compatible API endpoints:");
330    info!("   POST /v1/chat/completions");
331    info!("   GET  /v1/models");
332
333    if let Some(config) = &custom_endpoints {
334        info!("📡 Custom endpoints:");
335        for endpoint in &config.endpoints {
336            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
337        }
338    }
339
340    let listener = tokio::net::TcpListener::bind(address)
341        .await
342        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
343
344    axum::serve(listener, app)
345        .await
346        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
347
348    Ok(())
349}
350
351/// Starts the HTTP server with an agent and custom endpoints.
352///
353/// # Arguments
354///
355/// * `agent` - The agent to serve.
356/// * `model_name` - The model name to expose in the API.
357/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
358/// * `custom_endpoints` - Optional custom endpoints configuration.
359///
360/// # Returns
361///
362/// A `Result` that resolves when the server shuts down.
363pub async fn start_server_with_agent_and_custom_endpoints(
364    agent: Agent,
365    model_name: String,
366    address: &str,
367    custom_endpoints: Option<CustomEndpointsConfig>,
368) -> Result<()> {
369    let state = ServerState::with_agent(agent, model_name);
370
371    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
372
373    info!(
374        "🚀 Starting Helios Engine server with agent on http://{}",
375        address
376    );
377    info!("📡 OpenAI-compatible API endpoints:");
378    info!("   POST /v1/chat/completions");
379    info!("   GET  /v1/models");
380
381    if let Some(config) = &custom_endpoints {
382        info!("📡 Custom endpoints:");
383        for endpoint in &config.endpoints {
384            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
385        }
386    }
387
388    let listener = tokio::net::TcpListener::bind(address)
389        .await
390        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
391
392    axum::serve(listener, app)
393        .await
394        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
395
396    Ok(())
397}
398
399/// Loads custom endpoints configuration from a TOML file.
400///
401/// # Arguments
402///
403/// * `path` - Path to the custom endpoints configuration file.
404///
405/// # Returns
406///
407/// A `Result` containing the custom endpoints configuration.
408pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
409    let content = std::fs::read_to_string(path).map_err(|e| {
410        HeliosError::ConfigError(format!(
411            "Failed to read custom endpoints config file '{}': {}",
412            path, e
413        ))
414    })?;
415
416    toml::from_str(&content).map_err(|e| {
417        HeliosError::ConfigError(format!(
418            "Failed to parse custom endpoints config file '{}': {}",
419            path, e
420        ))
421    })
422}
423
424/// Creates the router with all endpoints.
425fn create_router(state: ServerState) -> Router {
426    Router::new()
427        .route("/v1/chat/completions", post(chat_completions))
428        .route("/v1/models", get(list_models))
429        .route("/health", get(health_check))
430        .layer(CorsLayer::permissive())
431        .layer(TraceLayer::new_for_http())
432        .with_state(state)
433}
434
435/// Creates the router with custom endpoints.
436fn create_router_with_custom_endpoints(
437    state: ServerState,
438    custom_endpoints: Option<CustomEndpointsConfig>,
439) -> Router {
440    let mut router = Router::new()
441        .route("/v1/chat/completions", post(chat_completions))
442        .route("/v1/models", get(list_models))
443        .route("/health", get(health_check));
444
445    // Add custom endpoints if provided
446    if let Some(config) = custom_endpoints {
447        for endpoint in config.endpoints {
448            let response = endpoint.response.clone();
449            let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
450
451            let handler = move || async move { (status_code, Json(response)) };
452
453            match endpoint.method.to_uppercase().as_str() {
454                "GET" => router = router.route(&endpoint.path, get(handler)),
455                "POST" => router = router.route(&endpoint.path, post(handler)),
456                "PUT" => router = router.route(&endpoint.path, put(handler)),
457                "DELETE" => router = router.route(&endpoint.path, delete(handler)),
458                "PATCH" => router = router.route(&endpoint.path, patch(handler)),
459                _ => {
460                    // Default to GET for unsupported methods
461                    router = router.route(&endpoint.path, get(handler));
462                }
463            }
464        }
465    }
466
467    router
468        .layer(CorsLayer::permissive())
469        .layer(TraceLayer::new_for_http())
470        .with_state(state)
471}
472
473/// Health check endpoint.
474async fn health_check() -> Json<serde_json::Value> {
475    Json(serde_json::json!({
476        "status": "ok",
477        "service": "helios-engine"
478    }))
479}
480
481/// Lists available models.
482async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
483    Json(ModelsResponse {
484        object: "list".to_string(),
485        data: vec![ModelInfo {
486            id: state.model_name.clone(),
487            object: "model".to_string(),
488            created: chrono::Utc::now().timestamp() as u64,
489            owned_by: "helios-engine".to_string(),
490        }],
491    })
492}
493
494/// Handles chat completion requests.
495async fn chat_completions(
496    State(state): State<ServerState>,
497    Json(request): Json<ChatCompletionRequest>,
498) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
499    // Convert OpenAI messages to ChatMessage format
500    let messages: Result<Vec<ChatMessage>> = request
501        .messages
502        .into_iter()
503        .map(|msg| {
504            // Convert OpenAI message format to internal ChatMessage format
505            // Maps standard OpenAI roles to our Role enum
506            let role = match msg.role.as_str() {
507                "system" => Role::System,       // System instructions/prompts
508                "user" => Role::User,           // User input messages
509                "assistant" => Role::Assistant, // AI assistant responses
510                "tool" => Role::Tool,           // Tool/function call results
511                _ => {
512                    // Reject invalid roles to maintain API compatibility
513                    return Err(HeliosError::ConfigError(format!(
514                        "Invalid role: {}",
515                        msg.role
516                    )));
517                }
518            };
519            Ok(ChatMessage {
520                role,
521                content: msg.content, // The actual message text
522                name: msg.name,       // Optional name for tool messages
523                tool_calls: None,     // Not used in conversion (OpenAI format differs)
524                tool_call_id: None,   // Not used in conversion (OpenAI format differs)
525            })
526        })
527        .collect();
528
529    let messages = messages.map_err(|e| {
530        error!("Failed to convert messages: {}", e);
531        StatusCode::BAD_REQUEST
532    })?;
533
534    let stream = request.stream.unwrap_or(false);
535
536    if stream {
537        // Handle streaming response
538        return Ok(stream_chat_completion(
539            state,
540            messages,
541            request.model,
542            request.temperature,
543            request.max_tokens,
544            request.stop.clone(),
545        )
546        .into_response());
547    }
548
549    // Handle non-streaming response
550    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
551    let created = chrono::Utc::now().timestamp() as u64;
552
553    // Clone messages for token estimation and LLM client usage
554    let messages_clone = messages.clone();
555
556    let response_content = if let Some(agent) = &state.agent {
557        // Use agent for response with full conversation history
558        let mut agent = agent.write().await;
559
560        match agent
561            .chat_with_history(
562                messages.clone(),
563                request.temperature,
564                request.max_tokens,
565                request.stop.clone(),
566            )
567            .await
568        {
569            Ok(content) => content,
570            Err(e) => {
571                error!("Agent error: {}", e);
572                return Err(StatusCode::INTERNAL_SERVER_ERROR);
573            }
574        }
575    } else if let Some(llm_client) = &state.llm_client {
576        // Use LLM client directly
577        match llm_client
578            .chat(
579                messages_clone,
580                None,
581                request.temperature,
582                request.max_tokens,
583                request.stop.clone(),
584            )
585            .await
586        {
587            Ok(msg) => msg.content,
588            Err(e) => {
589                error!("LLM error: {}", e);
590                return Err(StatusCode::INTERNAL_SERVER_ERROR);
591            }
592        }
593    } else {
594        return Err(StatusCode::INTERNAL_SERVER_ERROR);
595    };
596
597    // Estimate token usage (simplified - in production, use actual tokenizer)
598    let prompt_tokens = estimate_tokens(
599        &messages
600            .iter()
601            .map(|m| m.content.as_str())
602            .collect::<Vec<_>>()
603            .join(" "),
604    );
605    let completion_tokens = estimate_tokens(&response_content);
606
607    let response = ChatCompletionResponse {
608        id: completion_id,
609        object: "chat.completion".to_string(),
610        created,
611        model: request.model,
612        choices: vec![CompletionChoice {
613            index: 0,
614            message: OpenAIMessageResponse {
615                role: "assistant".to_string(),
616                content: response_content,
617            },
618            finish_reason: "stop".to_string(),
619        }],
620        usage: Usage {
621            prompt_tokens,
622            completion_tokens,
623            total_tokens: prompt_tokens + completion_tokens,
624        },
625    };
626
627    Ok(Json(response).into_response())
628}
629
630/// Streams a chat completion response.
631fn stream_chat_completion(
632    state: ServerState,
633    messages: Vec<ChatMessage>,
634    model: String,
635    temperature: Option<f32>,
636    max_tokens: Option<u32>,
637    stop: Option<Vec<String>>,
638) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
639    let (tx, rx) = tokio::sync::mpsc::channel(100);
640    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
641    let created = chrono::Utc::now().timestamp() as u64;
642
643    tokio::spawn(async move {
644        let on_chunk = |chunk: &str| {
645            let event = Event::default()
646                .json_data(serde_json::json!({
647                    "id": completion_id,
648                    "object": "chat.completion.chunk",
649                    "created": created,
650                    "model": model,
651                    "choices": [{
652                        "index": 0,
653                        "delta": {
654                            "content": chunk
655                        },
656                        "finish_reason": null
657                    }]
658                }))
659                .unwrap();
660            let _ = tx.try_send(Ok(event));
661        };
662
663        if let Some(agent) = &state.agent {
664            // Use agent for true streaming response with full conversation history
665            let mut agent = agent.write().await;
666
667            match agent
668                .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
669                .await
670            {
671                Ok(_) => {
672                    // Streaming completed successfully
673                    // The on_chunk callback has already been called for each token
674                }
675                Err(e) => {
676                    error!("Agent streaming error: {}", e);
677                }
678            }
679        } else if let Some(llm_client) = &state.llm_client {
680            // Use LLM client streaming
681            match llm_client
682                .chat_stream(
683                    messages,
684                    None,
685                    temperature,
686                    max_tokens,
687                    stop.clone(),
688                    on_chunk,
689                )
690                .await
691            {
692                Ok(_) => {}
693                Err(e) => {
694                    error!("LLM streaming error: {}", e);
695                }
696            }
697        };
698
699        // Send final event
700        let final_event = Event::default()
701            .json_data(serde_json::json!({
702                "id": completion_id,
703                "object": "chat.completion.chunk",
704                "created": created,
705                "model": model,
706                "choices": [{
707                    "index": 0,
708                    "delta": {},
709                    "finish_reason": "stop"
710                }]
711            }))
712            .unwrap();
713        let _ = tx.send(Ok(final_event)).await;
714    });
715
716    Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
717}
718
719/// Estimates the number of tokens in a text (simplified approximation).
720/// In production, use an actual tokenizer.
721pub fn estimate_tokens(text: &str) -> u32 {
722    // Rough approximation: ~4 characters per token
723    (text.len() as f32 / 4.0).ceil() as u32
724}