helios_engine/
serve.rs

1//! # Serve Module
2//!
3//! This module provides functionality to serve fully OpenAI-compatible API endpoints
4//! with real-time streaming and parameter control, allowing users to expose their
5//! agents or LLM clients via HTTP with full generation parameter support.
6//!
7//! ## Usage
8//!
9//! ### From CLI
10//! ```bash
11//! helios-engine serve --port 8000
12//! ```
13//!
14//! ### Programmatically
15//! ```no_run
16//! use helios_engine::{Config, serve};
17//!
18//! #[tokio::main]
19//! async fn main() -> helios_engine::Result<()> {
20//!     let config = Config::from_file("config.toml")?;
21//!     serve::start_server(config, "127.0.0.1:8000").await?;
22//!     Ok(())
23//! }
24//! ```
25
26use crate::agent::Agent;
27use crate::chat::{ChatMessage, Role};
28use crate::config::Config;
29use crate::error::{HeliosError, Result};
30use crate::llm::{LLMClient, LLMProviderType};
31use axum::{
32    extract::State,
33    http::StatusCode,
34    response::{
35        sse::{Event, Sse},
36        IntoResponse,
37    },
38    routing::{delete, get, patch, post, put},
39    Json, Router,
40};
41use futures::stream::Stream;
42use serde::{Deserialize, Serialize};
43use std::convert::Infallible;
44use std::sync::Arc;
45use tokio::sync::RwLock;
46use tokio_stream::wrappers::ReceiverStream;
47use tower_http::cors::CorsLayer;
48use tower_http::trace::TraceLayer;
49use tracing::{error, info};
50use uuid::Uuid;
51
52/// OpenAI-compatible chat completion request.
53#[derive(Debug, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub struct ChatCompletionRequest {
56    /// The model to use.
57    pub model: String,
58    /// The messages to send.
59    pub messages: Vec<OpenAIMessage>,
60    /// The temperature to use.
61    #[serde(default)]
62    pub temperature: Option<f32>,
63    /// The maximum number of tokens to generate.
64    #[serde(default)]
65    pub max_tokens: Option<u32>,
66    /// Whether to stream the response.
67    #[serde(default)]
68    pub stream: Option<bool>,
69    /// Stop sequences.
70    #[serde(default)]
71    pub stop: Option<Vec<String>>,
72}
73
74/// OpenAI-compatible message format.
75#[derive(Debug, Deserialize)]
76pub struct OpenAIMessage {
77    /// The role of the message sender.
78    pub role: String,
79    /// The content of the message.
80    pub content: String,
81    /// The name of the message sender (optional).
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub name: Option<String>,
84}
85
86/// OpenAI-compatible chat completion response.
87#[derive(Debug, Serialize)]
88pub struct ChatCompletionResponse {
89    /// The ID of the completion.
90    pub id: String,
91    /// The object type.
92    pub object: String,
93    /// The creation timestamp.
94    pub created: u64,
95    /// The model used.
96    pub model: String,
97    /// The choices in the response.
98    pub choices: Vec<CompletionChoice>,
99    /// Usage statistics.
100    pub usage: Usage,
101}
102
103/// A choice in a completion response.
104#[derive(Debug, Serialize)]
105pub struct CompletionChoice {
106    /// The index of the choice.
107    pub index: u32,
108    /// The message in the choice.
109    pub message: OpenAIMessageResponse,
110    /// The finish reason.
111    pub finish_reason: String,
112}
113
114/// A message in a completion response.
115#[derive(Debug, Serialize)]
116pub struct OpenAIMessageResponse {
117    /// The role of the message sender.
118    pub role: String,
119    /// The content of the message.
120    pub content: String,
121}
122
123/// Usage statistics for a completion.
124#[derive(Debug, Serialize)]
125pub struct Usage {
126    /// The number of prompt tokens.
127    pub prompt_tokens: u32,
128    /// The number of completion tokens.
129    pub completion_tokens: u32,
130    /// The total number of tokens.
131    pub total_tokens: u32,
132}
133
134/// Model information for the models endpoint.
135#[derive(Debug, Serialize)]
136pub struct ModelInfo {
137    /// The ID of the model.
138    pub id: String,
139    /// The object type.
140    pub object: String,
141    /// The creation timestamp.
142    pub created: u64,
143    /// The owner of the model.
144    pub owned_by: String,
145}
146
147/// Models list response.
148#[derive(Debug, Serialize)]
149pub struct ModelsResponse {
150    /// The object type.
151    pub object: String,
152    /// The list of models.
153    pub data: Vec<ModelInfo>,
154}
155
156/// Custom endpoint configuration.
157#[derive(Debug, Clone, Deserialize)]
158pub struct CustomEndpoint {
159    /// The HTTP method (GET, POST, PUT, DELETE, PATCH).
160    pub method: String,
161    /// The endpoint path.
162    pub path: String,
163    /// The response body as JSON.
164    pub response: serde_json::Value,
165    /// Optional status code (defaults to 200).
166    #[serde(default = "default_status_code")]
167    pub status_code: u16,
168}
169
170fn default_status_code() -> u16 {
171    200
172}
173
174/// Custom endpoints configuration.
175#[derive(Debug, Clone, Deserialize)]
176pub struct CustomEndpointsConfig {
177    /// List of custom endpoints.
178    pub endpoints: Vec<CustomEndpoint>,
179}
180
181/// Server state containing the LLM client and agent (if any).
182#[derive(Clone)]
183pub struct ServerState {
184    /// The LLM client for direct LLM calls.
185    pub llm_client: Option<Arc<LLMClient>>,
186    /// The agent (if serving an agent).
187    pub agent: Option<Arc<RwLock<Agent>>>,
188    /// The model name being served.
189    pub model_name: String,
190}
191
192impl ServerState {
193    /// Creates a new server state with an LLM client.
194    pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
195        Self {
196            llm_client: Some(Arc::new(llm_client)),
197            agent: None,
198            model_name,
199        }
200    }
201
202    /// Creates a new server state with an agent.
203    pub fn with_agent(agent: Agent, model_name: String) -> Self {
204        Self {
205            llm_client: None,
206            agent: Some(Arc::new(RwLock::new(agent))),
207            model_name,
208        }
209    }
210}
211
212/// Starts the HTTP server with the given configuration.
213///
214/// # Arguments
215///
216/// * `config` - The configuration to use for the LLM client.
217/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
218///
219/// # Returns
220///
221/// A `Result` that resolves when the server shuts down.
222pub async fn start_server(config: Config, address: &str) -> Result<()> {
223    #[cfg(feature = "local")]
224    let provider_type = if let Some(local_config) = config.local.clone() {
225        LLMProviderType::Local(local_config)
226    } else {
227        LLMProviderType::Remote(config.llm.clone())
228    };
229
230    #[cfg(not(feature = "local"))]
231    let provider_type = LLMProviderType::Remote(config.llm.clone());
232
233    let llm_client = LLMClient::new(provider_type).await?;
234
235    #[cfg(feature = "local")]
236    let model_name = config
237        .local
238        .as_ref()
239        .map(|_| "local-model".to_string())
240        .unwrap_or_else(|| config.llm.model_name.clone());
241
242    #[cfg(not(feature = "local"))]
243    let model_name = config.llm.model_name.clone();
244
245    let state = ServerState::with_llm_client(llm_client, model_name);
246
247    let app = create_router(state);
248
249    info!("🚀 Starting Helios Engine server on http://{}", address);
250    info!("📡 OpenAI-compatible API endpoints:");
251    info!("   POST /v1/chat/completions");
252    info!("   GET  /v1/models");
253
254    let listener = tokio::net::TcpListener::bind(address)
255        .await
256        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
257
258    axum::serve(listener, app)
259        .await
260        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
261
262    Ok(())
263}
264
265/// Starts the HTTP server with an agent.
266///
267/// # Arguments
268///
269/// * `agent` - The agent to serve.
270/// * `model_name` - The model name to expose in the API.
271/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
272///
273/// # Returns
274///
275/// A `Result` that resolves when the server shuts down.
276pub async fn start_server_with_agent(
277    agent: Agent,
278    model_name: String,
279    address: &str,
280) -> Result<()> {
281    let state = ServerState::with_agent(agent, model_name);
282
283    let app = create_router(state);
284
285    info!(
286        "🚀 Starting Helios Engine server with agent on http://{}",
287        address
288    );
289    info!("📡 OpenAI-compatible API endpoints:");
290    info!("   POST /v1/chat/completions");
291    info!("   GET  /v1/models");
292
293    let listener = tokio::net::TcpListener::bind(address)
294        .await
295        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
296
297    axum::serve(listener, app)
298        .await
299        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
300
301    Ok(())
302}
303
304/// Starts the HTTP server with custom endpoints.
305///
306/// # Arguments
307///
308/// * `config` - The configuration to use for the LLM client.
309/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
310/// * `custom_endpoints` - Optional custom endpoints configuration.
311///
312/// # Returns
313///
314/// A `Result` that resolves when the server shuts down.
315pub async fn start_server_with_custom_endpoints(
316    config: Config,
317    address: &str,
318    custom_endpoints: Option<CustomEndpointsConfig>,
319) -> Result<()> {
320    #[cfg(feature = "local")]
321    let provider_type = if let Some(local_config) = config.local.clone() {
322        LLMProviderType::Local(local_config)
323    } else {
324        LLMProviderType::Remote(config.llm.clone())
325    };
326
327    #[cfg(not(feature = "local"))]
328    let provider_type = LLMProviderType::Remote(config.llm.clone());
329
330    let llm_client = LLMClient::new(provider_type).await?;
331
332    #[cfg(feature = "local")]
333    let model_name = config
334        .local
335        .as_ref()
336        .map(|_| "local-model".to_string())
337        .unwrap_or_else(|| config.llm.model_name.clone());
338
339    #[cfg(not(feature = "local"))]
340    let model_name = config.llm.model_name.clone();
341
342    let state = ServerState::with_llm_client(llm_client, model_name);
343
344    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
345
346    info!("🚀 Starting Helios Engine server on http://{}", address);
347    info!("📡 OpenAI-compatible API endpoints:");
348    info!("   POST /v1/chat/completions");
349    info!("   GET  /v1/models");
350
351    if let Some(config) = &custom_endpoints {
352        info!("📡 Custom endpoints:");
353        for endpoint in &config.endpoints {
354            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
355        }
356    }
357
358    let listener = tokio::net::TcpListener::bind(address)
359        .await
360        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
361
362    axum::serve(listener, app)
363        .await
364        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
365
366    Ok(())
367}
368
369/// Starts the HTTP server with an agent and custom endpoints.
370///
371/// # Arguments
372///
373/// * `agent` - The agent to serve.
374/// * `model_name` - The model name to expose in the API.
375/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
376/// * `custom_endpoints` - Optional custom endpoints configuration.
377///
378/// # Returns
379///
380/// A `Result` that resolves when the server shuts down.
381pub async fn start_server_with_agent_and_custom_endpoints(
382    agent: Agent,
383    model_name: String,
384    address: &str,
385    custom_endpoints: Option<CustomEndpointsConfig>,
386) -> Result<()> {
387    let state = ServerState::with_agent(agent, model_name);
388
389    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
390
391    info!(
392        "🚀 Starting Helios Engine server with agent on http://{}",
393        address
394    );
395    info!("📡 OpenAI-compatible API endpoints:");
396    info!("   POST /v1/chat/completions");
397    info!("   GET  /v1/models");
398
399    if let Some(config) = &custom_endpoints {
400        info!("📡 Custom endpoints:");
401        for endpoint in &config.endpoints {
402            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
403        }
404    }
405
406    let listener = tokio::net::TcpListener::bind(address)
407        .await
408        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
409
410    axum::serve(listener, app)
411        .await
412        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
413
414    Ok(())
415}
416
417/// Loads custom endpoints configuration from a TOML file.
418///
419/// # Arguments
420///
421/// * `path` - Path to the custom endpoints configuration file.
422///
423/// # Returns
424///
425/// A `Result` containing the custom endpoints configuration.
426pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
427    let content = std::fs::read_to_string(path).map_err(|e| {
428        HeliosError::ConfigError(format!(
429            "Failed to read custom endpoints config file '{}': {}",
430            path, e
431        ))
432    })?;
433
434    toml::from_str(&content).map_err(|e| {
435        HeliosError::ConfigError(format!(
436            "Failed to parse custom endpoints config file '{}': {}",
437            path, e
438        ))
439    })
440}
441
442/// Creates the router with all endpoints.
443fn create_router(state: ServerState) -> Router {
444    Router::new()
445        .route("/v1/chat/completions", post(chat_completions))
446        .route("/v1/models", get(list_models))
447        .route("/health", get(health_check))
448        .layer(CorsLayer::permissive())
449        .layer(TraceLayer::new_for_http())
450        .with_state(state)
451}
452
453/// Creates the router with custom endpoints.
454fn create_router_with_custom_endpoints(
455    state: ServerState,
456    custom_endpoints: Option<CustomEndpointsConfig>,
457) -> Router {
458    let mut router = Router::new()
459        .route("/v1/chat/completions", post(chat_completions))
460        .route("/v1/models", get(list_models))
461        .route("/health", get(health_check));
462
463    // Add custom endpoints if provided
464    if let Some(config) = custom_endpoints {
465        for endpoint in config.endpoints {
466            let response = endpoint.response.clone();
467            let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
468
469            let handler = move || async move { (status_code, Json(response)) };
470
471            match endpoint.method.to_uppercase().as_str() {
472                "GET" => router = router.route(&endpoint.path, get(handler)),
473                "POST" => router = router.route(&endpoint.path, post(handler)),
474                "PUT" => router = router.route(&endpoint.path, put(handler)),
475                "DELETE" => router = router.route(&endpoint.path, delete(handler)),
476                "PATCH" => router = router.route(&endpoint.path, patch(handler)),
477                _ => {
478                    // Default to GET for unsupported methods
479                    router = router.route(&endpoint.path, get(handler));
480                }
481            }
482        }
483    }
484
485    router
486        .layer(CorsLayer::permissive())
487        .layer(TraceLayer::new_for_http())
488        .with_state(state)
489}
490
491/// Health check endpoint.
492async fn health_check() -> Json<serde_json::Value> {
493    Json(serde_json::json!({
494        "status": "ok",
495        "service": "helios-engine"
496    }))
497}
498
499/// Lists available models.
500async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
501    Json(ModelsResponse {
502        object: "list".to_string(),
503        data: vec![ModelInfo {
504            id: state.model_name.clone(),
505            object: "model".to_string(),
506            created: chrono::Utc::now().timestamp() as u64,
507            owned_by: "helios-engine".to_string(),
508        }],
509    })
510}
511
512/// Handles chat completion requests.
513async fn chat_completions(
514    State(state): State<ServerState>,
515    Json(request): Json<ChatCompletionRequest>,
516) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
517    // Convert OpenAI messages to ChatMessage format
518    let messages: Result<Vec<ChatMessage>> = request
519        .messages
520        .into_iter()
521        .map(|msg| {
522            // Convert OpenAI message format to internal ChatMessage format
523            // Maps standard OpenAI roles to our Role enum
524            let role = match msg.role.as_str() {
525                "system" => Role::System,       // System instructions/prompts
526                "user" => Role::User,           // User input messages
527                "assistant" => Role::Assistant, // AI assistant responses
528                "tool" => Role::Tool,           // Tool/function call results
529                _ => {
530                    // Reject invalid roles to maintain API compatibility
531                    return Err(HeliosError::ConfigError(format!(
532                        "Invalid role: {}",
533                        msg.role
534                    )));
535                }
536            };
537            Ok(ChatMessage {
538                role,
539                content: msg.content, // The actual message text
540                name: msg.name,       // Optional name for tool messages
541                tool_calls: None,     // Not used in conversion (OpenAI format differs)
542                tool_call_id: None,   // Not used in conversion (OpenAI format differs)
543            })
544        })
545        .collect();
546
547    let messages = messages.map_err(|e| {
548        error!("Failed to convert messages: {}", e);
549        StatusCode::BAD_REQUEST
550    })?;
551
552    let stream = request.stream.unwrap_or(false);
553
554    if stream {
555        // Handle streaming response
556        return Ok(stream_chat_completion(
557            state,
558            messages,
559            request.model,
560            request.temperature,
561            request.max_tokens,
562            request.stop.clone(),
563        )
564        .into_response());
565    }
566
567    // Handle non-streaming response
568    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
569    let created = chrono::Utc::now().timestamp() as u64;
570
571    // Clone messages for token estimation and LLM client usage
572    let messages_clone = messages.clone();
573
574    let response_content = if let Some(agent) = &state.agent {
575        // Use agent for response with full conversation history
576        let mut agent = agent.write().await;
577
578        match agent
579            .chat_with_history(
580                messages.clone(),
581                request.temperature,
582                request.max_tokens,
583                request.stop.clone(),
584            )
585            .await
586        {
587            Ok(content) => content,
588            Err(e) => {
589                error!("Agent error: {}", e);
590                return Err(StatusCode::INTERNAL_SERVER_ERROR);
591            }
592        }
593    } else if let Some(llm_client) = &state.llm_client {
594        // Use LLM client directly
595        match llm_client
596            .chat(
597                messages_clone,
598                None,
599                request.temperature,
600                request.max_tokens,
601                request.stop.clone(),
602            )
603            .await
604        {
605            Ok(msg) => msg.content,
606            Err(e) => {
607                error!("LLM error: {}", e);
608                return Err(StatusCode::INTERNAL_SERVER_ERROR);
609            }
610        }
611    } else {
612        return Err(StatusCode::INTERNAL_SERVER_ERROR);
613    };
614
615    // Estimate token usage (simplified - in production, use actual tokenizer)
616    let prompt_tokens = estimate_tokens(
617        &messages
618            .iter()
619            .map(|m| m.content.as_str())
620            .collect::<Vec<_>>()
621            .join(" "),
622    );
623    let completion_tokens = estimate_tokens(&response_content);
624
625    let response = ChatCompletionResponse {
626        id: completion_id,
627        object: "chat.completion".to_string(),
628        created,
629        model: request.model,
630        choices: vec![CompletionChoice {
631            index: 0,
632            message: OpenAIMessageResponse {
633                role: "assistant".to_string(),
634                content: response_content,
635            },
636            finish_reason: "stop".to_string(),
637        }],
638        usage: Usage {
639            prompt_tokens,
640            completion_tokens,
641            total_tokens: prompt_tokens + completion_tokens,
642        },
643    };
644
645    Ok(Json(response).into_response())
646}
647
648/// Streams a chat completion response.
649fn stream_chat_completion(
650    state: ServerState,
651    messages: Vec<ChatMessage>,
652    model: String,
653    temperature: Option<f32>,
654    max_tokens: Option<u32>,
655    stop: Option<Vec<String>>,
656) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
657    let (tx, rx) = tokio::sync::mpsc::channel(100);
658    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
659    let created = chrono::Utc::now().timestamp() as u64;
660
661    tokio::spawn(async move {
662        let on_chunk = |chunk: &str| {
663            let event = Event::default()
664                .json_data(serde_json::json!({
665                    "id": completion_id,
666                    "object": "chat.completion.chunk",
667                    "created": created,
668                    "model": model,
669                    "choices": [{
670                        "index": 0,
671                        "delta": {
672                            "content": chunk
673                        },
674                        "finish_reason": null
675                    }]
676                }))
677                .unwrap();
678            let _ = tx.try_send(Ok(event));
679        };
680
681        if let Some(agent) = &state.agent {
682            // Use agent for true streaming response with full conversation history
683            let mut agent = agent.write().await;
684
685            match agent
686                .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
687                .await
688            {
689                Ok(_) => {
690                    // Streaming completed successfully
691                    // The on_chunk callback has already been called for each token
692                }
693                Err(e) => {
694                    error!("Agent streaming error: {}", e);
695                }
696            }
697        } else if let Some(llm_client) = &state.llm_client {
698            // Use LLM client streaming
699            match llm_client
700                .chat_stream(
701                    messages,
702                    None,
703                    temperature,
704                    max_tokens,
705                    stop.clone(),
706                    on_chunk,
707                )
708                .await
709            {
710                Ok(_) => {}
711                Err(e) => {
712                    error!("LLM streaming error: {}", e);
713                }
714            }
715        };
716
717        // Send final event
718        let final_event = Event::default()
719            .json_data(serde_json::json!({
720                "id": completion_id,
721                "object": "chat.completion.chunk",
722                "created": created,
723                "model": model,
724                "choices": [{
725                    "index": 0,
726                    "delta": {},
727                    "finish_reason": "stop"
728                }]
729            }))
730            .unwrap();
731        let _ = tx.send(Ok(final_event)).await;
732    });
733
734    Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
735}
736
737/// Estimates the number of tokens in a text (simplified approximation).
738/// In production, use an actual tokenizer.
739pub fn estimate_tokens(text: &str) -> u32 {
740    // Rough approximation: ~4 characters per token
741    (text.len() as f32 / 4.0).ceil() as u32
742}