use crate::agent::Agent;
use crate::chat::{ChatMessage, Role};
use crate::config::Config;
use crate::error::{HeliosError, Result};
use crate::llm::{LLMClient, LLMProviderType};
use axum::{
extract::State,
http::StatusCode,
response::{
sse::{Event, Sse},
IntoResponse,
},
routing::{delete, get, patch, post, put},
Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tracing::{error, info};
use uuid::Uuid;
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub stop: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Serialize)]
pub struct CompletionChoice {
pub index: u32,
pub message: OpenAIMessageResponse,
pub finish_reason: String,
}
#[derive(Debug, Serialize)]
pub struct OpenAIMessageResponse {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Serialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
#[derive(Debug, Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CustomEndpoint {
pub method: String,
pub path: String,
pub response: serde_json::Value,
#[serde(default = "default_status_code")]
pub status_code: u16,
}
fn default_status_code() -> u16 {
200
}
#[derive(Debug, Clone, Deserialize)]
pub struct CustomEndpointsConfig {
pub endpoints: Vec<CustomEndpoint>,
}
impl CustomEndpointsConfig {
pub fn new() -> Self {
Self {
endpoints: Vec::new(),
}
}
pub fn add_endpoint(mut self, endpoint: CustomEndpoint) -> Self {
self.endpoints.push(endpoint);
self
}
}
impl Default for CustomEndpointsConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct ServerState {
pub llm_client: Option<Arc<LLMClient>>,
pub agent: Option<Arc<RwLock<Agent>>>,
pub model_name: String,
}
impl ServerState {
pub fn with_llm_client(llm_client: LLMClient, model_name: String) -> Self {
Self {
llm_client: Some(Arc::new(llm_client)),
agent: None,
model_name,
}
}
pub fn with_agent(agent: Agent, model_name: String) -> Self {
Self {
llm_client: None,
agent: Some(Arc::new(RwLock::new(agent))),
model_name,
}
}
}
pub async fn start_server(config: Config, address: &str) -> Result<()> {
#[cfg(feature = "local")]
let provider_type = if let Some(local_config) = config.local.clone() {
LLMProviderType::Local(local_config)
} else {
LLMProviderType::Remote(config.llm.clone())
};
#[cfg(not(feature = "local"))]
let provider_type = LLMProviderType::Remote(config.llm.clone());
let llm_client = LLMClient::new(provider_type).await?;
#[cfg(feature = "local")]
let model_name = config
.local
.as_ref()
.map(|_| "local-model".to_string())
.unwrap_or_else(|| config.llm.model_name.clone());
#[cfg(not(feature = "local"))]
let model_name = config.llm.model_name.clone();
let state = ServerState::with_llm_client(llm_client, model_name);
let app = create_router(state);
info!("🚀 Starting Helios Engine server on http://{}", address);
info!("📡 OpenAI-compatible API endpoints:");
info!(" POST /v1/chat/completions");
info!(" GET /v1/models");
let listener = tokio::net::TcpListener::bind(address)
.await
.map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
axum::serve(listener, app)
.await
.map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
Ok(())
}
pub async fn start_server_with_agent(
agent: Agent,
model_name: String,
address: &str,
) -> Result<()> {
let state = ServerState::with_agent(agent, model_name);
let app = create_router(state);
info!(
"🚀 Starting Helios Engine server with agent on http://{}",
address
);
info!("📡 OpenAI-compatible API endpoints:");
info!(" POST /v1/chat/completions");
info!(" GET /v1/models");
let listener = tokio::net::TcpListener::bind(address)
.await
.map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
axum::serve(listener, app)
.await
.map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
Ok(())
}
pub async fn start_server_with_custom_endpoints(
config: Config,
address: &str,
custom_endpoints: Option<CustomEndpointsConfig>,
) -> Result<()> {
#[cfg(feature = "local")]
let provider_type = if let Some(local_config) = config.local.clone() {
LLMProviderType::Local(local_config)
} else {
LLMProviderType::Remote(config.llm.clone())
};
#[cfg(not(feature = "local"))]
let provider_type = LLMProviderType::Remote(config.llm.clone());
let llm_client = LLMClient::new(provider_type).await?;
#[cfg(feature = "local")]
let model_name = config
.local
.as_ref()
.map(|_| "local-model".to_string())
.unwrap_or_else(|| config.llm.model_name.clone());
#[cfg(not(feature = "local"))]
let model_name = config.llm.model_name.clone();
let state = ServerState::with_llm_client(llm_client, model_name);
let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
info!("🚀 Starting Helios Engine server on http://{}", address);
info!("📡 OpenAI-compatible API endpoints:");
info!(" POST /v1/chat/completions");
info!(" GET /v1/models");
if let Some(config) = &custom_endpoints {
info!("📡 Custom endpoints:");
for endpoint in &config.endpoints {
info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
}
}
let listener = tokio::net::TcpListener::bind(address)
.await
.map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
axum::serve(listener, app)
.await
.map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
Ok(())
}
pub async fn start_server_with_agent_and_custom_endpoints(
agent: Agent,
model_name: String,
address: &str,
custom_endpoints: Option<CustomEndpointsConfig>,
) -> Result<()> {
let state = ServerState::with_agent(agent, model_name);
let app = create_router_with_custom_endpoints(state, custom_endpoints.clone());
info!(
"🚀 Starting Helios Engine server with agent on http://{}",
address
);
info!("📡 OpenAI-compatible API endpoints:");
info!(" POST /v1/chat/completions");
info!(" GET /v1/models");
if let Some(config) = &custom_endpoints {
info!("📡 Custom endpoints:");
for endpoint in &config.endpoints {
info!(" {} {}", endpoint.method.to_uppercase(), endpoint.path);
}
}
let listener = tokio::net::TcpListener::bind(address)
.await
.map_err(|e| HeliosError::ConfigError(format!("Failed to bind to {}: {}", address, e)))?;
axum::serve(listener, app)
.await
.map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
Ok(())
}
pub struct ServerBuilder {
agent: Option<Agent>,
model_name: String,
address: String,
endpoints: Vec<crate::endpoint_builder::CustomEndpoint>,
}
impl ServerBuilder {
pub fn with_agent(agent: Agent, model_name: impl Into<String>) -> Self {
Self {
agent: Some(agent),
model_name: model_name.into(),
address: "127.0.0.1:8000".to_string(),
endpoints: Vec::new(),
}
}
pub fn address(mut self, address: impl Into<String>) -> Self {
self.address = address.into();
self
}
pub fn endpoint(mut self, endpoint: crate::endpoint_builder::CustomEndpoint) -> Self {
self.endpoints.push(endpoint);
self
}
pub fn endpoints(mut self, endpoints: Vec<crate::endpoint_builder::CustomEndpoint>) -> Self {
self.endpoints.extend(endpoints);
self
}
pub fn with_endpoints(mut self, endpoints: &[crate::endpoint_builder::CustomEndpoint]) -> Self {
self.endpoints.extend_from_slice(endpoints);
self
}
pub async fn serve(self) -> Result<()> {
let agent = self.agent.expect("Agent must be set");
let state = ServerState::with_agent(agent, self.model_name.clone());
let app = create_router_with_new_endpoints(state, self.endpoints);
info!(
"🚀 Starting Helios Engine server with agent on http://{}",
self.address
);
info!("📡 OpenAI-compatible API endpoints:");
info!(" POST /v1/chat/completions");
info!(" GET /v1/models");
let listener = tokio::net::TcpListener::bind(&self.address)
.await
.map_err(|e| {
HeliosError::ConfigError(format!("Failed to bind to {}: {}", self.address, e))
})?;
axum::serve(listener, app)
.await
.map_err(|e| HeliosError::ConfigError(format!("Server error: {}", e)))?;
Ok(())
}
}
pub fn load_custom_endpoints_config(path: &str) -> Result<CustomEndpointsConfig> {
let content = std::fs::read_to_string(path).map_err(|e| {
HeliosError::ConfigError(format!(
"Failed to read custom endpoints config file '{}': {}",
path, e
))
})?;
toml::from_str(&content).map_err(|e| {
HeliosError::ConfigError(format!(
"Failed to parse custom endpoints config file '{}': {}",
path, e
))
})
}
fn create_router(state: ServerState) -> Router {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.route("/health", get(health_check))
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http())
.with_state(state)
}
fn create_router_with_custom_endpoints(
state: ServerState,
custom_endpoints: Option<CustomEndpointsConfig>,
) -> Router {
let mut router = Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.route("/health", get(health_check));
if let Some(config) = custom_endpoints {
for endpoint in config.endpoints {
let response = endpoint.response.clone();
let status_code = StatusCode::from_u16(endpoint.status_code).unwrap_or(StatusCode::OK);
let handler = move || async move { (status_code, Json(response)) };
match endpoint.method.to_uppercase().as_str() {
"GET" => router = router.route(&endpoint.path, get(handler)),
"POST" => router = router.route(&endpoint.path, post(handler)),
"PUT" => router = router.route(&endpoint.path, put(handler)),
"DELETE" => router = router.route(&endpoint.path, delete(handler)),
"PATCH" => router = router.route(&endpoint.path, patch(handler)),
_ => {
router = router.route(&endpoint.path, get(handler));
}
}
}
}
router
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http())
.with_state(state)
}
fn create_router_with_new_endpoints(
state: ServerState,
endpoints: Vec<crate::endpoint_builder::CustomEndpoint>,
) -> Router {
use crate::endpoint_builder::HttpMethod;
let mut router = Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.route("/health", get(health_check));
for endpoint in endpoints {
let handler_fn = endpoint.handler.clone();
let handler = move || {
let handler_fn = handler_fn.clone();
async move {
let response = handler_fn(None);
response.into_response()
}
};
match endpoint.method {
HttpMethod::Get => router = router.route(&endpoint.path, get(handler)),
HttpMethod::Post => router = router.route(&endpoint.path, post(handler)),
HttpMethod::Put => router = router.route(&endpoint.path, put(handler)),
HttpMethod::Delete => router = router.route(&endpoint.path, delete(handler)),
HttpMethod::Patch => router = router.route(&endpoint.path, patch(handler)),
}
if let Some(desc) = &endpoint.description {
info!(
" {} {} - {}",
match endpoint.method {
HttpMethod::Get => "GET",
HttpMethod::Post => "POST",
HttpMethod::Put => "PUT",
HttpMethod::Delete => "DELETE",
HttpMethod::Patch => "PATCH",
},
endpoint.path,
desc
);
}
}
router
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http())
.with_state(state)
}
async fn health_check() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"service": "helios-engine"
}))
}
async fn list_models(State(state): State<ServerState>) -> Json<ModelsResponse> {
Json(ModelsResponse {
object: "list".to_string(),
data: vec![ModelInfo {
id: state.model_name.clone(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "helios-engine".to_string(),
}],
})
}
async fn chat_completions(
State(state): State<ServerState>,
Json(request): Json<ChatCompletionRequest>,
) -> std::result::Result<impl axum::response::IntoResponse, StatusCode> {
let messages: Result<Vec<ChatMessage>> = request
.messages
.into_iter()
.map(|msg| {
let role = match msg.role.as_str() {
"system" => Role::System, "user" => Role::User, "assistant" => Role::Assistant, "tool" => Role::Tool, _ => {
return Err(HeliosError::ConfigError(format!(
"Invalid role: {}",
msg.role
)));
}
};
Ok(ChatMessage {
role,
content: msg.content, name: msg.name, tool_calls: None, tool_call_id: None, })
})
.collect();
let messages = messages.map_err(|e| {
error!("Failed to convert messages: {}", e);
StatusCode::BAD_REQUEST
})?;
let stream = request.stream.unwrap_or(false);
if stream {
return Ok(stream_chat_completion(
state,
messages,
request.model,
request.temperature,
request.max_tokens,
request.stop.clone(),
)
.into_response());
}
let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
let created = chrono::Utc::now().timestamp() as u64;
let messages_clone = messages.clone();
let response_content = if let Some(agent) = &state.agent {
let mut agent = agent.write().await;
match agent
.chat_with_history(
messages.clone(),
request.temperature,
request.max_tokens,
request.stop.clone(),
)
.await
{
Ok(content) => content,
Err(e) => {
error!("Agent error: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
} else if let Some(llm_client) = &state.llm_client {
match llm_client
.chat(
messages_clone,
None,
request.temperature,
request.max_tokens,
request.stop.clone(),
)
.await
{
Ok(msg) => msg.content,
Err(e) => {
error!("LLM error: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
} else {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
let prompt_tokens = estimate_tokens(
&messages
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join(" "),
);
let completion_tokens = estimate_tokens(&response_content);
let response = ChatCompletionResponse {
id: completion_id,
object: "chat.completion".to_string(),
created,
model: request.model,
choices: vec![CompletionChoice {
index: 0,
message: OpenAIMessageResponse {
role: "assistant".to_string(),
content: response_content,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
Ok(Json(response).into_response())
}
fn stream_chat_completion(
state: ServerState,
messages: Vec<ChatMessage>,
model: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
stop: Option<Vec<String>>,
) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let completion_id = format!("chatcmpl-{}", Uuid::new_v4());
let created = chrono::Utc::now().timestamp() as u64;
tokio::spawn(async move {
let on_chunk = |chunk: &str| {
let event = Event::default()
.json_data(serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": chunk
},
"finish_reason": null
}]
}))
.unwrap();
let _ = tx.try_send(Ok(event));
};
if let Some(agent) = &state.agent {
let mut agent = agent.write().await;
match agent
.chat_stream_with_history(messages, temperature, max_tokens, stop.clone(), on_chunk)
.await
{
Ok(_) => {
}
Err(e) => {
error!("Agent streaming error: {}", e);
}
}
} else if let Some(llm_client) = &state.llm_client {
match llm_client
.chat_stream(
messages,
None,
temperature,
max_tokens,
stop.clone(),
on_chunk,
)
.await
{
Ok(_) => {}
Err(e) => {
error!("LLM streaming error: {}", e);
}
}
};
let final_event = Event::default()
.json_data(serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}))
.unwrap();
let _ = tx.send(Ok(final_event)).await;
});
Sse::new(ReceiverStream::new(rx)).keep_alive(axum::response::sse::KeepAlive::default())
}
pub fn estimate_tokens(text: &str) -> u32 {
(text.len() as f32 / 4.0).ceil() as u32
}