Skip to main content

llm/providers/codex/
streaming.rs

1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8/// Process a typed `ResponseStreamEvent` stream into `LlmResponse` items.
9pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11    T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13    async_stream::stream! {
14        let message_id = uuid::Uuid::new_v4().to_string();
15        yield Ok(LlmResponse::Start { message_id });
16
17        let mut tool_collector = ToolCallCollector::<u32>::new();
18        let mut stream = Box::pin(stream);
19        let mut last_stop_reason: Option<StopReason> = None;
20
21        while let Some(result) = stream.next().await {
22            match result {
23                Ok(event) => {
24                    for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25                        yield response;
26                    }
27                }
28                Err(e) => {
29                    yield Err(LlmError::ApiError(e.to_string()));
30                    break;
31                }
32            }
33        }
34
35        // Complete any pending tool calls
36        for tc in tool_collector.complete_all() {
37            yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38        }
39
40        yield Ok(LlmResponse::Done {
41            stop_reason: last_stop_reason,
42        });
43    }
44}
45
46fn process_event(
47    event: ResponseStreamEvent,
48    tool_collector: &mut ToolCallCollector<u32>,
49    last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51    let mut responses = Vec::new();
52
53    match event {
54        ResponseStreamEvent::ResponseOutputTextDelta(e) => {
55            if !e.delta.is_empty() {
56                responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
57            }
58        }
59        ResponseStreamEvent::ResponseOutputItemAdded(e) => {
60            if let OutputItem::FunctionCall(call) = e.item {
61                let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
62                responses.extend(tool_responses.into_iter().map(Ok));
63            }
64        }
65        ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
66            let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
67            responses.extend(tool_responses.into_iter().map(Ok));
68        }
69        ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
70            if let Some(tc) = tool_collector.complete_one(e.output_index) {
71                responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
72            }
73        }
74        ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) => {
75            if !e.delta.is_empty() {
76                responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
77            }
78        }
79        ResponseStreamEvent::ResponseOutputItemDone(e) => {
80            if let OutputItem::Reasoning(reasoning) = e.item
81                && let Some(encrypted) = reasoning.encrypted_content
82            {
83                responses.push(Ok(LlmResponse::EncryptedReasoning { id: reasoning.id, content: encrypted }));
84            }
85        }
86        ResponseStreamEvent::ResponseCompleted(e) => {
87            if let Some(usage) = e.response.usage {
88                responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
89            }
90            match e.response.status {
91                Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
92                Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
93                _ => {}
94            }
95        }
96        ResponseStreamEvent::ResponseError(e) => {
97            responses.push(Err(LlmError::ApiError(format!("Codex API error: {}", e.message))));
98        }
99        // Events we don't need to act on
100        _ => {}
101    }
102
103    responses
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::TokenUsage;
110    use async_openai::types::responses::{
111        FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
112        ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
113        ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
114    };
115    /// Build a minimal `Response` with given status and optional usage via JSON deserialization.
116    fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
117        let status_str = serde_json::to_value(status).unwrap();
118        let mut json = serde_json::json!({
119            "id": "resp_1",
120            "object": "response",
121            "status": status_str,
122            "output": [],
123            "model": "test",
124            "created_at": 0
125        });
126        if let Some(u) = usage {
127            json["usage"] = serde_json::to_value(u).unwrap();
128        }
129        serde_json::from_value(json).unwrap()
130    }
131
132    fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
133        make_usage_full(input_tokens, output_tokens, 0, 0)
134    }
135
136    fn make_usage_full(
137        input_tokens: u32,
138        output_tokens: u32,
139        cached_tokens: u32,
140        reasoning_tokens: u32,
141    ) -> ResponseUsage {
142        serde_json::from_value(serde_json::json!({
143            "input_tokens": input_tokens,
144            "input_tokens_details": { "cached_tokens": cached_tokens },
145            "output_tokens": output_tokens,
146            "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
147            "total_tokens": input_tokens + output_tokens
148        }))
149        .unwrap()
150    }
151
152    fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
153        tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
154    }
155
156    #[tokio::test]
157    async fn test_text_stream() {
158        let events = vec![
159            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
160                output_index: 0,
161                content_index: 0,
162                delta: "Hello".to_string(),
163                sequence_number: 1,
164                item_id: "msg_1".to_string(),
165                logprobs: None,
166            }),
167            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
168                output_index: 0,
169                content_index: 0,
170                delta: " world".to_string(),
171                sequence_number: 2,
172                item_id: "msg_1".to_string(),
173                logprobs: None,
174            }),
175            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
176                sequence_number: 3,
177                response: make_response(&Status::Completed, Some(make_usage(10, 5))),
178            }),
179        ];
180
181        let stream = make_stream(events);
182        let mut response_stream = Box::pin(process_response_stream(stream));
183
184        let mut responses = Vec::new();
185        while let Some(result) = response_stream.next().await {
186            responses.push(result.unwrap());
187        }
188
189        assert!(matches!(responses[0], LlmResponse::Start { .. }));
190        assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
191        assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
192        assert!(matches!(
193            responses[3],
194            LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
195        ));
196        assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
197    }
198
199    #[tokio::test]
200    async fn test_tool_call_stream() {
201        let events = vec![
202            ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
203                sequence_number: 1,
204                output_index: 0,
205                item: OutputItem::FunctionCall(FunctionToolCall {
206                    id: Some("fc_1".to_string()),
207                    call_id: "call_1".to_string(),
208                    name: "read_file".to_string(),
209                    arguments: String::new(),
210                    status: None,
211                    namespace: None,
212                }),
213            }),
214            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
215                sequence_number: 2,
216                item_id: "fc_1".to_string(),
217                output_index: 0,
218                delta: r#"{"path":"#.to_string(),
219            }),
220            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
221                sequence_number: 3,
222                item_id: "fc_1".to_string(),
223                output_index: 0,
224                delta: r#""foo.rs"}"#.to_string(),
225            }),
226            ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
227                sequence_number: 4,
228                item_id: "fc_1".to_string(),
229                output_index: 0,
230                arguments: r#"{"path":"foo.rs"}"#.to_string(),
231                name: None,
232            }),
233            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
234                sequence_number: 5,
235                response: make_response(&Status::Completed, Some(make_usage(20, 10))),
236            }),
237        ];
238
239        let stream = make_stream(events);
240        let mut response_stream = Box::pin(process_response_stream(stream));
241
242        let mut responses = Vec::new();
243        while let Some(result) = response_stream.next().await {
244            responses.push(result.unwrap());
245        }
246
247        assert!(matches!(responses[0], LlmResponse::Start { .. }));
248        assert!(
249            matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
250        );
251        assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
252        assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
253
254        let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
255        assert!(tc.is_some());
256        if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
257            assert_eq!(tool_call.id, "fc_1");
258            assert_eq!(tool_call.name, "read_file");
259            assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
260        }
261    }
262
263    #[tokio::test]
264    async fn test_error_event() {
265        let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
266            sequence_number: 1,
267            code: None,
268            message: "Rate limit exceeded".to_string(),
269            param: None,
270        })];
271
272        let stream = make_stream(events);
273        let mut response_stream = Box::pin(process_response_stream(stream));
274
275        let mut responses = Vec::new();
276        while let Some(result) = response_stream.next().await {
277            responses.push(result);
278        }
279
280        assert!(responses[0].is_ok()); // Start
281        assert!(responses[1].is_err()); // Error
282    }
283
284    #[tokio::test]
285    async fn test_reasoning_delta() {
286        let events = vec![
287            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
288                sequence_number: 1,
289                item_id: "r_1".to_string(),
290                output_index: 0,
291                summary_index: 0,
292                delta: "Thinking about".to_string(),
293            }),
294            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
295                sequence_number: 2,
296                item_id: "r_1".to_string(),
297                output_index: 0,
298                summary_index: 0,
299                delta: " the problem".to_string(),
300            }),
301            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
302                sequence_number: 3,
303                response: make_response(&Status::Completed, None),
304            }),
305        ];
306
307        let stream = make_stream(events);
308        let mut response_stream = Box::pin(process_response_stream(stream));
309
310        let mut responses = Vec::new();
311        while let Some(result) = response_stream.next().await {
312            responses.push(result.unwrap());
313        }
314
315        assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
316        assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
317    }
318
319    #[tokio::test]
320    async fn test_incomplete_status_gives_length_stop_reason() {
321        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
322            sequence_number: 1,
323            response: make_response(&Status::Incomplete, None),
324        })];
325
326        let stream = make_stream(events);
327        let mut response_stream = Box::pin(process_response_stream(stream));
328
329        let mut responses = Vec::new();
330        while let Some(result) = response_stream.next().await {
331            responses.push(result.unwrap());
332        }
333
334        assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
335    }
336
337    #[tokio::test]
338    async fn test_stream_error_propagation() {
339        let events: Vec<Result<ResponseStreamEvent>> = vec![Err(LlmError::ApiError("connection lost".to_string()))];
340
341        let stream = tokio_stream::iter(events);
342        let mut response_stream = Box::pin(process_response_stream(stream));
343
344        let mut responses = Vec::new();
345        while let Some(result) = response_stream.next().await {
346            responses.push(result);
347        }
348
349        assert!(responses[0].is_ok()); // Start
350        assert!(responses[1].is_err()); // Stream error
351    }
352
353    #[test]
354    fn test_encrypted_reasoning_from_output_item_done() {
355        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
356            sequence_number: 1,
357            output_index: 0,
358            item: OutputItem::Reasoning(ReasoningItem {
359                id: "r_1".to_string(),
360                summary: vec![],
361                encrypted_content: Some("enc-blob-data".to_string()),
362                content: None,
363                status: None,
364            }),
365        });
366
367        let mut tool_collector = ToolCallCollector::<u32>::new();
368        let mut stop_reason = None;
369        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
370
371        assert_eq!(responses.len(), 1);
372        assert!(
373            matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
374        );
375    }
376
377    #[tokio::test]
378    async fn test_usage_forwards_reasoning_and_cache_read() {
379        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
380            sequence_number: 1,
381            response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
382        })];
383
384        let stream = make_stream(events);
385        let mut response_stream = Box::pin(process_response_stream(stream));
386
387        let mut responses = Vec::new();
388        while let Some(result) = response_stream.next().await {
389            responses.push(result.unwrap());
390        }
391
392        let usage = responses.iter().find_map(|r| match r {
393            LlmResponse::Usage { tokens } => Some(*tokens),
394            _ => None,
395        });
396
397        assert_eq!(
398            usage,
399            Some(TokenUsage {
400                input_tokens: 120,
401                output_tokens: 80,
402                cache_read_tokens: Some(50),
403                reasoning_tokens: Some(30),
404                ..TokenUsage::default()
405            })
406        );
407    }
408
409    #[test]
410    fn test_output_item_done_without_encrypted_content_is_ignored() {
411        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
412            sequence_number: 1,
413            output_index: 0,
414            item: OutputItem::Reasoning(ReasoningItem {
415                id: "r_2".to_string(),
416                summary: vec![],
417                encrypted_content: None,
418                content: None,
419                status: None,
420            }),
421        });
422
423        let mut tool_collector = ToolCallCollector::<u32>::new();
424        let mut stop_reason = None;
425        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
426
427        assert!(responses.is_empty());
428    }
429}