helios_engine/
serve.rs

1//! # Serve Module
2//!
3//! This module provides functionality to serve OpenAI-compatible API endpoints,
4//! allowing users to expose their agents or LLM clients via HTTP.
5//!
6//! ## Usage
7//!
8//! ### From CLI
9//! ```bash
10//! helios-engine serve --port 8000
11//! ```
12//!
13//! ### Programmatically
14//! ```no_run
15//! use helios_engine::{Config, serve};
16//!
17//! #[tokio::main]
18//! async fn main() -> helios_engine::Result<()> {
19//!     let config = Config::from_file("config.toml")?;
20//!     serve::start_server(config, "127.0.0.1:8000").await?;
21//!     Ok(())
22//! }
23//! ```
24
25use crate::agent::Agent;
26use crate::chat::{ChatMessage, Role};
27use crate::config::Config;
28use crate::error::{HeliosError, Result};
29use crate::llm::{LLMClient, LLMProviderType};
30use axum::{
31    extract::State,
32    http::StatusCode,
33    response::{
34        sse::{Event, Sse},
35        IntoResponse,
36    },
37    routing::{delete, get, patch, post, put},
38    Json, Router,
39};
40use futures::stream::Stream;
41use serde::{Deserialize, Serialize};
42use std::convert::Infallible;
43use std::sync::Arc;
44use tokio::sync::RwLock;
45use tokio_stream::wrappers::ReceiverStream;
46use tower_http::cors::CorsLayer;
47use tower_http::trace::TraceLayer;
48use tracing::{error, info};
49use uuid::Uuid;
50
51/// OpenAI-compatible chat completion request.
52#[derive(Debug, Deserialize)]
53#[serde(rename_all = "snake_case")]
54pub struct ChatCompletionRequest {
55    /// The model to use.
56    pub model: String,
57    /// The messages to send.
58    pub messages: Vec<OpenAIMessage>,
59    /// The temperature to use.
60    #[serde(default)]
61    pub temperature: Option<f32>,
62    /// The maximum number of tokens to generate.
63    #[serde(default)]
64    pub max_tokens: Option<u32>,
65    /// Whether to stream the response.
66    #[serde(default)]
67    pub stream: Option<bool>,
68    /// Stop sequences.
69    #[serde(default)]
70    pub stop: Option<Vec<String>>,
71}
72
73/// OpenAI-compatible message format.
74#[derive(Debug, Deserialize)]
75pub struct OpenAIMessage {
76    /// The role of the message sender.
77    pub role: String,
78    /// The content of the message.
79    pub content: String,
80    /// The name of the message sender (optional).
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub name: Option<String>,
83}
84
85/// OpenAI-compatible chat completion response.
86#[derive(Debug, Serialize)]
87pub struct ChatCompletionResponse {
88    /// The ID of the completion.
89    pub id: String,
90    /// The object type.
91    pub object: String,
92    /// The creation timestamp.
93    pub created: u64,
94    /// The model used.
95    pub model: String,
96    /// The choices in the response.
97    pub choices: Vec<CompletionChoice>,
98    /// Usage statistics.
99    pub usage: Usage,
100}
101
102/// A choice in a completion response.
103#[derive(Debug, Serialize)]
104pub struct CompletionChoice {
105    /// The index of the choice.
106    pub index: u32,
107    /// The message in the choice.
108    pub message: OpenAIMessageResponse,
109    /// The finish reason.
110    pub finish_reason: String,
111}
112
113/// A message in a completion response.
114#[derive(Debug, Serialize)]
115pub struct OpenAIMessageResponse {
116    /// The role of the message sender.
117    pub role: String,
118    /// The content of the message.
119    pub content: String,
120}
121
122/// Usage statistics for a completion.
123#[derive(Debug, Serialize)]
124pub struct Usage {
125    /// The number of prompt tokens.
126    pub prompt_tokens: u32,
127    /// The number of completion tokens.
128    pub completion_tokens: u32,
129    /// The total number of tokens.
130    pub total_tokens: u32,
131}
132
133/// Model information for the models endpoint.
134#[derive(Debug, Serialize)]
135pub struct ModelInfo {
136    /// The ID of the model.
137    pub id: String,
138    /// The object type.
139    pub object: String,
140    /// The creation timestamp.
141    pub created: u64,
142    /// The owner of the model.
143    pub owned_by: String,
144}
145
146/// Models list response.
147#[derive(Debug, Serialize)]
148pub struct ModelsResponse {
149    /// The object type.
150    pub object: String,
151    /// The list of models.
152    pub data: Vec<ModelInfo>,
153}
154
155/// Custom endpoint configuration.
156#[derive(Debug, Clone, Deserialize)]
157pub struct CustomEndpoint {
158    /// The HTTP method (GET, POST, PUT, DELETE, PATCH).
159    pub method: String,
160    /// The endpoint path.
161    pub path: String,
162    /// The response body as JSON.
163    pub response: serde_json::Value,
164    /// Optional status code (defaults to 200).
165    #[serde(default = "default_status_code")]
166    pub status_code: u16,
167}
168
169fn default_status_code() -> u16 {
170    200
171}
172
173/// Custom endpoints configuration.
174#[derive(Debug, Clone, Deserialize)]
175pub struct CustomEndpointsConfig {
176    /// List of custom endpoints.
177    pub endpoints: Vec<CustomEndpoint>,
178}
179
180/// Server state containing the LLM client and agent (if any).
181#[derive(Clone)]
182pub struct ServerState {
183    /// The LLM client for direct LLM calls.
184    pub llm_client: Option<Arc<LLMClient>>,
185    /// The agent (if serving an agent).
186    pub agent: Option<Arc<RwLock<Agent>>>,
187    /// The model name being served.
188    pub model_name: String,
189}
190
191impl ServerState {
192    /// Creates a new server state with an LLM client.
193    pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
194        Self {
195            llm_client: Some(Arc::new(llm_client)),
196            agent: None,
197            model_name,
198        }
199    }
200
201    /// Creates a new server state with an agent.
202    pub fn with_agent(agent: Agent, model_name: String) -> Self {
203        Self {
204            llm_client: None,
205            agent: Some(Arc::new(RwLock::new(agent))),
206            model_name,
207        }
208    }
209}
210
211/// Starts the HTTP server with the given configuration.
212///
213/// # Arguments
214///
215/// * `config` - The configuration to use for the LLM client.
216/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
217///
218/// # Returns
219///
220/// A `Result` that resolves when the server shuts down.
221pub async fn start_server(config: Config, address: &str) -> Result<()> {
222    let provider_type = if let Some(local_config) = config.local.clone() {
223        LLMProviderType::Local(local_config)
224    } else {
225        LLMProviderType::Remote(config.llm.clone())
226    };
227
228    let llm_client = LLMClient::new(provider_type).await?;
229    let model_name = config
230        .local
231        .as_ref()
232        .map(|_| "local-model".to_string())
233        .unwrap_or_else(|| config.llm.model_name.clone());
234
235    let state = ServerState::with_llm_client(llm_client, model_name);
236
237    let app = create_router(state);
238
239    info!("🚀 Starting Helios Engine server on http://{}", address);
240    info!("📡 OpenAI-compatible API endpoints:");
241    info!("   POST /v1/chat/completions");
242    info!("   GET  /v1/models");
243
244    let listener = tokio::net::TcpListener::bind(address)
245        .await
246        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
247
248    axum::serve(listener, app)
249        .await
250        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
251
252    Ok(())
253}
254
255/// Starts the HTTP server with an agent.
256///
257/// # Arguments
258///
259/// * `agent` - The agent to serve.
260/// * `model_name` - The model name to expose in the API.
261/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
262///
263/// # Returns
264///
265/// A `Result` that resolves when the server shuts down.
266pub async fn start_server_with_agent(
267    agent: Agent,
268    model_name: String,
269    address: &str,
270) -> Result<()> {
271    let state = ServerState::with_agent(agent, model_name);
272
273    let app = create_router(state);
274
275    info!(
276        "🚀 Starting Helios Engine server with agent on http://{}",
277        address
278    );
279    info!("📡 OpenAI-compatible API endpoints:");
280    info!("   POST /v1/chat/completions");
281    info!("   GET  /v1/models");
282
283    let listener = tokio::net::TcpListener::bind(address)
284        .await
285        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
286
287    axum::serve(listener, app)
288        .await
289        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
290
291    Ok(())
292}
293
294/// Starts the HTTP server with custom endpoints.
295///
296/// # Arguments
297///
298/// * `config` - The configuration to use for the LLM client.
299/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
300/// * `custom_endpoints` - Optional custom endpoints configuration.
301///
302/// # Returns
303///
304/// A `Result` that resolves when the server shuts down.
305pub async fn start_server_with_custom_endpoints(
306    config: Config,
307    address: &str,
308    custom_endpoints: Option<CustomEndpointsConfig>,
309) -> Result<()> {
310    let provider_type = if let Some(local_config) = config.local.clone() {
311        LLMProviderType::Local(local_config)
312    } else {
313        LLMProviderType::Remote(config.llm.clone())
314    };
315
316    let llm_client = LLMClient::new(provider_type).await?;
317    let model_name = config
318        .local
319        .as_ref()
320        .map(|_| "local-model".to_string())
321        .unwrap_or_else(|| config.llm.model_name.clone());
322
323    let state = ServerState::with_llm_client(llm_client, model_name);
324
325    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
326
327    info!("🚀 Starting Helios Engine server on http://{}", address);
328    info!("📡 OpenAI-compatible API endpoints:");
329    info!("   POST /v1/chat/completions");
330    info!("   GET  /v1/models");
331
332    if let Some(config) = &custom_endpoints {
333        info!("📡 Custom endpoints:");
334        for endpoint in &config.endpoints {
335            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
336        }
337    }
338
339    let listener = tokio::net::TcpListener::bind(address)
340        .await
341        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
342
343    axum::serve(listener, app)
344        .await
345        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
346
347    Ok(())
348}
349
350/// Starts the HTTP server with an agent and custom endpoints.
351///
352/// # Arguments
353///
354/// * `agent` - The agent to serve.
355/// * `model_name` - The model name to expose in the API.
356/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
357/// * `custom_endpoints` - Optional custom endpoints configuration.
358///
359/// # Returns
360///
361/// A `Result` that resolves when the server shuts down.
362pub async fn start_server_with_agent_and_custom_endpoints(
363    agent: Agent,
364    model_name: String,
365    address: &str,
366    custom_endpoints: Option<CustomEndpointsConfig>,
367) -> Result<()> {
368    let state = ServerState::with_agent(agent, model_name);
369
370    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
371
372    info!(
373        "🚀 Starting Helios Engine server with agent on http://{}",
374        address
375    );
376    info!("📡 OpenAI-compatible API endpoints:");
377    info!("   POST /v1/chat/completions");
378    info!("   GET  /v1/models");
379
380    if let Some(config) = &custom_endpoints {
381        info!("📡 Custom endpoints:");
382        for endpoint in &config.endpoints {
383            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
384        }
385    }
386
387    let listener = tokio::net::TcpListener::bind(address)
388        .await
389        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
390
391    axum::serve(listener, app)
392        .await
393        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
394
395    Ok(())
396}
397
398/// Loads custom endpoints configuration from a TOML file.
399///
400/// # Arguments
401///
402/// * `path` - Path to the custom endpoints configuration file.
403///
404/// # Returns
405///
406/// A `Result` containing the custom endpoints configuration.
407pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
408    let content = std::fs::read_to_string(path).map_err(|e| {
409        HeliosError::ConfigError(format!(
410            "Failed to read custom endpoints config file '{}': {}",
411            path, e
412        ))
413    })?;
414
415    toml::from_str(&content).map_err(|e| {
416        HeliosError::ConfigError(format!(
417            "Failed to parse custom endpoints config file '{}': {}",
418            path, e
419        ))
420    })
421}
422
423/// Creates the router with all endpoints.
424fn create_router(state: ServerState) -> Router {
425    Router::new()
426        .route("/v1/chat/completions", post(chat_completions))
427        .route("/v1/models", get(list_models))
428        .route("/health", get(health_check))
429        .layer(CorsLayer::permissive())
430        .layer(TraceLayer::new_for_http())
431        .with_state(state)
432}
433
434/// Creates the router with custom endpoints.
435fn create_router_with_custom_endpoints(
436    state: ServerState,
437    custom_endpoints: Option<CustomEndpointsConfig>,
438) -> Router {
439    let mut router = Router::new()
440        .route("/v1/chat/completions", post(chat_completions))
441        .route("/v1/models", get(list_models))
442        .route("/health", get(health_check));
443
444    // Add custom endpoints if provided
445    if let Some(config) = custom_endpoints {
446        for endpoint in config.endpoints {
447            let response = endpoint.response.clone();
448            let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
449
450            let handler = move || async move { (status_code, Json(response)) };
451
452            match endpoint.method.to_uppercase().as_str() {
453                "GET" => router = router.route(&endpoint.path, get(handler)),
454                "POST" => router = router.route(&endpoint.path, post(handler)),
455                "PUT" => router = router.route(&endpoint.path, put(handler)),
456                "DELETE" => router = router.route(&endpoint.path, delete(handler)),
457                "PATCH" => router = router.route(&endpoint.path, patch(handler)),
458                _ => {
459                    // Default to GET for unsupported methods
460                    router = router.route(&endpoint.path, get(handler));
461                }
462            }
463        }
464    }
465
466    router
467        .layer(CorsLayer::permissive())
468        .layer(TraceLayer::new_for_http())
469        .with_state(state)
470}
471
472/// Health check endpoint.
473async fn health_check() -> Json<serde_json::Value> {
474    Json(serde_json::json!({
475        "status": "ok",
476        "service": "helios-engine"
477    }))
478}
479
480/// Lists available models.
481async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
482    Json(ModelsResponse {
483        object: "list".to_string(),
484        data: vec![ModelInfo {
485            id: state.model_name.clone(),
486            object: "model".to_string(),
487            created: chrono::Utc::now().timestamp() as u64,
488            owned_by: "helios-engine".to_string(),
489        }],
490    })
491}
492
493/// Handles chat completion requests.
494async fn chat_completions(
495    State(state): State<ServerState>,
496    Json(request): Json<ChatCompletionRequest>,
497) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
498    // Convert OpenAI messages to ChatMessage format
499    let messages: Result<Vec<ChatMessage>> = request
500        .messages
501        .into_iter()
502        .map(|msg| {
503            // Convert OpenAI message format to internal ChatMessage format
504            // Maps standard OpenAI roles to our Role enum
505            let role = match msg.role.as_str() {
506                "system" => Role::System,       // System instructions/prompts
507                "user" => Role::User,           // User input messages
508                "assistant" => Role::Assistant, // AI assistant responses
509                "tool" => Role::Tool,           // Tool/function call results
510                _ => {
511                    // Reject invalid roles to maintain API compatibility
512                    return Err(HeliosError::ConfigError(format!(
513                        "Invalid role: {}",
514                        msg.role
515                    )));
516                }
517            };
518            Ok(ChatMessage {
519                role,
520                content: msg.content, // The actual message text
521                name: msg.name,       // Optional name for tool messages
522                tool_calls: None,     // Not used in conversion (OpenAI format differs)
523                tool_call_id: None,   // Not used in conversion (OpenAI format differs)
524            })
525        })
526        .collect();
527
528    let messages = messages.map_err(|e| {
529        error!("Failed to convert messages: {}", e);
530        StatusCode::BAD_REQUEST
531    })?;
532
533    let stream = request.stream.unwrap_or(false);
534
535    if stream {
536        // Handle streaming response
537        return Ok(stream_chat_completion(
538            state,
539            messages,
540            request.model,
541            request.temperature,
542            request.max_tokens,
543            request.stop.clone(),
544        )
545        .into_response());
546    }
547
548    // Handle non-streaming response
549    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
550    let created = chrono::Utc::now().timestamp() as u64;
551
552    // Clone messages for token estimation and LLM client usage
553    let messages_clone = messages.clone();
554
555    let response_content = if let Some(agent) = &state.agent {
556        // Use agent for response with full conversation history
557        let mut agent = agent.write().await;
558
559        match agent
560            .chat_with_history(
561                messages.clone(),
562                request.temperature,
563                request.max_tokens,
564                request.stop.clone(),
565            )
566            .await
567        {
568            Ok(content) => content,
569            Err(e) => {
570                error!("Agent error: {}", e);
571                return Err(StatusCode::INTERNAL_SERVER_ERROR);
572            }
573        }
574    } else if let Some(llm_client) = &state.llm_client {
575        // Use LLM client directly
576        match llm_client
577            .chat(
578                messages_clone,
579                None,
580                request.temperature,
581                request.max_tokens,
582                request.stop.clone(),
583            )
584            .await
585        {
586            Ok(msg) => msg.content,
587            Err(e) => {
588                error!("LLM error: {}", e);
589                return Err(StatusCode::INTERNAL_SERVER_ERROR);
590            }
591        }
592    } else {
593        return Err(StatusCode::INTERNAL_SERVER_ERROR);
594    };
595
596    // Estimate token usage (simplified - in production, use actual tokenizer)
597    let prompt_tokens = estimate_tokens(
598        &messages
599            .iter()
600            .map(|m| m.content.as_str())
601            .collect::<Vec<_>>()
602            .join(" "),
603    );
604    let completion_tokens = estimate_tokens(&response_content);
605
606    let response = ChatCompletionResponse {
607        id: completion_id,
608        object: "chat.completion".to_string(),
609        created,
610        model: request.model,
611        choices: vec![CompletionChoice {
612            index: 0,
613            message: OpenAIMessageResponse {
614                role: "assistant".to_string(),
615                content: response_content,
616            },
617            finish_reason: "stop".to_string(),
618        }],
619        usage: Usage {
620            prompt_tokens,
621            completion_tokens,
622            total_tokens: prompt_tokens + completion_tokens,
623        },
624    };
625
626    Ok(Json(response).into_response())
627}
628
629/// Streams a chat completion response.
630fn stream_chat_completion(
631    state: ServerState,
632    messages: Vec<ChatMessage>,
633    model: String,
634    temperature: Option<f32>,
635    max_tokens: Option<u32>,
636    stop: Option<Vec<String>>,
637) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
638    let (tx, rx) = tokio::sync::mpsc::channel(100);
639    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
640    let created = chrono::Utc::now().timestamp() as u64;
641
642    tokio::spawn(async move {
643        let on_chunk = |chunk: &str| {
644            let event = Event::default()
645                .json_data(serde_json::json!({
646                    "id": completion_id,
647                    "object": "chat.completion.chunk",
648                    "created": created,
649                    "model": model,
650                    "choices": [{
651                        "index": 0,
652                        "delta": {
653                            "content": chunk
654                        },
655                        "finish_reason": null
656                    }]
657                }))
658                .unwrap();
659            let _ = tx.try_send(Ok(event));
660        };
661
662        if let Some(agent) = &state.agent {
663            // Use agent for true streaming response with full conversation history
664            let mut agent = agent.write().await;
665
666            match agent
667                .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
668                .await
669            {
670                Ok(_) => {
671                    // Streaming completed successfully
672                    // The on_chunk callback has already been called for each token
673                }
674                Err(e) => {
675                    error!("Agent streaming error: {}", e);
676                }
677            }
678        } else if let Some(llm_client) = &state.llm_client {
679            // Use LLM client streaming
680            match llm_client
681                .chat_stream(
682                    messages,
683                    None,
684                    temperature,
685                    max_tokens,
686                    stop.clone(),
687                    on_chunk,
688                )
689                .await
690            {
691                Ok(_) => {}
692                Err(e) => {
693                    error!("LLM streaming error: {}", e);
694                }
695            }
696        };
697
698        // Send final event
699        let final_event = Event::default()
700            .json_data(serde_json::json!({
701                "id": completion_id,
702                "object": "chat.completion.chunk",
703                "created": created,
704                "model": model,
705                "choices": [{
706                    "index": 0,
707                    "delta": {},
708                    "finish_reason": "stop"
709                }]
710            }))
711            .unwrap();
712        let _ = tx.send(Ok(final_event)).await;
713    });
714
715    Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
716}
717
718/// Estimates the number of tokens in a text (simplified approximation).
719/// In production, use an actual tokenizer.
720pub fn estimate_tokens(text: &str) -> u32 {
721    // Rough approximation: ~4 characters per token
722    (text.len() as f32 / 4.0).ceil() as u32
723}