1use crate::{
2 agents::{registry::AgentRegistry, router::RouterAgent, Agent},
3 api::handlers::user_agents::resolve_agent,
4 auth::middleware::AuthUser,
5 db::agent_runs,
6 memory::estimate_tokens,
7 types::{
8 AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
9 UserMemory,
10 },
11 utils::toml_config::AgentConfig,
12 AppState,
13};
14use axum::{extract::State, response::Response, Extension, Json};
15use uuid::Uuid;
16
17#[utoipa::path(
19 post,
20 path = "/api/chat",
21 request_body = ChatRequest,
22 responses(
23 (status = 200, description = "Chat response", body = ChatResponse),
24 (status = 400, description = "Invalid input"),
25 (status = 401, description = "Unauthorized")
26 ),
27 tag = "chat",
28 security(("bearer" = []))
29)]
30pub async fn chat(
31 State(state): State<AppState>,
32 AuthUser(claims): AuthUser,
33 tenant_ctx: Option<Extension<crate::models::TenantContext>>,
34 Json(payload): Json<ChatRequest>,
35) -> Result<Response> {
36 let context_id = payload
38 .context_id
39 .unwrap_or_else(|| Uuid::new_v4().to_string());
40
41 if !state.db.conversation_exists(&context_id).await? {
43 state
44 .db
45 .create_conversation(&context_id, &claims.sub, None)
46 .await?;
47 }
48 let history = state.db.get_conversation_history(&context_id).await?;
49 let history_input_tokens: usize = history.iter().map(|m| estimate_tokens(&m.content)).sum();
51
52 let memory_facts = state.db.get_user_memory(&claims.sub).await?;
54 let preferences = state.db.get_user_preferences(&claims.sub).await?;
55 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
56 Some(UserMemory {
57 user_id: claims.sub.clone(),
58 preferences,
59 facts: memory_facts,
60 })
61 } else {
62 None
63 };
64
65 let agent_context = AgentContext {
67 user_id: claims.sub.clone(),
68 session_id: context_id.clone(),
69 conversation_history: history.clone(),
70 user_memory,
71 };
72
73 let agent_type = if let Some(at) = payload.agent_type {
75 at
76 } else {
77 let config = state.config_manager.config();
79 let router_model = config
80 .get_agent("router")
81 .map(|a| a.model.as_str())
82 .unwrap_or("fast");
83
84 let router_llm = match state
85 .provider_registry
86 .create_client_for_model(router_model)
87 .await
88 {
89 Ok(client) => client,
90 Err(_) => state.llm_factory.create_default().await?,
91 };
92
93 let router = RouterAgent::new(router_llm);
94 router.route(&payload.message, &agent_context).await?
95 };
96
97 let agent_name_for_run = AgentRegistry::type_to_name(&agent_type).to_string();
99 let start = std::time::Instant::now();
100 let (response, usage) = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
101 let duration_ms = start.elapsed().as_millis() as i64;
102
103 let msg_id = Uuid::new_v4().to_string();
105 state
106 .db
107 .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
108 .await?;
109
110 let resp_id = Uuid::new_v4().to_string();
111 state
112 .db
113 .add_message(
114 &resp_id,
115 &context_id,
116 MessageRole::Assistant,
117 &response.response,
118 )
119 .await?;
120
121 let (input_tokens, output_tokens) = if let Some(u) = usage {
123 (u.prompt_tokens, u.completion_tokens)
124 } else {
125 (
126 (history_input_tokens + estimate_tokens(&payload.message)) as u32,
127 estimate_tokens(&response.response) as u32,
128 )
129 };
130
131 {
133 let pool = state.tenant_db.pool().clone();
134 let agent_name = agent_name_for_run;
135 let user_id = claims.sub.clone();
136 let tenant_id_for_run = tenant_ctx
137 .map(|Extension(tc)| tc.tenant_id.clone())
138 .unwrap_or_else(|| "system".to_string());
139 let itok = input_tokens as i64;
140 let otok = output_tokens as i64;
141 tokio::spawn(async move {
142 let _ = agent_runs::insert_agent_run(
143 &pool,
144 &tenant_id_for_run,
145 &agent_name,
146 Some(&user_id),
147 "completed",
148 itok,
149 otok,
150 duration_ms,
151 None,
152 "unknown",
153 "unknown",
154 false,
155 )
156 .await;
157 });
158 }
159
160 let body = Json(response);
161 let mut response = body.into_response();
162 response.headers_mut().insert(
163 axum::http::HeaderName::from_static("x-input-tokens"),
164 axum::http::HeaderValue::from(input_tokens),
165 );
166 response.headers_mut().insert(
167 axum::http::HeaderName::from_static("x-output-tokens"),
168 axum::http::HeaderValue::from(output_tokens),
169 );
170
171 Ok(response)
172}
173
174async fn execute_agent(
175 agent_type: AgentType,
176 message: &str,
177 context: &AgentContext,
178 state: &AppState,
179) -> Result<(ChatResponse, Option<crate::llm::client::TokenUsage>)> {
180 let agent_name = AgentRegistry::type_to_name(&agent_type);
182
183 if agent_type == AgentType::Router {
184 return Err(AppError::InvalidInput(
185 "Router agent cannot be called directly".to_string(),
186 ));
187 }
188
189 let (user_agent, source) =
191 resolve_agent(state, &context.user_id, agent_name.to_string()).await?;
192
193 let config = AgentConfig {
195 model: user_agent.model.clone(),
196 system_prompt: user_agent.system_prompt.clone(),
197 tools: user_agent.tools_vec(),
198 max_tool_iterations: user_agent.max_tool_iterations as usize,
199 parallel_tools: user_agent.parallel_tools,
200 extra: std::collections::HashMap::new(),
201 };
202
203 let agent = state
205 .agent_registry
206 .create_agent_from_config(agent_name, &config)
207 .await?;
208
209 let agent_resp = agent.execute(message, context).await?;
211
212 Ok((ChatResponse {
213 response: agent_resp.content,
214 agent: format!("{:?} ({})", agent_type, source),
215 context_id: context.session_id.clone(),
216 sources: None,
217 }, agent_resp.usage))
218}
219
220#[utoipa::path(
222 get,
223 path = "/api/memory",
224 responses(
225 (status = 200, description = "User memory retrieved successfully"),
226 (status = 401, description = "Unauthorized")
227 ),
228 tag = "chat",
229 security(("bearer" = []))
230)]
231pub async fn get_user_memory(
232 State(state): State<AppState>,
233 AuthUser(claims): AuthUser,
234) -> Result<Json<UserMemory>> {
235 let facts = state.db.get_user_memory(&claims.sub).await?;
236 let preferences = state.db.get_user_preferences(&claims.sub).await?;
237
238 Ok(Json(UserMemory {
239 user_id: claims.sub,
240 preferences,
241 facts,
242 }))
243}
244
245#[derive(serde::Serialize)]
247pub struct StreamEvent {
248 pub event: String,
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub content: Option<String>,
253 #[serde(skip_serializing_if = "Option::is_none")]
255 pub agent: Option<String>,
256 #[serde(skip_serializing_if = "Option::is_none")]
258 pub context_id: Option<String>,
259 #[serde(skip_serializing_if = "Option::is_none")]
261 pub error: Option<String>,
262}
263
264#[utoipa::path(
266 post,
267 path = "/api/chat/stream",
268 request_body = ChatRequest,
269 responses(
270 (status = 200, description = "Streaming chat response"),
271 (status = 400, description = "Invalid input"),
272 (status = 401, description = "Unauthorized")
273 ),
274 tag = "chat",
275 security(("bearer" = []))
276)]
277pub async fn chat_stream(
278 State(state): State<AppState>,
279 AuthUser(claims): AuthUser,
280 Json(payload): Json<ChatRequest>,
281) -> axum::response::Sse<
282 impl futures::Stream<
283 Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
284 >,
285> {
286 use axum::response::sse::{Event, Sse};
287
288 let context_id = payload
290 .context_id
291 .clone()
292 .unwrap_or_else(|| Uuid::new_v4().to_string());
293
294 let state_clone = state.clone();
296 let claims_clone = claims.clone();
297 let message = payload.message.clone();
298 let agent_type_req = payload.agent_type;
299 let context_id_clone = context_id.clone();
300
301 let stream = async_stream::stream! {
302 if !state_clone.db.conversation_exists(&context_id_clone).await.unwrap_or(false) {
304 if let Err(e) = state_clone
305 .db
306 .create_conversation(&context_id_clone, &claims_clone.sub, None)
307 .await {
308 tracing::warn!("Failed to create conversation {}: {}", context_id_clone, e);
309 }
310 }
311
312 let history = state_clone.db.get_conversation_history(&context_id_clone).await.unwrap_or_else(|e| {
313 tracing::warn!("Failed to get conversation history for {}: {}", context_id_clone, e);
314 vec![]
315 });
316
317 let memory_facts = state_clone.db.get_user_memory(&claims_clone.sub).await.unwrap_or_else(|e| {
319 tracing::warn!("Failed to get user memory for {}: {}", claims_clone.sub, e);
320 vec![]
321 });
322 let preferences = state_clone.db.get_user_preferences(&claims_clone.sub).await.unwrap_or_else(|e| {
323 tracing::warn!("Failed to get user preferences for {}: {}", claims_clone.sub, e);
324 vec![]
325 });
326 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
327 Some(UserMemory {
328 user_id: claims_clone.sub.clone(),
329 preferences,
330 facts: memory_facts,
331 })
332 } else {
333 None
334 };
335
336 let agent_context = AgentContext {
338 user_id: claims_clone.sub.clone(),
339 session_id: context_id_clone.clone(),
340 conversation_history: history,
341 user_memory,
342 };
343
344 let agent_type = if let Some(at) = agent_type_req {
346 at
347 } else {
348 let config = state_clone.config_manager.config();
349 let router_model = config
350 .get_agent("router")
351 .map(|a| a.model.as_str())
352 .unwrap_or("fast");
353
354 let router_llm = match state_clone
355 .provider_registry
356 .create_client_for_model(router_model)
357 .await
358 {
359 Ok(client) => client,
360 Err(_) => match state_clone.llm_factory.create_default().await {
361 Ok(c) => c,
362 Err(e) => {
363 let event = StreamEvent {
364 event: "error".to_string(),
365 content: None,
366 agent: None,
367 context_id: Some(context_id_clone.clone()),
368 error: Some(format!("Failed to create LLM client: {}", e)),
369 };
370 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
371 return;
372 }
373 },
374 };
375
376 let router = RouterAgent::new(router_llm);
377 match router.route(&message, &agent_context).await {
378 Ok(t) => t,
379 Err(e) => {
380 let event = StreamEvent {
381 event: "error".to_string(),
382 content: None,
383 agent: None,
384 context_id: Some(context_id_clone.clone()),
385 error: Some(format!("Router failed: {}", e)),
386 };
387 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
388 return;
389 }
390 }
391 };
392
393 let agent_name = AgentRegistry::type_to_name(&agent_type);
395 let start_event = StreamEvent {
396 event: "start".to_string(),
397 content: None,
398 agent: Some(format!("{} (system)", agent_type)),
399 context_id: Some(context_id_clone.clone()),
400 error: None,
401 };
402 yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
403
404 let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
406 &state_clone,
407 &claims_clone.sub,
408 agent_name.to_string(),
409 ).await {
410 Ok(r) => r,
411 Err(e) => {
412 let event = StreamEvent {
413 event: "error".to_string(),
414 content: None,
415 agent: None,
416 context_id: Some(context_id_clone.clone()),
417 error: Some(format!("Failed to resolve agent: {}", e)),
418 };
419 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
420 return;
421 }
422 };
423
424 let llm = match state_clone
426 .provider_registry
427 .create_client_for_model(&user_agent.model)
428 .await
429 {
430 Ok(c) => c,
431 Err(_) => match state_clone.llm_factory.create_default().await {
432 Ok(c) => c,
433 Err(e) => {
434 let event = StreamEvent {
435 event: "error".to_string(),
436 content: None,
437 agent: None,
438 context_id: Some(context_id_clone.clone()),
439 error: Some(format!("Failed to create LLM: {}", e)),
440 };
441 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
442 return;
443 }
444 },
445 };
446
447 let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
449 let full_prompt = format!(
450 "{}\n\nUser: {}\nAssistant:",
451 system_prompt,
452 message
453 );
454
455 use futures::StreamExt;
457 let mut full_response = String::new();
458 match llm.stream(&full_prompt).await {
459 Ok(mut token_stream) => {
460 while let Some(token_result) = token_stream.next().await {
461 match token_result {
462 Ok(token) => {
463 full_response.push_str(&token);
464 let event = StreamEvent {
465 event: "token".to_string(),
466 content: Some(token),
467 agent: None,
468 context_id: None,
469 error: None,
470 };
471 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
472 }
473 Err(e) => {
474 let event = StreamEvent {
475 event: "error".to_string(),
476 content: None,
477 agent: None,
478 context_id: Some(context_id_clone.clone()),
479 error: Some(format!("Stream error: {}", e)),
480 };
481 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
482 return;
483 }
484 }
485 }
486 }
487 Err(e) => {
488 let event = StreamEvent {
489 event: "error".to_string(),
490 content: None,
491 agent: None,
492 context_id: Some(context_id_clone.clone()),
493 error: Some(format!("Failed to start stream: {}", e)),
494 };
495 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
496 return;
497 }
498 }
499
500 let msg_id = Uuid::new_v4().to_string();
502 if let Err(e) = state_clone
503 .db
504 .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
505 .await {
506 tracing::error!("Failed to store user message in conversation {}: {}", context_id_clone, e);
507 }
508
509 let resp_id = Uuid::new_v4().to_string();
510 if let Err(e) = state_clone
511 .db
512 .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
513 .await {
514 tracing::error!("Failed to store assistant message in conversation {}: {}", context_id_clone, e);
515 }
516
517 {
519 let pool = state_clone.tenant_db.pool().clone();
520 let tid = claims_clone.sub.clone();
521 let aname = agent_name.to_string();
522 let itok = crate::memory::estimate_tokens(&message) as i64;
523 let otok = crate::memory::estimate_tokens(&full_response) as i64;
524 let model = user_agent.model.clone();
525 tokio::spawn(async move {
526 let _ = crate::db::agent_runs::insert_agent_run(
527 &pool, &tid, &aname, Some(&tid), "completed",
528 itok, otok, 0, None, &model, "unknown", true,
529 )
530 .await;
531 });
532 }
533
534 let done_event = StreamEvent {
536 event: "done".to_string(),
537 content: None,
538 agent: Some(format!("{:?} ({})", agent_type, source)),
539 context_id: Some(context_id_clone),
540 error: None,
541 };
542 yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
543 };
544
545 Sse::new(stream).keep_alive(
546 axum::response::sse::KeepAlive::new()
547 .interval(std::time::Duration::from_secs(15))
548 .text("keep-alive"),
549 )
550}
551use axum::response::IntoResponse;