Skip to main content

ares/api/handlers/
chat.rs

1use crate::{
2    agents::{registry::AgentRegistry, router::RouterAgent, Agent},
3    api::handlers::user_agents::resolve_agent,
4    auth::middleware::AuthUser,
5    db::agent_runs,
6    memory::estimate_tokens,
7    types::{
8        AgentContext, AgentType, AppError, ChatRequest, ChatResponse, MessageRole, Result,
9        UserMemory,
10    },
11    utils::toml_config::AgentConfig,
12    AppState,
13};
14use axum::{extract::State, response::Response, Extension, Json};
15use uuid::Uuid;
16
17/// Chat with the AI assistant
18#[utoipa::path(
19    post,
20    path = "/api/chat",
21    request_body = ChatRequest,
22    responses(
23        (status = 200, description = "Chat response", body = ChatResponse),
24        (status = 400, description = "Invalid input"),
25        (status = 401, description = "Unauthorized")
26    ),
27    tag = "chat",
28    security(("bearer" = []))
29)]
30pub async fn chat(
31    State(state): State<AppState>,
32    AuthUser(claims): AuthUser,
33    tenant_ctx: Option<Extension<crate::models::TenantContext>>,
34    Json(payload): Json<ChatRequest>,
35) -> Result<Response> {
36    // Get or create conversation
37    let context_id = payload
38        .context_id
39        .unwrap_or_else(|| Uuid::new_v4().to_string());
40
41    // Check if conversation exists, create if not
42    if !state.db.conversation_exists(&context_id).await? {
43        state
44            .db
45            .create_conversation(&context_id, &claims.sub, None)
46            .await?;
47    }
48    let history = state.db.get_conversation_history(&context_id).await?;
49    // Compute history token estimate in the same pass (before clone into AgentContext)
50    let history_input_tokens: usize = history.iter().map(|m| estimate_tokens(&m.content)).sum();
51
52    // Load user memory
53    let memory_facts = state.db.get_user_memory(&claims.sub).await?;
54    let preferences = state.db.get_user_preferences(&claims.sub).await?;
55    let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
56        Some(UserMemory {
57            user_id: claims.sub.clone(),
58            preferences,
59            facts: memory_facts,
60        })
61    } else {
62        None
63    };
64
65    // Build agent context
66    let agent_context = AgentContext {
67        user_id: claims.sub.clone(),
68        session_id: context_id.clone(),
69        conversation_history: history.clone(),
70        user_memory,
71    };
72
73    // Route to appropriate agent
74    let agent_type = if let Some(at) = payload.agent_type {
75        at
76    } else {
77        // Get router model from config, or use default
78        let config = state.config_manager.config();
79        let router_model = config
80            .get_agent("router")
81            .map(|a| a.model.as_str())
82            .unwrap_or("fast");
83
84        let router_llm = match state
85            .provider_registry
86            .create_client_for_model(router_model)
87            .await
88        {
89            Ok(client) => client,
90            Err(_) => state.llm_factory.create_default().await?,
91        };
92
93        let router = RouterAgent::new(router_llm);
94        router.route(&payload.message, &agent_context).await?
95    };
96
97    // Execute agent with timing
98    let agent_name_for_run = AgentRegistry::type_to_name(&agent_type).to_string();
99    let start = std::time::Instant::now();
100    let (response, usage) = execute_agent(agent_type, &payload.message, &agent_context, &state).await?;
101    let duration_ms = start.elapsed().as_millis() as i64;
102
103    // Store messages in conversation
104    let msg_id = Uuid::new_v4().to_string();
105    state
106        .db
107        .add_message(&msg_id, &context_id, MessageRole::User, &payload.message)
108        .await?;
109
110    let resp_id = Uuid::new_v4().to_string();
111    state
112        .db
113        .add_message(
114            &resp_id,
115            &context_id,
116            MessageRole::Assistant,
117            &response.response,
118        )
119        .await?;
120
121    // Use actual LLM token counts; fall back to heuristic estimates if unavailable
122    let (input_tokens, output_tokens) = if let Some(u) = usage {
123        (u.prompt_tokens, u.completion_tokens)
124    } else {
125        (
126            (history_input_tokens + estimate_tokens(&payload.message)) as u32,
127            estimate_tokens(&response.response) as u32,
128        )
129    };
130
131    // Record agent run (fire-and-forget)
132    {
133        let pool = state.tenant_db.pool().clone();
134        let agent_name = agent_name_for_run;
135        let user_id = claims.sub.clone();
136        let tenant_id_for_run = tenant_ctx
137            .map(|Extension(tc)| tc.tenant_id.clone())
138            .unwrap_or_else(|| "system".to_string());
139        let itok = input_tokens as i64;
140        let otok = output_tokens as i64;
141        tokio::spawn(async move {
142            let _ = agent_runs::insert_agent_run(
143                &pool,
144                &tenant_id_for_run,
145                &agent_name,
146                Some(&user_id),
147                "completed",
148                itok,
149                otok,
150                duration_ms,
151                None,
152                "unknown",
153                "unknown",
154                false,
155            )
156            .await;
157        });
158    }
159
160    let body = Json(response);
161    let mut response = body.into_response();
162    response.headers_mut().insert(
163        axum::http::HeaderName::from_static("x-input-tokens"),
164        axum::http::HeaderValue::from(input_tokens),
165    );
166    response.headers_mut().insert(
167        axum::http::HeaderName::from_static("x-output-tokens"),
168        axum::http::HeaderValue::from(output_tokens),
169    );
170
171    Ok(response)
172}
173
174async fn execute_agent(
175    agent_type: AgentType,
176    message: &str,
177    context: &AgentContext,
178    state: &AppState,
179) -> Result<(ChatResponse, Option<crate::llm::client::TokenUsage>)> {
180    // Get agent name from type
181    let agent_name = AgentRegistry::type_to_name(&agent_type);
182
183    if agent_type == AgentType::Router {
184        return Err(AppError::InvalidInput(
185            "Router agent cannot be called directly".to_string(),
186        ));
187    }
188
189    // Resolve agent using the 3-tier hierarchy (User -> Community -> System)
190    let (user_agent, source) =
191        resolve_agent(state, &context.user_id, agent_name.to_string()).await?;
192
193    // Convert UserAgent to AgentConfig for the registry
194    let config = AgentConfig {
195        model: user_agent.model.clone(),
196        system_prompt: user_agent.system_prompt.clone(),
197        tools: user_agent.tools_vec(),
198        max_tool_iterations: user_agent.max_tool_iterations as usize,
199        parallel_tools: user_agent.parallel_tools,
200        extra: std::collections::HashMap::new(),
201    };
202
203    // Create agent from registry using the resolved config
204    let agent = state
205        .agent_registry
206        .create_agent_from_config(agent_name, &config)
207        .await?;
208
209    // Execute the agent
210    let agent_resp = agent.execute(message, context).await?;
211
212    Ok((ChatResponse {
213        response: agent_resp.content,
214        agent: format!("{:?} ({})", agent_type, source),
215        context_id: context.session_id.clone(),
216        sources: None,
217    }, agent_resp.usage))
218}
219
220/// Get user memory
221#[utoipa::path(
222    get,
223    path = "/api/memory",
224    responses(
225        (status = 200, description = "User memory retrieved successfully"),
226        (status = 401, description = "Unauthorized")
227    ),
228    tag = "chat",
229    security(("bearer" = []))
230)]
231pub async fn get_user_memory(
232    State(state): State<AppState>,
233    AuthUser(claims): AuthUser,
234) -> Result<Json<UserMemory>> {
235    let facts = state.db.get_user_memory(&claims.sub).await?;
236    let preferences = state.db.get_user_preferences(&claims.sub).await?;
237
238    Ok(Json(UserMemory {
239        user_id: claims.sub,
240        preferences,
241        facts,
242    }))
243}
244
245/// Streaming chat response event
246#[derive(serde::Serialize)]
247pub struct StreamEvent {
248    /// Event type: "start", "token", "done", "error"
249    pub event: String,
250    /// Token content (for "token" events)
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub content: Option<String>,
253    /// Agent type that handled the request (for "start" and "done" events)
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub agent: Option<String>,
256    /// Context ID for the conversation
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub context_id: Option<String>,
259    /// Error message (for "error" events)
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub error: Option<String>,
262}
263
264/// Stream a chat response using Server-Sent Events
265#[utoipa::path(
266    post,
267    path = "/api/chat/stream",
268    request_body = ChatRequest,
269    responses(
270        (status = 200, description = "Streaming chat response"),
271        (status = 400, description = "Invalid input"),
272        (status = 401, description = "Unauthorized")
273    ),
274    tag = "chat",
275    security(("bearer" = []))
276)]
277pub async fn chat_stream(
278    State(state): State<AppState>,
279    AuthUser(claims): AuthUser,
280    Json(payload): Json<ChatRequest>,
281) -> axum::response::Sse<
282    impl futures::Stream<
283        Item = std::result::Result<axum::response::sse::Event, std::convert::Infallible>,
284    >,
285> {
286    use axum::response::sse::{Event, Sse};
287
288    // Get or create conversation
289    let context_id = payload
290        .context_id
291        .clone()
292        .unwrap_or_else(|| Uuid::new_v4().to_string());
293
294    // Clone values we need for the async stream
295    let state_clone = state.clone();
296    let claims_clone = claims.clone();
297    let message = payload.message.clone();
298    let agent_type_req = payload.agent_type;
299    let context_id_clone = context_id.clone();
300
301    let stream = async_stream::stream! {
302        // Setup conversation
303        if !state_clone.db.conversation_exists(&context_id_clone).await.unwrap_or(false) {
304            if let Err(e) = state_clone
305                .db
306                .create_conversation(&context_id_clone, &claims_clone.sub, None)
307                .await {
308                tracing::warn!("Failed to create conversation {}: {}", context_id_clone, e);
309            }
310        }
311
312        let history = state_clone.db.get_conversation_history(&context_id_clone).await.unwrap_or_else(|e| {
313            tracing::warn!("Failed to get conversation history for {}: {}", context_id_clone, e);
314            vec![]
315        });
316
317        // Load user memory
318        let memory_facts = state_clone.db.get_user_memory(&claims_clone.sub).await.unwrap_or_else(|e| {
319            tracing::warn!("Failed to get user memory for {}: {}", claims_clone.sub, e);
320            vec![]
321        });
322        let preferences = state_clone.db.get_user_preferences(&claims_clone.sub).await.unwrap_or_else(|e| {
323            tracing::warn!("Failed to get user preferences for {}: {}", claims_clone.sub, e);
324            vec![]
325        });
326        let user_memory = if !memory_facts.is_empty() || !preferences.is_empty() {
327            Some(UserMemory {
328                user_id: claims_clone.sub.clone(),
329                preferences,
330                facts: memory_facts,
331            })
332        } else {
333            None
334        };
335
336        // Build agent context
337        let agent_context = AgentContext {
338            user_id: claims_clone.sub.clone(),
339            session_id: context_id_clone.clone(),
340            conversation_history: history,
341            user_memory,
342        };
343
344        // Route to appropriate agent
345        let agent_type = if let Some(at) = agent_type_req {
346            at
347        } else {
348            let config = state_clone.config_manager.config();
349            let router_model = config
350                .get_agent("router")
351                .map(|a| a.model.as_str())
352                .unwrap_or("fast");
353
354            let router_llm = match state_clone
355                .provider_registry
356                .create_client_for_model(router_model)
357                .await
358            {
359                Ok(client) => client,
360                Err(_) => match state_clone.llm_factory.create_default().await {
361                    Ok(c) => c,
362                    Err(e) => {
363                        let event = StreamEvent {
364                            event: "error".to_string(),
365                            content: None,
366                            agent: None,
367                            context_id: Some(context_id_clone.clone()),
368                            error: Some(format!("Failed to create LLM client: {}", e)),
369                        };
370                        yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
371                        return;
372                    }
373                },
374            };
375
376            let router = RouterAgent::new(router_llm);
377            match router.route(&message, &agent_context).await {
378                Ok(t) => t,
379                Err(e) => {
380                    let event = StreamEvent {
381                        event: "error".to_string(),
382                        content: None,
383                        agent: None,
384                        context_id: Some(context_id_clone.clone()),
385                        error: Some(format!("Router failed: {}", e)),
386                    };
387                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
388                    return;
389                }
390            }
391        };
392
393        // Send start event
394        let agent_name = AgentRegistry::type_to_name(&agent_type);
395        let start_event = StreamEvent {
396            event: "start".to_string(),
397            content: None,
398            agent: Some(format!("{} (system)", agent_type)),
399            context_id: Some(context_id_clone.clone()),
400            error: None,
401        };
402        yield Ok(Event::default().data(serde_json::to_string(&start_event).unwrap_or_default()));
403
404        // Resolve agent using hierarchy
405        let (user_agent, source) = match crate::api::handlers::user_agents::resolve_agent(
406            &state_clone,
407            &claims_clone.sub,
408            agent_name.to_string(),
409        ).await {
410            Ok(r) => r,
411            Err(e) => {
412                let event = StreamEvent {
413                    event: "error".to_string(),
414                    content: None,
415                    agent: None,
416                    context_id: Some(context_id_clone.clone()),
417                    error: Some(format!("Failed to resolve agent: {}", e)),
418                };
419                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
420                return;
421            }
422        };
423
424        // Get LLM client for streaming
425        let llm = match state_clone
426            .provider_registry
427            .create_client_for_model(&user_agent.model)
428            .await
429        {
430            Ok(c) => c,
431            Err(_) => match state_clone.llm_factory.create_default().await {
432                Ok(c) => c,
433                Err(e) => {
434                    let event = StreamEvent {
435                        event: "error".to_string(),
436                        content: None,
437                        agent: None,
438                        context_id: Some(context_id_clone.clone()),
439                        error: Some(format!("Failed to create LLM: {}", e)),
440                    };
441                    yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
442                    return;
443                }
444            },
445        };
446
447        // Build the prompt with system message and history
448        let system_prompt = user_agent.system_prompt.unwrap_or_else(|| "You are a helpful assistant.".to_string());
449        let full_prompt = format!(
450            "{}\n\nUser: {}\nAssistant:",
451            system_prompt,
452            message
453        );
454
455        // Stream tokens
456        use futures::StreamExt;
457        let mut full_response = String::new();
458        match llm.stream(&full_prompt).await {
459            Ok(mut token_stream) => {
460                while let Some(token_result) = token_stream.next().await {
461                    match token_result {
462                        Ok(token) => {
463                            full_response.push_str(&token);
464                            let event = StreamEvent {
465                                event: "token".to_string(),
466                                content: Some(token),
467                                agent: None,
468                                context_id: None,
469                                error: None,
470                            };
471                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
472                        }
473                        Err(e) => {
474                            let event = StreamEvent {
475                                event: "error".to_string(),
476                                content: None,
477                                agent: None,
478                                context_id: Some(context_id_clone.clone()),
479                                error: Some(format!("Stream error: {}", e)),
480                            };
481                            yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
482                            return;
483                        }
484                    }
485                }
486            }
487            Err(e) => {
488                let event = StreamEvent {
489                    event: "error".to_string(),
490                    content: None,
491                    agent: None,
492                    context_id: Some(context_id_clone.clone()),
493                    error: Some(format!("Failed to start stream: {}", e)),
494                };
495                yield Ok(Event::default().data(serde_json::to_string(&event).unwrap_or_default()));
496                return;
497            }
498        }
499
500        // Store messages in conversation
501        let msg_id = Uuid::new_v4().to_string();
502        if let Err(e) = state_clone
503            .db
504            .add_message(&msg_id, &context_id_clone, MessageRole::User, &message)
505            .await {
506            tracing::error!("Failed to store user message in conversation {}: {}", context_id_clone, e);
507        }
508
509        let resp_id = Uuid::new_v4().to_string();
510        if let Err(e) = state_clone
511            .db
512            .add_message(&resp_id, &context_id_clone, MessageRole::Assistant, &full_response)
513            .await {
514            tracing::error!("Failed to store assistant message in conversation {}: {}", context_id_clone, e);
515        }
516
517        // Record agent run for billing (streaming calls were previously invisible)
518        {
519            let pool = state_clone.tenant_db.pool().clone();
520            let tid = claims_clone.sub.clone();
521            let aname = agent_name.to_string();
522            let itok = crate::memory::estimate_tokens(&message) as i64;
523            let otok = crate::memory::estimate_tokens(&full_response) as i64;
524            let model = user_agent.model.clone();
525            tokio::spawn(async move {
526                let _ = crate::db::agent_runs::insert_agent_run(
527                    &pool, &tid, &aname, Some(&tid), "completed",
528                    itok, otok, 0, None, &model, "unknown", true,
529                )
530                .await;
531            });
532        }
533
534        // Send done event
535        let done_event = StreamEvent {
536            event: "done".to_string(),
537            content: None,
538            agent: Some(format!("{:?} ({})", agent_type, source)),
539            context_id: Some(context_id_clone),
540            error: None,
541        };
542        yield Ok(Event::default().data(serde_json::to_string(&done_event).unwrap_or_default()));
543    };
544
545    Sse::new(stream).keep_alive(
546        axum::response::sse::KeepAlive::new()
547            .interval(std::time::Duration::from_secs(15))
548            .text("keep-alive"),
549    )
550}
551use axum::response::IntoResponse;