1use crate::agent::Agent;
27use crate::chat::{ChatMessage, Role};
28use crate::config::Config;
29use crate::error::{HeliosError, Result};
30use crate::llm::{LLMClient, LLMProviderType};
31use axum::{
32 extract::State,
33 http::StatusCode,
34 response::{
35 sse::{Event, Sse},
36 IntoResponse,
37 },
38 routing::{delete, get, patch, post, put},
39 Json, Router,
40};
41use futures::stream::Stream;
42use serde::{Deserialize, Serialize};
43use std::convert::Infallible;
44use std::sync::Arc;
45use tokio::sync::RwLock;
46use tokio_stream::wrappers::ReceiverStream;
47use tower_http::cors::CorsLayer;
48use tower_http::trace::TraceLayer;
49use tracing::{error, info};
50use uuid::Uuid;
51
52#[derive(Debug, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub struct ChatCompletionRequest {
56 pub model: String,
58 pub messages: Vec<OpenAIMessage>,
60 #[serde(default)]
62 pub temperature: Option<f32>,
63 #[serde(default)]
65 pub max_tokens: Option<u32>,
66 #[serde(default)]
68 pub stream: Option<bool>,
69 #[serde(default)]
71 pub stop: Option<Vec<String>>,
72}
73
74#[derive(Debug, Deserialize)]
76pub struct OpenAIMessage {
77 pub role: String,
79 pub content: String,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub name: Option<String>,
84}
85
86#[derive(Debug, Serialize)]
88pub struct ChatCompletionResponse {
89 pub id: String,
91 pub object: String,
93 pub created: u64,
95 pub model: String,
97 pub choices: Vec<CompletionChoice>,
99 pub usage: Usage,
101}
102
103#[derive(Debug, Serialize)]
105pub struct CompletionChoice {
106 pub index: u32,
108 pub message: OpenAIMessageResponse,
110 pub finish_reason: String,
112}
113
114#[derive(Debug, Serialize)]
116pub struct OpenAIMessageResponse {
117 pub role: String,
119 pub content: String,
121}
122
123#[derive(Debug, Serialize)]
125pub struct Usage {
126 pub prompt_tokens: u32,
128 pub completion_tokens: u32,
130 pub total_tokens: u32,
132}
133
134#[derive(Debug, Serialize)]
136pub struct ModelInfo {
137 pub id: String,
139 pub object: String,
141 pub created: u64,
143 pub owned_by: String,
145}
146
147#[derive(Debug, Serialize)]
149pub struct ModelsResponse {
150 pub object: String,
152 pub data: Vec<ModelInfo>,
154}
155
156#[derive(Debug, Clone, Deserialize)]
158pub struct CustomEndpoint {
159 pub method: String,
161 pub path: String,
163 pub response: serde_json::Value,
165 #[serde(default = "default_status_code")]
167 pub status_code: u16,
168}
169
170fn default_status_code() -> u16 {
171 200
172}
173
174#[derive(Debug, Clone, Deserialize)]
176pub struct CustomEndpointsConfig {
177 pub endpoints: Vec<CustomEndpoint>,
179}
180
181#[derive(Clone)]
183pub struct ServerState {
184 pub llm_client: Option<Arc<LLMClient>>,
186 pub agent: Option<Arc<RwLock<Agent>>>,
188 pub model_name: String,
190}
191
192impl ServerState {
193 pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
195 Self {
196 llm_client: Some(Arc::new(llm_client)),
197 agent: None,
198 model_name,
199 }
200 }
201
202 pub fn with_agent(agent: Agent, model_name: String) -> Self {
204 Self {
205 llm_client: None,
206 agent: Some(Arc::new(RwLock::new(agent))),
207 model_name,
208 }
209 }
210}
211
212pub async fn start_server(config: Config, address: &str) -> Result<()> {
223 let provider_type = if let Some(local_config) = config.local.clone() {
224 LLMProviderType::Local(local_config)
225 } else {
226 LLMProviderType::Remote(config.llm.clone())
227 };
228
229 let llm_client = LLMClient::new(provider_type).await?;
230 let model_name = config
231 .local
232 .as_ref()
233 .map(|_| "local-model".to_string())
234 .unwrap_or_else(|| config.llm.model_name.clone());
235
236 let state = ServerState::with_llm_client(llm_client, model_name);
237
238 let app = create_router(state);
239
240 info!("🚀 Starting Helios Engine server on http://{}", address);
241 info!("📡 OpenAI-compatible API endpoints:");
242 info!(" POST /v1/chat/completions");
243 info!(" GET /v1/models");
244
245 let listener = tokio::net::TcpListener::bind(address)
246 .await
247 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
248
249 axum::serve(listener, app)
250 .await
251 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
252
253 Ok(())
254}
255
256pub async fn start_server_with_agent(
268 agent: Agent,
269 model_name: String,
270 address: &str,
271) -> Result<()> {
272 let state = ServerState::with_agent(agent, model_name);
273
274 let app = create_router(state);
275
276 info!(
277 "🚀 Starting Helios Engine server with agent on http://{}",
278 address
279 );
280 info!("📡 OpenAI-compatible API endpoints:");
281 info!(" POST /v1/chat/completions");
282 info!(" GET /v1/models");
283
284 let listener = tokio::net::TcpListener::bind(address)
285 .await
286 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
287
288 axum::serve(listener, app)
289 .await
290 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
291
292 Ok(())
293}
294
295pub async fn start_server_with_custom_endpoints(
307 config: Config,
308 address: &str,
309 custom_endpoints: Option<CustomEndpointsConfig>,
310) -> Result<()> {
311 let provider_type = if let Some(local_config) = config.local.clone() {
312 LLMProviderType::Local(local_config)
313 } else {
314 LLMProviderType::Remote(config.llm.clone())
315 };
316
317 let llm_client = LLMClient::new(provider_type).await?;
318 let model_name = config
319 .local
320 .as_ref()
321 .map(|_| "local-model".to_string())
322 .unwrap_or_else(|| config.llm.model_name.clone());
323
324 let state = ServerState::with_llm_client(llm_client, model_name);
325
326 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
327
328 info!("🚀 Starting Helios Engine server on http://{}", address);
329 info!("📡 OpenAI-compatible API endpoints:");
330 info!(" POST /v1/chat/completions");
331 info!(" GET /v1/models");
332
333 if let Some(config) = &custom_endpoints {
334 info!("📡 Custom endpoints:");
335 for endpoint in &config.endpoints {
336 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
337 }
338 }
339
340 let listener = tokio::net::TcpListener::bind(address)
341 .await
342 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
343
344 axum::serve(listener, app)
345 .await
346 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
347
348 Ok(())
349}
350
351pub async fn start_server_with_agent_and_custom_endpoints(
364 agent: Agent,
365 model_name: String,
366 address: &str,
367 custom_endpoints: Option<CustomEndpointsConfig>,
368) -> Result<()> {
369 let state = ServerState::with_agent(agent, model_name);
370
371 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
372
373 info!(
374 "🚀 Starting Helios Engine server with agent on http://{}",
375 address
376 );
377 info!("📡 OpenAI-compatible API endpoints:");
378 info!(" POST /v1/chat/completions");
379 info!(" GET /v1/models");
380
381 if let Some(config) = &custom_endpoints {
382 info!("📡 Custom endpoints:");
383 for endpoint in &config.endpoints {
384 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
385 }
386 }
387
388 let listener = tokio::net::TcpListener::bind(address)
389 .await
390 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
391
392 axum::serve(listener, app)
393 .await
394 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
395
396 Ok(())
397}
398
399pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
409 let content = std::fs::read_to_string(path).map_err(|e| {
410 HeliosError::ConfigError(format!(
411 "Failed to read custom endpoints config file '{}': {}",
412 path, e
413 ))
414 })?;
415
416 toml::from_str(&content).map_err(|e| {
417 HeliosError::ConfigError(format!(
418 "Failed to parse custom endpoints config file '{}': {}",
419 path, e
420 ))
421 })
422}
423
424fn create_router(state: ServerState) -> Router {
426 Router::new()
427 .route("/v1/chat/completions", post(chat_completions))
428 .route("/v1/models", get(list_models))
429 .route("/health", get(health_check))
430 .layer(CorsLayer::permissive())
431 .layer(TraceLayer::new_for_http())
432 .with_state(state)
433}
434
435fn create_router_with_custom_endpoints(
437 state: ServerState,
438 custom_endpoints: Option<CustomEndpointsConfig>,
439) -> Router {
440 let mut router = Router::new()
441 .route("/v1/chat/completions", post(chat_completions))
442 .route("/v1/models", get(list_models))
443 .route("/health", get(health_check));
444
445 if let Some(config) = custom_endpoints {
447 for endpoint in config.endpoints {
448 let response = endpoint.response.clone();
449 let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
450
451 let handler = move || async move { (status_code, Json(response)) };
452
453 match endpoint.method.to_uppercase().as_str() {
454 "GET" => router = router.route(&endpoint.path, get(handler)),
455 "POST" => router = router.route(&endpoint.path, post(handler)),
456 "PUT" => router = router.route(&endpoint.path, put(handler)),
457 "DELETE" => router = router.route(&endpoint.path, delete(handler)),
458 "PATCH" => router = router.route(&endpoint.path, patch(handler)),
459 _ => {
460 router = router.route(&endpoint.path, get(handler));
462 }
463 }
464 }
465 }
466
467 router
468 .layer(CorsLayer::permissive())
469 .layer(TraceLayer::new_for_http())
470 .with_state(state)
471}
472
473async fn health_check() -> Json<serde_json::Value> {
475 Json(serde_json::json!({
476 "status": "ok",
477 "service": "helios-engine"
478 }))
479}
480
481async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
483 Json(ModelsResponse {
484 object: "list".to_string(),
485 data: vec![ModelInfo {
486 id: state.model_name.clone(),
487 object: "model".to_string(),
488 created: chrono::Utc::now().timestamp() as u64,
489 owned_by: "helios-engine".to_string(),
490 }],
491 })
492}
493
494async fn chat_completions(
496 State(state): State<ServerState>,
497 Json(request): Json<ChatCompletionRequest>,
498) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
499 let messages: Result<Vec<ChatMessage>> = request
501 .messages
502 .into_iter()
503 .map(|msg| {
504 let role = match msg.role.as_str() {
507 "system" => Role::System, "user" => Role::User, "assistant" => Role::Assistant, "tool" => Role::Tool, _ => {
512 return Err(HeliosError::ConfigError(format!(
514 "Invalid role: {}",
515 msg.role
516 )));
517 }
518 };
519 Ok(ChatMessage {
520 role,
521 content: msg.content, name: msg.name, tool_calls: None, tool_call_id: None, })
526 })
527 .collect();
528
529 let messages = messages.map_err(|e| {
530 error!("Failed to convert messages: {}", e);
531 StatusCode::BAD_REQUEST
532 })?;
533
534 let stream = request.stream.unwrap_or(false);
535
536 if stream {
537 return Ok(stream_chat_completion(
539 state,
540 messages,
541 request.model,
542 request.temperature,
543 request.max_tokens,
544 request.stop.clone(),
545 )
546 .into_response());
547 }
548
549 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
551 let created = chrono::Utc::now().timestamp() as u64;
552
553 let messages_clone = messages.clone();
555
556 let response_content = if let Some(agent) = &state.agent {
557 let mut agent = agent.write().await;
559
560 match agent
561 .chat_with_history(
562 messages.clone(),
563 request.temperature,
564 request.max_tokens,
565 request.stop.clone(),
566 )
567 .await
568 {
569 Ok(content) => content,
570 Err(e) => {
571 error!("Agent error: {}", e);
572 return Err(StatusCode::INTERNAL_SERVER_ERROR);
573 }
574 }
575 } else if let Some(llm_client) = &state.llm_client {
576 match llm_client
578 .chat(
579 messages_clone,
580 None,
581 request.temperature,
582 request.max_tokens,
583 request.stop.clone(),
584 )
585 .await
586 {
587 Ok(msg) => msg.content,
588 Err(e) => {
589 error!("LLM error: {}", e);
590 return Err(StatusCode::INTERNAL_SERVER_ERROR);
591 }
592 }
593 } else {
594 return Err(StatusCode::INTERNAL_SERVER_ERROR);
595 };
596
597 let prompt_tokens = estimate_tokens(
599 &messages
600 .iter()
601 .map(|m| m.content.as_str())
602 .collect::<Vec<_>>()
603 .join(" "),
604 );
605 let completion_tokens = estimate_tokens(&response_content);
606
607 let response = ChatCompletionResponse {
608 id: completion_id,
609 object: "chat.completion".to_string(),
610 created,
611 model: request.model,
612 choices: vec![CompletionChoice {
613 index: 0,
614 message: OpenAIMessageResponse {
615 role: "assistant".to_string(),
616 content: response_content,
617 },
618 finish_reason: "stop".to_string(),
619 }],
620 usage: Usage {
621 prompt_tokens,
622 completion_tokens,
623 total_tokens: prompt_tokens + completion_tokens,
624 },
625 };
626
627 Ok(Json(response).into_response())
628}
629
630fn stream_chat_completion(
632 state: ServerState,
633 messages: Vec<ChatMessage>,
634 model: String,
635 temperature: Option<f32>,
636 max_tokens: Option<u32>,
637 stop: Option<Vec<String>>,
638) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
639 let (tx, rx) = tokio::sync::mpsc::channel(100);
640 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
641 let created = chrono::Utc::now().timestamp() as u64;
642
643 tokio::spawn(async move {
644 let on_chunk = |chunk: &str| {
645 let event = Event::default()
646 .json_data(serde_json::json!({
647 "id": completion_id,
648 "object": "chat.completion.chunk",
649 "created": created,
650 "model": model,
651 "choices": [{
652 "index": 0,
653 "delta": {
654 "content": chunk
655 },
656 "finish_reason": null
657 }]
658 }))
659 .unwrap();
660 let _ = tx.try_send(Ok(event));
661 };
662
663 if let Some(agent) = &state.agent {
664 let mut agent = agent.write().await;
666
667 match agent
668 .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
669 .await
670 {
671 Ok(_) => {
672 }
675 Err(e) => {
676 error!("Agent streaming error: {}", e);
677 }
678 }
679 } else if let Some(llm_client) = &state.llm_client {
680 match llm_client
682 .chat_stream(
683 messages,
684 None,
685 temperature,
686 max_tokens,
687 stop.clone(),
688 on_chunk,
689 )
690 .await
691 {
692 Ok(_) => {}
693 Err(e) => {
694 error!("LLM streaming error: {}", e);
695 }
696 }
697 };
698
699 let final_event = Event::default()
701 .json_data(serde_json::json!({
702 "id": completion_id,
703 "object": "chat.completion.chunk",
704 "created": created,
705 "model": model,
706 "choices": [{
707 "index": 0,
708 "delta": {},
709 "finish_reason": "stop"
710 }]
711 }))
712 .unwrap();
713 let _ = tx.send(Ok(final_event)).await;
714 });
715
716 Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
717}
718
719pub fn estimate_tokens(text: &str) -> u32 {
722 (text.len() as f32 / 4.0).ceil() as u32
724}