Skip to main content

aster_server/routes/
reply.rs

1use crate::state::AppState;
2use aster::agents::{AgentEvent, SessionConfig};
3use aster::conversation::message::{Message, MessageContent, TokenState};
4use aster::conversation::Conversation;
5use aster::session::SessionManager;
6use axum::{
7    extract::{DefaultBodyLimit, State},
8    http::{self, StatusCode},
9    response::IntoResponse,
10    routing::post,
11    Json, Router,
12};
13use bytes::Bytes;
14use futures::{stream::StreamExt, Stream};
15use rmcp::model::ServerNotification;
16use serde::{Deserialize, Serialize};
17use std::{
18    convert::Infallible,
19    pin::Pin,
20    sync::Arc,
21    task::{Context, Poll},
22    time::Duration,
23};
24use tokio::sync::mpsc;
25use tokio::time::timeout;
26use tokio_stream::wrappers::ReceiverStream;
27use tokio_util::sync::CancellationToken;
28
29fn track_tool_telemetry(content: &MessageContent, all_messages: &[Message]) {
30    match content {
31        MessageContent::ToolRequest(tool_request) => {
32            if let Ok(tool_call) = &tool_request.tool_call {
33                tracing::info!(monotonic_counter.aster.tool_calls = 1,
34                    tool_name = %tool_call.name,
35                    "Tool call started"
36                );
37            }
38        }
39        MessageContent::ToolResponse(tool_response) => {
40            let tool_name = all_messages
41                .iter()
42                .rev()
43                .find_map(|msg| {
44                    msg.content.iter().find_map(|c| {
45                        if let MessageContent::ToolRequest(req) = c {
46                            if req.id == tool_response.id {
47                                if let Ok(tool_call) = &req.tool_call {
48                                    Some(tool_call.name.clone())
49                                } else {
50                                    None
51                                }
52                            } else {
53                                None
54                            }
55                        } else {
56                            None
57                        }
58                    })
59                })
60                .unwrap_or_else(|| "unknown".to_string().into());
61
62            let success = tool_response.tool_result.is_ok();
63            let result_status = if success { "success" } else { "error" };
64
65            tracing::info!(
66                counter.aster.tool_completions = 1,
67                tool_name = %tool_name,
68                result = %result_status,
69                "Tool call completed"
70            );
71        }
72        _ => {}
73    }
74}
75
76#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)]
77pub struct ChatRequest {
78    user_message: Message,
79    #[serde(default)]
80    conversation_so_far: Option<Vec<Message>>,
81    session_id: String,
82    recipe_name: Option<String>,
83    recipe_version: Option<String>,
84}
85
86pub struct SseResponse {
87    rx: ReceiverStream<String>,
88}
89
90impl SseResponse {
91    fn new(rx: ReceiverStream<String>) -> Self {
92        Self { rx }
93    }
94}
95
96impl Stream for SseResponse {
97    type Item = Result<Bytes, Infallible>;
98
99    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        Pin::new(&mut self.rx)
101            .poll_next(cx)
102            .map(|opt| opt.map(|s| Ok(Bytes::from(s))))
103    }
104}
105
106impl IntoResponse for SseResponse {
107    fn into_response(self) -> axum::response::Response {
108        let stream = self;
109        let body = axum::body::Body::from_stream(stream);
110
111        http::Response::builder()
112            .header("Content-Type", "text/event-stream")
113            .header("Cache-Control", "no-cache")
114            .header("Connection", "keep-alive")
115            .body(body)
116            .unwrap()
117    }
118}
119
120#[derive(Debug, Serialize, utoipa::ToSchema)]
121#[serde(tag = "type")]
122pub enum MessageEvent {
123    Message {
124        message: Message,
125        token_state: TokenState,
126    },
127    Error {
128        error: String,
129    },
130    Finish {
131        reason: String,
132        token_state: TokenState,
133    },
134    ModelChange {
135        model: String,
136        mode: String,
137    },
138    Notification {
139        request_id: String,
140        #[schema(value_type = Object)]
141        message: ServerNotification,
142    },
143    UpdateConversation {
144        conversation: Conversation,
145    },
146    Ping,
147}
148
149async fn get_token_state(session_id: &str) -> TokenState {
150    SessionManager::get_session(session_id, false)
151        .await
152        .map(|session| TokenState {
153            input_tokens: session.input_tokens.unwrap_or(0),
154            output_tokens: session.output_tokens.unwrap_or(0),
155            total_tokens: session.total_tokens.unwrap_or(0),
156            accumulated_input_tokens: session.accumulated_input_tokens.unwrap_or(0),
157            accumulated_output_tokens: session.accumulated_output_tokens.unwrap_or(0),
158            accumulated_total_tokens: session.accumulated_total_tokens.unwrap_or(0),
159        })
160        .inspect_err(|e| {
161            tracing::warn!(
162                "Failed to fetch session token state for {}: {}",
163                session_id,
164                e
165            );
166        })
167        .unwrap_or_default()
168}
169
170async fn stream_event(
171    event: MessageEvent,
172    tx: &mpsc::Sender<String>,
173    cancel_token: &CancellationToken,
174) {
175    let json = serde_json::to_string(&event).unwrap_or_else(|e| {
176        format!(
177            r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#,
178            e
179        )
180    });
181
182    if tx.send(format!("data: {}\n\n", json)).await.is_err() {
183        tracing::info!("client hung up");
184        cancel_token.cancel();
185    }
186}
187
188#[allow(clippy::too_many_lines)]
189#[utoipa::path(
190    post,
191    path = "/reply",
192    request_body = ChatRequest,
193    responses(
194        (status = 200, description = "Streaming response initiated",
195         body = MessageEvent,
196         content_type = "text/event-stream"),
197        (status = 424, description = "Agent not initialized"),
198        (status = 500, description = "Internal server error")
199    )
200)]
201pub async fn reply(
202    State(state): State<Arc<AppState>>,
203    Json(request): Json<ChatRequest>,
204) -> Result<SseResponse, StatusCode> {
205    let session_start = std::time::Instant::now();
206
207    tracing::info!(
208        counter.aster.session_starts = 1,
209        session_type = "app",
210        interface = "ui",
211        "Session started"
212    );
213
214    let session_id = request.session_id.clone();
215
216    if let Some(recipe_name) = request.recipe_name.clone() {
217        if state.mark_recipe_run_if_absent(&session_id).await {
218            let recipe_version = request
219                .recipe_version
220                .clone()
221                .unwrap_or_else(|| "unknown".to_string());
222
223            tracing::info!(
224                counter.aster.recipe_runs = 1,
225                recipe_name = %recipe_name,
226                recipe_version = %recipe_version,
227                session_type = "app",
228                interface = "ui",
229                "Recipe execution started"
230            );
231        }
232    }
233
234    let (tx, rx) = mpsc::channel(100);
235    let stream = ReceiverStream::new(rx);
236    let cancel_token = CancellationToken::new();
237
238    let user_message = request.user_message;
239    let conversation_so_far = request.conversation_so_far;
240
241    let task_cancel = cancel_token.clone();
242    let task_tx = tx.clone();
243
244    drop(tokio::spawn(async move {
245        let agent = match state.get_agent(session_id.clone()).await {
246            Ok(agent) => agent,
247            Err(e) => {
248                tracing::error!("Failed to get session agent: {}", e);
249                let _ = stream_event(
250                    MessageEvent::Error {
251                        error: format!("Failed to get session agent: {}", e),
252                    },
253                    &task_tx,
254                    &task_cancel,
255                )
256                .await;
257                return;
258            }
259        };
260
261        let session = match SessionManager::get_session(&session_id, true).await {
262            Ok(metadata) => metadata,
263            Err(e) => {
264                tracing::error!("Failed to read session for {}: {}", session_id, e);
265                let _ = stream_event(
266                    MessageEvent::Error {
267                        error: format!("Failed to read session: {}", e),
268                    },
269                    &task_tx,
270                    &cancel_token,
271                )
272                .await;
273                return;
274            }
275        };
276
277        let session_config = SessionConfig {
278            id: session_id.clone(),
279            schedule_id: session.schedule_id.clone(),
280            max_turns: None,
281            retry_config: None,
282            system_prompt: None,
283        };
284
285        let mut all_messages = match conversation_so_far {
286            Some(history) => {
287                let conv = Conversation::new_unvalidated(history);
288                if let Err(e) = SessionManager::replace_conversation(&session_id, &conv).await {
289                    tracing::warn!(
290                        "Failed to replace session conversation for {}: {}",
291                        session_id,
292                        e
293                    );
294                }
295                conv
296            }
297            None => session.conversation.unwrap_or_default(),
298        };
299        all_messages.push(user_message.clone());
300
301        let mut stream = match agent
302            .reply(
303                user_message.clone(),
304                session_config,
305                Some(task_cancel.clone()),
306            )
307            .await
308        {
309            Ok(stream) => stream,
310            Err(e) => {
311                tracing::error!("Failed to start reply stream: {:?}", e);
312                stream_event(
313                    MessageEvent::Error {
314                        error: e.to_string(),
315                    },
316                    &task_tx,
317                    &cancel_token,
318                )
319                .await;
320                return;
321            }
322        };
323
324        let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(500));
325        loop {
326            tokio::select! {
327                _ = task_cancel.cancelled() => {
328                    tracing::info!("Agent task cancelled");
329                    break;
330                }
331                _ = heartbeat_interval.tick() => {
332                    stream_event(MessageEvent::Ping, &tx, &cancel_token).await;
333                }
334                response = timeout(Duration::from_millis(500), stream.next()) => {
335                    match response {
336                        Ok(Some(Ok(AgentEvent::Message(message)))) => {
337                            for content in &message.content {
338                                track_tool_telemetry(content, all_messages.messages());
339                            }
340
341                            all_messages.push(message.clone());
342
343                            let token_state = get_token_state(&session_id).await;
344
345                            stream_event(MessageEvent::Message { message, token_state }, &tx, &cancel_token).await;
346                        }
347                        Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => {
348                            all_messages = new_messages.clone();
349                            stream_event(MessageEvent::UpdateConversation {conversation: new_messages}, &tx, &cancel_token).await;
350
351                        }
352                        Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => {
353                            stream_event(MessageEvent::ModelChange { model, mode }, &tx, &cancel_token).await;
354                        }
355                        Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
356                            stream_event(MessageEvent::Notification{
357                                request_id: request_id.clone(),
358                                message: n,
359                            }, &tx, &cancel_token).await;
360                        }
361
362                        Ok(Some(Err(e))) => {
363                            tracing::error!("Error processing message: {}", e);
364                            stream_event(
365                                MessageEvent::Error {
366                                    error: e.to_string(),
367                                },
368                                &tx,
369                                &cancel_token,
370                            ).await;
371                            break;
372                        }
373                        Ok(None) => {
374                            break;
375                        }
376                        Err(_) => {
377                            if tx.is_closed() {
378                                break;
379                            }
380                            continue;
381                        }
382                    }
383                }
384            }
385        }
386
387        let session_duration = session_start.elapsed();
388
389        if let Ok(session) = SessionManager::get_session(&session_id, true).await {
390            let total_tokens = session.total_tokens.unwrap_or(0);
391            tracing::info!(
392                counter.aster.session_completions = 1,
393                session_type = "app",
394                interface = "ui",
395                exit_type = "normal",
396                duration_ms = session_duration.as_millis() as u64,
397                total_tokens = total_tokens,
398                message_count = session.message_count,
399                "Session completed"
400            );
401
402            tracing::info!(
403                counter.aster.session_duration_ms = session_duration.as_millis() as u64,
404                session_type = "app",
405                interface = "ui",
406                "Session duration"
407            );
408
409            if total_tokens > 0 {
410                tracing::info!(
411                    counter.aster.session_tokens = total_tokens,
412                    session_type = "app",
413                    interface = "ui",
414                    "Session tokens"
415                );
416            }
417        } else {
418            tracing::info!(
419                counter.aster.session_completions = 1,
420                session_type = "app",
421                interface = "ui",
422                exit_type = "normal",
423                duration_ms = session_duration.as_millis() as u64,
424                total_tokens = 0u64,
425                message_count = all_messages.len(),
426                "Session completed"
427            );
428
429            tracing::info!(
430                counter.aster.session_duration_ms = session_duration.as_millis() as u64,
431                session_type = "app",
432                interface = "ui",
433                "Session duration"
434            );
435        }
436
437        let final_token_state = get_token_state(&session_id).await;
438
439        let _ = stream_event(
440            MessageEvent::Finish {
441                reason: "stop".to_string(),
442                token_state: final_token_state,
443            },
444            &task_tx,
445            &cancel_token,
446        )
447        .await;
448    }));
449    Ok(SseResponse::new(stream))
450}
451
452pub fn routes(state: Arc<AppState>) -> Router {
453    Router::new()
454        .route(
455            "/reply",
456            post(reply).layer(DefaultBodyLimit::max(50 * 1024 * 1024)),
457        )
458        .with_state(state)
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    mod integration_tests {
466        use super::*;
467        use aster::conversation::message::Message;
468        use axum::{body::Body, http::Request};
469        use tower::ServiceExt;
470
471        #[tokio::test(flavor = "multi_thread")]
472        async fn test_reply_endpoint() {
473            let state = AppState::new().await.unwrap();
474
475            let app = routes(state);
476
477            let request = Request::builder()
478                .uri("/reply")
479                .method("POST")
480                .header("content-type", "application/json")
481                .header("x-secret-key", "test-secret")
482                .body(Body::from(
483                    serde_json::to_string(&ChatRequest {
484                        user_message: Message::user().with_text("test message"),
485                        conversation_so_far: None,
486                        session_id: "test-session".to_string(),
487                        recipe_name: None,
488                        recipe_version: None,
489                    })
490                    .unwrap(),
491                ))
492                .unwrap();
493
494            let response = app.oneshot(request).await.unwrap();
495
496            assert_eq!(response.status(), StatusCode::OK);
497        }
498    }
499}