1use crate::agent::Agent;
27use crate::chat::{ChatMessage, Role};
28use crate::config::Config;
29use crate::error::{HeliosError, Result};
30use crate::llm::{LLMClient, LLMProviderType};
31use axum::{
32 extract::State,
33 http::StatusCode,
34 response::{
35 sse::{Event, Sse},
36 IntoResponse,
37 },
38 routing::{delete, get, patch, post, put},
39 Json, Router,
40};
41use futures::stream::Stream;
42use serde::{Deserialize, Serialize};
43use std::convert::Infallible;
44use std::sync::Arc;
45use tokio::sync::RwLock;
46use tokio_stream::wrappers::ReceiverStream;
47use tower_http::cors::CorsLayer;
48use tower_http::trace::TraceLayer;
49use tracing::{error, info};
50use uuid::Uuid;
51
52#[derive(Debug, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub struct ChatCompletionRequest {
56 pub model: String,
58 pub messages: Vec<OpenAIMessage>,
60 #[serde(default)]
62 pub temperature: Option<f32>,
63 #[serde(default)]
65 pub max_tokens: Option<u32>,
66 #[serde(default)]
68 pub stream: Option<bool>,
69 #[serde(default)]
71 pub stop: Option<Vec<String>>,
72}
73
74#[derive(Debug, Deserialize)]
76pub struct OpenAIMessage {
77 pub role: String,
79 pub content: String,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub name: Option<String>,
84}
85
86#[derive(Debug, Serialize)]
88pub struct ChatCompletionResponse {
89 pub id: String,
91 pub object: String,
93 pub created: u64,
95 pub model: String,
97 pub choices: Vec<CompletionChoice>,
99 pub usage: Usage,
101}
102
103#[derive(Debug, Serialize)]
105pub struct CompletionChoice {
106 pub index: u32,
108 pub message: OpenAIMessageResponse,
110 pub finish_reason: String,
112}
113
114#[derive(Debug, Serialize)]
116pub struct OpenAIMessageResponse {
117 pub role: String,
119 pub content: String,
121}
122
123#[derive(Debug, Serialize)]
125pub struct Usage {
126 pub prompt_tokens: u32,
128 pub completion_tokens: u32,
130 pub total_tokens: u32,
132}
133
134#[derive(Debug, Serialize)]
136pub struct ModelInfo {
137 pub id: String,
139 pub object: String,
141 pub created: u64,
143 pub owned_by: String,
145}
146
147#[derive(Debug, Serialize)]
149pub struct ModelsResponse {
150 pub object: String,
152 pub data: Vec<ModelInfo>,
154}
155
156#[derive(Debug, Clone, Deserialize)]
159pub struct CustomEndpoint {
160 pub method: String,
162 pub path: String,
164 pub response: serde_json::Value,
166 #[serde(default = "default_status_code")]
168 pub status_code: u16,
169}
170
171fn default_status_code() -> u16 {
172 200
173}
174
175#[derive(Debug, Clone, Deserialize)]
178pub struct CustomEndpointsConfig {
179 pub endpoints: Vec<CustomEndpoint>,
181}
182
183impl CustomEndpointsConfig {
184 pub fn new() -> Self {
186 Self {
187 endpoints: Vec::new(),
188 }
189 }
190
191 pub fn add_endpoint(mut self, endpoint: CustomEndpoint) -> Self {
193 self.endpoints.push(endpoint);
194 self
195 }
196}
197
198impl Default for CustomEndpointsConfig {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[derive(Clone)]
206pub struct ServerState {
207 pub llm_client: Option<Arc<LLMClient>>,
209 pub agent: Option<Arc<RwLock<Agent>>>,
211 pub model_name: String,
213}
214
215impl ServerState {
216 pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
218 Self {
219 llm_client: Some(Arc::new(llm_client)),
220 agent: None,
221 model_name,
222 }
223 }
224
225 pub fn with_agent(agent: Agent, model_name: String) -> Self {
227 Self {
228 llm_client: None,
229 agent: Some(Arc::new(RwLock::new(agent))),
230 model_name,
231 }
232 }
233}
234
235pub async fn start_server(config: Config, address: &str) -> Result<()> {
246 #[cfg(feature = "local")]
247 let provider_type = if let Some(local_config) = config.local.clone() {
248 LLMProviderType::Local(local_config)
249 } else {
250 LLMProviderType::Remote(config.llm.clone())
251 };
252
253 #[cfg(not(feature = "local"))]
254 let provider_type = LLMProviderType::Remote(config.llm.clone());
255
256 let llm_client = LLMClient::new(provider_type).await?;
257
258 #[cfg(feature = "local")]
259 let model_name = config
260 .local
261 .as_ref()
262 .map(|_| "local-model".to_string())
263 .unwrap_or_else(|| config.llm.model_name.clone());
264
265 #[cfg(not(feature = "local"))]
266 let model_name = config.llm.model_name.clone();
267
268 let state = ServerState::with_llm_client(llm_client, model_name);
269
270 let app = create_router(state);
271
272 info!("🚀 Starting Helios Engine server on http://{}", address);
273 info!("📡 OpenAI-compatible API endpoints:");
274 info!(" POST /v1/chat/completions");
275 info!(" GET /v1/models");
276
277 let listener = tokio::net::TcpListener::bind(address)
278 .await
279 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
280
281 axum::serve(listener, app)
282 .await
283 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
284
285 Ok(())
286}
287
288pub async fn start_server_with_agent(
300 agent: Agent,
301 model_name: String,
302 address: &str,
303) -> Result<()> {
304 let state = ServerState::with_agent(agent, model_name);
305
306 let app = create_router(state);
307
308 info!(
309 "🚀 Starting Helios Engine server with agent on http://{}",
310 address
311 );
312 info!("📡 OpenAI-compatible API endpoints:");
313 info!(" POST /v1/chat/completions");
314 info!(" GET /v1/models");
315
316 let listener = tokio::net::TcpListener::bind(address)
317 .await
318 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
319
320 axum::serve(listener, app)
321 .await
322 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
323
324 Ok(())
325}
326
327pub async fn start_server_with_custom_endpoints(
339 config: Config,
340 address: &str,
341 custom_endpoints: Option<CustomEndpointsConfig>,
342) -> Result<()> {
343 #[cfg(feature = "local")]
344 let provider_type = if let Some(local_config) = config.local.clone() {
345 LLMProviderType::Local(local_config)
346 } else {
347 LLMProviderType::Remote(config.llm.clone())
348 };
349
350 #[cfg(not(feature = "local"))]
351 let provider_type = LLMProviderType::Remote(config.llm.clone());
352
353 let llm_client = LLMClient::new(provider_type).await?;
354
355 #[cfg(feature = "local")]
356 let model_name = config
357 .local
358 .as_ref()
359 .map(|_| "local-model".to_string())
360 .unwrap_or_else(|| config.llm.model_name.clone());
361
362 #[cfg(not(feature = "local"))]
363 let model_name = config.llm.model_name.clone();
364
365 let state = ServerState::with_llm_client(llm_client, model_name);
366
367 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
368
369 info!("🚀 Starting Helios Engine server on http://{}", address);
370 info!("📡 OpenAI-compatible API endpoints:");
371 info!(" POST /v1/chat/completions");
372 info!(" GET /v1/models");
373
374 if let Some(config) = &custom_endpoints {
375 info!("📡 Custom endpoints:");
376 for endpoint in &config.endpoints {
377 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
378 }
379 }
380
381 let listener = tokio::net::TcpListener::bind(address)
382 .await
383 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
384
385 axum::serve(listener, app)
386 .await
387 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
388
389 Ok(())
390}
391
392pub async fn start_server_with_agent_and_custom_endpoints(
405 agent: Agent,
406 model_name: String,
407 address: &str,
408 custom_endpoints: Option<CustomEndpointsConfig>,
409) -> Result<()> {
410 let state = ServerState::with_agent(agent, model_name);
411
412 let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
413
414 info!(
415 "🚀 Starting Helios Engine server with agent on http://{}",
416 address
417 );
418 info!("📡 OpenAI-compatible API endpoints:");
419 info!(" POST /v1/chat/completions");
420 info!(" GET /v1/models");
421
422 if let Some(config) = &custom_endpoints {
423 info!("📡 Custom endpoints:");
424 for endpoint in &config.endpoints {
425 info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
426 }
427 }
428
429 let listener = tokio::net::TcpListener::bind(address)
430 .await
431 .map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
432
433 axum::serve(listener, app)
434 .await
435 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
436
437 Ok(())
438}
439
440pub struct ServerBuilder {
443 agent: Option<Agent>,
444 model_name: String,
445 address: String,
446 endpoints: Vec<crate::endpoint_builder::CustomEndpoint>,
447}
448
449impl ServerBuilder {
450 pub fn with_agent(agent: Agent, model_name: impl Into<String>) -> Self {
452 Self {
453 agent: Some(agent),
454 model_name: model_name.into(),
455 address: "127.0.0.1:8000".to_string(),
456 endpoints: Vec::new(),
457 }
458 }
459
460 pub fn address(mut self, address: impl Into<String>) -> Self {
462 self.address = address.into();
463 self
464 }
465
466 pub fn endpoint(mut self, endpoint: crate::endpoint_builder::CustomEndpoint) -> Self {
468 self.endpoints.push(endpoint);
469 self
470 }
471
472 pub fn endpoints(mut self, endpoints: Vec<crate::endpoint_builder::CustomEndpoint>) -> Self {
492 self.endpoints.extend(endpoints);
493 self
494 }
495
496 pub fn with_endpoints(mut self, endpoints: &[crate::endpoint_builder::CustomEndpoint]) -> Self {
514 self.endpoints.extend_from_slice(endpoints);
515 self
516 }
517
518 pub async fn serve(self) -> Result<()> {
520 let agent = self.agent.expect("Agent must be set");
521 let state = ServerState::with_agent(agent, self.model_name.clone());
522
523 let app = create_router_with_new_endpoints(state, self.endpoints);
524
525 info!(
526 "🚀 Starting Helios Engine server with agent on http://{}",
527 self.address
528 );
529 info!("📡 OpenAI-compatible API endpoints:");
530 info!(" POST /v1/chat/completions");
531 info!(" GET /v1/models");
532
533 let listener = tokio::net::TcpListener::bind(&self.address)
534 .await
535 .map_err(|e| {
536 HeliosError::ConfigError(format!("Failed to bind to {}: {}", self.address, e))
537 })?;
538
539 axum::serve(listener, app)
540 .await
541 .map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
542
543 Ok(())
544 }
545}
546
547pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
557 let content = std::fs::read_to_string(path).map_err(|e| {
558 HeliosError::ConfigError(format!(
559 "Failed to read custom endpoints config file '{}': {}",
560 path, e
561 ))
562 })?;
563
564 toml::from_str(&content).map_err(|e| {
565 HeliosError::ConfigError(format!(
566 "Failed to parse custom endpoints config file '{}': {}",
567 path, e
568 ))
569 })
570}
571
572fn create_router(state: ServerState) -> Router {
574 Router::new()
575 .route("/v1/chat/completions", post(chat_completions))
576 .route("/v1/models", get(list_models))
577 .route("/health", get(health_check))
578 .layer(CorsLayer::permissive())
579 .layer(TraceLayer::new_for_http())
580 .with_state(state)
581}
582
583fn create_router_with_custom_endpoints(
585 state: ServerState,
586 custom_endpoints: Option<CustomEndpointsConfig>,
587) -> Router {
588 let mut router = Router::new()
589 .route("/v1/chat/completions", post(chat_completions))
590 .route("/v1/models", get(list_models))
591 .route("/health", get(health_check));
592
593 if let Some(config) = custom_endpoints {
595 for endpoint in config.endpoints {
596 let response = endpoint.response.clone();
597 let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
598
599 let handler = move || async move { (status_code, Json(response)) };
600
601 match endpoint.method.to_uppercase().as_str() {
602 "GET" => router = router.route(&endpoint.path, get(handler)),
603 "POST" => router = router.route(&endpoint.path, post(handler)),
604 "PUT" => router = router.route(&endpoint.path, put(handler)),
605 "DELETE" => router = router.route(&endpoint.path, delete(handler)),
606 "PATCH" => router = router.route(&endpoint.path, patch(handler)),
607 _ => {
608 router = router.route(&endpoint.path, get(handler));
610 }
611 }
612 }
613 }
614
615 router
616 .layer(CorsLayer::permissive())
617 .layer(TraceLayer::new_for_http())
618 .with_state(state)
619}
620
621fn create_router_with_new_endpoints(
623 state: ServerState,
624 endpoints: Vec<crate::endpoint_builder::CustomEndpoint>,
625) -> Router {
626 use crate::endpoint_builder::HttpMethod;
627
628 let mut router = Router::new()
629 .route("/v1/chat/completions", post(chat_completions))
630 .route("/v1/models", get(list_models))
631 .route("/health", get(health_check));
632
633 for endpoint in endpoints {
635 let handler_fn = endpoint.handler.clone();
636
637 let handler = move || {
638 let handler_fn = handler_fn.clone();
639 async move {
640 let response = handler_fn(None);
641 response.into_response()
642 }
643 };
644
645 match endpoint.method {
646 HttpMethod::Get => router = router.route(&endpoint.path, get(handler)),
647 HttpMethod::Post => router = router.route(&endpoint.path, post(handler)),
648 HttpMethod::Put => router = router.route(&endpoint.path, put(handler)),
649 HttpMethod::Delete => router = router.route(&endpoint.path, delete(handler)),
650 HttpMethod::Patch => router = router.route(&endpoint.path, patch(handler)),
651 }
652
653 if let Some(desc) = &endpoint.description {
654 info!(
655 " {} {} - {}",
656 match endpoint.method {
657 HttpMethod::Get => "GET",
658 HttpMethod::Post => "POST",
659 HttpMethod::Put => "PUT",
660 HttpMethod::Delete => "DELETE",
661 HttpMethod::Patch => "PATCH",
662 },
663 endpoint.path,
664 desc
665 );
666 }
667 }
668
669 router
670 .layer(CorsLayer::permissive())
671 .layer(TraceLayer::new_for_http())
672 .with_state(state)
673}
674
675async fn health_check() -> Json<serde_json::Value> {
677 Json(serde_json::json!({
678 "status": "ok",
679 "service": "helios-engine"
680 }))
681}
682
683async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
685 Json(ModelsResponse {
686 object: "list".to_string(),
687 data: vec![ModelInfo {
688 id: state.model_name.clone(),
689 object: "model".to_string(),
690 created: chrono::Utc::now().timestamp() as u64,
691 owned_by: "helios-engine".to_string(),
692 }],
693 })
694}
695
696async fn chat_completions(
698 State(state): State<ServerState>,
699 Json(request): Json<ChatCompletionRequest>,
700) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
701 let messages: Result<Vec<ChatMessage>> = request
703 .messages
704 .into_iter()
705 .map(|msg| {
706 let role = match msg.role.as_str() {
709 "system" => Role::System, "user" => Role::User, "assistant" => Role::Assistant, "tool" => Role::Tool, _ => {
714 return Err(HeliosError::ConfigError(format!(
716 "Invalid role: {}",
717 msg.role
718 )));
719 }
720 };
721 Ok(ChatMessage {
722 role,
723 content: msg.content, name: msg.name, tool_calls: None, tool_call_id: None, })
728 })
729 .collect();
730
731 let messages = messages.map_err(|e| {
732 error!("Failed to convert messages: {}", e);
733 StatusCode::BAD_REQUEST
734 })?;
735
736 let stream = request.stream.unwrap_or(false);
737
738 if stream {
739 return Ok(stream_chat_completion(
741 state,
742 messages,
743 request.model,
744 request.temperature,
745 request.max_tokens,
746 request.stop.clone(),
747 )
748 .into_response());
749 }
750
751 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
753 let created = chrono::Utc::now().timestamp() as u64;
754
755 let messages_clone = messages.clone();
757
758 let response_content = if let Some(agent) = &state.agent {
759 let mut agent = agent.write().await;
761
762 match agent
763 .chat_with_history(
764 messages.clone(),
765 request.temperature,
766 request.max_tokens,
767 request.stop.clone(),
768 )
769 .await
770 {
771 Ok(content) => content,
772 Err(e) => {
773 error!("Agent error: {}", e);
774 return Err(StatusCode::INTERNAL_SERVER_ERROR);
775 }
776 }
777 } else if let Some(llm_client) = &state.llm_client {
778 match llm_client
780 .chat(
781 messages_clone,
782 None,
783 request.temperature,
784 request.max_tokens,
785 request.stop.clone(),
786 )
787 .await
788 {
789 Ok(msg) => msg.content,
790 Err(e) => {
791 error!("LLM error: {}", e);
792 return Err(StatusCode::INTERNAL_SERVER_ERROR);
793 }
794 }
795 } else {
796 return Err(StatusCode::INTERNAL_SERVER_ERROR);
797 };
798
799 let prompt_tokens = estimate_tokens(
801 &messages
802 .iter()
803 .map(|m| m.content.as_str())
804 .collect::<Vec<_>>()
805 .join(" "),
806 );
807 let completion_tokens = estimate_tokens(&response_content);
808
809 let response = ChatCompletionResponse {
810 id: completion_id,
811 object: "chat.completion".to_string(),
812 created,
813 model: request.model,
814 choices: vec![CompletionChoice {
815 index: 0,
816 message: OpenAIMessageResponse {
817 role: "assistant".to_string(),
818 content: response_content,
819 },
820 finish_reason: "stop".to_string(),
821 }],
822 usage: Usage {
823 prompt_tokens,
824 completion_tokens,
825 total_tokens: prompt_tokens + completion_tokens,
826 },
827 };
828
829 Ok(Json(response).into_response())
830}
831
832fn stream_chat_completion(
834 state: ServerState,
835 messages: Vec<ChatMessage>,
836 model: String,
837 temperature: Option<f32>,
838 max_tokens: Option<u32>,
839 stop: Option<Vec<String>>,
840) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
841 let (tx, rx) = tokio::sync::mpsc::channel(100);
842 let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
843 let created = chrono::Utc::now().timestamp() as u64;
844
845 tokio::spawn(async move {
846 let on_chunk = |chunk: &str| {
847 let event = Event::default()
848 .json_data(serde_json::json!({
849 "id": completion_id,
850 "object": "chat.completion.chunk",
851 "created": created,
852 "model": model,
853 "choices": [{
854 "index": 0,
855 "delta": {
856 "content": chunk
857 },
858 "finish_reason": null
859 }]
860 }))
861 .unwrap();
862 let _ = tx.try_send(Ok(event));
863 };
864
865 if let Some(agent) = &state.agent {
866 let mut agent = agent.write().await;
868
869 match agent
870 .chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
871 .await
872 {
873 Ok(_) => {
874 }
877 Err(e) => {
878 error!("Agent streaming error: {}", e);
879 }
880 }
881 } else if let Some(llm_client) = &state.llm_client {
882 match llm_client
884 .chat_stream(
885 messages,
886 None,
887 temperature,
888 max_tokens,
889 stop.clone(),
890 on_chunk,
891 )
892 .await
893 {
894 Ok(_) => {}
895 Err(e) => {
896 error!("LLM streaming error: {}", e);
897 }
898 }
899 };
900
901 let final_event = Event::default()
903 .json_data(serde_json::json!({
904 "id": completion_id,
905 "object": "chat.completion.chunk",
906 "created": created,
907 "model": model,
908 "choices": [{
909 "index": 0,
910 "delta": {},
911 "finish_reason": "stop"
912 }]
913 }))
914 .unwrap();
915 let _ = tx.send(Ok(final_event)).await;
916 });
917
918 Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
919}
920
921pub fn estimate_tokens(text: &str) -> u32 {
924 (text.len() as f32 / 4.0).ceil() as u32
926}