1use crate::{
2 agents::{registry::AgentRegistry, router::RouterAgent, Agent},
3 api::handlers::user_agents::resolve_agent,
4 auth::middleware::AuthUser,
5 types::{
6 AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
7 UserMemory,
8 },
9 utils::toml_config::AgentConfig,
10 AppState,
11};
12use axum::{extract::State, Json};
13use uuid::Uuid;
14
15#[utoipa::path(
17 post,
18 path = "/api/chat",
19 request_body = ChatRequest,
20 responses(
21 (status = 200, description = "Chat response", body = ChatResponse),
22 (status = 400, description = "Invalid input"),
23 (status = 401, description = "Unauthorized")
24 ),
25 tag = "chat",
26 security(("bearer" = []))
27)]
28pub async fn chat(
29 State(state): State<AppState>,
30 AuthUser(claims): AuthUser,
31 Json(payload): Json<ChatRequest>,
32) -> Result<Json<ChatResponse>> {
33 let context_id = payload
35 .context_id
36 .unwrap_or_else(|| Uuid::new_v4().to_string());
37
38 if !state.turso.conversation_exists(&context_id).await? {
40 state
41 .turso
42 .create_conversation(&context_id, &claims.sub, None)
43 .await?;
44 }
45 let history = state.turso.get_conversation_history(&context_id).await?;
46
47 let memory_facts = state.turso.get_user_memory(&claims.sub).await?;
49 let preferences = state.turso.get_user_preferences(&claims.sub).await?;
50 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
51 Some(UserMemory {
52 user_id: claims.sub.clone(),
53 preferences,
54 facts: memory_facts,
55 })
56 } else {
57 None
58 };
59
60 let agent_context = AgentContext {
62 user_id: claims.sub.clone(),
63 session_id: context_id.clone(),
64 conversation_history: history.clone(),
65 user_memory,
66 };
67
68 let agent_type = if let Some(at) = payload.agent_type {
70 at
71 } else {
72 let config = state.config_manager.config();
74 let router_model = config
75 .get_agent("router")
76 .map(|a| a.model.as_str())
77 .unwrap_or("fast");
78
79 let router_llm = match state
80 .provider_registry
81 .create_client_for_model(router_model)
82 .await
83 {
84 Ok(client) => client,
85 Err(_) => state.llm_factory.create_default().await?,
86 };
87
88 let router = RouterAgent::new(router_llm);
89 router.route(&payload.message, &agent_context).await?
90 };
91
92 let response = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
94
95 let msg_id = Uuid::new_v4().to_string();
97 state
98 .turso
99 .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
100 .await?;
101
102 let resp_id = Uuid::new_v4().to_string();
103 state
104 .turso
105 .add_message(
106 &resp_id,
107 &context_id,
108 MessageRole::Assistant,
109 &response.response,
110 )
111 .await?;
112
113 Ok(Json(response))
114}
115
116async fn execute_agent(
117 agent_type: AgentType,
118 message: &str,
119 context: &AgentContext,
120 state: &AppState,
121) -> Result<ChatResponse> {
122 let agent_name = AgentRegistry::type_to_name(&agent_type);
124
125 if agent_type == AgentType::Router {
126 return Err(AppError::InvalidInput(
127 "Router agent cannot be called directly".to_string(),
128 ));
129 }
130
131 let (user_agent, source) = resolve_agent(state, &context.user_id, agent_name).await?;
133
134 let config = AgentConfig {
136 model: user_agent.model.clone(),
137 system_prompt: user_agent.system_prompt.clone(),
138 tools: user_agent.tools_vec(),
139 max_tool_iterations: user_agent.max_tool_iterations as usize,
140 parallel_tools: user_agent.parallel_tools,
141 extra: std::collections::HashMap::new(),
142 };
143
144 let agent = state
146 .agent_registry
147 .create_agent_from_config(agent_name, &config)
148 .await?;
149
150 let response = agent.execute(message, context).await?;
152
153 Ok(ChatResponse {
154 response,
155 agent: format!("{:?} ({})", agent_type, source),
156 context_id: context.session_id.clone(),
157 sources: None,
158 })
159}
160
161#[utoipa::path(
163 get,
164 path = "/api/memory",
165 responses(
166 (status = 200, description = "User memory retrieved successfully"),
167 (status = 401, description = "Unauthorized")
168 ),
169 tag = "chat",
170 security(("bearer" = []))
171)]
172pub async fn get_user_memory(
173 State(state): State<AppState>,
174 AuthUser(claims): AuthUser,
175) -> Result<Json<UserMemory>> {
176 let facts = state.turso.get_user_memory(&claims.sub).await?;
177 let preferences = state.turso.get_user_preferences(&claims.sub).await?;
178
179 Ok(Json(UserMemory {
180 user_id: claims.sub,
181 preferences,
182 facts,
183 }))
184}
185
186#[derive(serde::Serialize)]
188pub struct StreamEvent {
189 pub event: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
193 pub content: Option<String>,
194 #[serde(skip_serializing_if = "Option::is_none")]
196 pub agent: Option<String>,
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub context_id: Option<String>,
200 #[serde(skip_serializing_if = "Option::is_none")]
202 pub error: Option<String>,
203}
204
205#[utoipa::path(
207 post,
208 path = "/api/chat/stream",
209 request_body = ChatRequest,
210 responses(
211 (status = 200, description = "Streaming chat response"),
212 (status = 400, description = "Invalid input"),
213 (status = 401, description = "Unauthorized")
214 ),
215 tag = "chat",
216 security(("bearer" = []))
217)]
218pub async fn chat_stream(
219 State(state): State<AppState>,
220 AuthUser(claims): AuthUser,
221 Json(payload): Json<ChatRequest>,
222) -> axum::response::Sse<
223 impl futures::Stream<
224 Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
225 >,
226> {
227 use axum::response::sse::{Event, Sse};
228
229 let context_id = payload
231 .context_id
232 .clone()
233 .unwrap_or_else(|| Uuid::new_v4().to_string());
234
235 let state_clone = state.clone();
237 let claims_clone = claims.clone();
238 let message = payload.message.clone();
239 let agent_type_req = payload.agent_type;
240 let context_id_clone = context_id.clone();
241
242 let stream = async_stream::stream! {
243 if !state_clone.turso.conversation_exists(&context_id_clone).await.unwrap_or(false) {
245 if let Err(e) = state_clone
246 .turso
247 .create_conversation(&context_id_clone, &claims_clone.sub, None)
248 .await {
249 tracing::warn!("Failed to create conversation {}: {}", context_id_clone, e);
250 }
251 }
252
253 let history = state_clone.turso.get_conversation_history(&context_id_clone).await.unwrap_or_else(|e| {
254 tracing::warn!("Failed to get conversation history for {}: {}", context_id_clone, e);
255 vec![]
256 });
257
258 let memory_facts = state_clone.turso.get_user_memory(&claims_clone.sub).await.unwrap_or_else(|e| {
260 tracing::warn!("Failed to get user memory for {}: {}", claims_clone.sub, e);
261 vec![]
262 });
263 let preferences = state_clone.turso.get_user_preferences(&claims_clone.sub).await.unwrap_or_else(|e| {
264 tracing::warn!("Failed to get user preferences for {}: {}", claims_clone.sub, e);
265 vec![]
266 });
267 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
268 Some(UserMemory {
269 user_id: claims_clone.sub.clone(),
270 preferences,
271 facts: memory_facts,
272 })
273 } else {
274 None
275 };
276
277 let agent_context = AgentContext {
279 user_id: claims_clone.sub.clone(),
280 session_id: context_id_clone.clone(),
281 conversation_history: history,
282 user_memory,
283 };
284
285 let agent_type = if let Some(at) = agent_type_req {
287 at
288 } else {
289 let config = state_clone.config_manager.config();
290 let router_model = config
291 .get_agent("router")
292 .map(|a| a.model.as_str())
293 .unwrap_or("fast");
294
295 let router_llm = match state_clone
296 .provider_registry
297 .create_client_for_model(router_model)
298 .await
299 {
300 Ok(client) => client,
301 Err(_) => match state_clone.llm_factory.create_default().await {
302 Ok(c) => c,
303 Err(e) => {
304 let event = StreamEvent {
305 event: "error".to_string(),
306 content: None,
307 agent: None,
308 context_id: Some(context_id_clone.clone()),
309 error: Some(format!("Failed to create LLM client: {}", e)),
310 };
311 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
312 return;
313 }
314 },
315 };
316
317 let router = RouterAgent::new(router_llm);
318 match router.route(&message, &agent_context).await {
319 Ok(t) => t,
320 Err(e) => {
321 let event = StreamEvent {
322 event: "error".to_string(),
323 content: None,
324 agent: None,
325 context_id: Some(context_id_clone.clone()),
326 error: Some(format!("Router failed: {}", e)),
327 };
328 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
329 return;
330 }
331 }
332 };
333
334 let agent_name = AgentRegistry::type_to_name(&agent_type);
336 let start_event = StreamEvent {
337 event: "start".to_string(),
338 content: None,
339 agent: Some(format!("{} (system)", agent_type)),
340 context_id: Some(context_id_clone.clone()),
341 error: None,
342 };
343 yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
344
345 let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
347 &state_clone,
348 &claims_clone.sub,
349 agent_name,
350 ).await {
351 Ok(r) => r,
352 Err(e) => {
353 let event = StreamEvent {
354 event: "error".to_string(),
355 content: None,
356 agent: None,
357 context_id: Some(context_id_clone.clone()),
358 error: Some(format!("Failed to resolve agent: {}", e)),
359 };
360 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
361 return;
362 }
363 };
364
365 let llm = match state_clone
367 .provider_registry
368 .create_client_for_model(&user_agent.model)
369 .await
370 {
371 Ok(c) => c,
372 Err(_) => match state_clone.llm_factory.create_default().await {
373 Ok(c) => c,
374 Err(e) => {
375 let event = StreamEvent {
376 event: "error".to_string(),
377 content: None,
378 agent: None,
379 context_id: Some(context_id_clone.clone()),
380 error: Some(format!("Failed to create LLM: {}", e)),
381 };
382 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
383 return;
384 }
385 },
386 };
387
388 let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
390 let full_prompt = format!(
391 "{}\n\nUser: {}\nAssistant:",
392 system_prompt,
393 message
394 );
395
396 use futures::StreamExt;
398 let mut full_response = String::new();
399 match llm.stream(&full_prompt).await {
400 Ok(mut token_stream) => {
401 while let Some(token_result) = token_stream.next().await {
402 match token_result {
403 Ok(token) => {
404 full_response.push_str(&token);
405 let event = StreamEvent {
406 event: "token".to_string(),
407 content: Some(token),
408 agent: None,
409 context_id: None,
410 error: None,
411 };
412 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
413 }
414 Err(e) => {
415 let event = StreamEvent {
416 event: "error".to_string(),
417 content: None,
418 agent: None,
419 context_id: Some(context_id_clone.clone()),
420 error: Some(format!("Stream error: {}", e)),
421 };
422 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
423 return;
424 }
425 }
426 }
427 }
428 Err(e) => {
429 let event = StreamEvent {
430 event: "error".to_string(),
431 content: None,
432 agent: None,
433 context_id: Some(context_id_clone.clone()),
434 error: Some(format!("Failed to start stream: {}", e)),
435 };
436 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
437 return;
438 }
439 }
440
441 let msg_id = Uuid::new_v4().to_string();
443 if let Err(e) = state_clone
444 .turso
445 .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
446 .await {
447 tracing::error!("Failed to store user message in conversation {}: {}", context_id_clone, e);
448 }
449
450 let resp_id = Uuid::new_v4().to_string();
451 if let Err(e) = state_clone
452 .turso
453 .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
454 .await {
455 tracing::error!("Failed to store assistant message in conversation {}: {}", context_id_clone, e);
456 }
457
458 let done_event = StreamEvent {
460 event: "done".to_string(),
461 content: None,
462 agent: Some(format!("{:?} ({})", agent_type, source)),
463 context_id: Some(context_id_clone),
464 error: None,
465 };
466 yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
467 };
468
469 Sse::new(stream).keep_alive(
470 axum::response::sse::KeepAlive::new()
471 .interval(std::time::Duration::from_secs(15))
472 .text("keep-alive"),
473 )
474}