Skip to main content

codex_relay/
stream.rs

1use async_stream::stream;
2use axum::response::{
3    sse::{Event, KeepAlive},
4    Sse,
5};
6use eventsource_stream::Eventsource as EventsourceExt;
7use futures_util::StreamExt;
8use serde_json::{json, Value};
9use std::collections::BTreeMap;
10use std::sync::Arc;
11use tracing::{error, warn};
12
13use crate::{
14    session::SessionStore,
15    types::{ChatMessage, ChatRequest, ChatStreamChunk},
16};
17
18pub struct StreamArgs {
19    pub client: reqwest::Client,
20    pub url: String,
21    pub api_key: Arc<String>,
22    pub chat_req: ChatRequest,
23    pub response_id: String,
24    pub sessions: SessionStore,
25    pub prior_messages: Vec<ChatMessage>,
26    /// The fully translated request messages (including replayed history).
27    /// Used to save correct session history so turn-level reasoning can be
28    /// recovered when Codex replays the conversation without previous_response_id.
29    pub request_messages: Vec<ChatMessage>,
30    pub model: String,
31}
32
33struct ToolCallAccum {
34    id: String,
35    name: String,
36    arguments: String,
37}
38
39/// Translate an upstream Chat Completions SSE stream into a Responses API SSE stream.
40///
41/// Text response event sequence:
42///   response.created → response.output_item.added (message) → response.output_text.delta*
43///   → response.output_item.done → response.completed
44///
45/// Tool call response event sequence:
46///   response.created → [accumulate deltas] → response.output_item.added (function_call)
47///   → response.function_call_arguments.delta → response.output_item.done → response.completed
48pub fn translate_stream(
49    args: StreamArgs,
50) -> Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>> {
51    let StreamArgs {
52        client,
53        url,
54        api_key,
55        chat_req,
56        response_id,
57        sessions,
58        prior_messages,
59        request_messages,
60        model,
61    } = args;
62    let msg_item_id = format!("msg_{}", uuid::Uuid::new_v4().simple());
63
64    let event_stream = stream! {
65        yield Ok(Event::default()
66            .event("response.created")
67            .data(json!({
68                "type": "response.created",
69                "response": { "id": &response_id, "status": "in_progress", "model": &model }
70            }).to_string()));
71
72        let mut builder = client.post(&url).header("Content-Type", "application/json");
73        if !api_key.is_empty() {
74            builder = builder.bearer_auth(api_key.as_str());
75        }
76
77        let upstream = match builder.json(&chat_req).send().await {
78            Ok(r) if r.status().is_success() => r,
79            Ok(r) => {
80                let status = r.status();
81                let body = r.text().await.unwrap_or_default();
82                error!("upstream {status}: {body}");
83                yield Ok(Event::default().event("response.failed").data(
84                    json!({"type": "response.failed", "response": {"id": &response_id, "status": "failed", "error": {"code": status.as_u16().to_string(), "message": body}}}).to_string()
85                ));
86                return;
87            }
88            Err(e) => {
89                error!("upstream request failed: {e}");
90                yield Ok(Event::default().event("response.failed").data(
91                    json!({"type": "response.failed", "response": {"id": &response_id, "status": "failed", "error": {"code": "connection_error", "message": e.to_string()}}}).to_string()
92                ));
93                return;
94            }
95        };
96
97        let mut accumulated_text = String::new();
98        let mut accumulated_reasoning = String::new();
99        let mut tool_calls: BTreeMap<usize, ToolCallAccum> = BTreeMap::new();
100        let mut emitted_message_item = false;
101        let mut source = upstream.bytes_stream().eventsource();
102
103        while let Some(ev) = source.next().await {
104            match ev {
105                Err(e) => {
106                    warn!("SSE parse error: {e}");
107                    break;
108                }
109                Ok(ev) if ev.data.trim() == "[DONE]" => break,
110                Ok(ev) if ev.data.is_empty() => continue,
111                Ok(ev) => {
112                    match serde_json::from_str::<ChatStreamChunk>(&ev.data) {
113                        Err(e) => warn!("chunk parse error: {e} — data: {}", ev.data),
114                        Ok(chunk) => {
115                            for choice in &chunk.choices {
116                                // Reasoning/thinking content (kimi-k2.6 etc.)
117                                if let Some(rc) = choice.delta.reasoning_content.as_deref() {
118                                    if !rc.is_empty() {
119                                        accumulated_reasoning.push_str(rc);
120                                    }
121                                }
122
123                                // Text content
124                                let content = choice.delta.content.as_deref().unwrap_or("");
125                                if !content.is_empty() {
126                                    if !emitted_message_item {
127                                        yield Ok(Event::default()
128                                            .event("response.output_item.added")
129                                            .data(json!({
130                                                "type": "response.output_item.added",
131                                                "output_index": 0,
132                                                "item": { "type": "message", "id": &msg_item_id, "role": "assistant", "content": [], "status": "in_progress" }
133                                            }).to_string()));
134                                        emitted_message_item = true;
135                                    }
136                                    accumulated_text.push_str(content);
137                                    yield Ok(Event::default()
138                                        .event("response.output_text.delta")
139                                        .data(json!({
140                                            "type": "response.output_text.delta",
141                                            "item_id": &msg_item_id,
142                                            "output_index": 0,
143                                            "content_index": 0,
144                                            "delta": content
145                                        }).to_string()));
146                                }
147
148                                // Tool call deltas — accumulate by index
149                                if let Some(delta_calls) = &choice.delta.tool_calls {
150                                    for dc in delta_calls {
151                                        let entry = tool_calls.entry(dc.index).or_insert(ToolCallAccum {
152                                            id: String::new(),
153                                            name: String::new(),
154                                            arguments: String::new(),
155                                        });
156                                        if let Some(id) = &dc.id {
157                                            if !id.is_empty() { entry.id.clone_from(id); }
158                                        }
159                                        if let Some(func) = &dc.function {
160                                            if let Some(n) = &func.name {
161                                                if !n.is_empty() { entry.name.push_str(n); }
162                                            }
163                                            if let Some(a) = &func.arguments {
164                                                entry.arguments.push_str(a);
165                                            }
166                                        }
167                                    }
168                                }
169                            }
170                        }
171                    }
172                }
173            }
174        }
175
176        // Close message item if one was opened
177        if emitted_message_item {
178            yield Ok(Event::default()
179                .event("response.output_item.done")
180                .data(json!({
181                    "type": "response.output_item.done",
182                    "output_index": 0,
183                    "item": {
184                        "type": "message",
185                        "id": &msg_item_id,
186                        "role": "assistant",
187                        "status": "completed",
188                        "content": [{"type": "output_text", "text": &accumulated_text}]
189                    }
190                }).to_string()));
191        }
192
193        // Emit function_call items for each accumulated tool call
194        let base_index: usize = if emitted_message_item { 1 } else { 0 };
195        let mut fc_items: Vec<Value> = Vec::new();
196
197        for (rel_idx, (_, tc)) in tool_calls.iter().enumerate() {
198            let fc_item_id = format!("fc_{}", uuid::Uuid::new_v4().simple());
199            let output_index = base_index + rel_idx;
200
201            yield Ok(Event::default()
202                .event("response.output_item.added")
203                .data(json!({
204                    "type": "response.output_item.added",
205                    "output_index": output_index,
206                    "item": {
207                        "type": "function_call",
208                        "id": &fc_item_id,
209                        "call_id": &tc.id,
210                        "name": &tc.name,
211                        "arguments": "",
212                        "status": "in_progress"
213                    }
214                }).to_string()));
215
216            if !tc.arguments.is_empty() {
217                yield Ok(Event::default()
218                    .event("response.function_call_arguments.delta")
219                    .data(json!({
220                        "type": "response.function_call_arguments.delta",
221                        "item_id": &fc_item_id,
222                        "output_index": output_index,
223                        "delta": &tc.arguments
224                    }).to_string()));
225            }
226
227            yield Ok(Event::default()
228                .event("response.output_item.done")
229                .data(json!({
230                    "type": "response.output_item.done",
231                    "output_index": output_index,
232                    "item": {
233                        "type": "function_call",
234                        "id": &fc_item_id,
235                        "call_id": &tc.id,
236                        "name": &tc.name,
237                        "arguments": &tc.arguments,
238                        "status": "completed"
239                    }
240                }).to_string()));
241
242            fc_items.push(json!({
243                "type": "function_call",
244                "id": fc_item_id,
245                "call_id": &tc.id,
246                "name": &tc.name,
247                "arguments": &tc.arguments,
248                "status": "completed"
249            }));
250        }
251
252        // Persist turn to session store
253        // Store reasoning_content per call_id so translate.rs can inject it
254        // back when Codex replays function_call items in the next request.
255        for tc in tool_calls.values() {
256            if !tc.id.is_empty() {
257                sessions.store_reasoning(tc.id.clone(), accumulated_reasoning.clone());
258            }
259        }
260
261        let assistant_tool_calls: Option<Vec<Value>> = if tool_calls.is_empty() {
262            None
263        } else {
264            Some(tool_calls.values().map(|tc| json!({
265                "id": &tc.id,
266                "type": "function",
267                "function": { "name": &tc.name, "arguments": &tc.arguments }
268            })).collect())
269        };
270        let assistant_msg = ChatMessage {
271            role: "assistant".into(),
272            content: if accumulated_text.is_empty() { None } else { Some(accumulated_text.clone()) },
273            reasoning_content: if accumulated_reasoning.is_empty() { None } else { Some(accumulated_reasoning.clone()) },
274            tool_calls: assistant_tool_calls,
275            tool_call_id: None,
276            name: None,
277        };
278
279        // Index reasoning by turn fingerprint so it can be recovered when
280        // Codex replays the full conversation in input[] without previous_response_id.
281        if !accumulated_reasoning.is_empty() {
282            sessions.store_turn_reasoning(&request_messages, &assistant_msg, accumulated_reasoning.clone());
283        }
284
285        let mut messages = prior_messages;
286        messages.push(assistant_msg);
287        sessions.save_with_id(response_id.clone(), messages);
288
289        // Build output array for response.completed
290        let mut output_items: Vec<Value> = Vec::new();
291        if emitted_message_item {
292            output_items.push(json!({
293                "type": "message",
294                "id": &msg_item_id,
295                "role": "assistant",
296                "status": "completed",
297                "content": [{"type": "output_text", "text": &accumulated_text}]
298            }));
299        }
300        output_items.extend(fc_items);
301
302        yield Ok(Event::default()
303            .event("response.completed")
304            .data(json!({
305                "type": "response.completed",
306                "response": {
307                    "id": &response_id,
308                    "status": "completed",
309                    "model": &model,
310                    "output": output_items
311                }
312            }).to_string()));
313    };
314
315    Sse::new(event_stream).keep_alive(KeepAlive::default())
316}