1use crate::agent::Agent;
26use crate::chat::{ChatMessage, Role};
27use crate::config::Config;
28use crate::error::{HeliosError, Result};
29use crate::llm::{LLMClient, LLMProviderType};
30use axum::{
31 extract::State,
32 http::StatusCode,
33 response::{
34 sse::{Event, Sse},
35 IntoResponse,
36 },
37 routing::{delete, get, patch, post, put},
38 Json, Router,
39};
40use futures::stream::Stream;
41use serde::{Deserialize, Serialize};
42use std::convert::Infallible;
43use std::sync::Arc;
44use tokio::sync::RwLock;
45use tokio_stream::wrappers::ReceiverStream;
46use tower_http::cors::CorsLayer;
47use tower_http::trace::TraceLayer;
48use tracing::{error, info};
49use uuid::Uuid;
50
51#[derive(Debug, Deserialize)]
53#[serde(rename_all = "snake_case")]
54pub struct ChatCompletionRequest {
55 pub model: String,
57 pub messages: Vec<OpenAIMessage>,
59 #[serde(default)]
61 pub temperature: Option<f32>,
62 #[serde(default)]
64 pub max_tokens: Option<u32>,
65 #[serde(default)]
67 pub stream: Option<bool>,
68 #[serde(default)]
70 pub stop: Option<Vec<String>>,
71}
72
73#[derive(Debug, Deserialize)]
75pub struct OpenAIMessage {
76 pub role: String,
78 pub content: String,
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub name: Option<String>,
83}
84
85#[derive(Debug, Serialize)]
87pub struct ChatCompletionResponse {
88 pub id: String,
90 pub object: String,
92 pub created: u64,
94 pub model: String,
96 pub choices: Vec<CompletionChoice>,
98 pub usage: Usage,
100}
101
102#[derive(Debug, Serialize)]
104pub struct CompletionChoice {
105 pub index: u32,
107 pub message: OpenAIMessageResponse,
109 pub finish_reason: String,
111}
112
113#[derive(Debug, Serialize)]
115pub struct OpenAIMessageResponse {
116 pub role: String,
118 pub content: String,
120}
121
122#[derive(Debug, Serialize)]
124pub struct Usage {
125 pub prompt_tokens: u32,
127 pub completion_tokens: u32,
129 pub total_tokens: u32,
131}
132
133#[derive(Debug, Serialize)]
135pub struct ModelInfo {
136 pub id: String,
138 pub object: String,
140 pub created: u64,
142 pub owned_by: String,
144}
145
146#[derive(Debug, Serialize)]
148pub struct ModelsResponse {
149 pub object: String,
151 pub data: Vec<ModelInfo>,
153}
154
155#[derive(Debug, Clone, Deserialize)]
157pub struct CustomEndpoint {
158 pub method: String,
160 pub path: String,
162 pub response: serde_json::Value,
164 #[serde(default = "default_status_code")]
166 pub status_code: u16,
167}
168
169fn default_status_code() -> u16 {
170 200
171}
172
173#[derive(Debug, Clone, Deserialize)]
175pub struct CustomEndpointsConfig {
176 pub endpoints: Vec<CustomEndpoint>,
178}
179
180#[derive(Clone)]
182pub struct ServerState {
183 pub llm_client: Option<Arc<LLMClient>>,
185 pub agent: Option<Arc<RwLock<Agent>>>,
187 pub model_name: String,
189}
190
191impl ServerState {
192 pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
194 Self {
195 llm_client: Some(Arc::new(llm_client)),
196 agent: None,
197 model_name,
198 }
199 }
200
201 pub fn with_agent(agent: Agent, model_name: String) -> Self {
203 Self {
204 llm_client: None,
205 agent: Some(Arc::new(RwLock::new(agent))),
206 model_name,
207 }
208 }
209}
210
211pub async fn start_server(config: Config, address: &str) -> Result<()> {
222 let provider_type = if let Some(local_config) = config.local.clone() {
223 LLMProviderType::Local(local_config)
224 } else {
225 LLMProviderType::Remote(config.llm.clone())
226 };
227
228 let llm_client = LLMClient::new(provider_type).await?;
229 let model_name = config
230 .local
231 .as_ref()
232 .map(|_| "local-model".to_string())
233 .unwrap_or_else(|| config.llm.model_name.clone());
234
235 let state = ServerState::with_llm_client(llm_client, model_name);
236
237 let app = create_router(state);
238
239 info!("🚀 Starting Helios Engine server on http://{}", address);
240 info!("📡 OpenAI-compatible API endpoints:");
241 info!(" POST /v1/chat/completions");
242 info!(" GET /v1/models");
243
244 let listener = tokio::net::TcpListener::bind(address)
245 .await
246 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
247
248 axum::serve(listener, app)
249 .await
250 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
251
252 Ok(())
253}
254
255pub async fn start_server_with_agent(
267 agent: Agent,
268 model_name: String,
269 address: &str,
270) -> Result<()> {
271 let state = ServerState::with_agent(agent, model_name);
272
273 let app = create_router(state);
274
275 info!(
276 "🚀 Starting Helios Engine server with agent on http://{}",
277 address
278 );
279 info!("📡 OpenAI-compatible API endpoints:");
280 info!(" POST /v1/chat/completions");
281 info!(" GET /v1/models");
282
283 let listener = tokio::net::TcpListener::bind(address)
284 .await
285 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
286
287 axum::serve(listener, app)
288 .await
289 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
290
291 Ok(())
292}
293
294pub async fn start_server_with_custom_endpoints(
306 config: Config,
307 address: &str,
308 custom_endpoints: Option<CustomEndpointsConfig>,
309) -> Result<()> {
310 let provider_type = if let Some(local_config) = config.local.clone() {
311 LLMProviderType::Local(local_config)
312 } else {
313 LLMProviderType::Remote(config.llm.clone())
314 };
315
316 let llm_client = LLMClient::new(provider_type).await?;
317 let model_name = config
318 .local
319 .as_ref()
320 .map(|_| "local-model".to_string())
321 .unwrap_or_else(|| config.llm.model_name.clone());
322
323 let state = ServerState::with_llm_client(llm_client, model_name);
324
325 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
326
327 info!("🚀 Starting Helios Engine server on http://{}", address);
328 info!("📡 OpenAI-compatible API endpoints:");
329 info!(" POST /v1/chat/completions");
330 info!(" GET /v1/models");
331
332 if let Some(config) = &custom_endpoints {
333 info!("📡 Custom endpoints:");
334 for endpoint in &config.endpoints {
335 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
336 }
337 }
338
339 let listener = tokio::net::TcpListener::bind(address)
340 .await
341 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
342
343 axum::serve(listener, app)
344 .await
345 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
346
347 Ok(())
348}
349
350pub async fn start_server_with_agent_and_custom_endpoints(
363 agent: Agent,
364 model_name: String,
365 address: &str,
366 custom_endpoints: Option<CustomEndpointsConfig>,
367) -> Result<()> {
368 let state = ServerState::with_agent(agent, model_name);
369
370 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
371
372 info!(
373 "🚀 Starting Helios Engine server with agent on http://{}",
374 address
375 );
376 info!("📡 OpenAI-compatible API endpoints:");
377 info!(" POST /v1/chat/completions");
378 info!(" GET /v1/models");
379
380 if let Some(config) = &custom_endpoints {
381 info!("📡 Custom endpoints:");
382 for endpoint in &config.endpoints {
383 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
384 }
385 }
386
387 let listener = tokio::net::TcpListener::bind(address)
388 .await
389 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
390
391 axum::serve(listener, app)
392 .await
393 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
394
395 Ok(())
396}
397
398pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
408 let content = std::fs::read_to_string(path).map_err(|e| {
409 HeliosError::ConfigError(format!(
410 "Failed to read custom endpoints config file '{}': {}",
411 path, e
412 ))
413 })?;
414
415 toml::from_str(&content).map_err(|e| {
416 HeliosError::ConfigError(format!(
417 "Failed to parse custom endpoints config file '{}': {}",
418 path, e
419 ))
420 })
421}
422
423fn create_router(state: ServerState) -> Router {
425 Router::new()
426 .route("/v1/chat/completions", post(chat_completions))
427 .route("/v1/models", get(list_models))
428 .route("/health", get(health_check))
429 .layer(CorsLayer::permissive())
430 .layer(TraceLayer::new_for_http())
431 .with_state(state)
432}
433
434fn create_router_with_custom_endpoints(
436 state: ServerState,
437 custom_endpoints: Option<CustomEndpointsConfig>,
438) -> Router {
439 let mut router = Router::new()
440 .route("/v1/chat/completions", post(chat_completions))
441 .route("/v1/models", get(list_models))
442 .route("/health", get(health_check));
443
444 if let Some(config) = custom_endpoints {
446 for endpoint in config.endpoints {
447 let response = endpoint.response.clone();
448 let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
449
450 let handler = move || async move { (status_code, Json(response)) };
451
452 match endpoint.method.to_uppercase().as_str() {
453 "GET" => router = router.route(&endpoint.path, get(handler)),
454 "POST" => router = router.route(&endpoint.path, post(handler)),
455 "PUT" => router = router.route(&endpoint.path, put(handler)),
456 "DELETE" => router = router.route(&endpoint.path, delete(handler)),
457 "PATCH" => router = router.route(&endpoint.path, patch(handler)),
458 _ => {
459 router = router.route(&endpoint.path, get(handler));
461 }
462 }
463 }
464 }
465
466 router
467 .layer(CorsLayer::permissive())
468 .layer(TraceLayer::new_for_http())
469 .with_state(state)
470}
471
472async fn health_check() -> Json<serde_json::Value> {
474 Json(serde_json::json!({
475 "status": "ok",
476 "service": "helios-engine"
477 }))
478}
479
480async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
482 Json(ModelsResponse {
483 object: "list".to_string(),
484 data: vec![ModelInfo {
485 id: state.model_name.clone(),
486 object: "model".to_string(),
487 created: chrono::Utc::now().timestamp() as u64,
488 owned_by: "helios-engine".to_string(),
489 }],
490 })
491}
492
493async fn chat_completions(
495 State(state): State<ServerState>,
496 Json(request): Json<ChatCompletionRequest>,
497) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
498 let messages: Result<Vec<ChatMessage>> = request
500 .messages
501 .into_iter()
502 .map(|msg| {
503 let role = match msg.role.as_str() {
506 "system" => Role::System, "user" => Role::User, "assistant" => Role::Assistant, "tool" => Role::Tool, _ => {
511 return Err(HeliosError::ConfigError(format!(
513 "Invalid role: {}",
514 msg.role
515 )));
516 }
517 };
518 Ok(ChatMessage {
519 role,
520 content: msg.content, name: msg.name, tool_calls: None, tool_call_id: None, })
525 })
526 .collect();
527
528 let messages = messages.map_err(|e| {
529 error!("Failed to convert messages: {}", e);
530 StatusCode::BAD_REQUEST
531 })?;
532
533 let stream = request.stream.unwrap_or(false);
534
535 if stream {
536 return Ok(stream_chat_completion(
538 state,
539 messages,
540 request.model,
541 request.temperature,
542 request.max_tokens,
543 request.stop.clone(),
544 )
545 .into_response());
546 }
547
548 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
550 let created = chrono::Utc::now().timestamp() as u64;
551
552 let messages_clone = messages.clone();
554
555 let response_content = if let Some(agent) = &state.agent {
556 let mut agent = agent.write().await;
558
559 match agent
560 .chat_with_history(
561 messages.clone(),
562 request.temperature,
563 request.max_tokens,
564 request.stop.clone(),
565 )
566 .await
567 {
568 Ok(content) => content,
569 Err(e) => {
570 error!("Agent error: {}", e);
571 return Err(StatusCode::INTERNAL_SERVER_ERROR);
572 }
573 }
574 } else if let Some(llm_client) = &state.llm_client {
575 match llm_client
577 .chat(
578 messages_clone,
579 None,
580 request.temperature,
581 request.max_tokens,
582 request.stop.clone(),
583 )
584 .await
585 {
586 Ok(msg) => msg.content,
587 Err(e) => {
588 error!("LLM error: {}", e);
589 return Err(StatusCode::INTERNAL_SERVER_ERROR);
590 }
591 }
592 } else {
593 return Err(StatusCode::INTERNAL_SERVER_ERROR);
594 };
595
596 let prompt_tokens = estimate_tokens(
598 &messages
599 .iter()
600 .map(|m| m.content.as_str())
601 .collect::<Vec<_>>()
602 .join(" "),
603 );
604 let completion_tokens = estimate_tokens(&response_content);
605
606 let response = ChatCompletionResponse {
607 id: completion_id,
608 object: "chat.completion".to_string(),
609 created,
610 model: request.model,
611 choices: vec![CompletionChoice {
612 index: 0,
613 message: OpenAIMessageResponse {
614 role: "assistant".to_string(),
615 content: response_content,
616 },
617 finish_reason: "stop".to_string(),
618 }],
619 usage: Usage {
620 prompt_tokens,
621 completion_tokens,
622 total_tokens: prompt_tokens + completion_tokens,
623 },
624 };
625
626 Ok(Json(response).into_response())
627}
628
629fn stream_chat_completion(
631 state: ServerState,
632 messages: Vec<ChatMessage>,
633 model: String,
634 temperature: Option<f32>,
635 max_tokens: Option<u32>,
636 stop: Option<Vec<String>>,
637) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
638 let (tx, rx) = tokio::sync::mpsc::channel(100);
639 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
640 let created = chrono::Utc::now().timestamp() as u64;
641
642 tokio::spawn(async move {
643 let on_chunk = |chunk: &str| {
644 let event = Event::default()
645 .json_data(serde_json::json!({
646 "id": completion_id,
647 "object": "chat.completion.chunk",
648 "created": created,
649 "model": model,
650 "choices": [{
651 "index": 0,
652 "delta": {
653 "content": chunk
654 },
655 "finish_reason": null
656 }]
657 }))
658 .unwrap();
659 let _ = tx.try_send(Ok(event));
660 };
661
662 if let Some(agent) = &state.agent {
663 let mut agent = agent.write().await;
665
666 match agent
667 .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
668 .await
669 {
670 Ok(_) => {
671 }
674 Err(e) => {
675 error!("Agent streaming error: {}", e);
676 }
677 }
678 } else if let Some(llm_client) = &state.llm_client {
679 match llm_client
681 .chat_stream(
682 messages,
683 None,
684 temperature,
685 max_tokens,
686 stop.clone(),
687 on_chunk,
688 )
689 .await
690 {
691 Ok(_) => {}
692 Err(e) => {
693 error!("LLM streaming error: {}", e);
694 }
695 }
696 };
697
698 let final_event = Event::default()
700 .json_data(serde_json::json!({
701 "id": completion_id,
702 "object": "chat.completion.chunk",
703 "created": created,
704 "model": model,
705 "choices": [{
706 "index": 0,
707 "delta": {},
708 "finish_reason": "stop"
709 }]
710 }))
711 .unwrap();
712 let _ = tx.send(Ok(final_event)).await;
713 });
714
715 Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
716}
717
718pub fn estimate_tokens(text: &str) -> u32 {
721 (text.len() as f32 / 4.0).ceil() as u32
723}