Skip to main content

llm/providers/codex/
streaming.rs

1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8/// Process a typed `ResponseStreamEvent` stream into `LlmResponse` items.
9pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11    T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13    async_stream::stream! {
14        let message_id = uuid::Uuid::new_v4().to_string();
15        yield Ok(LlmResponse::Start { message_id });
16
17        let mut tool_collector = ToolCallCollector::<u32>::new();
18        let mut stream = Box::pin(stream);
19        let mut last_stop_reason: Option<StopReason> = None;
20
21        while let Some(result) = stream.next().await {
22            match result {
23                Ok(event) => {
24                    for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25                        yield response;
26                    }
27                }
28                Err(e) => {
29                    yield Err(LlmError::ApiError(e.to_string()));
30                    break;
31                }
32            }
33        }
34
35        // Complete any pending tool calls
36        for tc in tool_collector.complete_all() {
37            yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38        }
39
40        yield Ok(LlmResponse::Done {
41            stop_reason: last_stop_reason,
42        });
43    }
44}
45
46fn process_event(
47    event: ResponseStreamEvent,
48    tool_collector: &mut ToolCallCollector<u32>,
49    last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51    let mut responses = Vec::new();
52
53    match event {
54        ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
55            responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
56        }
57        ResponseStreamEvent::ResponseOutputItemAdded(e) => {
58            if let OutputItem::FunctionCall(call) = e.item {
59                let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
60                responses.extend(tool_responses.into_iter().map(Ok));
61            }
62        }
63        ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
64            let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
65            responses.extend(tool_responses.into_iter().map(Ok));
66        }
67        ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
68            if let Some(tc) = tool_collector.complete_one(e.output_index) {
69                responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
70            }
71        }
72        ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
73            responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
74        }
75        ResponseStreamEvent::ResponseOutputItemDone(e) => {
76            if let OutputItem::Reasoning(reasoning) = e.item
77                && let Some(id) = reasoning.id
78                && let Some(encrypted) = reasoning.encrypted_content
79            {
80                responses.push(Ok(LlmResponse::EncryptedReasoning { id, content: encrypted }));
81            }
82        }
83        ResponseStreamEvent::ResponseCompleted(e) => {
84            if let Some(usage) = e.response.usage {
85                responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
86            }
87            match e.response.status {
88                Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
89                Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
90                _ => {}
91            }
92        }
93        ResponseStreamEvent::ResponseError(e) => {
94            responses.push(Err(LlmError::ApiError(format!("Codex API error: {}", e.message))));
95        }
96        // Events we don't need to act on
97        _ => {}
98    }
99
100    responses
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::TokenUsage;
107    use async_openai::types::responses::{
108        FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
109        ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
110        ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
111    };
112    /// Build a minimal `Response` with given status and optional usage via JSON deserialization.
113    fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
114        let status_str = serde_json::to_value(status).unwrap();
115        let mut json = serde_json::json!({
116            "id": "resp_1",
117            "object": "response",
118            "status": status_str,
119            "output": [],
120            "model": "test",
121            "created_at": 0
122        });
123        if let Some(u) = usage {
124            json["usage"] = serde_json::to_value(u).unwrap();
125        }
126        serde_json::from_value(json).unwrap()
127    }
128
129    fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
130        make_usage_full(input_tokens, output_tokens, 0, 0)
131    }
132
133    fn make_usage_full(
134        input_tokens: u32,
135        output_tokens: u32,
136        cached_tokens: u32,
137        reasoning_tokens: u32,
138    ) -> ResponseUsage {
139        serde_json::from_value(serde_json::json!({
140            "input_tokens": input_tokens,
141            "input_tokens_details": { "cached_tokens": cached_tokens },
142            "output_tokens": output_tokens,
143            "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
144            "total_tokens": input_tokens + output_tokens
145        }))
146        .unwrap()
147    }
148
149    fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
150        tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
151    }
152
153    #[tokio::test]
154    async fn test_text_stream() {
155        let events = vec![
156            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
157                output_index: 0,
158                content_index: 0,
159                delta: "Hello".to_string(),
160                sequence_number: 1,
161                item_id: "msg_1".to_string(),
162                logprobs: None,
163            }),
164            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
165                output_index: 0,
166                content_index: 0,
167                delta: " world".to_string(),
168                sequence_number: 2,
169                item_id: "msg_1".to_string(),
170                logprobs: None,
171            }),
172            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
173                sequence_number: 3,
174                response: make_response(&Status::Completed, Some(make_usage(10, 5))),
175            }),
176        ];
177
178        let stream = make_stream(events);
179        let mut response_stream = Box::pin(process_response_stream(stream));
180
181        let mut responses = Vec::new();
182        while let Some(result) = response_stream.next().await {
183            responses.push(result.unwrap());
184        }
185
186        assert!(matches!(responses[0], LlmResponse::Start { .. }));
187        assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
188        assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
189        assert!(matches!(
190            responses[3],
191            LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
192        ));
193        assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
194    }
195
196    #[tokio::test]
197    async fn test_tool_call_stream() {
198        let events = vec![
199            ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
200                sequence_number: 1,
201                output_index: 0,
202                item: OutputItem::FunctionCall(FunctionToolCall {
203                    id: Some("fc_1".to_string()),
204                    call_id: "call_1".to_string(),
205                    name: "read_file".to_string(),
206                    arguments: String::new(),
207                    status: None,
208                    namespace: None,
209                }),
210            }),
211            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
212                sequence_number: 2,
213                item_id: "fc_1".to_string(),
214                output_index: 0,
215                delta: r#"{"path":"#.to_string(),
216            }),
217            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
218                sequence_number: 3,
219                item_id: "fc_1".to_string(),
220                output_index: 0,
221                delta: r#""foo.rs"}"#.to_string(),
222            }),
223            ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
224                sequence_number: 4,
225                item_id: "fc_1".to_string(),
226                output_index: 0,
227                arguments: r#"{"path":"foo.rs"}"#.to_string(),
228                name: None,
229            }),
230            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
231                sequence_number: 5,
232                response: make_response(&Status::Completed, Some(make_usage(20, 10))),
233            }),
234        ];
235
236        let stream = make_stream(events);
237        let mut response_stream = Box::pin(process_response_stream(stream));
238
239        let mut responses = Vec::new();
240        while let Some(result) = response_stream.next().await {
241            responses.push(result.unwrap());
242        }
243
244        assert!(matches!(responses[0], LlmResponse::Start { .. }));
245        assert!(
246            matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
247        );
248        assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
249        assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
250
251        let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
252        assert!(tc.is_some());
253        if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
254            assert_eq!(tool_call.id, "fc_1");
255            assert_eq!(tool_call.name, "read_file");
256            assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
257        }
258    }
259
260    #[tokio::test]
261    async fn test_error_event() {
262        let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
263            sequence_number: 1,
264            code: None,
265            message: "Rate limit exceeded".to_string(),
266            param: None,
267        })];
268
269        let stream = make_stream(events);
270        let mut response_stream = Box::pin(process_response_stream(stream));
271
272        let mut responses = Vec::new();
273        while let Some(result) = response_stream.next().await {
274            responses.push(result);
275        }
276
277        assert!(responses[0].is_ok()); // Start
278        assert!(responses[1].is_err()); // Error
279    }
280
281    #[tokio::test]
282    async fn test_reasoning_delta() {
283        let events = vec![
284            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
285                sequence_number: 1,
286                item_id: "r_1".to_string(),
287                output_index: 0,
288                summary_index: 0,
289                delta: "Thinking about".to_string(),
290            }),
291            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
292                sequence_number: 2,
293                item_id: "r_1".to_string(),
294                output_index: 0,
295                summary_index: 0,
296                delta: " the problem".to_string(),
297            }),
298            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
299                sequence_number: 3,
300                response: make_response(&Status::Completed, None),
301            }),
302        ];
303
304        let stream = make_stream(events);
305        let mut response_stream = Box::pin(process_response_stream(stream));
306
307        let mut responses = Vec::new();
308        while let Some(result) = response_stream.next().await {
309            responses.push(result.unwrap());
310        }
311
312        assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
313        assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
314    }
315
316    #[tokio::test]
317    async fn test_incomplete_status_gives_length_stop_reason() {
318        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
319            sequence_number: 1,
320            response: make_response(&Status::Incomplete, None),
321        })];
322
323        let stream = make_stream(events);
324        let mut response_stream = Box::pin(process_response_stream(stream));
325
326        let mut responses = Vec::new();
327        while let Some(result) = response_stream.next().await {
328            responses.push(result.unwrap());
329        }
330
331        assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
332    }
333
334    #[tokio::test]
335    async fn test_stream_error_propagation() {
336        let events: Vec<Result<ResponseStreamEvent>> = vec![Err(LlmError::ApiError("connection lost".to_string()))];
337
338        let stream = tokio_stream::iter(events);
339        let mut response_stream = Box::pin(process_response_stream(stream));
340
341        let mut responses = Vec::new();
342        while let Some(result) = response_stream.next().await {
343            responses.push(result);
344        }
345
346        assert!(responses[0].is_ok()); // Start
347        assert!(responses[1].is_err()); // Stream error
348    }
349
350    #[test]
351    fn test_encrypted_reasoning_from_output_item_done() {
352        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
353            sequence_number: 1,
354            output_index: 0,
355            item: OutputItem::Reasoning(ReasoningItem {
356                id: Some("r_1".to_string()),
357                summary: vec![],
358                encrypted_content: Some("enc-blob-data".to_string()),
359                content: None,
360                status: None,
361            }),
362        });
363
364        let mut tool_collector = ToolCallCollector::<u32>::new();
365        let mut stop_reason = None;
366        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
367
368        assert_eq!(responses.len(), 1);
369        assert!(
370            matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
371        );
372    }
373
374    #[tokio::test]
375    async fn test_usage_forwards_reasoning_and_cache_read() {
376        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
377            sequence_number: 1,
378            response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
379        })];
380
381        let stream = make_stream(events);
382        let mut response_stream = Box::pin(process_response_stream(stream));
383
384        let mut responses = Vec::new();
385        while let Some(result) = response_stream.next().await {
386            responses.push(result.unwrap());
387        }
388
389        let usage = responses.iter().find_map(|r| match r {
390            LlmResponse::Usage { tokens } => Some(*tokens),
391            _ => None,
392        });
393
394        assert_eq!(
395            usage,
396            Some(TokenUsage {
397                input_tokens: 120,
398                output_tokens: 80,
399                cache_read_tokens: Some(50),
400                reasoning_tokens: Some(30),
401                ..TokenUsage::default()
402            })
403        );
404    }
405
406    #[test]
407    fn test_output_item_done_without_encrypted_content_is_ignored() {
408        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
409            sequence_number: 1,
410            output_index: 0,
411            item: OutputItem::Reasoning(ReasoningItem {
412                id: Some("r_2".to_string()),
413                summary: vec![],
414                encrypted_content: None,
415                content: None,
416                status: None,
417            }),
418        });
419
420        let mut tool_collector = ToolCallCollector::<u32>::new();
421        let mut stop_reason = None;
422        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
423
424        assert!(responses.is_empty());
425    }
426}