Skip to main content

ares/api/handlers/
chat.rs

1
2
3use crate::{
4    agents::{registry::AgentRegistry, router::RouterAgent, Agent},
5    api::handlers::user_agents::resolve_agent,
6    auth::middleware::AuthUser,
7    db::agent_runs,
8    memory::estimate_tokens,
9    types::{
10        AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
11        UserMemory,
12    },
13    utils::toml_config::AgentConfig,
14    AppState,
15};
16use axum::{extract::State, response::Response, Extension, Json};
17use uuid::Uuid;
18
19/// Chat with the AI assistant
20#[utoipa::path(
21    post,
22    path = "/api/chat",
23    request_body = ChatRequest,
24    responses(
25        (status = 200, description = "Chat response", body = ChatResponse),
26        (status = 400, description = "Invalid input"),
27        (status = 401, description = "Unauthorized")
28    ),
29    tag = "chat",
30    security(("bearer" = []))
31)]
32pub async fn chat(
33    State(state): State<AppState>,
34    AuthUser(claims): AuthUser,
35    tenant_ctx: Option<Extension<crate::models::TenantContext>>,
36    Json(payload): Json<ChatRequest>,
37) -> Result<Response> {
38    // Get or create conversation
39    let context_id = payload
40        .context_id
41        .unwrap_or_else(|| Uuid::new_v4().to_string());
42
43    // Check if conversation exists, create if not
44    if !state.db.conversation_exists(&context_id).await? {
45        state
46            .db
47            .create_conversation(&context_id, &claims.sub, None)
48            .await?;
49    }
50    let history = state.db.get_conversation_history(&context_id).await?;
51    // Compute history token estimate in the same pass (before clone into AgentContext)
52    let history_input_tokens: usize = history.iter().map(|m| estimate_tokens(&m.content)).sum();
53
54    // Load user memory
55    let memory_facts = state.db.get_user_memory(&claims.sub).await?;
56    let preferences = state.db.get_user_preferences(&claims.sub).await?;
57    let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
58        Some(UserMemory {
59            user_id: claims.sub.clone(),
60            preferences,
61            facts: memory_facts,
62        })
63    } else {
64        None
65    };
66
67    // Build agent context
68    let agent_context = AgentContext {
69        user_id: claims.sub.clone(),
70        session_id: context_id.clone(),
71        conversation_history: history.clone(),
72        user_memory,
73    };
74
75    // Route to appropriate agent
76    let agent_type = if let Some(at) = payload.agent_type {
77        at
78    } else {
79        // Get router model from config, or use default
80        let config = state.config_manager.config();
81        let router_model = config
82            .get_agent("router")
83            .map(|a| a.model.as_str())
84            .unwrap_or("fast");
85
86        let router_llm = match state
87            .provider_registry
88            .create_client_for_model(router_model)
89            .await
90        {
91            Ok(client) => client,
92            Err(_) => state.llm_factory.create_default().await?,
93        };
94
95        let router = RouterAgent::new(router_llm);
96        router.route(&payload.message, &agent_context).await?
97    };
98
99    // Execute agent with timing
100    let agent_name_for_run = AgentRegistry::type_to_name(&agent_type).to_string();
101    let start = std::time::Instant::now();
102    let response = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
103    let duration_ms = start.elapsed().as_millis() as i64;
104
105    // Store messages in conversation
106    let msg_id = Uuid::new_v4().to_string();
107    state
108        .db
109        .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
110        .await?;
111
112    let resp_id = Uuid::new_v4().to_string();
113    state
114        .db
115        .add_message(
116            &resp_id,
117            &context_id,
118            MessageRole::Assistant,
119            &response.response,
120        )
121        .await?;
122
123    // Estimate token counts using the shared heuristic (~4 chars/token).
124    // Input includes full context: conversation history + current message.
125    // Real counts require Agent::execute() → TokenUsage (tracked as future work).
126    let input_tokens = (history_input_tokens + estimate_tokens(&payload.message)) as u32;
127    let output_tokens = estimate_tokens(&response.response) as u32;
128
129    // Record agent run (fire-and-forget)
130    {
131        let pool = state.tenant_db.pool().clone();
132        let agent_name = agent_name_for_run;
133        let user_id = claims.sub.clone();
134        let tenant_id_for_run = tenant_ctx
135            .map(|Extension(tc)| tc.tenant_id.clone())
136            .unwrap_or_else(|| "system".to_string());
137        let itok = input_tokens as i64;
138        let otok = output_tokens as i64;
139        tokio::spawn(async move {
140            let _ = agent_runs::insert_agent_run(
141                &pool, &tenant_id_for_run, &agent_name, Some(&user_id),
142                "completed", itok, otok, duration_ms, None,
143            ).await;
144        });
145    }
146
147    let body = Json(response);
148    let mut response = body.into_response();
149    response.headers_mut().insert(
150        axum::http::HeaderName::from_static("x-input-tokens"),
151        axum::http::HeaderValue::from(input_tokens),
152    );
153    response.headers_mut().insert(
154        axum::http::HeaderName::from_static("x-output-tokens"),
155        axum::http::HeaderValue::from(output_tokens),
156    );
157
158    Ok(response)
159}
160
161async fn execute_agent(
162    agent_type: AgentType,
163    message: &str,
164    context: &AgentContext,
165    state: &AppState,
166) -> Result<ChatResponse> {
167    // Get agent name from type
168    let agent_name = AgentRegistry::type_to_name(&agent_type);
169
170    if agent_type == AgentType::Router {
171        return Err(AppError::InvalidInput(
172            "Router agent cannot be called directly".to_string(),
173        ));
174    }
175
176    // Resolve agent using the 3-tier hierarchy (User -> Community -> System)
177    let (user_agent, source) = resolve_agent(state, &context.user_id, agent_name.to_string()).await?;
178
179    // Convert UserAgent to AgentConfig for the registry
180    let config = AgentConfig {
181        model: user_agent.model.clone(),
182        system_prompt: user_agent.system_prompt.clone(),
183        tools: user_agent.tools_vec(),
184        max_tool_iterations: user_agent.max_tool_iterations as usize,
185        parallel_tools: user_agent.parallel_tools,
186        extra: std::collections::HashMap::new(),
187    };
188
189    // Create agent from registry using the resolved config
190    let agent = state
191        .agent_registry
192        .create_agent_from_config(agent_name, &config)
193        .await?;
194
195    // Execute the agent
196    let response = agent.execute(message, context).await?;
197
198    Ok(ChatResponse {
199        response,
200        agent: format!("{:?} ({})", agent_type, source),
201        context_id: context.session_id.clone(),
202        sources: None,
203    })
204}
205
206/// Get user memory
207#[utoipa::path(
208    get,
209    path = "/api/memory",
210    responses(
211        (status = 200, description = "User memory retrieved successfully"),
212        (status = 401, description = "Unauthorized")
213    ),
214    tag = "chat",
215    security(("bearer" = []))
216)]
217pub async fn get_user_memory(
218    State(state): State<AppState>,
219    AuthUser(claims): AuthUser,
220) -> Result<Json<UserMemory>> {
221    let facts = state.db.get_user_memory(&claims.sub).await?;
222    let preferences = state.db.get_user_preferences(&claims.sub).await?;
223
224    Ok(Json(UserMemory {
225        user_id: claims.sub,
226        preferences,
227        facts,
228    }))
229}
230
231/// Streaming chat response event
232#[derive(serde::Serialize)]
233pub struct StreamEvent {
234    /// Event type: "start", "token", "done", "error"
235    pub event: String,
236    /// Token content (for "token" events)
237    #[serde(skip_serializing_if = "Option::is_none")]
238    pub content: Option<String>,
239    /// Agent type that handled the request (for "start" and "done" events)
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub agent: Option<String>,
242    /// Context ID for the conversation
243    #[serde(skip_serializing_if = "Option::is_none")]
244    pub context_id: Option<String>,
245    /// Error message (for "error" events)
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub error: Option<String>,
248}
249
250/// Stream a chat response using Server-Sent Events
251#[utoipa::path(
252    post,
253    path = "/api/chat/stream",
254    request_body = ChatRequest,
255    responses(
256        (status = 200, description = "Streaming chat response"),
257        (status = 400, description = "Invalid input"),
258        (status = 401, description = "Unauthorized")
259    ),
260    tag = "chat",
261    security(("bearer" = []))
262)]
263pub async fn chat_stream(
264    State(state): State<AppState>,
265    AuthUser(claims): AuthUser,
266    Json(payload): Json<ChatRequest>,
267) -> axum::response::Sse<
268    impl futures::Stream<
269        Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
270    >,
271> {
272    use axum::response::sse::{Event, Sse};
273
274    // Get or create conversation
275    let context_id = payload
276        .context_id
277        .clone()
278        .unwrap_or_else(|| Uuid::new_v4().to_string());
279
280    // Clone values we need for the async stream
281    let state_clone = state.clone();
282    let claims_clone = claims.clone();
283    let message = payload.message.clone();
284    let agent_type_req = payload.agent_type;
285    let context_id_clone = context_id.clone();
286
287    let stream = async_stream::stream! {
288        // Setup conversation
289        if !state_clone.db.conversation_exists(&context_id_clone).await.unwrap_or(false) {
290            if let Err(e) = state_clone
291                .db
292                .create_conversation(&context_id_clone, &claims_clone.sub, None)
293                .await {
294                tracing::warn!("Failed to create conversation {}: {}", context_id_clone, e);
295            }
296        }
297
298        let history = state_clone.db.get_conversation_history(&context_id_clone).await.unwrap_or_else(|e| {
299            tracing::warn!("Failed to get conversation history for {}: {}", context_id_clone, e);
300            vec![]
301        });
302
303        // Load user memory
304        let memory_facts = state_clone.db.get_user_memory(&claims_clone.sub).await.unwrap_or_else(|e| {
305            tracing::warn!("Failed to get user memory for {}: {}", claims_clone.sub, e);
306            vec![]
307        });
308        let preferences = state_clone.db.get_user_preferences(&claims_clone.sub).await.unwrap_or_else(|e| {
309            tracing::warn!("Failed to get user preferences for {}: {}", claims_clone.sub, e);
310            vec![]
311        });
312        let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
313            Some(UserMemory {
314                user_id: claims_clone.sub.clone(),
315                preferences,
316                facts: memory_facts,
317            })
318        } else {
319            None
320        };
321
322        // Build agent context
323        let agent_context = AgentContext {
324            user_id: claims_clone.sub.clone(),
325            session_id: context_id_clone.clone(),
326            conversation_history: history,
327            user_memory,
328        };
329
330        // Route to appropriate agent
331        let agent_type = if let Some(at) = agent_type_req {
332            at
333        } else {
334            let config = state_clone.config_manager.config();
335            let router_model = config
336                .get_agent("router")
337                .map(|a| a.model.as_str())
338                .unwrap_or("fast");
339
340            let router_llm = match state_clone
341                .provider_registry
342                .create_client_for_model(router_model)
343                .await
344            {
345                Ok(client) => client,
346                Err(_) => match state_clone.llm_factory.create_default().await {
347                    Ok(c) => c,
348                    Err(e) => {
349                        let event = StreamEvent {
350                            event: "error".to_string(),
351                            content: None,
352                            agent: None,
353                            context_id: Some(context_id_clone.clone()),
354                            error: Some(format!("Failed to create LLM client: {}", e)),
355                        };
356                        yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
357                        return;
358                    }
359                },
360            };
361
362            let router = RouterAgent::new(router_llm);
363            match router.route(&message, &agent_context).await {
364                Ok(t) => t,
365                Err(e) => {
366                    let event = StreamEvent {
367                        event: "error".to_string(),
368                        content: None,
369                        agent: None,
370                        context_id: Some(context_id_clone.clone()),
371                        error: Some(format!("Router failed: {}", e)),
372                    };
373                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
374                    return;
375                }
376            }
377        };
378
379        // Send start event
380        let agent_name = AgentRegistry::type_to_name(&agent_type);
381        let start_event = StreamEvent {
382            event: "start".to_string(),
383            content: None,
384            agent: Some(format!("{} (system)", agent_type)),
385            context_id: Some(context_id_clone.clone()),
386            error: None,
387        };
388        yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
389
390        // Resolve agent using hierarchy
391        let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
392            &state_clone,
393            &claims_clone.sub,
394            agent_name.to_string(),
395        ).await {
396            Ok(r) => r,
397            Err(e) => {
398                let event = StreamEvent {
399                    event: "error".to_string(),
400                    content: None,
401                    agent: None,
402                    context_id: Some(context_id_clone.clone()),
403                    error: Some(format!("Failed to resolve agent: {}", e)),
404                };
405                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
406                return;
407            }
408        };
409
410        // Get LLM client for streaming
411        let llm = match state_clone
412            .provider_registry
413            .create_client_for_model(&user_agent.model)
414            .await
415        {
416            Ok(c) => c,
417            Err(_) => match state_clone.llm_factory.create_default().await {
418                Ok(c) => c,
419                Err(e) => {
420                    let event = StreamEvent {
421                        event: "error".to_string(),
422                        content: None,
423                        agent: None,
424                        context_id: Some(context_id_clone.clone()),
425                        error: Some(format!("Failed to create LLM: {}", e)),
426                    };
427                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
428                    return;
429                }
430            },
431        };
432
433        // Build the prompt with system message and history
434        let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
435        let full_prompt = format!(
436            "{}\n\nUser: {}\nAssistant:",
437            system_prompt,
438            message
439        );
440
441        // Stream tokens
442        use futures::StreamExt;
443        let mut full_response = String::new();
444        match llm.stream(&full_prompt).await {
445            Ok(mut token_stream) => {
446                while let Some(token_result) = token_stream.next().await {
447                    match token_result {
448                        Ok(token) => {
449                            full_response.push_str(&token);
450                            let event = StreamEvent {
451                                event: "token".to_string(),
452                                content: Some(token),
453                                agent: None,
454                                context_id: None,
455                                error: None,
456                            };
457                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
458                        }
459                        Err(e) => {
460                            let event = StreamEvent {
461                                event: "error".to_string(),
462                                content: None,
463                                agent: None,
464                                context_id: Some(context_id_clone.clone()),
465                                error: Some(format!("Stream error: {}", e)),
466                            };
467                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
468                            return;
469                        }
470                    }
471                }
472            }
473            Err(e) => {
474                let event = StreamEvent {
475                    event: "error".to_string(),
476                    content: None,
477                    agent: None,
478                    context_id: Some(context_id_clone.clone()),
479                    error: Some(format!("Failed to start stream: {}", e)),
480                };
481                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
482                return;
483            }
484        }
485
486        // Store messages in conversation
487        let msg_id = Uuid::new_v4().to_string();
488        if let Err(e) = state_clone
489            .db
490            .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
491            .await {
492            tracing::error!("Failed to store user message in conversation {}: {}", context_id_clone, e);
493        }
494
495        let resp_id = Uuid::new_v4().to_string();
496        if let Err(e) = state_clone
497            .db
498            .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
499            .await {
500            tracing::error!("Failed to store assistant message in conversation {}: {}", context_id_clone, e);
501        }
502
503        // Send done event
504        let done_event = StreamEvent {
505            event: "done".to_string(),
506            content: None,
507            agent: Some(format!("{:?} ({})", agent_type, source)),
508            context_id: Some(context_id_clone),
509            error: None,
510        };
511        yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
512    };
513
514    Sse::new(stream).keep_alive(
515        axum::response::sse::KeepAlive::new()
516            .interval(std::time::Duration::from_secs(15))
517            .text("keep-alive"),
518    )
519}
520use axum::response::IntoResponse;