Skip to main content

llm/providers/codex/
streaming.rs

1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8/// Process a typed `ResponseStreamEvent` stream into `LlmResponse` items.
9pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11    T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13    async_stream::stream! {
14        let message_id = uuid::Uuid::new_v4().to_string();
15        yield Ok(LlmResponse::Start { message_id });
16
17        let mut tool_collector = ToolCallCollector::<u32>::new();
18        let mut stream = Box::pin(stream);
19        let mut last_stop_reason: Option<StopReason> = None;
20
21        while let Some(result) = stream.next().await {
22            match result {
23                Ok(event) => {
24                    for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25                        yield response;
26                    }
27                }
28                Err(e) => {
29                    yield Err(LlmError::ApiError(e.to_string()));
30                    break;
31                }
32            }
33        }
34
35        // Complete any pending tool calls
36        for tc in tool_collector.complete_all() {
37            yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38        }
39
40        yield Ok(LlmResponse::Done {
41            stop_reason: last_stop_reason,
42        });
43    }
44}
45
46fn process_event(
47    event: ResponseStreamEvent,
48    tool_collector: &mut ToolCallCollector<u32>,
49    last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51    let mut responses = Vec::new();
52
53    match event {
54        ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
55            responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
56        }
57        ResponseStreamEvent::ResponseOutputItemAdded(e) => {
58            if let OutputItem::FunctionCall(call) = e.item {
59                let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
60                responses.extend(tool_responses.into_iter().map(Ok));
61            }
62        }
63        ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
64            let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
65            responses.extend(tool_responses.into_iter().map(Ok));
66        }
67        ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
68            if let Some(tc) = tool_collector.complete_one(e.output_index) {
69                responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
70            }
71        }
72        ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
73            responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
74        }
75        ResponseStreamEvent::ResponseOutputItemDone(e) => {
76            if let OutputItem::Reasoning(reasoning) = e.item
77                && let Some(encrypted) = reasoning.encrypted_content
78            {
79                responses.push(Ok(LlmResponse::EncryptedReasoning { id: reasoning.id, content: encrypted }));
80            }
81        }
82        ResponseStreamEvent::ResponseCompleted(e) => {
83            if let Some(usage) = e.response.usage {
84                responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
85            }
86            match e.response.status {
87                Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
88                Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
89                _ => {}
90            }
91        }
92        ResponseStreamEvent::ResponseError(e) => {
93            responses.push(Err(LlmError::ApiError(format!("Codex API error: {}", e.message))));
94        }
95        // Events we don't need to act on
96        _ => {}
97    }
98
99    responses
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::TokenUsage;
106    use async_openai::types::responses::{
107        FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
108        ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
109        ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
110    };
111    /// Build a minimal `Response` with given status and optional usage via JSON deserialization.
112    fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
113        let status_str = serde_json::to_value(status).unwrap();
114        let mut json = serde_json::json!({
115            "id": "resp_1",
116            "object": "response",
117            "status": status_str,
118            "output": [],
119            "model": "test",
120            "created_at": 0
121        });
122        if let Some(u) = usage {
123            json["usage"] = serde_json::to_value(u).unwrap();
124        }
125        serde_json::from_value(json).unwrap()
126    }
127
128    fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
129        make_usage_full(input_tokens, output_tokens, 0, 0)
130    }
131
132    fn make_usage_full(
133        input_tokens: u32,
134        output_tokens: u32,
135        cached_tokens: u32,
136        reasoning_tokens: u32,
137    ) -> ResponseUsage {
138        serde_json::from_value(serde_json::json!({
139            "input_tokens": input_tokens,
140            "input_tokens_details": { "cached_tokens": cached_tokens },
141            "output_tokens": output_tokens,
142            "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
143            "total_tokens": input_tokens + output_tokens
144        }))
145        .unwrap()
146    }
147
148    fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
149        tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
150    }
151
152    #[tokio::test]
153    async fn test_text_stream() {
154        let events = vec![
155            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
156                output_index: 0,
157                content_index: 0,
158                delta: "Hello".to_string(),
159                sequence_number: 1,
160                item_id: "msg_1".to_string(),
161                logprobs: None,
162            }),
163            ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
164                output_index: 0,
165                content_index: 0,
166                delta: " world".to_string(),
167                sequence_number: 2,
168                item_id: "msg_1".to_string(),
169                logprobs: None,
170            }),
171            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
172                sequence_number: 3,
173                response: make_response(&Status::Completed, Some(make_usage(10, 5))),
174            }),
175        ];
176
177        let stream = make_stream(events);
178        let mut response_stream = Box::pin(process_response_stream(stream));
179
180        let mut responses = Vec::new();
181        while let Some(result) = response_stream.next().await {
182            responses.push(result.unwrap());
183        }
184
185        assert!(matches!(responses[0], LlmResponse::Start { .. }));
186        assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
187        assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
188        assert!(matches!(
189            responses[3],
190            LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
191        ));
192        assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
193    }
194
195    #[tokio::test]
196    async fn test_tool_call_stream() {
197        let events = vec![
198            ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
199                sequence_number: 1,
200                output_index: 0,
201                item: OutputItem::FunctionCall(FunctionToolCall {
202                    id: Some("fc_1".to_string()),
203                    call_id: "call_1".to_string(),
204                    name: "read_file".to_string(),
205                    arguments: String::new(),
206                    status: None,
207                    namespace: None,
208                }),
209            }),
210            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
211                sequence_number: 2,
212                item_id: "fc_1".to_string(),
213                output_index: 0,
214                delta: r#"{"path":"#.to_string(),
215            }),
216            ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
217                sequence_number: 3,
218                item_id: "fc_1".to_string(),
219                output_index: 0,
220                delta: r#""foo.rs"}"#.to_string(),
221            }),
222            ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
223                sequence_number: 4,
224                item_id: "fc_1".to_string(),
225                output_index: 0,
226                arguments: r#"{"path":"foo.rs"}"#.to_string(),
227                name: None,
228            }),
229            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
230                sequence_number: 5,
231                response: make_response(&Status::Completed, Some(make_usage(20, 10))),
232            }),
233        ];
234
235        let stream = make_stream(events);
236        let mut response_stream = Box::pin(process_response_stream(stream));
237
238        let mut responses = Vec::new();
239        while let Some(result) = response_stream.next().await {
240            responses.push(result.unwrap());
241        }
242
243        assert!(matches!(responses[0], LlmResponse::Start { .. }));
244        assert!(
245            matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
246        );
247        assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
248        assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
249
250        let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
251        assert!(tc.is_some());
252        if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
253            assert_eq!(tool_call.id, "fc_1");
254            assert_eq!(tool_call.name, "read_file");
255            assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
256        }
257    }
258
259    #[tokio::test]
260    async fn test_error_event() {
261        let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
262            sequence_number: 1,
263            code: None,
264            message: "Rate limit exceeded".to_string(),
265            param: None,
266        })];
267
268        let stream = make_stream(events);
269        let mut response_stream = Box::pin(process_response_stream(stream));
270
271        let mut responses = Vec::new();
272        while let Some(result) = response_stream.next().await {
273            responses.push(result);
274        }
275
276        assert!(responses[0].is_ok()); // Start
277        assert!(responses[1].is_err()); // Error
278    }
279
280    #[tokio::test]
281    async fn test_reasoning_delta() {
282        let events = vec![
283            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
284                sequence_number: 1,
285                item_id: "r_1".to_string(),
286                output_index: 0,
287                summary_index: 0,
288                delta: "Thinking about".to_string(),
289            }),
290            ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
291                sequence_number: 2,
292                item_id: "r_1".to_string(),
293                output_index: 0,
294                summary_index: 0,
295                delta: " the problem".to_string(),
296            }),
297            ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
298                sequence_number: 3,
299                response: make_response(&Status::Completed, None),
300            }),
301        ];
302
303        let stream = make_stream(events);
304        let mut response_stream = Box::pin(process_response_stream(stream));
305
306        let mut responses = Vec::new();
307        while let Some(result) = response_stream.next().await {
308            responses.push(result.unwrap());
309        }
310
311        assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
312        assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
313    }
314
315    #[tokio::test]
316    async fn test_incomplete_status_gives_length_stop_reason() {
317        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
318            sequence_number: 1,
319            response: make_response(&Status::Incomplete, None),
320        })];
321
322        let stream = make_stream(events);
323        let mut response_stream = Box::pin(process_response_stream(stream));
324
325        let mut responses = Vec::new();
326        while let Some(result) = response_stream.next().await {
327            responses.push(result.unwrap());
328        }
329
330        assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
331    }
332
333    #[tokio::test]
334    async fn test_stream_error_propagation() {
335        let events: Vec<Result<ResponseStreamEvent>> = vec![Err(LlmError::ApiError("connection lost".to_string()))];
336
337        let stream = tokio_stream::iter(events);
338        let mut response_stream = Box::pin(process_response_stream(stream));
339
340        let mut responses = Vec::new();
341        while let Some(result) = response_stream.next().await {
342            responses.push(result);
343        }
344
345        assert!(responses[0].is_ok()); // Start
346        assert!(responses[1].is_err()); // Stream error
347    }
348
349    #[test]
350    fn test_encrypted_reasoning_from_output_item_done() {
351        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
352            sequence_number: 1,
353            output_index: 0,
354            item: OutputItem::Reasoning(ReasoningItem {
355                id: "r_1".to_string(),
356                summary: vec![],
357                encrypted_content: Some("enc-blob-data".to_string()),
358                content: None,
359                status: None,
360            }),
361        });
362
363        let mut tool_collector = ToolCallCollector::<u32>::new();
364        let mut stop_reason = None;
365        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
366
367        assert_eq!(responses.len(), 1);
368        assert!(
369            matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
370        );
371    }
372
373    #[tokio::test]
374    async fn test_usage_forwards_reasoning_and_cache_read() {
375        let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
376            sequence_number: 1,
377            response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
378        })];
379
380        let stream = make_stream(events);
381        let mut response_stream = Box::pin(process_response_stream(stream));
382
383        let mut responses = Vec::new();
384        while let Some(result) = response_stream.next().await {
385            responses.push(result.unwrap());
386        }
387
388        let usage = responses.iter().find_map(|r| match r {
389            LlmResponse::Usage { tokens } => Some(*tokens),
390            _ => None,
391        });
392
393        assert_eq!(
394            usage,
395            Some(TokenUsage {
396                input_tokens: 120,
397                output_tokens: 80,
398                cache_read_tokens: Some(50),
399                reasoning_tokens: Some(30),
400                ..TokenUsage::default()
401            })
402        );
403    }
404
405    #[test]
406    fn test_output_item_done_without_encrypted_content_is_ignored() {
407        let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
408            sequence_number: 1,
409            output_index: 0,
410            item: OutputItem::Reasoning(ReasoningItem {
411                id: "r_2".to_string(),
412                summary: vec![],
413                encrypted_content: None,
414                content: None,
415                status: None,
416            }),
417        });
418
419        let mut tool_collector = ToolCallCollector::<u32>::new();
420        let mut stop_reason = None;
421        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
422
423        assert!(responses.is_empty());
424    }
425}