hehe_server/routes/
chat.rs

1use axum::{
2    extract::State,
3    response::sse::{Event, KeepAlive, Sse},
4    Json,
5};
6use futures::stream::Stream;
7use hehe_agent::AgentEvent;
8use hehe_core::Id;
9use serde::{Deserialize, Serialize};
10use std::convert::Infallible;
11use std::str::FromStr;
12use tokio_stream::StreamExt;
13
14use crate::error::{Result, ServerError};
15use crate::state::AppState;
16
17#[derive(Deserialize)]
18pub struct ChatRequest {
19    pub session_id: Option<String>,
20    pub message: String,
21    #[serde(default)]
22    pub stream: bool,
23}
24
25#[derive(Serialize)]
26pub struct ChatResponse {
27    pub session_id: String,
28    pub response: String,
29    pub tool_calls: Vec<ToolCallInfo>,
30    pub iterations: usize,
31}
32
33#[derive(Serialize)]
34pub struct ToolCallInfo {
35    pub id: String,
36    pub name: String,
37    pub output: String,
38    pub is_error: bool,
39}
40
41pub async fn chat(
42    State(state): State<AppState>,
43    Json(request): Json<ChatRequest>,
44) -> Result<Json<ChatResponse>> {
45    let session_id = request.session_id.and_then(|s| Id::from_str(&s).ok());
46    let session = state.get_or_create_session(session_id).await;
47
48    let response = state
49        .agent
50        .process(&session, &request.message)
51        .await
52        .map_err(ServerError::from)?;
53
54    Ok(Json(ChatResponse {
55        session_id: session.id().to_string(),
56        response: response.text,
57        tool_calls: response
58            .tool_calls
59            .into_iter()
60            .map(|tc| ToolCallInfo {
61                id: tc.id,
62                name: tc.name,
63                output: tc.output,
64                is_error: tc.is_error,
65            })
66            .collect(),
67        iterations: response.iterations,
68    }))
69}
70
71pub async fn chat_stream(
72    State(state): State<AppState>,
73    Json(request): Json<ChatRequest>,
74) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
75    let session_id = request.session_id.and_then(|s| Id::from_str(&s).ok());
76    let session = state.get_or_create_session(session_id).await;
77    let message = request.message;
78
79    let event_stream = state.agent.chat_stream(&session, &message);
80
81    let sse_stream = event_stream.map(|event| {
82        let data = match &event {
83            AgentEvent::MessageStart { session_id } => {
84                serde_json::json!({
85                    "type": "message_start",
86                    "session_id": session_id.to_string()
87                })
88            }
89            AgentEvent::TextDelta { delta } => {
90                serde_json::json!({
91                    "type": "text_delta",
92                    "delta": delta
93                })
94            }
95            AgentEvent::TextComplete { text } => {
96                serde_json::json!({
97                    "type": "text_complete",
98                    "text": text
99                })
100            }
101            AgentEvent::ToolUseStart { id, name, input } => {
102                serde_json::json!({
103                    "type": "tool_use_start",
104                    "id": id,
105                    "name": name,
106                    "input": input
107                })
108            }
109            AgentEvent::ToolUseEnd { id, output, is_error } => {
110                serde_json::json!({
111                    "type": "tool_use_end",
112                    "id": id,
113                    "output": output,
114                    "is_error": is_error
115                })
116            }
117            AgentEvent::Thinking { content } => {
118                serde_json::json!({
119                    "type": "thinking",
120                    "content": content
121                })
122            }
123            AgentEvent::MessageEnd { session_id } => {
124                serde_json::json!({
125                    "type": "message_end",
126                    "session_id": session_id.to_string()
127                })
128            }
129            AgentEvent::Error { message } => {
130                serde_json::json!({
131                    "type": "error",
132                    "message": message
133                })
134            }
135        };
136
137        Ok(Event::default().data(data.to_string()))
138    });
139
140    Sse::new(sse_stream).keep_alive(KeepAlive::default())
141}