1
2
3use crate::{
4 agents::{registry::AgentRegistry, router::RouterAgent, Agent},
5 api::handlers::user_agents::resolve_agent,
6 auth::middleware::AuthUser,
7 db::agent_runs,
8 memory::estimate_tokens,
9 types::{
10 AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
11 UserMemory,
12 },
13 utils::toml_config::AgentConfig,
14 AppState,
15};
16use axum::{extract::State, response::Response, Extension, Json};
17use uuid::Uuid;
18
19#[utoipa::path(
21 post,
22 path = "/api/chat",
23 request_body = ChatRequest,
24 responses(
25 (status = 200, description = "Chat response", body = ChatResponse),
26 (status = 400, description = "Invalid input"),
27 (status = 401, description = "Unauthorized")
28 ),
29 tag = "chat",
30 security(("bearer" = []))
31)]
32pub async fn chat(
33 State(state): State<AppState>,
34 AuthUser(claims): AuthUser,
35 tenant_ctx: Option<Extension<crate::models::TenantContext>>,
36 Json(payload): Json<ChatRequest>,
37) -> Result<Response> {
38 let context_id = payload
40 .context_id
41 .unwrap_or_else(|| Uuid::new_v4().to_string());
42
43 if !state.db.conversation_exists(&context_id).await? {
45 state
46 .db
47 .create_conversation(&context_id, &claims.sub, None)
48 .await?;
49 }
50 let history = state.db.get_conversation_history(&context_id).await?;
51 let history_input_tokens: usize = history.iter().map(|m| estimate_tokens(&m.content)).sum();
53
54 let memory_facts = state.db.get_user_memory(&claims.sub).await?;
56 let preferences = state.db.get_user_preferences(&claims.sub).await?;
57 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
58 Some(UserMemory {
59 user_id: claims.sub.clone(),
60 preferences,
61 facts: memory_facts,
62 })
63 } else {
64 None
65 };
66
67 let agent_context = AgentContext {
69 user_id: claims.sub.clone(),
70 session_id: context_id.clone(),
71 conversation_history: history.clone(),
72 user_memory,
73 };
74
75 let agent_type = if let Some(at) = payload.agent_type {
77 at
78 } else {
79 let config = state.config_manager.config();
81 let router_model = config
82 .get_agent("router")
83 .map(|a| a.model.as_str())
84 .unwrap_or("fast");
85
86 let router_llm = match state
87 .provider_registry
88 .create_client_for_model(router_model)
89 .await
90 {
91 Ok(client) => client,
92 Err(_) => state.llm_factory.create_default().await?,
93 };
94
95 let router = RouterAgent::new(router_llm);
96 router.route(&payload.message, &agent_context).await?
97 };
98
99 let agent_name_for_run = AgentRegistry::type_to_name(&agent_type).to_string();
101 let start = std::time::Instant::now();
102 let response = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
103 let duration_ms = start.elapsed().as_millis() as i64;
104
105 let msg_id = Uuid::new_v4().to_string();
107 state
108 .db
109 .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
110 .await?;
111
112 let resp_id = Uuid::new_v4().to_string();
113 state
114 .db
115 .add_message(
116 &resp_id,
117 &context_id,
118 MessageRole::Assistant,
119 &response.response,
120 )
121 .await?;
122
123 let input_tokens = (history_input_tokens + estimate_tokens(&payload.message)) as u32;
127 let output_tokens = estimate_tokens(&response.response) as u32;
128
129 {
131 let pool = state.tenant_db.pool().clone();
132 let agent_name = agent_name_for_run;
133 let user_id = claims.sub.clone();
134 let tenant_id_for_run = tenant_ctx
135 .map(|Extension(tc)| tc.tenant_id.clone())
136 .unwrap_or_else(|| "system".to_string());
137 let itok = input_tokens as i64;
138 let otok = output_tokens as i64;
139 tokio::spawn(async move {
140 let _ = agent_runs::insert_agent_run(
141 &pool, &tenant_id_for_run, &agent_name, Some(&user_id),
142 "completed", itok, otok, duration_ms, None,
143 ).await;
144 });
145 }
146
147 let body = Json(response);
148 let mut response = body.into_response();
149 response.headers_mut().insert(
150 axum::http::HeaderName::from_static("x-input-tokens"),
151 axum::http::HeaderValue::from(input_tokens),
152 );
153 response.headers_mut().insert(
154 axum::http::HeaderName::from_static("x-output-tokens"),
155 axum::http::HeaderValue::from(output_tokens),
156 );
157
158 Ok(response)
159}
160
161async fn execute_agent(
162 agent_type: AgentType,
163 message: &str,
164 context: &AgentContext,
165 state: &AppState,
166) -> Result<ChatResponse> {
167 let agent_name = AgentRegistry::type_to_name(&agent_type);
169
170 if agent_type == AgentType::Router {
171 return Err(AppError::InvalidInput(
172 "Router agent cannot be called directly".to_string(),
173 ));
174 }
175
176 let (user_agent, source) = resolve_agent(state, &context.user_id, agent_name.to_string()).await?;
178
179 let config = AgentConfig {
181 model: user_agent.model.clone(),
182 system_prompt: user_agent.system_prompt.clone(),
183 tools: user_agent.tools_vec(),
184 max_tool_iterations: user_agent.max_tool_iterations as usize,
185 parallel_tools: user_agent.parallel_tools,
186 extra: std::collections::HashMap::new(),
187 };
188
189 let agent = state
191 .agent_registry
192 .create_agent_from_config(agent_name, &config)
193 .await?;
194
195 let response = agent.execute(message, context).await?;
197
198 Ok(ChatResponse {
199 response,
200 agent: format!("{:?} ({})", agent_type, source),
201 context_id: context.session_id.clone(),
202 sources: None,
203 })
204}
205
206#[utoipa::path(
208 get,
209 path = "/api/memory",
210 responses(
211 (status = 200, description = "User memory retrieved successfully"),
212 (status = 401, description = "Unauthorized")
213 ),
214 tag = "chat",
215 security(("bearer" = []))
216)]
217pub async fn get_user_memory(
218 State(state): State<AppState>,
219 AuthUser(claims): AuthUser,
220) -> Result<Json<UserMemory>> {
221 let facts = state.db.get_user_memory(&claims.sub).await?;
222 let preferences = state.db.get_user_preferences(&claims.sub).await?;
223
224 Ok(Json(UserMemory {
225 user_id: claims.sub,
226 preferences,
227 facts,
228 }))
229}
230
231#[derive(serde::Serialize)]
233pub struct StreamEvent {
234 pub event: String,
236 #[serde(skip_serializing_if = "Option::is_none")]
238 pub content: Option<String>,
239 #[serde(skip_serializing_if = "Option::is_none")]
241 pub agent: Option<String>,
242 #[serde(skip_serializing_if = "Option::is_none")]
244 pub context_id: Option<String>,
245 #[serde(skip_serializing_if = "Option::is_none")]
247 pub error: Option<String>,
248}
249
250#[utoipa::path(
252 post,
253 path = "/api/chat/stream",
254 request_body = ChatRequest,
255 responses(
256 (status = 200, description = "Streaming chat response"),
257 (status = 400, description = "Invalid input"),
258 (status = 401, description = "Unauthorized")
259 ),
260 tag = "chat",
261 security(("bearer" = []))
262)]
263pub async fn chat_stream(
264 State(state): State<AppState>,
265 AuthUser(claims): AuthUser,
266 Json(payload): Json<ChatRequest>,
267) -> axum::response::Sse<
268 impl futures::Stream<
269 Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
270 >,
271> {
272 use axum::response::sse::{Event, Sse};
273
274 let context_id = payload
276 .context_id
277 .clone()
278 .unwrap_or_else(|| Uuid::new_v4().to_string());
279
280 let state_clone = state.clone();
282 let claims_clone = claims.clone();
283 let message = payload.message.clone();
284 let agent_type_req = payload.agent_type;
285 let context_id_clone = context_id.clone();
286
287 let stream = async_stream::stream! {
288 if !state_clone.db.conversation_exists(&context_id_clone).await.unwrap_or(false) {
290 if let Err(e) = state_clone
291 .db
292 .create_conversation(&context_id_clone, &claims_clone.sub, None)
293 .await {
294 tracing::warn!("Failed to create conversation {}: {}", context_id_clone, e);
295 }
296 }
297
298 let history = state_clone.db.get_conversation_history(&context_id_clone).await.unwrap_or_else(|e| {
299 tracing::warn!("Failed to get conversation history for {}: {}", context_id_clone, e);
300 vec![]
301 });
302
303 let memory_facts = state_clone.db.get_user_memory(&claims_clone.sub).await.unwrap_or_else(|e| {
305 tracing::warn!("Failed to get user memory for {}: {}", claims_clone.sub, e);
306 vec![]
307 });
308 let preferences = state_clone.db.get_user_preferences(&claims_clone.sub).await.unwrap_or_else(|e| {
309 tracing::warn!("Failed to get user preferences for {}: {}", claims_clone.sub, e);
310 vec![]
311 });
312 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
313 Some(UserMemory {
314 user_id: claims_clone.sub.clone(),
315 preferences,
316 facts: memory_facts,
317 })
318 } else {
319 None
320 };
321
322 let agent_context = AgentContext {
324 user_id: claims_clone.sub.clone(),
325 session_id: context_id_clone.clone(),
326 conversation_history: history,
327 user_memory,
328 };
329
330 let agent_type = if let Some(at) = agent_type_req {
332 at
333 } else {
334 let config = state_clone.config_manager.config();
335 let router_model = config
336 .get_agent("router")
337 .map(|a| a.model.as_str())
338 .unwrap_or("fast");
339
340 let router_llm = match state_clone
341 .provider_registry
342 .create_client_for_model(router_model)
343 .await
344 {
345 Ok(client) => client,
346 Err(_) => match state_clone.llm_factory.create_default().await {
347 Ok(c) => c,
348 Err(e) => {
349 let event = StreamEvent {
350 event: "error".to_string(),
351 content: None,
352 agent: None,
353 context_id: Some(context_id_clone.clone()),
354 error: Some(format!("Failed to create LLM client: {}", e)),
355 };
356 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
357 return;
358 }
359 },
360 };
361
362 let router = RouterAgent::new(router_llm);
363 match router.route(&message, &agent_context).await {
364 Ok(t) => t,
365 Err(e) => {
366 let event = StreamEvent {
367 event: "error".to_string(),
368 content: None,
369 agent: None,
370 context_id: Some(context_id_clone.clone()),
371 error: Some(format!("Router failed: {}", e)),
372 };
373 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
374 return;
375 }
376 }
377 };
378
379 let agent_name = AgentRegistry::type_to_name(&agent_type);
381 let start_event = StreamEvent {
382 event: "start".to_string(),
383 content: None,
384 agent: Some(format!("{} (system)", agent_type)),
385 context_id: Some(context_id_clone.clone()),
386 error: None,
387 };
388 yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
389
390 let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
392 &state_clone,
393 &claims_clone.sub,
394 agent_name.to_string(),
395 ).await {
396 Ok(r) => r,
397 Err(e) => {
398 let event = StreamEvent {
399 event: "error".to_string(),
400 content: None,
401 agent: None,
402 context_id: Some(context_id_clone.clone()),
403 error: Some(format!("Failed to resolve agent: {}", e)),
404 };
405 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
406 return;
407 }
408 };
409
410 let llm = match state_clone
412 .provider_registry
413 .create_client_for_model(&user_agent.model)
414 .await
415 {
416 Ok(c) => c,
417 Err(_) => match state_clone.llm_factory.create_default().await {
418 Ok(c) => c,
419 Err(e) => {
420 let event = StreamEvent {
421 event: "error".to_string(),
422 content: None,
423 agent: None,
424 context_id: Some(context_id_clone.clone()),
425 error: Some(format!("Failed to create LLM: {}", e)),
426 };
427 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
428 return;
429 }
430 },
431 };
432
433 let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
435 let full_prompt = format!(
436 "{}\n\nUser: {}\nAssistant:",
437 system_prompt,
438 message
439 );
440
441 use futures::StreamExt;
443 let mut full_response = String::new();
444 match llm.stream(&full_prompt).await {
445 Ok(mut token_stream) => {
446 while let Some(token_result) = token_stream.next().await {
447 match token_result {
448 Ok(token) => {
449 full_response.push_str(&token);
450 let event = StreamEvent {
451 event: "token".to_string(),
452 content: Some(token),
453 agent: None,
454 context_id: None,
455 error: None,
456 };
457 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
458 }
459 Err(e) => {
460 let event = StreamEvent {
461 event: "error".to_string(),
462 content: None,
463 agent: None,
464 context_id: Some(context_id_clone.clone()),
465 error: Some(format!("Stream error: {}", e)),
466 };
467 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
468 return;
469 }
470 }
471 }
472 }
473 Err(e) => {
474 let event = StreamEvent {
475 event: "error".to_string(),
476 content: None,
477 agent: None,
478 context_id: Some(context_id_clone.clone()),
479 error: Some(format!("Failed to start stream: {}", e)),
480 };
481 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
482 return;
483 }
484 }
485
486 let msg_id = Uuid::new_v4().to_string();
488 if let Err(e) = state_clone
489 .db
490 .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
491 .await {
492 tracing::error!("Failed to store user message in conversation {}: {}", context_id_clone, e);
493 }
494
495 let resp_id = Uuid::new_v4().to_string();
496 if let Err(e) = state_clone
497 .db
498 .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
499 .await {
500 tracing::error!("Failed to store assistant message in conversation {}: {}", context_id_clone, e);
501 }
502
503 let done_event = StreamEvent {
505 event: "done".to_string(),
506 content: None,
507 agent: Some(format!("{:?} ({})", agent_type, source)),
508 context_id: Some(context_id_clone),
509 error: None,
510 };
511 yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
512 };
513
514 Sse::new(stream).keep_alive(
515 axum::response::sse::KeepAlive::new()
516 .interval(std::time::Duration::from_secs(15))
517 .text("keep-alive"),
518 )
519}
520use axum::response::IntoResponse;