Skip to main content

ares/api/handlers/
chat.rs

1use crate::{
2    agents::{registry::AgentRegistry, router::RouterAgent, Agent},
3    api::handlers::user_agents::resolve_agent,
4    auth::middleware::AuthUser,
5    types::{
6        AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
7        UserMemory,
8    },
9    utils::toml_config::AgentConfig,
10    AppState,
11};
12use axum::{extract::State, Json};
13use uuid::Uuid;
14
15/// Chat with the AI assistant
16#[utoipa::path(
17    post,
18    path = "/api/chat",
19    request_body = ChatRequest,
20    responses(
21        (status = 200, description = "Chat response", body = ChatResponse),
22        (status = 400, description = "Invalid input"),
23        (status = 401, description = "Unauthorized")
24    ),
25    tag = "chat",
26    security(("bearer" = []))
27)]
28pub async fn chat(
29    State(state): State<AppState>,
30    AuthUser(claims): AuthUser,
31    Json(payload): Json<ChatRequest>,
32) -> Result<Json<ChatResponse>> {
33    // Get or create conversation
34    let context_id = payload
35        .context_id
36        .unwrap_or_else(|| Uuid::new_v4().to_string());
37
38    // Check if conversation exists, create if not
39    if !state.turso.conversation_exists(&context_id).await? {
40        state
41            .turso
42            .create_conversation(&context_id, &claims.sub, None)
43            .await?;
44    }
45    let history = state.turso.get_conversation_history(&context_id).await?;
46
47    // Load user memory
48    let memory_facts = state.turso.get_user_memory(&claims.sub).await?;
49    let preferences = state.turso.get_user_preferences(&claims.sub).await?;
50    let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
51        Some(UserMemory {
52            user_id: claims.sub.clone(),
53            preferences,
54            facts: memory_facts,
55        })
56    } else {
57        None
58    };
59
60    // Build agent context
61    let agent_context = AgentContext {
62        user_id: claims.sub.clone(),
63        session_id: context_id.clone(),
64        conversation_history: history.clone(),
65        user_memory,
66    };
67
68    // Route to appropriate agent
69    let agent_type = if let Some(at) = payload.agent_type {
70        at
71    } else {
72        // Get router model from config, or use default
73        let config = state.config_manager.config();
74        let router_model = config
75            .get_agent("router")
76            .map(|a| a.model.as_str())
77            .unwrap_or("fast");
78
79        let router_llm = match state
80            .provider_registry
81            .create_client_for_model(router_model)
82            .await
83        {
84            Ok(client) => client,
85            Err(_) => state.llm_factory.create_default().await?,
86        };
87
88        let router = RouterAgent::new(router_llm);
89        router.route(&payload.message, &agent_context).await?
90    };
91
92    // Execute agent
93    let response = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
94
95    // Store messages in conversation
96    let msg_id = Uuid::new_v4().to_string();
97    state
98        .turso
99        .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
100        .await?;
101
102    let resp_id = Uuid::new_v4().to_string();
103    state
104        .turso
105        .add_message(
106            &resp_id,
107            &context_id,
108            MessageRole::Assistant,
109            &response.response,
110        )
111        .await?;
112
113    Ok(Json(response))
114}
115
116async fn execute_agent(
117    agent_type: AgentType,
118    message: &str,
119    context: &AgentContext,
120    state: &AppState,
121) -> Result<ChatResponse> {
122    // Get agent name from type
123    let agent_name = AgentRegistry::type_to_name(&agent_type);
124
125    if agent_type == AgentType::Router {
126        return Err(AppError::InvalidInput(
127            "Router agent cannot be called directly".to_string(),
128        ));
129    }
130
131    // Resolve agent using the 3-tier hierarchy (User -> Community -> System)
132    let (user_agent, source) = resolve_agent(state, &context.user_id, agent_name).await?;
133
134    // Convert UserAgent to AgentConfig for the registry
135    let config = AgentConfig {
136        model: user_agent.model.clone(),
137        system_prompt: user_agent.system_prompt.clone(),
138        tools: user_agent.tools_vec(),
139        max_tool_iterations: user_agent.max_tool_iterations as usize,
140        parallel_tools: user_agent.parallel_tools,
141        extra: std::collections::HashMap::new(),
142    };
143
144    // Create agent from registry using the resolved config
145    let agent = state
146        .agent_registry
147        .create_agent_from_config(agent_name, &config)
148        .await?;
149
150    // Execute the agent
151    let response = agent.execute(message, context).await?;
152
153    Ok(ChatResponse {
154        response,
155        agent: format!("{:?} ({})", agent_type, source),
156        context_id: context.session_id.clone(),
157        sources: None,
158    })
159}
160
161/// Get user memory
162#[utoipa::path(
163    get,
164    path = "/api/memory",
165    responses(
166        (status = 200, description = "User memory retrieved successfully"),
167        (status = 401, description = "Unauthorized")
168    ),
169    tag = "chat",
170    security(("bearer" = []))
171)]
172pub async fn get_user_memory(
173    State(state): State<AppState>,
174    AuthUser(claims): AuthUser,
175) -> Result<Json<UserMemory>> {
176    let facts = state.turso.get_user_memory(&claims.sub).await?;
177    let preferences = state.turso.get_user_preferences(&claims.sub).await?;
178
179    Ok(Json(UserMemory {
180        user_id: claims.sub,
181        preferences,
182        facts,
183    }))
184}
185
186/// Streaming chat response event
187#[derive(serde::Serialize)]
188pub struct StreamEvent {
189    /// Event type: "start", "token", "done", "error"
190    pub event: String,
191    /// Token content (for "token" events)
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub content: Option<String>,
194    /// Agent type that handled the request (for "start" and "done" events)
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub agent: Option<String>,
197    /// Context ID for the conversation
198    #[serde(skip_serializing_if = "Option::is_none")]
199    pub context_id: Option<String>,
200    /// Error message (for "error" events)
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub error: Option<String>,
203}
204
205/// Stream a chat response using Server-Sent Events
206#[utoipa::path(
207    post,
208    path = "/api/chat/stream",
209    request_body = ChatRequest,
210    responses(
211        (status = 200, description = "Streaming chat response"),
212        (status = 400, description = "Invalid input"),
213        (status = 401, description = "Unauthorized")
214    ),
215    tag = "chat",
216    security(("bearer" = []))
217)]
218pub async fn chat_stream(
219    State(state): State<AppState>,
220    AuthUser(claims): AuthUser,
221    Json(payload): Json<ChatRequest>,
222) -> axum::response::Sse<
223    impl futures::Stream<
224        Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
225    >,
226> {
227    use axum::response::sse::{Event, Sse};
228
229    // Get or create conversation
230    let context_id = payload
231        .context_id
232        .clone()
233        .unwrap_or_else(|| Uuid::new_v4().to_string());
234
235    // Clone values we need for the async stream
236    let state_clone = state.clone();
237    let claims_clone = claims.clone();
238    let message = payload.message.clone();
239    let agent_type_req = payload.agent_type;
240    let context_id_clone = context_id.clone();
241
242    let stream = async_stream::stream! {
243        // Setup conversation
244        if !state_clone.turso.conversation_exists(&context_id_clone).await.unwrap_or(false) {
245            if let Err(e) = state_clone
246                .turso
247                .create_conversation(&context_id_clone, &claims_clone.sub, None)
248                .await {
249                tracing::warn!("Failed to create conversation {}: {}", context_id_clone, e);
250            }
251        }
252
253        let history = state_clone.turso.get_conversation_history(&context_id_clone).await.unwrap_or_else(|e| {
254            tracing::warn!("Failed to get conversation history for {}: {}", context_id_clone, e);
255            vec![]
256        });
257
258        // Load user memory
259        let memory_facts = state_clone.turso.get_user_memory(&claims_clone.sub).await.unwrap_or_else(|e| {
260            tracing::warn!("Failed to get user memory for {}: {}", claims_clone.sub, e);
261            vec![]
262        });
263        let preferences = state_clone.turso.get_user_preferences(&claims_clone.sub).await.unwrap_or_else(|e| {
264            tracing::warn!("Failed to get user preferences for {}: {}", claims_clone.sub, e);
265            vec![]
266        });
267        let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
268            Some(UserMemory {
269                user_id: claims_clone.sub.clone(),
270                preferences,
271                facts: memory_facts,
272            })
273        } else {
274            None
275        };
276
277        // Build agent context
278        let agent_context = AgentContext {
279            user_id: claims_clone.sub.clone(),
280            session_id: context_id_clone.clone(),
281            conversation_history: history,
282            user_memory,
283        };
284
285        // Route to appropriate agent
286        let agent_type = if let Some(at) = agent_type_req {
287            at
288        } else {
289            let config = state_clone.config_manager.config();
290            let router_model = config
291                .get_agent("router")
292                .map(|a| a.model.as_str())
293                .unwrap_or("fast");
294
295            let router_llm = match state_clone
296                .provider_registry
297                .create_client_for_model(router_model)
298                .await
299            {
300                Ok(client) => client,
301                Err(_) => match state_clone.llm_factory.create_default().await {
302                    Ok(c) => c,
303                    Err(e) => {
304                        let event = StreamEvent {
305                            event: "error".to_string(),
306                            content: None,
307                            agent: None,
308                            context_id: Some(context_id_clone.clone()),
309                            error: Some(format!("Failed to create LLM client: {}", e)),
310                        };
311                        yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
312                        return;
313                    }
314                },
315            };
316
317            let router = RouterAgent::new(router_llm);
318            match router.route(&message, &agent_context).await {
319                Ok(t) => t,
320                Err(e) => {
321                    let event = StreamEvent {
322                        event: "error".to_string(),
323                        content: None,
324                        agent: None,
325                        context_id: Some(context_id_clone.clone()),
326                        error: Some(format!("Router failed: {}", e)),
327                    };
328                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
329                    return;
330                }
331            }
332        };
333
334        // Send start event
335        let agent_name = AgentRegistry::type_to_name(&agent_type);
336        let start_event = StreamEvent {
337            event: "start".to_string(),
338            content: None,
339            agent: Some(format!("{} (system)", agent_type)),
340            context_id: Some(context_id_clone.clone()),
341            error: None,
342        };
343        yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
344
345        // Resolve agent using hierarchy
346        let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
347            &state_clone,
348            &claims_clone.sub,
349            agent_name,
350        ).await {
351            Ok(r) => r,
352            Err(e) => {
353                let event = StreamEvent {
354                    event: "error".to_string(),
355                    content: None,
356                    agent: None,
357                    context_id: Some(context_id_clone.clone()),
358                    error: Some(format!("Failed to resolve agent: {}", e)),
359                };
360                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
361                return;
362            }
363        };
364
365        // Get LLM client for streaming
366        let llm = match state_clone
367            .provider_registry
368            .create_client_for_model(&user_agent.model)
369            .await
370        {
371            Ok(c) => c,
372            Err(_) => match state_clone.llm_factory.create_default().await {
373                Ok(c) => c,
374                Err(e) => {
375                    let event = StreamEvent {
376                        event: "error".to_string(),
377                        content: None,
378                        agent: None,
379                        context_id: Some(context_id_clone.clone()),
380                        error: Some(format!("Failed to create LLM: {}", e)),
381                    };
382                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
383                    return;
384                }
385            },
386        };
387
388        // Build the prompt with system message and history
389        let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
390        let full_prompt = format!(
391            "{}\n\nUser: {}\nAssistant:",
392            system_prompt,
393            message
394        );
395
396        // Stream tokens
397        use futures::StreamExt;
398        let mut full_response = String::new();
399        match llm.stream(&full_prompt).await {
400            Ok(mut token_stream) => {
401                while let Some(token_result) = token_stream.next().await {
402                    match token_result {
403                        Ok(token) => {
404                            full_response.push_str(&token);
405                            let event = StreamEvent {
406                                event: "token".to_string(),
407                                content: Some(token),
408                                agent: None,
409                                context_id: None,
410                                error: None,
411                            };
412                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
413                        }
414                        Err(e) => {
415                            let event = StreamEvent {
416                                event: "error".to_string(),
417                                content: None,
418                                agent: None,
419                                context_id: Some(context_id_clone.clone()),
420                                error: Some(format!("Stream error: {}", e)),
421                            };
422                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
423                            return;
424                        }
425                    }
426                }
427            }
428            Err(e) => {
429                let event = StreamEvent {
430                    event: "error".to_string(),
431                    content: None,
432                    agent: None,
433                    context_id: Some(context_id_clone.clone()),
434                    error: Some(format!("Failed to start stream: {}", e)),
435                };
436                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
437                return;
438            }
439        }
440
441        // Store messages in conversation
442        let msg_id = Uuid::new_v4().to_string();
443        if let Err(e) = state_clone
444            .turso
445            .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
446            .await {
447            tracing::error!("Failed to store user message in conversation {}: {}", context_id_clone, e);
448        }
449
450        let resp_id = Uuid::new_v4().to_string();
451        if let Err(e) = state_clone
452            .turso
453            .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
454            .await {
455            tracing::error!("Failed to store assistant message in conversation {}: {}", context_id_clone, e);
456        }
457
458        // Send done event
459        let done_event = StreamEvent {
460            event: "done".to_string(),
461            content: None,
462            agent: Some(format!("{:?} ({})", agent_type, source)),
463            context_id: Some(context_id_clone),
464            error: None,
465        };
466        yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
467    };
468
469    Sse::new(stream).keep_alive(
470        axum::response::sse::KeepAlive::new()
471            .interval(std::time::Duration::from_secs(15))
472            .text("keep-alive"),
473    )
474}