1use crate::{
2 agents::{registry::AgentRegistry, router::RouterAgent, Agent},
3 api::handlers::user_agents::resolve_agent,
4 auth::middleware::AuthUser,
5 types::{
6 AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
7 UserMemory,
8 },
9 utils::toml_config::AgentConfig,
10 AppState,
11};
12use axum::{extract::State, Json};
13use uuid::Uuid;
14
15#[utoipa::path(
17 post,
18 path = "/api/chat",
19 request_body = ChatRequest,
20 responses(
21 (status = 200, description = "Chat response", body = ChatResponse),
22 (status = 400, description = "Invalid input"),
23 (status = 401, description = "Unauthorized")
24 ),
25 tag = "chat",
26 security(("bearer" = []))
27)]
28pub async fn chat(
29 State(state): State<AppState>,
30 AuthUser(claims): AuthUser,
31 Json(payload): Json<ChatRequest>,
32) -> Result<Json<ChatResponse>> {
33 let context_id = payload
35 .context_id
36 .unwrap_or_else(|| Uuid::new_v4().to_string());
37
38 if !state.turso.conversation_exists(&context_id).await? {
40 state
41 .turso
42 .create_conversation(&context_id, &claims.sub, None)
43 .await?;
44 }
45 let history = state.turso.get_conversation_history(&context_id).await?;
46
47 let memory_facts = state.turso.get_user_memory(&claims.sub).await?;
49 let preferences = state.turso.get_user_preferences(&claims.sub).await?;
50 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
51 Some(UserMemory {
52 user_id: claims.sub.clone(),
53 preferences,
54 facts: memory_facts,
55 })
56 } else {
57 None
58 };
59
60 let agent_context = AgentContext {
62 user_id: claims.sub.clone(),
63 session_id: context_id.clone(),
64 conversation_history: history.clone(),
65 user_memory,
66 };
67
68 let agent_type = if let Some(at) = payload.agent_type {
70 at
71 } else {
72 let config = state.config_manager.config();
74 let router_model = config
75 .get_agent("router")
76 .map(|a| a.model.as_str())
77 .unwrap_or("fast");
78
79 let router_llm = match state
80 .provider_registry
81 .create_client_for_model(router_model)
82 .await
83 {
84 Ok(client) => client,
85 Err(_) => state.llm_factory.create_default().await?,
86 };
87
88 let router = RouterAgent::new(router_llm);
89 router.route(&payload.message, &agent_context).await?
90 };
91
92 let response = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
94
95 let msg_id = Uuid::new_v4().to_string();
97 state
98 .turso
99 .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
100 .await?;
101
102 let resp_id = Uuid::new_v4().to_string();
103 state
104 .turso
105 .add_message(
106 &resp_id,
107 &context_id,
108 MessageRole::Assistant,
109 &response.response,
110 )
111 .await?;
112
113 Ok(Json(response))
114}
115
116async fn execute_agent(
117 agent_type: AgentType,
118 message: &str,
119 context: &AgentContext,
120 state: &AppState,
121) -> Result<ChatResponse> {
122 let agent_name = AgentRegistry::type_to_name(agent_type);
124
125 if agent_type == AgentType::Router {
126 return Err(AppError::InvalidInput(
127 "Router agent cannot be called directly".to_string(),
128 ));
129 }
130
131 let (user_agent, source) = resolve_agent(state, &context.user_id, agent_name).await?;
133
134 let config = AgentConfig {
136 model: user_agent.model.clone(),
137 system_prompt: user_agent.system_prompt.clone(),
138 tools: user_agent.tools_vec(),
139 max_tool_iterations: user_agent.max_tool_iterations as usize,
140 parallel_tools: user_agent.parallel_tools,
141 extra: std::collections::HashMap::new(),
142 };
143
144 let agent = state
146 .agent_registry
147 .create_agent_from_config(agent_name, &config)
148 .await?;
149
150 let response = agent.execute(message, context).await?;
152
153 Ok(ChatResponse {
154 response,
155 agent: format!("{:?} ({})", agent_type, source),
156 context_id: context.session_id.clone(),
157 sources: None,
158 })
159}
160
161pub async fn get_user_memory(
163 State(state): State<AppState>,
164 AuthUser(claims): AuthUser,
165) -> Result<Json<UserMemory>> {
166 let facts = state.turso.get_user_memory(&claims.sub).await?;
167 let preferences = state.turso.get_user_preferences(&claims.sub).await?;
168
169 Ok(Json(UserMemory {
170 user_id: claims.sub,
171 preferences,
172 facts,
173 }))
174}
175
176#[derive(serde::Serialize)]
178pub struct StreamEvent {
179 pub event: String,
181 #[serde(skip_serializing_if = "Option::is_none")]
183 pub content: Option<String>,
184 #[serde(skip_serializing_if = "Option::is_none")]
186 pub agent: Option<String>,
187 #[serde(skip_serializing_if = "Option::is_none")]
189 pub context_id: Option<String>,
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub error: Option<String>,
193}
194
195#[utoipa::path(
197 post,
198 path = "/api/chat/stream",
199 request_body = ChatRequest,
200 responses(
201 (status = 200, description = "Streaming chat response"),
202 (status = 400, description = "Invalid input"),
203 (status = 401, description = "Unauthorized")
204 ),
205 tag = "chat",
206 security(("bearer" = []))
207)]
208pub async fn chat_stream(
209 State(state): State<AppState>,
210 AuthUser(claims): AuthUser,
211 Json(payload): Json<ChatRequest>,
212) -> axum::response::Sse<
213 impl futures::Stream<
214 Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
215 >,
216> {
217 use axum::response::sse::{Event, Sse};
218
219 let context_id = payload
221 .context_id
222 .clone()
223 .unwrap_or_else(|| Uuid::new_v4().to_string());
224
225 let state_clone = state.clone();
227 let claims_clone = claims.clone();
228 let message = payload.message.clone();
229 let agent_type_req = payload.agent_type;
230 let context_id_clone = context_id.clone();
231
232 let stream = async_stream::stream! {
233 if !state_clone.turso.conversation_exists(&context_id_clone).await.unwrap_or(false) {
235 let _ = state_clone
236 .turso
237 .create_conversation(&context_id_clone, &claims_clone.sub, None)
238 .await;
239 }
240
241 let history = state_clone.turso.get_conversation_history(&context_id_clone).await.unwrap_or_default();
242
243 let memory_facts = state_clone.turso.get_user_memory(&claims_clone.sub).await.unwrap_or_default();
245 let preferences = state_clone.turso.get_user_preferences(&claims_clone.sub).await.unwrap_or_default();
246 let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
247 Some(UserMemory {
248 user_id: claims_clone.sub.clone(),
249 preferences,
250 facts: memory_facts,
251 })
252 } else {
253 None
254 };
255
256 let agent_context = AgentContext {
258 user_id: claims_clone.sub.clone(),
259 session_id: context_id_clone.clone(),
260 conversation_history: history,
261 user_memory,
262 };
263
264 let agent_type = if let Some(at) = agent_type_req {
266 at
267 } else {
268 let config = state_clone.config_manager.config();
269 let router_model = config
270 .get_agent("router")
271 .map(|a| a.model.as_str())
272 .unwrap_or("fast");
273
274 let router_llm = match state_clone
275 .provider_registry
276 .create_client_for_model(router_model)
277 .await
278 {
279 Ok(client) => client,
280 Err(_) => match state_clone.llm_factory.create_default().await {
281 Ok(c) => c,
282 Err(e) => {
283 let event = StreamEvent {
284 event: "error".to_string(),
285 content: None,
286 agent: None,
287 context_id: Some(context_id_clone.clone()),
288 error: Some(format!("Failed to create LLM client: {}", e)),
289 };
290 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
291 return;
292 }
293 },
294 };
295
296 let router = RouterAgent::new(router_llm);
297 match router.route(&message, &agent_context).await {
298 Ok(t) => t,
299 Err(e) => {
300 let event = StreamEvent {
301 event: "error".to_string(),
302 content: None,
303 agent: None,
304 context_id: Some(context_id_clone.clone()),
305 error: Some(format!("Router failed: {}", e)),
306 };
307 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
308 return;
309 }
310 }
311 };
312
313 let agent_name = AgentRegistry::type_to_name(agent_type);
315 let start_event = StreamEvent {
316 event: "start".to_string(),
317 content: None,
318 agent: Some(format!("{:?} (system)", agent_type)),
319 context_id: Some(context_id_clone.clone()),
320 error: None,
321 };
322 yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
323
324 let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
326 &state_clone,
327 &claims_clone.sub,
328 agent_name,
329 ).await {
330 Ok(r) => r,
331 Err(e) => {
332 let event = StreamEvent {
333 event: "error".to_string(),
334 content: None,
335 agent: None,
336 context_id: Some(context_id_clone.clone()),
337 error: Some(format!("Failed to resolve agent: {}", e)),
338 };
339 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
340 return;
341 }
342 };
343
344 let llm = match state_clone
346 .provider_registry
347 .create_client_for_model(&user_agent.model)
348 .await
349 {
350 Ok(c) => c,
351 Err(_) => match state_clone.llm_factory.create_default().await {
352 Ok(c) => c,
353 Err(e) => {
354 let event = StreamEvent {
355 event: "error".to_string(),
356 content: None,
357 agent: None,
358 context_id: Some(context_id_clone.clone()),
359 error: Some(format!("Failed to create LLM: {}", e)),
360 };
361 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
362 return;
363 }
364 },
365 };
366
367 let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
369 let full_prompt = format!(
370 "{}\n\nUser: {}\nAssistant:",
371 system_prompt,
372 message
373 );
374
375 use futures::StreamExt;
377 let mut full_response = String::new();
378 match llm.stream(&full_prompt).await {
379 Ok(mut token_stream) => {
380 while let Some(token_result) = token_stream.next().await {
381 match token_result {
382 Ok(token) => {
383 full_response.push_str(&token);
384 let event = StreamEvent {
385 event: "token".to_string(),
386 content: Some(token),
387 agent: None,
388 context_id: None,
389 error: None,
390 };
391 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
392 }
393 Err(e) => {
394 let event = StreamEvent {
395 event: "error".to_string(),
396 content: None,
397 agent: None,
398 context_id: Some(context_id_clone.clone()),
399 error: Some(format!("Stream error: {}", e)),
400 };
401 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
402 return;
403 }
404 }
405 }
406 }
407 Err(e) => {
408 let event = StreamEvent {
409 event: "error".to_string(),
410 content: None,
411 agent: None,
412 context_id: Some(context_id_clone.clone()),
413 error: Some(format!("Failed to start stream: {}", e)),
414 };
415 yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
416 return;
417 }
418 }
419
420 let msg_id = Uuid::new_v4().to_string();
422 let _ = state_clone
423 .turso
424 .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
425 .await;
426
427 let resp_id = Uuid::new_v4().to_string();
428 let _ = state_clone
429 .turso
430 .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
431 .await;
432
433 let done_event = StreamEvent {
435 event: "done".to_string(),
436 content: None,
437 agent: Some(format!("{:?} ({})", agent_type, source)),
438 context_id: Some(context_id_clone),
439 error: None,
440 };
441 yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
442 };
443
444 Sse::new(stream).keep_alive(
445 axum::response::sse::KeepAlive::new()
446 .interval(std::time::Duration::from_secs(15))
447 .text("keep-alive"),
448 )
449}