use crate::{
agents::{registry::AgentRegistry, router::RouterAgent, Agent},
api::handlers::user_agents::resolve_agent,
auth::middleware::AuthUser,
db::agent_runs,
memory::estimate_tokens,
types::{
AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
UserMemory,
},
utils::toml_config::AgentConfig,
AppState,
};
use axum::{extract::State, response::Response, Extension, Json};
use uuid::Uuid;
#[utoipa::path(
post,
path = "/api/chat",
request_body = ChatRequest,
responses(
(status = 200, description = "Chat response", body = ChatResponse),
(status = 400, description = "Invalid input"),
(status = 401, description = "Unauthorized")
),
tag = "chat",
security(("bearer" = []))
)]
pub async fn chat(
State(state): State<AppState>,
AuthUser(claims): AuthUser,
tenant_ctx: Option<Extension<crate::models::TenantContext>>,
Json(payload): Json<ChatRequest>,
) -> Result<Response> {
let context_id = payload
.context_id
.unwrap_or_else(|| Uuid::new_v4().to_string());
if !state.db.conversation_exists(&context_id).await? {
state
.db
.create_conversation(&context_id, &claims.sub, None)
.await?;
}
let history = state.db.get_conversation_history(&context_id).await?;
let history_input_tokens: usize = history.iter().map(|m| estimate_tokens(&m.content)).sum();
let memory_facts = state.db.get_user_memory(&claims.sub).await?;
let preferences = state.db.get_user_preferences(&claims.sub).await?;
let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
Some(UserMemory {
user_id: claims.sub.clone(),
preferences,
facts: memory_facts,
})
} else {
None
};
let agent_context = AgentContext {
user_id: claims.sub.clone(),
session_id: context_id.clone(),
conversation_history: history.clone(),
user_memory,
};
let agent_type = if let Some(at) = payload.agent_type {
at
} else {
let config = state.config_manager.config();
let router_model = config
.get_agent("router")
.map(|a| a.model.as_str())
.unwrap_or("fast");
let router_llm = match state
.provider_registry
.create_client_for_model(router_model)
.await
{
Ok(client) => client,
Err(_) => state.llm_factory.create_default().await?,
};
let router = RouterAgent::new(router_llm);
router.route(&payload.message, &agent_context).await?
};
let agent_name_for_run = AgentRegistry::type_to_name(&agent_type).to_string();
let start = std::time::Instant::now();
let (response, usage) = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
let duration_ms = start.elapsed().as_millis() as i64;
let msg_id = Uuid::new_v4().to_string();
state
.db
.add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
.await?;
let resp_id = Uuid::new_v4().to_string();
state
.db
.add_message(
&resp_id,
&context_id,
MessageRole::Assistant,
&response.response,
)
.await?;
let (input_tokens, output_tokens) = if let Some(u) = usage {
(u.prompt_tokens, u.completion_tokens)
} else {
(
(history_input_tokens + estimate_tokens(&payload.message)) as u32,
estimate_tokens(&response.response) as u32,
)
};
{
let pool = state.tenant_db.pool().clone();
let agent_name = agent_name_for_run;
let user_id = claims.sub.clone();
let tenant_id_for_run = tenant_ctx
.map(|Extension(tc)| tc.tenant_id.clone())
.unwrap_or_else(|| "system".to_string());
let itok = input_tokens as i64;
let otok = output_tokens as i64;
tokio::spawn(async move {
let _ = agent_runs::insert_agent_run(
&pool,
&tenant_id_for_run,
&agent_name,
Some(&user_id),
"completed",
itok,
otok,
duration_ms,
None,
"unknown",
"unknown",
false,
)
.await;
});
}
let body = Json(response);
let mut response = body.into_response();
response.headers_mut().insert(
axum::http::HeaderName::from_static("x-input-tokens"),
axum::http::HeaderValue::from(input_tokens),
);
response.headers_mut().insert(
axum::http::HeaderName::from_static("x-output-tokens"),
axum::http::HeaderValue::from(output_tokens),
);
Ok(response)
}
async fn execute_agent(
agent_type: AgentType,
message: &str,
context: &AgentContext,
state: &AppState,
) -> Result<(ChatResponse, Option<crate::llm::client::TokenUsage>)> {
let agent_name = AgentRegistry::type_to_name(&agent_type);
if agent_type == AgentType::Router {
return Err(AppError::InvalidInput(
"Router agent cannot be called directly".to_string(),
));
}
let (user_agent, source) =
resolve_agent(state, &context.user_id, agent_name.to_string()).await?;
let config = AgentConfig {
model: user_agent.model.clone(),
system_prompt: user_agent.system_prompt.clone(),
tools: user_agent.tools_vec(),
max_tool_iterations: user_agent.max_tool_iterations as usize,
parallel_tools: user_agent.parallel_tools,
extra: std::collections::HashMap::new(),
};
let agent = state
.agent_registry
.create_agent_from_config(agent_name, &config)
.await?;
let agent_resp = agent.execute(message, context).await?;
Ok((ChatResponse {
response: agent_resp.content,
agent: format!("{:?} ({})", agent_type, source),
context_id: context.session_id.clone(),
sources: None,
}, agent_resp.usage))
}
#[utoipa::path(
get,
path = "/api/memory",
responses(
(status = 200, description = "User memory retrieved successfully"),
(status = 401, description = "Unauthorized")
),
tag = "chat",
security(("bearer" = []))
)]
pub async fn get_user_memory(
State(state): State<AppState>,
AuthUser(claims): AuthUser,
) -> Result<Json<UserMemory>> {
let facts = state.db.get_user_memory(&claims.sub).await?;
let preferences = state.db.get_user_preferences(&claims.sub).await?;
Ok(Json(UserMemory {
user_id: claims.sub,
preferences,
facts,
}))
}
#[derive(serde::Serialize)]
pub struct StreamEvent {
pub event: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub agent: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[utoipa::path(
post,
path = "/api/chat/stream",
request_body = ChatRequest,
responses(
(status = 200, description = "Streaming chat response"),
(status = 400, description = "Invalid input"),
(status = 401, description = "Unauthorized")
),
tag = "chat",
security(("bearer" = []))
)]
pub async fn chat_stream(
State(state): State<AppState>,
AuthUser(claims): AuthUser,
Json(payload): Json<ChatRequest>,
) -> axum::response::Sse<
impl futures::Stream<
Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
>,
> {
use axum::response::sse::{Event, Sse};
let context_id = payload
.context_id
.clone()
.unwrap_or_else(|| Uuid::new_v4().to_string());
let state_clone = state.clone();
let claims_clone = claims.clone();
let message = payload.message.clone();
let agent_type_req = payload.agent_type;
let context_id_clone = context_id.clone();
let stream = async_stream::stream! {
if !state_clone.db.conversation_exists(&context_id_clone).await.unwrap_or(false) {
if let Err(e) = state_clone
.db
.create_conversation(&context_id_clone, &claims_clone.sub, None)
.await {
tracing::warn!("Failed to create conversation {}: {}", context_id_clone, e);
}
}
let history = state_clone.db.get_conversation_history(&context_id_clone).await.unwrap_or_else(|e| {
tracing::warn!("Failed to get conversation history for {}: {}", context_id_clone, e);
vec![]
});
let memory_facts = state_clone.db.get_user_memory(&claims_clone.sub).await.unwrap_or_else(|e| {
tracing::warn!("Failed to get user memory for {}: {}", claims_clone.sub, e);
vec![]
});
let preferences = state_clone.db.get_user_preferences(&claims_clone.sub).await.unwrap_or_else(|e| {
tracing::warn!("Failed to get user preferences for {}: {}", claims_clone.sub, e);
vec![]
});
let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
Some(UserMemory {
user_id: claims_clone.sub.clone(),
preferences,
facts: memory_facts,
})
} else {
None
};
let agent_context = AgentContext {
user_id: claims_clone.sub.clone(),
session_id: context_id_clone.clone(),
conversation_history: history,
user_memory,
};
let agent_type = if let Some(at) = agent_type_req {
at
} else {
let config = state_clone.config_manager.config();
let router_model = config
.get_agent("router")
.map(|a| a.model.as_str())
.unwrap_or("fast");
let router_llm = match state_clone
.provider_registry
.create_client_for_model(router_model)
.await
{
Ok(client) => client,
Err(_) => match state_clone.llm_factory.create_default().await {
Ok(c) => c,
Err(e) => {
let event = StreamEvent {
event: "error".to_string(),
content: None,
agent: None,
context_id: Some(context_id_clone.clone()),
error: Some(format!("Failed to create LLM client: {}", e)),
};
yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
return;
}
},
};
let router = RouterAgent::new(router_llm);
match router.route(&message, &agent_context).await {
Ok(t) => t,
Err(e) => {
let event = StreamEvent {
event: "error".to_string(),
content: None,
agent: None,
context_id: Some(context_id_clone.clone()),
error: Some(format!("Router failed: {}", e)),
};
yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
return;
}
}
};
let agent_name = AgentRegistry::type_to_name(&agent_type);
let start_event = StreamEvent {
event: "start".to_string(),
content: None,
agent: Some(format!("{} (system)", agent_type)),
context_id: Some(context_id_clone.clone()),
error: None,
};
yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
&state_clone,
&claims_clone.sub,
agent_name.to_string(),
).await {
Ok(r) => r,
Err(e) => {
let event = StreamEvent {
event: "error".to_string(),
content: None,
agent: None,
context_id: Some(context_id_clone.clone()),
error: Some(format!("Failed to resolve agent: {}", e)),
};
yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
return;
}
};
let llm = match state_clone
.provider_registry
.create_client_for_model(&user_agent.model)
.await
{
Ok(c) => c,
Err(_) => match state_clone.llm_factory.create_default().await {
Ok(c) => c,
Err(e) => {
let event = StreamEvent {
event: "error".to_string(),
content: None,
agent: None,
context_id: Some(context_id_clone.clone()),
error: Some(format!("Failed to create LLM: {}", e)),
};
yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
return;
}
},
};
let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
let full_prompt = format!(
"{}\n\nUser: {}\nAssistant:",
system_prompt,
message
);
use futures::StreamExt;
let mut full_response = String::new();
match llm.stream(&full_prompt).await {
Ok(mut token_stream) => {
while let Some(token_result) = token_stream.next().await {
match token_result {
Ok(token) => {
full_response.push_str(&token);
let event = StreamEvent {
event: "token".to_string(),
content: Some(token),
agent: None,
context_id: None,
error: None,
};
yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
}
Err(e) => {
let event = StreamEvent {
event: "error".to_string(),
content: None,
agent: None,
context_id: Some(context_id_clone.clone()),
error: Some(format!("Stream error: {}", e)),
};
yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
return;
}
}
}
}
Err(e) => {
let event = StreamEvent {
event: "error".to_string(),
content: None,
agent: None,
context_id: Some(context_id_clone.clone()),
error: Some(format!("Failed to start stream: {}", e)),
};
yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
return;
}
}
let msg_id = Uuid::new_v4().to_string();
if let Err(e) = state_clone
.db
.add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
.await {
tracing::error!("Failed to store user message in conversation {}: {}", context_id_clone, e);
}
let resp_id = Uuid::new_v4().to_string();
if let Err(e) = state_clone
.db
.add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
.await {
tracing::error!("Failed to store assistant message in conversation {}: {}", context_id_clone, e);
}
{
let pool = state_clone.tenant_db.pool().clone();
let tid = claims_clone.sub.clone();
let aname = agent_name.to_string();
let itok = crate::memory::estimate_tokens(&message) as i64;
let otok = crate::memory::estimate_tokens(&full_response) as i64;
let model = user_agent.model.clone();
tokio::spawn(async move {
let _ = crate::db::agent_runs::insert_agent_run(
&pool, &tid, &aname, Some(&tid), "completed",
itok, otok, 0, None, &model, "unknown", true,
)
.await;
});
}
let done_event = StreamEvent {
event: "done".to_string(),
content: None,
agent: Some(format!("{:?} ({})", agent_type, source)),
context_id: Some(context_id_clone),
error: None,
};
yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
};
Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text("keep-alive"),
)
}
use axum::response::IntoResponse;