ares/api/handlers/
chat.rs

1use crate::{
2    agents::{registry::AgentRegistry, router::RouterAgent, Agent},
3    api::handlers::user_agents::resolve_agent,
4    auth::middleware::AuthUser,
5    types::{
6        AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
7        UserMemory,
8    },
9    utils::toml_config::AgentConfig,
10    AppState,
11};
12use axum::{extract::State, Json};
13use uuid::Uuid;
14
15/// Chat with the AI assistant
16#[utoipa::path(
17    post,
18    path = "/api/chat",
19    request_body = ChatRequest,
20    responses(
21        (status = 200, description = "Chat response", body = ChatResponse),
22        (status = 400, description = "Invalid input"),
23        (status = 401, description = "Unauthorized")
24    ),
25    tag = "chat",
26    security(("bearer" = []))
27)]
28pub async fn chat(
29    State(state): State<AppState>,
30    AuthUser(claims): AuthUser,
31    Json(payload): Json<ChatRequest>,
32) -> Result<Json<ChatResponse>> {
33    // Get or create conversation
34    let context_id = payload
35        .context_id
36        .unwrap_or_else(|| Uuid::new_v4().to_string());
37
38    // Check if conversation exists, create if not
39    if !state.turso.conversation_exists(&context_id).await? {
40        state
41            .turso
42            .create_conversation(&context_id, &claims.sub, None)
43            .await?;
44    }
45    let history = state.turso.get_conversation_history(&context_id).await?;
46
47    // Load user memory
48    let memory_facts = state.turso.get_user_memory(&claims.sub).await?;
49    let preferences = state.turso.get_user_preferences(&claims.sub).await?;
50    let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
51        Some(UserMemory {
52            user_id: claims.sub.clone(),
53            preferences,
54            facts: memory_facts,
55        })
56    } else {
57        None
58    };
59
60    // Build agent context
61    let agent_context = AgentContext {
62        user_id: claims.sub.clone(),
63        session_id: context_id.clone(),
64        conversation_history: history.clone(),
65        user_memory,
66    };
67
68    // Route to appropriate agent
69    let agent_type = if let Some(at) = payload.agent_type {
70        at
71    } else {
72        // Get router model from config, or use default
73        let config = state.config_manager.config();
74        let router_model = config
75            .get_agent("router")
76            .map(|a| a.model.as_str())
77            .unwrap_or("fast");
78
79        let router_llm = match state
80            .provider_registry
81            .create_client_for_model(router_model)
82            .await
83        {
84            Ok(client) => client,
85            Err(_) => state.llm_factory.create_default().await?,
86        };
87
88        let router = RouterAgent::new(router_llm);
89        router.route(&payload.message, &agent_context).await?
90    };
91
92    // Execute agent
93    let response = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
94
95    // Store messages in conversation
96    let msg_id = Uuid::new_v4().to_string();
97    state
98        .turso
99        .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
100        .await?;
101
102    let resp_id = Uuid::new_v4().to_string();
103    state
104        .turso
105        .add_message(
106            &resp_id,
107            &context_id,
108            MessageRole::Assistant,
109            &response.response,
110        )
111        .await?;
112
113    Ok(Json(response))
114}
115
116async fn execute_agent(
117    agent_type: AgentType,
118    message: &str,
119    context: &AgentContext,
120    state: &AppState,
121) -> Result<ChatResponse> {
122    // Get agent name from type
123    let agent_name = AgentRegistry::type_to_name(agent_type);
124
125    if agent_type == AgentType::Router {
126        return Err(AppError::InvalidInput(
127            "Router agent cannot be called directly".to_string(),
128        ));
129    }
130
131    // Resolve agent using the 3-tier hierarchy (User -> Community -> System)
132    let (user_agent, source) = resolve_agent(state, &context.user_id, agent_name).await?;
133
134    // Convert UserAgent to AgentConfig for the registry
135    let config = AgentConfig {
136        model: user_agent.model.clone(),
137        system_prompt: user_agent.system_prompt.clone(),
138        tools: user_agent.tools_vec(),
139        max_tool_iterations: user_agent.max_tool_iterations as usize,
140        parallel_tools: user_agent.parallel_tools,
141        extra: std::collections::HashMap::new(),
142    };
143
144    // Create agent from registry using the resolved config
145    let agent = state
146        .agent_registry
147        .create_agent_from_config(agent_name, &config)
148        .await?;
149
150    // Execute the agent
151    let response = agent.execute(message, context).await?;
152
153    Ok(ChatResponse {
154        response,
155        agent: format!("{:?} ({})", agent_type, source),
156        context_id: context.session_id.clone(),
157        sources: None,
158    })
159}
160
161/// Get user memory
162pub async fn get_user_memory(
163    State(state): State<AppState>,
164    AuthUser(claims): AuthUser,
165) -> Result<Json<UserMemory>> {
166    let facts = state.turso.get_user_memory(&claims.sub).await?;
167    let preferences = state.turso.get_user_preferences(&claims.sub).await?;
168
169    Ok(Json(UserMemory {
170        user_id: claims.sub,
171        preferences,
172        facts,
173    }))
174}
175
176/// Streaming chat response event
177#[derive(serde::Serialize)]
178pub struct StreamEvent {
179    /// Event type: "start", "token", "done", "error"
180    pub event: String,
181    /// Token content (for "token" events)
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub content: Option<String>,
184    /// Agent type that handled the request (for "start" and "done" events)
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub agent: Option<String>,
187    /// Context ID for the conversation
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub context_id: Option<String>,
190    /// Error message (for "error" events)
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub error: Option<String>,
193}
194
195/// Stream a chat response using Server-Sent Events
196#[utoipa::path(
197    post,
198    path = "/api/chat/stream",
199    request_body = ChatRequest,
200    responses(
201        (status = 200, description = "Streaming chat response"),
202        (status = 400, description = "Invalid input"),
203        (status = 401, description = "Unauthorized")
204    ),
205    tag = "chat",
206    security(("bearer" = []))
207)]
208pub async fn chat_stream(
209    State(state): State<AppState>,
210    AuthUser(claims): AuthUser,
211    Json(payload): Json<ChatRequest>,
212) -> axum::response::Sse<
213    impl futures::Stream<
214        Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
215    >,
216> {
217    use axum::response::sse::{Event, Sse};
218
219    // Get or create conversation
220    let context_id = payload
221        .context_id
222        .clone()
223        .unwrap_or_else(|| Uuid::new_v4().to_string());
224
225    // Clone values we need for the async stream
226    let state_clone = state.clone();
227    let claims_clone = claims.clone();
228    let message = payload.message.clone();
229    let agent_type_req = payload.agent_type;
230    let context_id_clone = context_id.clone();
231
232    let stream = async_stream::stream! {
233        // Setup conversation
234        if !state_clone.turso.conversation_exists(&context_id_clone).await.unwrap_or(false) {
235            let _ = state_clone
236                .turso
237                .create_conversation(&context_id_clone, &claims_clone.sub, None)
238                .await;
239        }
240
241        let history = state_clone.turso.get_conversation_history(&context_id_clone).await.unwrap_or_default();
242
243        // Load user memory
244        let memory_facts = state_clone.turso.get_user_memory(&claims_clone.sub).await.unwrap_or_default();
245        let preferences = state_clone.turso.get_user_preferences(&claims_clone.sub).await.unwrap_or_default();
246        let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
247            Some(UserMemory {
248                user_id: claims_clone.sub.clone(),
249                preferences,
250                facts: memory_facts,
251            })
252        } else {
253            None
254        };
255
256        // Build agent context
257        let agent_context = AgentContext {
258            user_id: claims_clone.sub.clone(),
259            session_id: context_id_clone.clone(),
260            conversation_history: history,
261            user_memory,
262        };
263
264        // Route to appropriate agent
265        let agent_type = if let Some(at) = agent_type_req {
266            at
267        } else {
268            let config = state_clone.config_manager.config();
269            let router_model = config
270                .get_agent("router")
271                .map(|a| a.model.as_str())
272                .unwrap_or("fast");
273
274            let router_llm = match state_clone
275                .provider_registry
276                .create_client_for_model(router_model)
277                .await
278            {
279                Ok(client) => client,
280                Err(_) => match state_clone.llm_factory.create_default().await {
281                    Ok(c) => c,
282                    Err(e) => {
283                        let event = StreamEvent {
284                            event: "error".to_string(),
285                            content: None,
286                            agent: None,
287                            context_id: Some(context_id_clone.clone()),
288                            error: Some(format!("Failed to create LLM client: {}", e)),
289                        };
290                        yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
291                        return;
292                    }
293                },
294            };
295
296            let router = RouterAgent::new(router_llm);
297            match router.route(&message, &agent_context).await {
298                Ok(t) => t,
299                Err(e) => {
300                    let event = StreamEvent {
301                        event: "error".to_string(),
302                        content: None,
303                        agent: None,
304                        context_id: Some(context_id_clone.clone()),
305                        error: Some(format!("Router failed: {}", e)),
306                    };
307                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
308                    return;
309                }
310            }
311        };
312
313        // Send start event
314        let agent_name = AgentRegistry::type_to_name(agent_type);
315        let start_event = StreamEvent {
316            event: "start".to_string(),
317            content: None,
318            agent: Some(format!("{:?} (system)", agent_type)),
319            context_id: Some(context_id_clone.clone()),
320            error: None,
321        };
322        yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
323
324        // Resolve agent using hierarchy
325        let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
326            &state_clone,
327            &claims_clone.sub,
328            agent_name,
329        ).await {
330            Ok(r) => r,
331            Err(e) => {
332                let event = StreamEvent {
333                    event: "error".to_string(),
334                    content: None,
335                    agent: None,
336                    context_id: Some(context_id_clone.clone()),
337                    error: Some(format!("Failed to resolve agent: {}", e)),
338                };
339                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
340                return;
341            }
342        };
343
344        // Get LLM client for streaming
345        let llm = match state_clone
346            .provider_registry
347            .create_client_for_model(&user_agent.model)
348            .await
349        {
350            Ok(c) => c,
351            Err(_) => match state_clone.llm_factory.create_default().await {
352                Ok(c) => c,
353                Err(e) => {
354                    let event = StreamEvent {
355                        event: "error".to_string(),
356                        content: None,
357                        agent: None,
358                        context_id: Some(context_id_clone.clone()),
359                        error: Some(format!("Failed to create LLM: {}", e)),
360                    };
361                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
362                    return;
363                }
364            },
365        };
366
367        // Build the prompt with system message and history
368        let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
369        let full_prompt = format!(
370            "{}\n\nUser: {}\nAssistant:",
371            system_prompt,
372            message
373        );
374
375        // Stream tokens
376        use futures::StreamExt;
377        let mut full_response = String::new();
378        match llm.stream(&full_prompt).await {
379            Ok(mut token_stream) => {
380                while let Some(token_result) = token_stream.next().await {
381                    match token_result {
382                        Ok(token) => {
383                            full_response.push_str(&token);
384                            let event = StreamEvent {
385                                event: "token".to_string(),
386                                content: Some(token),
387                                agent: None,
388                                context_id: None,
389                                error: None,
390                            };
391                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
392                        }
393                        Err(e) => {
394                            let event = StreamEvent {
395                                event: "error".to_string(),
396                                content: None,
397                                agent: None,
398                                context_id: Some(context_id_clone.clone()),
399                                error: Some(format!("Stream error: {}", e)),
400                            };
401                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
402                            return;
403                        }
404                    }
405                }
406            }
407            Err(e) => {
408                let event = StreamEvent {
409                    event: "error".to_string(),
410                    content: None,
411                    agent: None,
412                    context_id: Some(context_id_clone.clone()),
413                    error: Some(format!("Failed to start stream: {}", e)),
414                };
415                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
416                return;
417            }
418        }
419
420        // Store messages in conversation
421        let msg_id = Uuid::new_v4().to_string();
422        let _ = state_clone
423            .turso
424            .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
425            .await;
426
427        let resp_id = Uuid::new_v4().to_string();
428        let _ = state_clone
429            .turso
430            .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
431            .await;
432
433        // Send done event
434        let done_event = StreamEvent {
435            event: "done".to_string(),
436            content: None,
437            agent: Some(format!("{:?} ({})", agent_type, source)),
438            context_id: Some(context_id_clone),
439            error: None,
440        };
441        yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
442    };
443
444    Sse::new(stream).keep_alive(
445        axum::response::sse::KeepAlive::new()
446            .interval(std::time::Duration::from_secs(15))
447            .text("keep-alive"),
448    )
449}