1use crate::state::AppState;
2use aster::agents::{AgentEvent, SessionConfig};
3use aster::conversation::message::{Message, MessageContent, TokenState};
4use aster::conversation::Conversation;
5use aster::session::SessionManager;
6use axum::{
7 extract::{DefaultBodyLimit, State},
8 http::{self, StatusCode},
9 response::IntoResponse,
10 routing::post,
11 Json, Router,
12};
13use bytes::Bytes;
14use futures::{stream::StreamExt, Stream};
15use rmcp::model::ServerNotification;
16use serde::{Deserialize, Serialize};
17use std::{
18 convert::Infallible,
19 pin::Pin,
20 sync::Arc,
21 task::{Context, Poll},
22 time::Duration,
23};
24use tokio::sync::mpsc;
25use tokio::time::timeout;
26use tokio_stream::wrappers::ReceiverStream;
27use tokio_util::sync::CancellationToken;
28
29fn track_tool_telemetry(content: &MessageContent, all_messages: &[Message]) {
30 match content {
31 MessageContent::ToolRequest(tool_request) => {
32 if let Ok(tool_call) = &tool_request.tool_call {
33 tracing::info!(monotonic_counter.aster.tool_calls = 1,
34 tool_name = %tool_call.name,
35 "Tool call started"
36 );
37 }
38 }
39 MessageContent::ToolResponse(tool_response) => {
40 let tool_name = all_messages
41 .iter()
42 .rev()
43 .find_map(|msg| {
44 msg.content.iter().find_map(|c| {
45 if let MessageContent::ToolRequest(req) = c {
46 if req.id == tool_response.id {
47 if let Ok(tool_call) = &req.tool_call {
48 Some(tool_call.name.clone())
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 } else {
56 None
57 }
58 })
59 })
60 .unwrap_or_else(|| "unknown".to_string().into());
61
62 let success = tool_response.tool_result.is_ok();
63 let result_status = if success { "success" } else { "error" };
64
65 tracing::info!(
66 counter.aster.tool_completions = 1,
67 tool_name = %tool_name,
68 result = %result_status,
69 "Tool call completed"
70 );
71 }
72 _ => {}
73 }
74}
75
76#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)]
77pub struct ChatRequest {
78 user_message: Message,
79 #[serde(default)]
80 conversation_so_far: Option<Vec<Message>>,
81 session_id: String,
82 recipe_name: Option<String>,
83 recipe_version: Option<String>,
84}
85
86pub struct SseResponse {
87 rx: ReceiverStream<String>,
88}
89
90impl SseResponse {
91 fn new(rx: ReceiverStream<String>) -> Self {
92 Self { rx }
93 }
94}
95
96impl Stream for SseResponse {
97 type Item = Result<Bytes, Infallible>;
98
99 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 Pin::new(&mut self.rx)
101 .poll_next(cx)
102 .map(|opt| opt.map(|s| Ok(Bytes::from(s))))
103 }
104}
105
106impl IntoResponse for SseResponse {
107 fn into_response(self) -> axum::response::Response {
108 let stream = self;
109 let body = axum::body::Body::from_stream(stream);
110
111 http::Response::builder()
112 .header("Content-Type", "text/event-stream")
113 .header("Cache-Control", "no-cache")
114 .header("Connection", "keep-alive")
115 .body(body)
116 .unwrap()
117 }
118}
119
120#[derive(Debug, Serialize, utoipa::ToSchema)]
121#[serde(tag = "type")]
122pub enum MessageEvent {
123 Message {
124 message: Message,
125 token_state: TokenState,
126 },
127 Error {
128 error: String,
129 },
130 Finish {
131 reason: String,
132 token_state: TokenState,
133 },
134 ModelChange {
135 model: String,
136 mode: String,
137 },
138 Notification {
139 request_id: String,
140 #[schema(value_type = Object)]
141 message: ServerNotification,
142 },
143 UpdateConversation {
144 conversation: Conversation,
145 },
146 Ping,
147}
148
149async fn get_token_state(session_id: &str) -> TokenState {
150 SessionManager::get_session(session_id, false)
151 .await
152 .map(|session| TokenState {
153 input_tokens: session.input_tokens.unwrap_or(0),
154 output_tokens: session.output_tokens.unwrap_or(0),
155 total_tokens: session.total_tokens.unwrap_or(0),
156 accumulated_input_tokens: session.accumulated_input_tokens.unwrap_or(0),
157 accumulated_output_tokens: session.accumulated_output_tokens.unwrap_or(0),
158 accumulated_total_tokens: session.accumulated_total_tokens.unwrap_or(0),
159 })
160 .inspect_err(|e| {
161 tracing::warn!(
162 "Failed to fetch session token state for {}: {}",
163 session_id,
164 e
165 );
166 })
167 .unwrap_or_default()
168}
169
170async fn stream_event(
171 event: MessageEvent,
172 tx: &mpsc::Sender<String>,
173 cancel_token: &CancellationToken,
174) {
175 let json = serde_json::to_string(&event).unwrap_or_else(|e| {
176 format!(
177 r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#,
178 e
179 )
180 });
181
182 if tx.send(format!("data: {}\n\n", json)).await.is_err() {
183 tracing::info!("client hung up");
184 cancel_token.cancel();
185 }
186}
187
188#[allow(clippy::too_many_lines)]
189#[utoipa::path(
190 post,
191 path = "/reply",
192 request_body = ChatRequest,
193 responses(
194 (status = 200, description = "Streaming response initiated",
195 body = MessageEvent,
196 content_type = "text/event-stream"),
197 (status = 424, description = "Agent not initialized"),
198 (status = 500, description = "Internal server error")
199 )
200)]
201pub async fn reply(
202 State(state): State<Arc<AppState>>,
203 Json(request): Json<ChatRequest>,
204) -> Result<SseResponse, StatusCode> {
205 let session_start = std::time::Instant::now();
206
207 tracing::info!(
208 counter.aster.session_starts = 1,
209 session_type = "app",
210 interface = "ui",
211 "Session started"
212 );
213
214 let session_id = request.session_id.clone();
215
216 if let Some(recipe_name) = request.recipe_name.clone() {
217 if state.mark_recipe_run_if_absent(&session_id).await {
218 let recipe_version = request
219 .recipe_version
220 .clone()
221 .unwrap_or_else(|| "unknown".to_string());
222
223 tracing::info!(
224 counter.aster.recipe_runs = 1,
225 recipe_name = %recipe_name,
226 recipe_version = %recipe_version,
227 session_type = "app",
228 interface = "ui",
229 "Recipe execution started"
230 );
231 }
232 }
233
234 let (tx, rx) = mpsc::channel(100);
235 let stream = ReceiverStream::new(rx);
236 let cancel_token = CancellationToken::new();
237
238 let user_message = request.user_message;
239 let conversation_so_far = request.conversation_so_far;
240
241 let task_cancel = cancel_token.clone();
242 let task_tx = tx.clone();
243
244 drop(tokio::spawn(async move {
245 let agent = match state.get_agent(session_id.clone()).await {
246 Ok(agent) => agent,
247 Err(e) => {
248 tracing::error!("Failed to get session agent: {}", e);
249 let _ = stream_event(
250 MessageEvent::Error {
251 error: format!("Failed to get session agent: {}", e),
252 },
253 &task_tx,
254 &task_cancel,
255 )
256 .await;
257 return;
258 }
259 };
260
261 let session = match SessionManager::get_session(&session_id, true).await {
262 Ok(metadata) => metadata,
263 Err(e) => {
264 tracing::error!("Failed to read session for {}: {}", session_id, e);
265 let _ = stream_event(
266 MessageEvent::Error {
267 error: format!("Failed to read session: {}", e),
268 },
269 &task_tx,
270 &cancel_token,
271 )
272 .await;
273 return;
274 }
275 };
276
277 let session_config = SessionConfig {
278 id: session_id.clone(),
279 schedule_id: session.schedule_id.clone(),
280 max_turns: None,
281 retry_config: None,
282 system_prompt: None,
283 };
284
285 let mut all_messages = match conversation_so_far {
286 Some(history) => {
287 let conv = Conversation::new_unvalidated(history);
288 if let Err(e) = SessionManager::replace_conversation(&session_id, &conv).await {
289 tracing::warn!(
290 "Failed to replace session conversation for {}: {}",
291 session_id,
292 e
293 );
294 }
295 conv
296 }
297 None => session.conversation.unwrap_or_default(),
298 };
299 all_messages.push(user_message.clone());
300
301 let mut stream = match agent
302 .reply(
303 user_message.clone(),
304 session_config,
305 Some(task_cancel.clone()),
306 )
307 .await
308 {
309 Ok(stream) => stream,
310 Err(e) => {
311 tracing::error!("Failed to start reply stream: {:?}", e);
312 stream_event(
313 MessageEvent::Error {
314 error: e.to_string(),
315 },
316 &task_tx,
317 &cancel_token,
318 )
319 .await;
320 return;
321 }
322 };
323
324 let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(500));
325 loop {
326 tokio::select! {
327 _ = task_cancel.cancelled() => {
328 tracing::info!("Agent task cancelled");
329 break;
330 }
331 _ = heartbeat_interval.tick() => {
332 stream_event(MessageEvent::Ping, &tx, &cancel_token).await;
333 }
334 response = timeout(Duration::from_millis(500), stream.next()) => {
335 match response {
336 Ok(Some(Ok(AgentEvent::Message(message)))) => {
337 for content in &message.content {
338 track_tool_telemetry(content, all_messages.messages());
339 }
340
341 all_messages.push(message.clone());
342
343 let token_state = get_token_state(&session_id).await;
344
345 stream_event(MessageEvent::Message { message, token_state }, &tx, &cancel_token).await;
346 }
347 Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => {
348 all_messages = new_messages.clone();
349 stream_event(MessageEvent::UpdateConversation {conversation: new_messages}, &tx, &cancel_token).await;
350
351 }
352 Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
353 stream_event(MessageEvent::ModelChange { model, mode }, &tx, &cancel_token).await;
354 }
355 Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
356 stream_event(MessageEvent::Notification{
357 request_id: request_id.clone(),
358 message: n,
359 }, &tx, &cancel_token).await;
360 }
361
362 Ok(Some(Err(e))) => {
363 tracing::error!("Error processing message: {}", e);
364 stream_event(
365 MessageEvent::Error {
366 error: e.to_string(),
367 },
368 &tx,
369 &cancel_token,
370 ).await;
371 break;
372 }
373 Ok(None) => {
374 break;
375 }
376 Err(_) => {
377 if tx.is_closed() {
378 break;
379 }
380 continue;
381 }
382 }
383 }
384 }
385 }
386
387 let session_duration = session_start.elapsed();
388
389 if let Ok(session) = SessionManager::get_session(&session_id, true).await {
390 let total_tokens = session.total_tokens.unwrap_or(0);
391 tracing::info!(
392 counter.aster.session_completions = 1,
393 session_type = "app",
394 interface = "ui",
395 exit_type = "normal",
396 duration_ms = session_duration.as_millis() as u64,
397 total_tokens = total_tokens,
398 message_count = session.message_count,
399 "Session completed"
400 );
401
402 tracing::info!(
403 counter.aster.session_duration_ms = session_duration.as_millis() as u64,
404 session_type = "app",
405 interface = "ui",
406 "Session duration"
407 );
408
409 if total_tokens > 0 {
410 tracing::info!(
411 counter.aster.session_tokens = total_tokens,
412 session_type = "app",
413 interface = "ui",
414 "Session tokens"
415 );
416 }
417 } else {
418 tracing::info!(
419 counter.aster.session_completions = 1,
420 session_type = "app",
421 interface = "ui",
422 exit_type = "normal",
423 duration_ms = session_duration.as_millis() as u64,
424 total_tokens = 0u64,
425 message_count = all_messages.len(),
426 "Session completed"
427 );
428
429 tracing::info!(
430 counter.aster.session_duration_ms = session_duration.as_millis() as u64,
431 session_type = "app",
432 interface = "ui",
433 "Session duration"
434 );
435 }
436
437 let final_token_state = get_token_state(&session_id).await;
438
439 let _ = stream_event(
440 MessageEvent::Finish {
441 reason: "stop".to_string(),
442 token_state: final_token_state,
443 },
444 &task_tx,
445 &cancel_token,
446 )
447 .await;
448 }));
449 Ok(SseResponse::new(stream))
450}
451
452pub fn routes(state: Arc<AppState>) -> Router {
453 Router::new()
454 .route(
455 "/reply",
456 post(reply).layer(DefaultBodyLimit::max(50 * 1024 * 1024)),
457 )
458 .with_state(state)
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 mod integration_tests {
466 use super::*;
467 use aster::conversation::message::Message;
468 use axum::{body::Body, http::Request};
469 use tower::ServiceExt;
470
471 #[tokio::test(flavor = "multi_thread")]
472 async fn test_reply_endpoint() {
473 let state = AppState::new().await.unwrap();
474
475 let app = routes(state);
476
477 let request = Request::builder()
478 .uri("/reply")
479 .method("POST")
480 .header("content-type", "application/json")
481 .header("x-secret-key", "test-secret")
482 .body(Body::from(
483 serde_json::to_string(&ChatRequest {
484 user_message: Message::user().with_text("test message"),
485 conversation_so_far: None,
486 session_id: "test-session".to_string(),
487 recipe_name: None,
488 recipe_version: None,
489 })
490 .unwrap(),
491 ))
492 .unwrap();
493
494 let response = app.oneshot(request).await.unwrap();
495
496 assert_eq!(response.status(), StatusCode::OK);
497 }
498 }
499}