helios_engine/
serve.rs

1//! # Serve Module
2//!
3//! This module provides functionality to serve fully OpenAI-compatible API endpoints
4//! with real-time streaming and parameter control, allowing users to expose their
5//! agents or LLM clients via HTTP with full generation parameter support.
6//!
7//! ## Usage
8//!
9//! ### From CLI
10//! ```bash
11//! helios-engine serve --port 8000
12//! ```
13//!
14//! ### Programmatically
15//! ```no_run
16//! use helios_engine::{Config, serve};
17//!
18//! #[tokio::main]
19//! async fn main() -> helios_engine::Result<()> {
20//!     let config = Config::from_file("config.toml")?;
21//!     serve::start_server(config, "127.0.0.1:8000").await?;
22//!     Ok(())
23//! }
24//! ```
25
26use crate::agent::Agent;
27use crate::chat::{ChatMessage, Role};
28use crate::config::Config;
29use crate::error::{HeliosError, Result};
30use crate::llm::{LLMClient, LLMProviderType};
31use axum::{
32    extract::State,
33    http::StatusCode,
34    response::{
35        sse::{Event, Sse},
36        IntoResponse,
37    },
38    routing::{delete, get, patch, post, put},
39    Json, Router,
40};
41use futures::stream::Stream;
42use serde::{Deserialize, Serialize};
43use std::convert::Infallible;
44use std::sync::Arc;
45use tokio::sync::RwLock;
46use tokio_stream::wrappers::ReceiverStream;
47use tower_http::cors::CorsLayer;
48use tower_http::trace::TraceLayer;
49use tracing::{error, info};
50use uuid::Uuid;
51
52/// OpenAI-compatible chat completion request.
53#[derive(Debug, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub struct ChatCompletionRequest {
56    /// The model to use.
57    pub model: String,
58    /// The messages to send.
59    pub messages: Vec<OpenAIMessage>,
60    /// The temperature to use.
61    #[serde(default)]
62    pub temperature: Option<f32>,
63    /// The maximum number of tokens to generate.
64    #[serde(default)]
65    pub max_tokens: Option<u32>,
66    /// Whether to stream the response.
67    #[serde(default)]
68    pub stream: Option<bool>,
69    /// Stop sequences.
70    #[serde(default)]
71    pub stop: Option<Vec<String>>,
72}
73
74/// OpenAI-compatible message format.
75#[derive(Debug, Deserialize)]
76pub struct OpenAIMessage {
77    /// The role of the message sender.
78    pub role: String,
79    /// The content of the message.
80    pub content: String,
81    /// The name of the message sender (optional).
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub name: Option<String>,
84}
85
86/// OpenAI-compatible chat completion response.
87#[derive(Debug, Serialize)]
88pub struct ChatCompletionResponse {
89    /// The ID of the completion.
90    pub id: String,
91    /// The object type.
92    pub object: String,
93    /// The creation timestamp.
94    pub created: u64,
95    /// The model used.
96    pub model: String,
97    /// The choices in the response.
98    pub choices: Vec<CompletionChoice>,
99    /// Usage statistics.
100    pub usage: Usage,
101}
102
103/// A choice in a completion response.
104#[derive(Debug, Serialize)]
105pub struct CompletionChoice {
106    /// The index of the choice.
107    pub index: u32,
108    /// The message in the choice.
109    pub message: OpenAIMessageResponse,
110    /// The finish reason.
111    pub finish_reason: String,
112}
113
114/// A message in a completion response.
115#[derive(Debug, Serialize)]
116pub struct OpenAIMessageResponse {
117    /// The role of the message sender.
118    pub role: String,
119    /// The content of the message.
120    pub content: String,
121}
122
123/// Usage statistics for a completion.
124#[derive(Debug, Serialize)]
125pub struct Usage {
126    /// The number of prompt tokens.
127    pub prompt_tokens: u32,
128    /// The number of completion tokens.
129    pub completion_tokens: u32,
130    /// The total number of tokens.
131    pub total_tokens: u32,
132}
133
134/// Model information for the models endpoint.
135#[derive(Debug, Serialize)]
136pub struct ModelInfo {
137    /// The ID of the model.
138    pub id: String,
139    /// The object type.
140    pub object: String,
141    /// The creation timestamp.
142    pub created: u64,
143    /// The owner of the model.
144    pub owned_by: String,
145}
146
147/// Models list response.
148#[derive(Debug, Serialize)]
149pub struct ModelsResponse {
150    /// The object type.
151    pub object: String,
152    /// The list of models.
153    pub data: Vec<ModelInfo>,
154}
155
156/// Legacy custom endpoint configuration for backward compatibility.
157/// Use the new `endpoint_builder` module for a better API.
158#[derive(Debug, Clone, Deserialize)]
159pub struct CustomEndpoint {
160    /// The HTTP method (GET, POST, PUT, DELETE, PATCH).
161    pub method: String,
162    /// The endpoint path.
163    pub path: String,
164    /// The response body as JSON.
165    pub response: serde_json::Value,
166    /// Optional status code (defaults to 200).
167    #[serde(default = "default_status_code")]
168    pub status_code: u16,
169}
170
171fn default_status_code() -> u16 {
172    200
173}
174
175/// Legacy custom endpoints configuration for backward compatibility.
176/// Use the new `Endpoints` builder for a better API.
177#[derive(Debug, Clone, Deserialize)]
178pub struct CustomEndpointsConfig {
179    /// List of custom endpoints.
180    pub endpoints: Vec<CustomEndpoint>,
181}
182
183impl CustomEndpointsConfig {
184    /// Creates a new empty custom endpoints configuration.
185    pub fn new() -> Self {
186        Self {
187            endpoints: Vec::new(),
188        }
189    }
190
191    /// Adds a custom endpoint to the configuration.
192    pub fn add_endpoint(mut self, endpoint: CustomEndpoint) -> Self {
193        self.endpoints.push(endpoint);
194        self
195    }
196}
197
198impl Default for CustomEndpointsConfig {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204/// Server state containing the LLM client and agent (if any).
205#[derive(Clone)]
206pub struct ServerState {
207    /// The LLM client for direct LLM calls.
208    pub llm_client: Option<Arc<LLMClient>>,
209    /// The agent (if serving an agent).
210    pub agent: Option<Arc<RwLock<Agent>>>,
211    /// The model name being served.
212    pub model_name: String,
213}
214
215impl ServerState {
216    /// Creates a new server state with an LLM client.
217    pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
218        Self {
219            llm_client: Some(Arc::new(llm_client)),
220            agent: None,
221            model_name,
222        }
223    }
224
225    /// Creates a new server state with an agent.
226    pub fn with_agent(agent: Agent, model_name: String) -> Self {
227        Self {
228            llm_client: None,
229            agent: Some(Arc::new(RwLock::new(agent))),
230            model_name,
231        }
232    }
233}
234
235/// Starts the HTTP server with the given configuration.
236///
237/// # Arguments
238///
239/// * `config` - The configuration to use for the LLM client.
240/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
241///
242/// # Returns
243///
244/// A `Result` that resolves when the server shuts down.
245pub async fn start_server(config: Config, address: &str) -> Result<()> {
246    #[cfg(feature = "local")]
247    let provider_type = if let Some(local_config) = config.local.clone() {
248        LLMProviderType::Local(local_config)
249    } else {
250        LLMProviderType::Remote(config.llm.clone())
251    };
252
253    #[cfg(not(feature = "local"))]
254    let provider_type = LLMProviderType::Remote(config.llm.clone());
255
256    let llm_client = LLMClient::new(provider_type).await?;
257
258    #[cfg(feature = "local")]
259    let model_name = config
260        .local
261        .as_ref()
262        .map(|_| "local-model".to_string())
263        .unwrap_or_else(|| config.llm.model_name.clone());
264
265    #[cfg(not(feature = "local"))]
266    let model_name = config.llm.model_name.clone();
267
268    let state = ServerState::with_llm_client(llm_client, model_name);
269
270    let app = create_router(state);
271
272    info!("🚀 Starting Helios Engine server on http://{}", address);
273    info!("📡 OpenAI-compatible API endpoints:");
274    info!("   POST /v1/chat/completions");
275    info!("   GET  /v1/models");
276
277    let listener = tokio::net::TcpListener::bind(address)
278        .await
279        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
280
281    axum::serve(listener, app)
282        .await
283        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
284
285    Ok(())
286}
287
288/// Starts the HTTP server with an agent.
289///
290/// # Arguments
291///
292/// * `agent` - The agent to serve.
293/// * `model_name` - The model name to expose in the API.
294/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
295///
296/// # Returns
297///
298/// A `Result` that resolves when the server shuts down.
299pub async fn start_server_with_agent(
300    agent: Agent,
301    model_name: String,
302    address: &str,
303) -> Result<()> {
304    let state = ServerState::with_agent(agent, model_name);
305
306    let app = create_router(state);
307
308    info!(
309        "🚀 Starting Helios Engine server with agent on http://{}",
310        address
311    );
312    info!("📡 OpenAI-compatible API endpoints:");
313    info!("   POST /v1/chat/completions");
314    info!("   GET  /v1/models");
315
316    let listener = tokio::net::TcpListener::bind(address)
317        .await
318        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
319
320    axum::serve(listener, app)
321        .await
322        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
323
324    Ok(())
325}
326
327/// Starts the HTTP server with custom endpoints.
328///
329/// # Arguments
330///
331/// * `config` - The configuration to use for the LLM client.
332/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
333/// * `custom_endpoints` - Optional custom endpoints configuration.
334///
335/// # Returns
336///
337/// A `Result` that resolves when the server shuts down.
338pub async fn start_server_with_custom_endpoints(
339    config: Config,
340    address: &str,
341    custom_endpoints: Option<CustomEndpointsConfig>,
342) -> Result<()> {
343    #[cfg(feature = "local")]
344    let provider_type = if let Some(local_config) = config.local.clone() {
345        LLMProviderType::Local(local_config)
346    } else {
347        LLMProviderType::Remote(config.llm.clone())
348    };
349
350    #[cfg(not(feature = "local"))]
351    let provider_type = LLMProviderType::Remote(config.llm.clone());
352
353    let llm_client = LLMClient::new(provider_type).await?;
354
355    #[cfg(feature = "local")]
356    let model_name = config
357        .local
358        .as_ref()
359        .map(|_| "local-model".to_string())
360        .unwrap_or_else(|| config.llm.model_name.clone());
361
362    #[cfg(not(feature = "local"))]
363    let model_name = config.llm.model_name.clone();
364
365    let state = ServerState::with_llm_client(llm_client, model_name);
366
367    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
368
369    info!("🚀 Starting Helios Engine server on http://{}", address);
370    info!("📡 OpenAI-compatible API endpoints:");
371    info!("   POST /v1/chat/completions");
372    info!("   GET  /v1/models");
373
374    if let Some(config) = &custom_endpoints {
375        info!("📡 Custom endpoints:");
376        for endpoint in &config.endpoints {
377            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
378        }
379    }
380
381    let listener = tokio::net::TcpListener::bind(address)
382        .await
383        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
384
385    axum::serve(listener, app)
386        .await
387        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
388
389    Ok(())
390}
391
392/// Starts the HTTP server with an agent and custom endpoints.
393///
394/// # Arguments
395///
396/// * `agent` - The agent to serve.
397/// * `model_name` - The model name to expose in the API.
398/// * `address` - The address to bind to (e.g., "127.0.0.1:8000").
399/// * `custom_endpoints` - Optional custom endpoints configuration.
400///
401/// # Returns
402///
403/// A `Result` that resolves when the server shuts down.
404pub async fn start_server_with_agent_and_custom_endpoints(
405    agent: Agent,
406    model_name: String,
407    address: &str,
408    custom_endpoints: Option<CustomEndpointsConfig>,
409) -> Result<()> {
410    let state = ServerState::with_agent(agent, model_name);
411
412    let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
413
414    info!(
415        "🚀 Starting Helios Engine server with agent on http://{}",
416        address
417    );
418    info!("📡 OpenAI-compatible API endpoints:");
419    info!("   POST /v1/chat/completions");
420    info!("   GET  /v1/models");
421
422    if let Some(config) = &custom_endpoints {
423        info!("📡 Custom endpoints:");
424        for endpoint in &config.endpoints {
425            info!("   {} {}", endpoint.method.to_uppercase(), endpoint.path);
426        }
427    }
428
429    let listener = tokio::net::TcpListener::bind(address)
430        .await
431        .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
432
433    axum::serve(listener, app)
434        .await
435        .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
436
437    Ok(())
438}
439
440/// Builder for creating a server with agent and endpoints.
441/// This provides a more ergonomic API for server configuration.
442pub struct ServerBuilder {
443    agent: Option<Agent>,
444    model_name: String,
445    address: String,
446    endpoints: Vec<crate::endpoint_builder::CustomEndpoint>,
447}
448
449impl ServerBuilder {
450    /// Creates a new server builder with an agent.
451    pub fn with_agent(agent: Agent, model_name: impl Into<String>) -> Self {
452        Self {
453            agent: Some(agent),
454            model_name: model_name.into(),
455            address: "127.0.0.1:8000".to_string(),
456            endpoints: Vec::new(),
457        }
458    }
459
460    /// Sets the server address (default: "127.0.0.1:8000").
461    pub fn address(mut self, address: impl Into<String>) -> Self {
462        self.address = address.into();
463        self
464    }
465
466    /// Adds a custom endpoint to the server.
467    pub fn endpoint(mut self, endpoint: crate::endpoint_builder::CustomEndpoint) -> Self {
468        self.endpoints.push(endpoint);
469        self
470    }
471
472    /// Adds multiple custom endpoints to the server.
473    /// This is the preferred way to add multiple endpoints at once.
474    ///
475    /// # Example
476    /// ```no_run
477    /// # use helios_engine::{Agent, ServerBuilder, get};
478    /// # async fn example(agent: Agent) -> helios_engine::Result<()> {
479    /// let endpoints = vec![
480    ///     get("/api/v1", serde_json::json!({"version": "1.0"})),
481    ///     get("/api/status", serde_json::json!({"status": "ok"})),
482    /// ];
483    ///
484    /// ServerBuilder::with_agent(agent, "model")
485    ///     .endpoints(endpoints)
486    ///     .serve()
487    ///     .await?;
488    /// # Ok(())
489    /// # }
490    /// ```
491    pub fn endpoints(mut self, endpoints: Vec<crate::endpoint_builder::CustomEndpoint>) -> Self {
492        self.endpoints.extend(endpoints);
493        self
494    }
495
496    /// Alternative syntax: adds multiple endpoints using a slice.
497    /// This allows you to pass endpoints inline with array syntax.
498    ///
499    /// # Example
500    /// ```no_run
501    /// # use helios_engine::{Agent, ServerBuilder, get};
502    /// # async fn example(agent: Agent) -> helios_engine::Result<()> {
503    /// ServerBuilder::with_agent(agent, "model")
504    ///     .with_endpoints(&[
505    ///         get("/api/v1", serde_json::json!({"version": "1.0"})),
506    ///         get("/api/status", serde_json::json!({"status": "ok"})),
507    ///     ])
508    ///     .serve()
509    ///     .await?;
510    /// # Ok(())
511    /// # }
512    /// ```
513    pub fn with_endpoints(mut self, endpoints: &[crate::endpoint_builder::CustomEndpoint]) -> Self {
514        self.endpoints.extend_from_slice(endpoints);
515        self
516    }
517
518    /// Starts the server.
519    pub async fn serve(self) -> Result<()> {
520        let agent = self.agent.expect("Agent must be set");
521        let state = ServerState::with_agent(agent, self.model_name.clone());
522
523        let app = create_router_with_new_endpoints(state, self.endpoints);
524
525        info!(
526            "🚀 Starting Helios Engine server with agent on http://{}",
527            self.address
528        );
529        info!("📡 OpenAI-compatible API endpoints:");
530        info!("   POST /v1/chat/completions");
531        info!("   GET  /v1/models");
532
533        let listener = tokio::net::TcpListener::bind(&self.address)
534            .await
535            .map_err(|e| {
536                HeliosError::ConfigError(format!("Failed to bind to {}: {}", self.address, e))
537            })?;
538
539        axum::serve(listener, app)
540            .await
541            .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
542
543        Ok(())
544    }
545}
546
547/// Loads custom endpoints configuration from a TOML file.
548///
549/// # Arguments
550///
551/// * `path` - Path to the custom endpoints configuration file.
552///
553/// # Returns
554///
555/// A `Result` containing the custom endpoints configuration.
556pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
557    let content = std::fs::read_to_string(path).map_err(|e| {
558        HeliosError::ConfigError(format!(
559            "Failed to read custom endpoints config file '{}': {}",
560            path, e
561        ))
562    })?;
563
564    toml::from_str(&content).map_err(|e| {
565        HeliosError::ConfigError(format!(
566            "Failed to parse custom endpoints config file '{}': {}",
567            path, e
568        ))
569    })
570}
571
572/// Creates the router with all endpoints.
573fn create_router(state: ServerState) -> Router {
574    Router::new()
575        .route("/v1/chat/completions", post(chat_completions))
576        .route("/v1/models", get(list_models))
577        .route("/health", get(health_check))
578        .layer(CorsLayer::permissive())
579        .layer(TraceLayer::new_for_http())
580        .with_state(state)
581}
582
583/// Creates the router with custom endpoints.
584fn create_router_with_custom_endpoints(
585    state: ServerState,
586    custom_endpoints: Option<CustomEndpointsConfig>,
587) -> Router {
588    let mut router = Router::new()
589        .route("/v1/chat/completions", post(chat_completions))
590        .route("/v1/models", get(list_models))
591        .route("/health", get(health_check));
592
593    // Add custom endpoints if provided
594    if let Some(config) = custom_endpoints {
595        for endpoint in config.endpoints {
596            let response = endpoint.response.clone();
597            let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
598
599            let handler = move || async move { (status_code, Json(response)) };
600
601            match endpoint.method.to_uppercase().as_str() {
602                "GET" => router = router.route(&endpoint.path, get(handler)),
603                "POST" => router = router.route(&endpoint.path, post(handler)),
604                "PUT" => router = router.route(&endpoint.path, put(handler)),
605                "DELETE" => router = router.route(&endpoint.path, delete(handler)),
606                "PATCH" => router = router.route(&endpoint.path, patch(handler)),
607                _ => {
608                    // Default to GET for unsupported methods
609                    router = router.route(&endpoint.path, get(handler));
610                }
611            }
612        }
613    }
614
615    router
616        .layer(CorsLayer::permissive())
617        .layer(TraceLayer::new_for_http())
618        .with_state(state)
619}
620
621/// Creates the router with new-style custom endpoints.
622fn create_router_with_new_endpoints(
623    state: ServerState,
624    endpoints: Vec<crate::endpoint_builder::CustomEndpoint>,
625) -> Router {
626    use crate::endpoint_builder::HttpMethod;
627
628    let mut router = Router::new()
629        .route("/v1/chat/completions", post(chat_completions))
630        .route("/v1/models", get(list_models))
631        .route("/health", get(health_check));
632
633    // Add new-style custom endpoints
634    for endpoint in endpoints {
635        let handler_fn = endpoint.handler.clone();
636
637        let handler = move || {
638            let handler_fn = handler_fn.clone();
639            async move {
640                let response = handler_fn(None);
641                response.into_response()
642            }
643        };
644
645        match endpoint.method {
646            HttpMethod::Get => router = router.route(&endpoint.path, get(handler)),
647            HttpMethod::Post => router = router.route(&endpoint.path, post(handler)),
648            HttpMethod::Put => router = router.route(&endpoint.path, put(handler)),
649            HttpMethod::Delete => router = router.route(&endpoint.path, delete(handler)),
650            HttpMethod::Patch => router = router.route(&endpoint.path, patch(handler)),
651        }
652
653        if let Some(desc) = &endpoint.description {
654            info!(
655                "   {} {} - {}",
656                match endpoint.method {
657                    HttpMethod::Get => "GET",
658                    HttpMethod::Post => "POST",
659                    HttpMethod::Put => "PUT",
660                    HttpMethod::Delete => "DELETE",
661                    HttpMethod::Patch => "PATCH",
662                },
663                endpoint.path,
664                desc
665            );
666        }
667    }
668
669    router
670        .layer(CorsLayer::permissive())
671        .layer(TraceLayer::new_for_http())
672        .with_state(state)
673}
674
675/// Health check endpoint.
676async fn health_check() -> Json<serde_json::Value> {
677    Json(serde_json::json!({
678        "status": "ok",
679        "service": "helios-engine"
680    }))
681}
682
683/// Lists available models.
684async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
685    Json(ModelsResponse {
686        object: "list".to_string(),
687        data: vec![ModelInfo {
688            id: state.model_name.clone(),
689            object: "model".to_string(),
690            created: chrono::Utc::now().timestamp() as u64,
691            owned_by: "helios-engine".to_string(),
692        }],
693    })
694}
695
696/// Handles chat completion requests.
697async fn chat_completions(
698    State(state): State<ServerState>,
699    Json(request): Json<ChatCompletionRequest>,
700) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
701    // Convert OpenAI messages to ChatMessage format
702    let messages: Result<Vec<ChatMessage>> = request
703        .messages
704        .into_iter()
705        .map(|msg| {
706            // Convert OpenAI message format to internal ChatMessage format
707            // Maps standard OpenAI roles to our Role enum
708            let role = match msg.role.as_str() {
709                "system" => Role::System,       // System instructions/prompts
710                "user" => Role::User,           // User input messages
711                "assistant" => Role::Assistant, // AI assistant responses
712                "tool" => Role::Tool,           // Tool/function call results
713                _ => {
714                    // Reject invalid roles to maintain API compatibility
715                    return Err(HeliosError::ConfigError(format!(
716                        "Invalid role: {}",
717                        msg.role
718                    )));
719                }
720            };
721            Ok(ChatMessage {
722                role,
723                content: msg.content, // The actual message text
724                name: msg.name,       // Optional name for tool messages
725                tool_calls: None,     // Not used in conversion (OpenAI format differs)
726                tool_call_id: None,   // Not used in conversion (OpenAI format differs)
727            })
728        })
729        .collect();
730
731    let messages = messages.map_err(|e| {
732        error!("Failed to convert messages: {}", e);
733        StatusCode::BAD_REQUEST
734    })?;
735
736    let stream = request.stream.unwrap_or(false);
737
738    if stream {
739        // Handle streaming response
740        return Ok(stream_chat_completion(
741            state,
742            messages,
743            request.model,
744            request.temperature,
745            request.max_tokens,
746            request.stop.clone(),
747        )
748        .into_response());
749    }
750
751    // Handle non-streaming response
752    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
753    let created = chrono::Utc::now().timestamp() as u64;
754
755    // Clone messages for token estimation and LLM client usage
756    let messages_clone = messages.clone();
757
758    let response_content = if let Some(agent) = &state.agent {
759        // Use agent for response with full conversation history
760        let mut agent = agent.write().await;
761
762        match agent
763            .chat_with_history(
764                messages.clone(),
765                request.temperature,
766                request.max_tokens,
767                request.stop.clone(),
768            )
769            .await
770        {
771            Ok(content) => content,
772            Err(e) => {
773                error!("Agent error: {}", e);
774                return Err(StatusCode::INTERNAL_SERVER_ERROR);
775            }
776        }
777    } else if let Some(llm_client) = &state.llm_client {
778        // Use LLM client directly
779        match llm_client
780            .chat(
781                messages_clone,
782                None,
783                request.temperature,
784                request.max_tokens,
785                request.stop.clone(),
786            )
787            .await
788        {
789            Ok(msg) => msg.content,
790            Err(e) => {
791                error!("LLM error: {}", e);
792                return Err(StatusCode::INTERNAL_SERVER_ERROR);
793            }
794        }
795    } else {
796        return Err(StatusCode::INTERNAL_SERVER_ERROR);
797    };
798
799    // Estimate token usage (simplified - in production, use actual tokenizer)
800    let prompt_tokens = estimate_tokens(
801        &messages
802            .iter()
803            .map(|m| m.content.as_str())
804            .collect::<Vec<_>>()
805            .join(" "),
806    );
807    let completion_tokens = estimate_tokens(&response_content);
808
809    let response = ChatCompletionResponse {
810        id: completion_id,
811        object: "chat.completion".to_string(),
812        created,
813        model: request.model,
814        choices: vec![CompletionChoice {
815            index: 0,
816            message: OpenAIMessageResponse {
817                role: "assistant".to_string(),
818                content: response_content,
819            },
820            finish_reason: "stop".to_string(),
821        }],
822        usage: Usage {
823            prompt_tokens,
824            completion_tokens,
825            total_tokens: prompt_tokens + completion_tokens,
826        },
827    };
828
829    Ok(Json(response).into_response())
830}
831
832/// Streams a chat completion response.
833fn stream_chat_completion(
834    state: ServerState,
835    messages: Vec<ChatMessage>,
836    model: String,
837    temperature: Option<f32>,
838    max_tokens: Option<u32>,
839    stop: Option<Vec<String>>,
840) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
841    let (tx, rx) = tokio::sync::mpsc::channel(100);
842    let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
843    let created = chrono::Utc::now().timestamp() as u64;
844
845    tokio::spawn(async move {
846        let on_chunk = |chunk: &str| {
847            let event = Event::default()
848                .json_data(serde_json::json!({
849                    "id": completion_id,
850                    "object": "chat.completion.chunk",
851                    "created": created,
852                    "model": model,
853                    "choices": [{
854                        "index": 0,
855                        "delta": {
856                            "content": chunk
857                        },
858                        "finish_reason": null
859                    }]
860                }))
861                .unwrap();
862            let _ = tx.try_send(Ok(event));
863        };
864
865        if let Some(agent) = &state.agent {
866            // Use agent for true streaming response with full conversation history
867            let mut agent = agent.write().await;
868
869            match agent
870                .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
871                .await
872            {
873                Ok(_) => {
874                    // Streaming completed successfully
875                    // The on_chunk callback has already been called for each token
876                }
877                Err(e) => {
878                    error!("Agent streaming error: {}", e);
879                }
880            }
881        } else if let Some(llm_client) = &state.llm_client {
882            // Use LLM client streaming
883            match llm_client
884                .chat_stream(
885                    messages,
886                    None,
887                    temperature,
888                    max_tokens,
889                    stop.clone(),
890                    on_chunk,
891                )
892                .await
893            {
894                Ok(_) => {}
895                Err(e) => {
896                    error!("LLM streaming error: {}", e);
897                }
898            }
899        };
900
901        // Send final event
902        let final_event = Event::default()
903            .json_data(serde_json::json!({
904                "id": completion_id,
905                "object": "chat.completion.chunk",
906                "created": created,
907                "model": model,
908                "choices": [{
909                    "index": 0,
910                    "delta": {},
911                    "finish_reason": "stop"
912                }]
913            }))
914            .unwrap();
915        let _ = tx.send(Ok(final_event)).await;
916    });
917
918    Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
919}
920
921/// Estimates the number of tokens in a text (simplified approximation).
922/// In production, use an actual tokenizer.
923pub fn estimate_tokens(text: &str) -> u32 {
924    // Rough approximation: ~4 characters per token
925    (text.len() as f32 / 4.0).ceil() as u32
926}