1use crate::agent::Agent;
27use crate::chat::{ChatMessage, Role};
28use crate::config::Config;
29use crate::error::{HeliosError, Result};
30use crate::llm::{LLMClient, LLMProviderType};
31use axum::{
32 extract::State,
33 http::StatusCode,
34 response::{
35 sse::{Event, Sse},
36 IntoResponse,
37 },
38 routing::{delete, get, patch, post, put},
39 Json, Router,
40};
41use futures::stream::Stream;
42use serde::{Deserialize, Serialize};
43use std::convert::Infallible;
44use std::sync::Arc;
45use tokio::sync::RwLock;
46use tokio_stream::wrappers::ReceiverStream;
47use tower_http::cors::CorsLayer;
48use tower_http::trace::TraceLayer;
49use tracing::{error, info};
50use uuid::Uuid;
51
52#[derive(Debug, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub struct ChatCompletionRequest {
56 pub model: String,
58 pub messages: Vec<OpenAIMessage>,
60 #[serde(default)]
62 pub temperature: Option<f32>,
63 #[serde(default)]
65 pub max_tokens: Option<u32>,
66 #[serde(default)]
68 pub stream: Option<bool>,
69 #[serde(default)]
71 pub stop: Option<Vec<String>>,
72}
73
74#[derive(Debug, Deserialize)]
76pub struct OpenAIMessage {
77 pub role: String,
79 pub content: String,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub name: Option<String>,
84}
85
86#[derive(Debug, Serialize)]
88pub struct ChatCompletionResponse {
89 pub id: String,
91 pub object: String,
93 pub created: u64,
95 pub model: String,
97 pub choices: Vec<CompletionChoice>,
99 pub usage: Usage,
101}
102
103#[derive(Debug, Serialize)]
105pub struct CompletionChoice {
106 pub index: u32,
108 pub message: OpenAIMessageResponse,
110 pub finish_reason: String,
112}
113
114#[derive(Debug, Serialize)]
116pub struct OpenAIMessageResponse {
117 pub role: String,
119 pub content: String,
121}
122
123#[derive(Debug, Serialize)]
125pub struct Usage {
126 pub prompt_tokens: u32,
128 pub completion_tokens: u32,
130 pub total_tokens: u32,
132}
133
134#[derive(Debug, Serialize)]
136pub struct ModelInfo {
137 pub id: String,
139 pub object: String,
141 pub created: u64,
143 pub owned_by: String,
145}
146
147#[derive(Debug, Serialize)]
149pub struct ModelsResponse {
150 pub object: String,
152 pub data: Vec<ModelInfo>,
154}
155
156#[derive(Debug, Clone, Deserialize)]
158pub struct CustomEndpoint {
159 pub method: String,
161 pub path: String,
163 pub response: serde_json::Value,
165 #[serde(default = "default_status_code")]
167 pub status_code: u16,
168}
169
170fn default_status_code() -> u16 {
171 200
172}
173
174#[derive(Debug, Clone, Deserialize)]
176pub struct CustomEndpointsConfig {
177 pub endpoints: Vec<CustomEndpoint>,
179}
180
181#[derive(Clone)]
183pub struct ServerState {
184 pub llm_client: Option<Arc<LLMClient>>,
186 pub agent: Option<Arc<RwLock<Agent>>>,
188 pub model_name: String,
190}
191
192impl ServerState {
193 pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
195 Self {
196 llm_client: Some(Arc::new(llm_client)),
197 agent: None,
198 model_name,
199 }
200 }
201
202 pub fn with_agent(agent: Agent, model_name: String) -> Self {
204 Self {
205 llm_client: None,
206 agent: Some(Arc::new(RwLock::new(agent))),
207 model_name,
208 }
209 }
210}
211
212pub async fn start_server(config: Config, address: &str) -> Result<()> {
223 #[cfg(feature = "local")]
224 let provider_type = if let Some(local_config) = config.local.clone() {
225 LLMProviderType::Local(local_config)
226 } else {
227 LLMProviderType::Remote(config.llm.clone())
228 };
229
230 #[cfg(not(feature = "local"))]
231 let provider_type = LLMProviderType::Remote(config.llm.clone());
232
233 let llm_client = LLMClient::new(provider_type).await?;
234
235 #[cfg(feature = "local")]
236 let model_name = config
237 .local
238 .as_ref()
239 .map(|_| "local-model".to_string())
240 .unwrap_or_else(|| config.llm.model_name.clone());
241
242 #[cfg(not(feature = "local"))]
243 let model_name = config.llm.model_name.clone();
244
245 let state = ServerState::with_llm_client(llm_client, model_name);
246
247 let app = create_router(state);
248
249 info!("🚀 Starting Helios Engine server on http://{}", address);
250 info!("📡 OpenAI-compatible API endpoints:");
251 info!(" POST /v1/chat/completions");
252 info!(" GET /v1/models");
253
254 let listener = tokio::net::TcpListener::bind(address)
255 .await
256 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
257
258 axum::serve(listener, app)
259 .await
260 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
261
262 Ok(())
263}
264
265pub async fn start_server_with_agent(
277 agent: Agent,
278 model_name: String,
279 address: &str,
280) -> Result<()> {
281 let state = ServerState::with_agent(agent, model_name);
282
283 let app = create_router(state);
284
285 info!(
286 "🚀 Starting Helios Engine server with agent on http://{}",
287 address
288 );
289 info!("📡 OpenAI-compatible API endpoints:");
290 info!(" POST /v1/chat/completions");
291 info!(" GET /v1/models");
292
293 let listener = tokio::net::TcpListener::bind(address)
294 .await
295 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
296
297 axum::serve(listener, app)
298 .await
299 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
300
301 Ok(())
302}
303
304pub async fn start_server_with_custom_endpoints(
316 config: Config,
317 address: &str,
318 custom_endpoints: Option<CustomEndpointsConfig>,
319) -> Result<()> {
320 #[cfg(feature = "local")]
321 let provider_type = if let Some(local_config) = config.local.clone() {
322 LLMProviderType::Local(local_config)
323 } else {
324 LLMProviderType::Remote(config.llm.clone())
325 };
326
327 #[cfg(not(feature = "local"))]
328 let provider_type = LLMProviderType::Remote(config.llm.clone());
329
330 let llm_client = LLMClient::new(provider_type).await?;
331
332 #[cfg(feature = "local")]
333 let model_name = config
334 .local
335 .as_ref()
336 .map(|_| "local-model".to_string())
337 .unwrap_or_else(|| config.llm.model_name.clone());
338
339 #[cfg(not(feature = "local"))]
340 let model_name = config.llm.model_name.clone();
341
342 let state = ServerState::with_llm_client(llm_client, model_name);
343
344 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
345
346 info!("🚀 Starting Helios Engine server on http://{}", address);
347 info!("📡 OpenAI-compatible API endpoints:");
348 info!(" POST /v1/chat/completions");
349 info!(" GET /v1/models");
350
351 if let Some(config) = &custom_endpoints {
352 info!("📡 Custom endpoints:");
353 for endpoint in &config.endpoints {
354 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
355 }
356 }
357
358 let listener = tokio::net::TcpListener::bind(address)
359 .await
360 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
361
362 axum::serve(listener, app)
363 .await
364 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
365
366 Ok(())
367}
368
369pub async fn start_server_with_agent_and_custom_endpoints(
382 agent: Agent,
383 model_name: String,
384 address: &str,
385 custom_endpoints: Option<CustomEndpointsConfig>,
386) -> Result<()> {
387 let state = ServerState::with_agent(agent, model_name);
388
389 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
390
391 info!(
392 "🚀 Starting Helios Engine server with agent on http://{}",
393 address
394 );
395 info!("📡 OpenAI-compatible API endpoints:");
396 info!(" POST /v1/chat/completions");
397 info!(" GET /v1/models");
398
399 if let Some(config) = &custom_endpoints {
400 info!("📡 Custom endpoints:");
401 for endpoint in &config.endpoints {
402 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
403 }
404 }
405
406 let listener = tokio::net::TcpListener::bind(address)
407 .await
408 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
409
410 axum::serve(listener, app)
411 .await
412 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
413
414 Ok(())
415}
416
417pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
427 let content = std::fs::read_to_string(path).map_err(|e| {
428 HeliosError::ConfigError(format!(
429 "Failed to read custom endpoints config file '{}': {}",
430 path, e
431 ))
432 })?;
433
434 toml::from_str(&content).map_err(|e| {
435 HeliosError::ConfigError(format!(
436 "Failed to parse custom endpoints config file '{}': {}",
437 path, e
438 ))
439 })
440}
441
442fn create_router(state: ServerState) -> Router {
444 Router::new()
445 .route("/v1/chat/completions", post(chat_completions))
446 .route("/v1/models", get(list_models))
447 .route("/health", get(health_check))
448 .layer(CorsLayer::permissive())
449 .layer(TraceLayer::new_for_http())
450 .with_state(state)
451}
452
453fn create_router_with_custom_endpoints(
455 state: ServerState,
456 custom_endpoints: Option<CustomEndpointsConfig>,
457) -> Router {
458 let mut router = Router::new()
459 .route("/v1/chat/completions", post(chat_completions))
460 .route("/v1/models", get(list_models))
461 .route("/health", get(health_check));
462
463 if let Some(config) = custom_endpoints {
465 for endpoint in config.endpoints {
466 let response = endpoint.response.clone();
467 let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
468
469 let handler = move || async move { (status_code, Json(response)) };
470
471 match endpoint.method.to_uppercase().as_str() {
472 "GET" => router = router.route(&endpoint.path, get(handler)),
473 "POST" => router = router.route(&endpoint.path, post(handler)),
474 "PUT" => router = router.route(&endpoint.path, put(handler)),
475 "DELETE" => router = router.route(&endpoint.path, delete(handler)),
476 "PATCH" => router = router.route(&endpoint.path, patch(handler)),
477 _ => {
478 router = router.route(&endpoint.path, get(handler));
480 }
481 }
482 }
483 }
484
485 router
486 .layer(CorsLayer::permissive())
487 .layer(TraceLayer::new_for_http())
488 .with_state(state)
489}
490
491async fn health_check() -> Json<serde_json::Value> {
493 Json(serde_json::json!({
494 "status": "ok",
495 "service": "helios-engine"
496 }))
497}
498
499async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
501 Json(ModelsResponse {
502 object: "list".to_string(),
503 data: vec![ModelInfo {
504 id: state.model_name.clone(),
505 object: "model".to_string(),
506 created: chrono::Utc::now().timestamp() as u64,
507 owned_by: "helios-engine".to_string(),
508 }],
509 })
510}
511
512async fn chat_completions(
514 State(state): State<ServerState>,
515 Json(request): Json<ChatCompletionRequest>,
516) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
517 let messages: Result<Vec<ChatMessage>> = request
519 .messages
520 .into_iter()
521 .map(|msg| {
522 let role = match msg.role.as_str() {
525 "system" => Role::System, "user" => Role::User, "assistant" => Role::Assistant, "tool" => Role::Tool, _ => {
530 return Err(HeliosError::ConfigError(format!(
532 "Invalid role: {}",
533 msg.role
534 )));
535 }
536 };
537 Ok(ChatMessage {
538 role,
539 content: msg.content, name: msg.name, tool_calls: None, tool_call_id: None, })
544 })
545 .collect();
546
547 let messages = messages.map_err(|e| {
548 error!("Failed to convert messages: {}", e);
549 StatusCode::BAD_REQUEST
550 })?;
551
552 let stream = request.stream.unwrap_or(false);
553
554 if stream {
555 return Ok(stream_chat_completion(
557 state,
558 messages,
559 request.model,
560 request.temperature,
561 request.max_tokens,
562 request.stop.clone(),
563 )
564 .into_response());
565 }
566
567 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
569 let created = chrono::Utc::now().timestamp() as u64;
570
571 let messages_clone = messages.clone();
573
574 let response_content = if let Some(agent) = &state.agent {
575 let mut agent = agent.write().await;
577
578 match agent
579 .chat_with_history(
580 messages.clone(),
581 request.temperature,
582 request.max_tokens,
583 request.stop.clone(),
584 )
585 .await
586 {
587 Ok(content) => content,
588 Err(e) => {
589 error!("Agent error: {}", e);
590 return Err(StatusCode::INTERNAL_SERVER_ERROR);
591 }
592 }
593 } else if let Some(llm_client) = &state.llm_client {
594 match llm_client
596 .chat(
597 messages_clone,
598 None,
599 request.temperature,
600 request.max_tokens,
601 request.stop.clone(),
602 )
603 .await
604 {
605 Ok(msg) => msg.content,
606 Err(e) => {
607 error!("LLM error: {}", e);
608 return Err(StatusCode::INTERNAL_SERVER_ERROR);
609 }
610 }
611 } else {
612 return Err(StatusCode::INTERNAL_SERVER_ERROR);
613 };
614
615 let prompt_tokens = estimate_tokens(
617 &messages
618 .iter()
619 .map(|m| m.content.as_str())
620 .collect::<Vec<_>>()
621 .join(" "),
622 );
623 let completion_tokens = estimate_tokens(&response_content);
624
625 let response = ChatCompletionResponse {
626 id: completion_id,
627 object: "chat.completion".to_string(),
628 created,
629 model: request.model,
630 choices: vec![CompletionChoice {
631 index: 0,
632 message: OpenAIMessageResponse {
633 role: "assistant".to_string(),
634 content: response_content,
635 },
636 finish_reason: "stop".to_string(),
637 }],
638 usage: Usage {
639 prompt_tokens,
640 completion_tokens,
641 total_tokens: prompt_tokens + completion_tokens,
642 },
643 };
644
645 Ok(Json(response).into_response())
646}
647
648fn stream_chat_completion(
650 state: ServerState,
651 messages: Vec<ChatMessage>,
652 model: String,
653 temperature: Option<f32>,
654 max_tokens: Option<u32>,
655 stop: Option<Vec<String>>,
656) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
657 let (tx, rx) = tokio::sync::mpsc::channel(100);
658 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
659 let created = chrono::Utc::now().timestamp() as u64;
660
661 tokio::spawn(async move {
662 let on_chunk = |chunk: &str| {
663 let event = Event::default()
664 .json_data(serde_json::json!({
665 "id": completion_id,
666 "object": "chat.completion.chunk",
667 "created": created,
668 "model": model,
669 "choices": [{
670 "index": 0,
671 "delta": {
672 "content": chunk
673 },
674 "finish_reason": null
675 }]
676 }))
677 .unwrap();
678 let _ = tx.try_send(Ok(event));
679 };
680
681 if let Some(agent) = &state.agent {
682 let mut agent = agent.write().await;
684
685 match agent
686 .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
687 .await
688 {
689 Ok(_) => {
690 }
693 Err(e) => {
694 error!("Agent streaming error: {}", e);
695 }
696 }
697 } else if let Some(llm_client) = &state.llm_client {
698 match llm_client
700 .chat_stream(
701 messages,
702 None,
703 temperature,
704 max_tokens,
705 stop.clone(),
706 on_chunk,
707 )
708 .await
709 {
710 Ok(_) => {}
711 Err(e) => {
712 error!("LLM streaming error: {}", e);
713 }
714 }
715 };
716
717 let final_event = Event::default()
719 .json_data(serde_json::json!({
720 "id": completion_id,
721 "object": "chat.completion.chunk",
722 "created": created,
723 "model": model,
724 "choices": [{
725 "index": 0,
726 "delta": {},
727 "finish_reason": "stop"
728 }]
729 }))
730 .unwrap();
731 let _ = tx.send(Ok(final_event)).await;
732 });
733
734 Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
735}
736
737pub fn estimate_tokens(text: &str) -> u32 {
740 (text.len() as f32 / 4.0).ceil() as u32
742}