Skip to main content

llm/providers/codex/
streaming.rs

1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8/// Process a typed `ResponseStreamEvent` stream into `LlmResponse` items.
9pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11    T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13    async_stream::stream! {
14        let message_id = uuid::Uuid::new_v4().to_string();
15        yield Ok(LlmResponse::Start { message_id });
16
17        let mut tool_collector = ToolCallCollector::<u32>::new();
18        let mut stream = Box::pin(stream);
19        let mut last_stop_reason: Option<StopReason> = None;
20
21        while let Some(result) = stream.next().await {
22            match result {
23                Ok(event) => {
24                    for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25                        yield response;
26                    }
27                }
28                Err(e) => {
29                    yield Err(LlmError::StreamInterrupted(e.to_string()));
30                    break;
31                }
32            }
33        }
34
35        // Complete any pending tool calls
36        for tc in tool_collector.complete_all() {
37            yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38        }
39
40        yield Ok(LlmResponse::Done {
41            stop_reason: last_stop_reason,
42        });
43    }
44}
45
46fn process_event(
47    event: ResponseStreamEvent,
48    tool_collector: &mut ToolCallCollector<u32>,
49    last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51    let mut responses = Vec::new();
52
53    match event {
54        ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
55            responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
56        }
57        ResponseStreamEvent::ResponseOutputItemAdded(e) => {
58            if let OutputItem::FunctionCall(call) = e.item {
59                let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
60                responses.extend(tool_responses.into_iter().map(Ok));
61            }
62        }
63        ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
64            let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
65            responses.extend(tool_responses.into_iter().map(Ok));
66        }
67        ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
68            if let Some(tc) = tool_collector.complete_one(e.output_index) {
69                responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
70            }
71        }
72        ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
73            responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
74        }
75        ResponseStreamEvent::ResponseOutputItemDone(e) => {
76            if let OutputItem::Reasoning(reasoning) = e.item
77                && let Some(id) = reasoning.id
78                && let Some(encrypted) = reasoning.encrypted_content
79            {
80                responses.push(Ok(LlmResponse::EncryptedReasoning { id, content: encrypted }));
81            }
82        }
83        ResponseStreamEvent::ResponseCompleted(e) => {
84            if let Some(usage) = e.response.usage {
85                responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
86            }
87            match e.response.status {
88                Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
89                Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
90                _ => {}
91            }
92        }
93        ResponseStreamEvent::ResponseError(e) => {
94            responses
95                .push(Err(LlmError::ServerError { status: None, message: format!("Codex API error: {}", e.message) }));
96        }
97        // Events we don't need to act on
98        _ => {}
99    }
100
101    responses
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::TokenUsage;
108    use async_openai::types::responses::{
109        FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
110        ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
111        ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
112    };
113    /// Build a minimal `Response` with given status and optional usage via JSON deserialization.
114    fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
115        let status_str = serde_json::to_value(status).unwrap();
116        let mut json = serde_json::json!({
117            "id": "resp_1",
118            "object": "response",
119            "status": status_str,
120            "output": [],
121            "model": "test",
122            "created_at": 0
123        });
124        if let Some(u) = usage {
125            json["usage"] = serde_json::to_value(u).unwrap();
126        }
127        serde_json::from_value(json).unwrap()
128    }
129
130    fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
131        make_usage_full(input_tokens, output_tokens, 0, 0)
132    }
133
134    fn make_usage_full(
135        input_tokens: u32,
136        output_tokens: u32,
137        cached_tokens: u32,
138        reasoning_tokens: u32,
139    ) -> ResponseUsage {
140        serde_json::from_value(serde_json::json!({
141            "input_tokens": input_tokens,
142            "input_tokens_details": { "cached_tokens": cached_tokens },
143            "output_tokens": output_tokens,
144            "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
145            "total_tokens": input_tokens + output_tokens
146        }))
147        .unwrap()
148    }
149
150    fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
151        tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
152    }
153
154    #[tokio::test]
155    async fn test_text_stream() {
156        let events = vec![
157            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
158                output_index: 0,
159                content_index: 0,
160                delta: "Hello".to_string(),
161                sequence_number: 1,
162                item_id: "msg_1".to_string(),
163                logprobs: None,
164            }),
165            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
166                output_index: 0,
167                content_index: 0,
168                delta: " world".to_string(),
169                sequence_number: 2,
170                item_id: "msg_1".to_string(),
171                logprobs: None,
172            }),
173            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
174                sequence_number: 3,
175                response: make_response(&Status::Completed, Some(make_usage(10, 5))),
176            }),
177        ];
178
179        let stream = make_stream(events);
180        let mut response_stream = Box::pin(process_response_stream(stream));
181
182        let mut responses = Vec::new();
183        while let Some(result) = response_stream.next().await {
184            responses.push(result.unwrap());
185        }
186
187        assert!(matches!(responses[0], LlmResponse::Start { .. }));
188        assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
189        assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
190        assert!(matches!(
191            responses[3],
192            LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
193        ));
194        assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
195    }
196
197    #[tokio::test]
198    async fn test_tool_call_stream() {
199        let events = vec![
200            ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
201                sequence_number: 1,
202                output_index: 0,
203                item: OutputItem::FunctionCall(FunctionToolCall {
204                    id: Some("fc_1".to_string()),
205                    call_id: "call_1".to_string(),
206                    name: "read_file".to_string(),
207                    arguments: String::new(),
208                    status: None,
209                    namespace: None,
210                }),
211            }),
212            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
213                sequence_number: 2,
214                item_id: "fc_1".to_string(),
215                output_index: 0,
216                delta: r#"{"path":"#.to_string(),
217            }),
218            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
219                sequence_number: 3,
220                item_id: "fc_1".to_string(),
221                output_index: 0,
222                delta: r#""foo.rs"}"#.to_string(),
223            }),
224            ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
225                sequence_number: 4,
226                item_id: "fc_1".to_string(),
227                output_index: 0,
228                arguments: r#"{"path":"foo.rs"}"#.to_string(),
229                name: None,
230            }),
231            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
232                sequence_number: 5,
233                response: make_response(&Status::Completed, Some(make_usage(20, 10))),
234            }),
235        ];
236
237        let stream = make_stream(events);
238        let mut response_stream = Box::pin(process_response_stream(stream));
239
240        let mut responses = Vec::new();
241        while let Some(result) = response_stream.next().await {
242            responses.push(result.unwrap());
243        }
244
245        assert!(matches!(responses[0], LlmResponse::Start { .. }));
246        assert!(
247            matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
248        );
249        assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
250        assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
251
252        let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
253        assert!(tc.is_some());
254        if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
255            assert_eq!(tool_call.id, "fc_1");
256            assert_eq!(tool_call.name, "read_file");
257            assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
258        }
259    }
260
261    #[tokio::test]
262    async fn test_error_event_is_retryable_server_error() {
263        let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
264            sequence_number: 1,
265            code: None,
266            message: "Rate limit exceeded".to_string(),
267            param: None,
268        })];
269
270        let stream = make_stream(events);
271        let mut response_stream = Box::pin(process_response_stream(stream));
272
273        let mut responses = Vec::new();
274        while let Some(result) = response_stream.next().await {
275            responses.push(result);
276        }
277
278        assert!(responses[0].is_ok());
279        let err = responses[1].as_ref().expect_err("expected ResponseError to surface as Err");
280        assert!(matches!(err, LlmError::ServerError { status: None, .. }), "got {err:?}");
281        assert!(err.is_retryable(), "ResponseError must be retryable so the agent can recover");
282    }
283
284    #[tokio::test]
285    async fn test_reasoning_delta() {
286        let events = vec![
287            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
288                sequence_number: 1,
289                item_id: "r_1".to_string(),
290                output_index: 0,
291                summary_index: 0,
292                delta: "Thinking about".to_string(),
293            }),
294            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
295                sequence_number: 2,
296                item_id: "r_1".to_string(),
297                output_index: 0,
298                summary_index: 0,
299                delta: " the problem".to_string(),
300            }),
301            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
302                sequence_number: 3,
303                response: make_response(&Status::Completed, None),
304            }),
305        ];
306
307        let stream = make_stream(events);
308        let mut response_stream = Box::pin(process_response_stream(stream));
309
310        let mut responses = Vec::new();
311        while let Some(result) = response_stream.next().await {
312            responses.push(result.unwrap());
313        }
314
315        assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
316        assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
317    }
318
319    #[tokio::test]
320    async fn test_incomplete_status_gives_length_stop_reason() {
321        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
322            sequence_number: 1,
323            response: make_response(&Status::Incomplete, None),
324        })];
325
326        let stream = make_stream(events);
327        let mut response_stream = Box::pin(process_response_stream(stream));
328
329        let mut responses = Vec::new();
330        while let Some(result) = response_stream.next().await {
331            responses.push(result.unwrap());
332        }
333
334        assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
335    }
336
337    #[tokio::test]
338    async fn test_stream_error_propagation_is_retryable() {
339        let events: Vec<Result<ResponseStreamEvent>> =
340            vec![Err(LlmError::StreamInterrupted("connection lost".to_string()))];
341
342        let stream = tokio_stream::iter(events);
343        let mut response_stream = Box::pin(process_response_stream(stream));
344
345        let mut responses = Vec::new();
346        while let Some(result) = response_stream.next().await {
347            responses.push(result);
348        }
349
350        assert!(responses[0].is_ok());
351        let err = responses[1].as_ref().expect_err("expected upstream Err to surface as Err");
352        assert!(matches!(err, LlmError::StreamInterrupted(_)), "got {err:?}");
353        assert!(err.is_retryable(), "mid-stream interrupts must be retryable");
354    }
355
356    #[test]
357    fn test_encrypted_reasoning_from_output_item_done() {
358        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
359            sequence_number: 1,
360            output_index: 0,
361            item: OutputItem::Reasoning(ReasoningItem {
362                id: Some("r_1".to_string()),
363                summary: vec![],
364                encrypted_content: Some("enc-blob-data".to_string()),
365                content: None,
366                status: None,
367            }),
368        });
369
370        let mut tool_collector = ToolCallCollector::<u32>::new();
371        let mut stop_reason = None;
372        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
373
374        assert_eq!(responses.len(), 1);
375        assert!(
376            matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
377        );
378    }
379
380    #[tokio::test]
381    async fn test_usage_forwards_reasoning_and_cache_read() {
382        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
383            sequence_number: 1,
384            response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
385        })];
386
387        let stream = make_stream(events);
388        let mut response_stream = Box::pin(process_response_stream(stream));
389
390        let mut responses = Vec::new();
391        while let Some(result) = response_stream.next().await {
392            responses.push(result.unwrap());
393        }
394
395        let usage = responses.iter().find_map(|r| match r {
396            LlmResponse::Usage { tokens } => Some(*tokens),
397            _ => None,
398        });
399
400        assert_eq!(
401            usage,
402            Some(TokenUsage {
403                input_tokens: 120,
404                output_tokens: 80,
405                cache_read_tokens: Some(50),
406                reasoning_tokens: Some(30),
407                ..TokenUsage::default()
408            })
409        );
410    }
411
412    #[test]
413    fn test_output_item_done_without_encrypted_content_is_ignored() {
414        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
415            sequence_number: 1,
416            output_index: 0,
417            item: OutputItem::Reasoning(ReasoningItem {
418                id: Some("r_2".to_string()),
419                summary: vec![],
420                encrypted_content: None,
421                content: None,
422                status: None,
423            }),
424        });
425
426        let mut tool_collector = ToolCallCollector::<u32>::new();
427        let mut stop_reason = None;
428        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
429
430        assert!(responses.is_empty());
431    }
432}