1use async_stream::stream;
2use axum::response::{
3 sse::{Event, KeepAlive},
4 Sse,
5};
6use eventsource_stream::Eventsource as EventsourceExt;
7use futures_util::StreamExt;
8use serde_json::{json, Value};
9use std::collections::BTreeMap;
10use std::sync::Arc;
11use tracing::{error, warn};
12
13use crate::{
14 session::SessionStore,
15 types::{ChatMessage, ChatRequest, ChatStreamChunk},
16};
17
18pub struct StreamArgs {
19 pub client: reqwest::Client,
20 pub url: String,
21 pub api_key: Arc<String>,
22 pub chat_req: ChatRequest,
23 pub response_id: String,
24 pub sessions: SessionStore,
25 pub prior_messages: Vec<ChatMessage>,
26 pub request_messages: Vec<ChatMessage>,
30 pub model: String,
31}
32
33struct ToolCallAccum {
34 id: String,
35 name: String,
36 arguments: String,
37}
38
39pub fn translate_stream(
49 args: StreamArgs,
50) -> Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>> {
51 let StreamArgs {
52 client,
53 url,
54 api_key,
55 chat_req,
56 response_id,
57 sessions,
58 prior_messages,
59 request_messages,
60 model,
61 } = args;
62 let msg_item_id = format!("msg_{}", uuid::Uuid::new_v4().simple());
63
64 let event_stream = stream! {
65 yield Ok(Event::default()
66 .event("response.created")
67 .data(json!({
68 "type": "response.created",
69 "response": { "id": &response_id, "status": "in_progress", "model": &model }
70 }).to_string()));
71
72 let mut builder = client.post(&url).header("Content-Type", "application/json");
73 if !api_key.is_empty() {
74 builder = builder.bearer_auth(api_key.as_str());
75 }
76
77 let upstream = match builder.json(&chat_req).send().await {
78 Ok(r) if r.status().is_success() => r,
79 Ok(r) => {
80 let status = r.status();
81 let body = r.text().await.unwrap_or_default();
82 error!("upstream {status}: {body}");
83 yield Ok(Event::default().event("response.failed").data(
84 json!({"type": "response.failed", "response": {"id": &response_id, "status": "failed", "error": {"code": status.as_u16().to_string(), "message": body}}}).to_string()
85 ));
86 return;
87 }
88 Err(e) => {
89 error!("upstream request failed: {e}");
90 yield Ok(Event::default().event("response.failed").data(
91 json!({"type": "response.failed", "response": {"id": &response_id, "status": "failed", "error": {"code": "connection_error", "message": e.to_string()}}}).to_string()
92 ));
93 return;
94 }
95 };
96
97 let mut accumulated_text = String::new();
98 let mut accumulated_reasoning = String::new();
99 let mut tool_calls: BTreeMap<usize, ToolCallAccum> = BTreeMap::new();
100 let mut emitted_message_item = false;
101 let mut source = upstream.bytes_stream().eventsource();
102
103 while let Some(ev) = source.next().await {
104 match ev {
105 Err(e) => {
106 warn!("SSE parse error: {e}");
107 break;
108 }
109 Ok(ev) if ev.data.trim() == "[DONE]" => break,
110 Ok(ev) if ev.data.is_empty() => continue,
111 Ok(ev) => {
112 match serde_json::from_str::<ChatStreamChunk>(&ev.data) {
113 Err(e) => warn!("chunk parse error: {e} — data: {}", ev.data),
114 Ok(chunk) => {
115 for choice in &chunk.choices {
116 if let Some(rc) = choice.delta.reasoning_content.as_deref() {
118 if !rc.is_empty() {
119 accumulated_reasoning.push_str(rc);
120 }
121 }
122
123 let content = choice.delta.content.as_deref().unwrap_or("");
125 if !content.is_empty() {
126 if !emitted_message_item {
127 yield Ok(Event::default()
128 .event("response.output_item.added")
129 .data(json!({
130 "type": "response.output_item.added",
131 "output_index": 0,
132 "item": { "type": "message", "id": &msg_item_id, "role": "assistant", "content": [], "status": "in_progress" }
133 }).to_string()));
134 emitted_message_item = true;
135 }
136 accumulated_text.push_str(content);
137 yield Ok(Event::default()
138 .event("response.output_text.delta")
139 .data(json!({
140 "type": "response.output_text.delta",
141 "item_id": &msg_item_id,
142 "output_index": 0,
143 "content_index": 0,
144 "delta": content
145 }).to_string()));
146 }
147
148 if let Some(delta_calls) = &choice.delta.tool_calls {
150 for dc in delta_calls {
151 let entry = tool_calls.entry(dc.index).or_insert(ToolCallAccum {
152 id: String::new(),
153 name: String::new(),
154 arguments: String::new(),
155 });
156 if let Some(id) = &dc.id {
157 if !id.is_empty() { entry.id.clone_from(id); }
158 }
159 if let Some(func) = &dc.function {
160 if let Some(n) = &func.name {
161 if !n.is_empty() { entry.name.push_str(n); }
162 }
163 if let Some(a) = &func.arguments {
164 entry.arguments.push_str(a);
165 }
166 }
167 }
168 }
169 }
170 }
171 }
172 }
173 }
174 }
175
176 if emitted_message_item {
178 yield Ok(Event::default()
179 .event("response.output_item.done")
180 .data(json!({
181 "type": "response.output_item.done",
182 "output_index": 0,
183 "item": {
184 "type": "message",
185 "id": &msg_item_id,
186 "role": "assistant",
187 "status": "completed",
188 "content": [{"type": "output_text", "text": &accumulated_text}]
189 }
190 }).to_string()));
191 }
192
193 let base_index: usize = if emitted_message_item { 1 } else { 0 };
195 let mut fc_items: Vec<Value> = Vec::new();
196
197 for (rel_idx, (_, tc)) in tool_calls.iter().enumerate() {
198 let fc_item_id = format!("fc_{}", uuid::Uuid::new_v4().simple());
199 let output_index = base_index + rel_idx;
200
201 yield Ok(Event::default()
202 .event("response.output_item.added")
203 .data(json!({
204 "type": "response.output_item.added",
205 "output_index": output_index,
206 "item": {
207 "type": "function_call",
208 "id": &fc_item_id,
209 "call_id": &tc.id,
210 "name": &tc.name,
211 "arguments": "",
212 "status": "in_progress"
213 }
214 }).to_string()));
215
216 if !tc.arguments.is_empty() {
217 yield Ok(Event::default()
218 .event("response.function_call_arguments.delta")
219 .data(json!({
220 "type": "response.function_call_arguments.delta",
221 "item_id": &fc_item_id,
222 "output_index": output_index,
223 "delta": &tc.arguments
224 }).to_string()));
225 }
226
227 yield Ok(Event::default()
228 .event("response.output_item.done")
229 .data(json!({
230 "type": "response.output_item.done",
231 "output_index": output_index,
232 "item": {
233 "type": "function_call",
234 "id": &fc_item_id,
235 "call_id": &tc.id,
236 "name": &tc.name,
237 "arguments": &tc.arguments,
238 "status": "completed"
239 }
240 }).to_string()));
241
242 fc_items.push(json!({
243 "type": "function_call",
244 "id": fc_item_id,
245 "call_id": &tc.id,
246 "name": &tc.name,
247 "arguments": &tc.arguments,
248 "status": "completed"
249 }));
250 }
251
252 for tc in tool_calls.values() {
256 if !tc.id.is_empty() {
257 sessions.store_reasoning(tc.id.clone(), accumulated_reasoning.clone());
258 }
259 }
260
261 let assistant_tool_calls: Option<Vec<Value>> = if tool_calls.is_empty() {
262 None
263 } else {
264 Some(tool_calls.values().map(|tc| json!({
265 "id": &tc.id,
266 "type": "function",
267 "function": { "name": &tc.name, "arguments": &tc.arguments }
268 })).collect())
269 };
270 let assistant_msg = ChatMessage {
271 role: "assistant".into(),
272 content: if accumulated_text.is_empty() { None } else { Some(accumulated_text.clone()) },
273 reasoning_content: if accumulated_reasoning.is_empty() { None } else { Some(accumulated_reasoning.clone()) },
274 tool_calls: assistant_tool_calls,
275 tool_call_id: None,
276 name: None,
277 };
278
279 if !accumulated_reasoning.is_empty() {
282 sessions.store_turn_reasoning(&request_messages, &assistant_msg, accumulated_reasoning.clone());
283 }
284
285 let mut messages = prior_messages;
286 messages.push(assistant_msg);
287 sessions.save_with_id(response_id.clone(), messages);
288
289 let mut output_items: Vec<Value> = Vec::new();
291 if emitted_message_item {
292 output_items.push(json!({
293 "type": "message",
294 "id": &msg_item_id,
295 "role": "assistant",
296 "status": "completed",
297 "content": [{"type": "output_text", "text": &accumulated_text}]
298 }));
299 }
300 output_items.extend(fc_items);
301
302 yield Ok(Event::default()
303 .event("response.completed")
304 .data(json!({
305 "type": "response.completed",
306 "response": {
307 "id": &response_id,
308 "status": "completed",
309 "model": &model,
310 "output": output_items
311 }
312 }).to_string()));
313 };
314
315 Sse::new(event_stream).keep_alive(KeepAlive::default())
316}