hehe_server/routes/
chat.rs1use axum::{
2 extract::State,
3 response::sse::{Event, KeepAlive, Sse},
4 Json,
5};
6use futures::stream::Stream;
7use hehe_agent::AgentEvent;
8use hehe_core::Id;
9use serde::{Deserialize, Serialize};
10use std::convert::Infallible;
11use std::str::FromStr;
12use tokio_stream::StreamExt;
13
14use crate::error::{Result, ServerError};
15use crate::state::AppState;
16
17#[derive(Deserialize)]
18pub struct ChatRequest {
19 pub session_id: Option<String>,
20 pub message: String,
21 #[serde(default)]
22 pub stream: bool,
23}
24
25#[derive(Serialize)]
26pub struct ChatResponse {
27 pub session_id: String,
28 pub response: String,
29 pub tool_calls: Vec<ToolCallInfo>,
30 pub iterations: usize,
31}
32
33#[derive(Serialize)]
34pub struct ToolCallInfo {
35 pub id: String,
36 pub name: String,
37 pub output: String,
38 pub is_error: bool,
39}
40
41pub async fn chat(
42 State(state): State<AppState>,
43 Json(request): Json<ChatRequest>,
44) -> Result<Json<ChatResponse>> {
45 let session_id = request.session_id.and_then(|s| Id::from_str(&s).ok());
46 let session = state.get_or_create_session(session_id).await;
47
48 let response = state
49 .agent
50 .process(&session, &request.message)
51 .await
52 .map_err(ServerError::from)?;
53
54 Ok(Json(ChatResponse {
55 session_id: session.id().to_string(),
56 response: response.text,
57 tool_calls: response
58 .tool_calls
59 .into_iter()
60 .map(|tc| ToolCallInfo {
61 id: tc.id,
62 name: tc.name,
63 output: tc.output,
64 is_error: tc.is_error,
65 })
66 .collect(),
67 iterations: response.iterations,
68 }))
69}
70
71pub async fn chat_stream(
72 State(state): State<AppState>,
73 Json(request): Json<ChatRequest>,
74) -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
75 let session_id = request.session_id.and_then(|s| Id::from_str(&s).ok());
76 let session = state.get_or_create_session(session_id).await;
77 let message = request.message;
78
79 let event_stream = state.agent.chat_stream(&session, &message);
80
81 let sse_stream = event_stream.map(|event| {
82 let data = match &event {
83 AgentEvent::MessageStart { session_id } => {
84 serde_json::json!({
85 "type": "message_start",
86 "session_id": session_id.to_string()
87 })
88 }
89 AgentEvent::TextDelta { delta } => {
90 serde_json::json!({
91 "type": "text_delta",
92 "delta": delta
93 })
94 }
95 AgentEvent::TextComplete { text } => {
96 serde_json::json!({
97 "type": "text_complete",
98 "text": text
99 })
100 }
101 AgentEvent::ToolUseStart { id, name, input } => {
102 serde_json::json!({
103 "type": "tool_use_start",
104 "id": id,
105 "name": name,
106 "input": input
107 })
108 }
109 AgentEvent::ToolUseEnd { id, output, is_error } => {
110 serde_json::json!({
111 "type": "tool_use_end",
112 "id": id,
113 "output": output,
114 "is_error": is_error
115 })
116 }
117 AgentEvent::Thinking { content } => {
118 serde_json::json!({
119 "type": "thinking",
120 "content": content
121 })
122 }
123 AgentEvent::MessageEnd { session_id } => {
124 serde_json::json!({
125 "type": "message_end",
126 "session_id": session_id.to_string()
127 })
128 }
129 AgentEvent::Error { message } => {
130 serde_json::json!({
131 "type": "error",
132 "message": message
133 })
134 }
135 };
136
137 Ok(Event::default().data(data.to_string()))
138 });
139
140 Sse::new(sse_stream).keep_alive(KeepAlive::default())
141}