Skip to main content

llm/providers/codex/
streaming.rs

1use async_openai::types::responses::{OutputItem, ResponseUsage, Status};
2use futures::Stream;
3use serde::Deserialize;
4use tokio_stream::StreamExt;
5
6use crate::providers::tool_call_collector::ToolCallCollector;
7use crate::{LlmError, LlmResponse, Result, StopReason};
8
9#[derive(Debug, Deserialize)]
10#[serde(tag = "type")]
11pub enum CodexStreamEvent {
12    #[serde(rename = "response.output_text.delta")]
13    OutputTextDelta(CodexTextDeltaEvent),
14    #[serde(rename = "response.output_item.added")]
15    OutputItemAdded(CodexOutputItemEvent),
16    #[serde(rename = "response.output_item.done")]
17    OutputItemDone(CodexOutputItemEvent),
18    #[serde(rename = "response.function_call_arguments.delta")]
19    FunctionCallArgumentsDelta(CodexFunctionCallArgumentsDeltaEvent),
20    #[serde(rename = "response.function_call_arguments.done")]
21    FunctionCallArgumentsDone(CodexFunctionCallArgumentsDoneEvent),
22    #[serde(rename = "response.reasoning_summary_text.delta")]
23    ReasoningSummaryTextDelta(CodexTextDeltaEvent),
24    #[serde(rename = "response.completed")]
25    Completed(CodexResponseCompletedEvent),
26    #[serde(rename = "error")]
27    Error(CodexErrorEvent),
28    #[serde(other)]
29    Ignored,
30}
31
32#[derive(Debug, Deserialize)]
33pub struct CodexTextDeltaEvent {
34    pub delta: String,
35}
36
37#[derive(Debug, Deserialize)]
38pub struct CodexOutputItemEvent {
39    pub output_index: u32,
40    pub item: OutputItem,
41}
42
43#[derive(Debug, Deserialize)]
44pub struct CodexFunctionCallArgumentsDeltaEvent {
45    pub output_index: u32,
46    pub delta: String,
47}
48
49#[derive(Debug, Deserialize)]
50pub struct CodexFunctionCallArgumentsDoneEvent {
51    pub output_index: u32,
52}
53
54#[derive(Debug, Deserialize)]
55pub struct CodexResponseCompletedEvent {
56    pub response: CodexResponseCompleted,
57}
58
59#[derive(Debug, Deserialize)]
60pub struct CodexResponseCompleted {
61    #[serde(default)]
62    pub usage: Option<ResponseUsage>,
63    #[serde(default)]
64    pub status: Option<Status>,
65}
66
67#[derive(Debug, Deserialize)]
68pub struct CodexErrorEvent {
69    pub message: String,
70}
71
72/// Process a Codex response stream into `LlmResponse` items.
73pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
74where
75    T: Stream<Item = Result<CodexStreamEvent>> + Send + Unpin,
76{
77    async_stream::stream! {
78        let message_id = uuid::Uuid::new_v4().to_string();
79        yield Ok(LlmResponse::Start { message_id });
80
81        let mut tool_collector = ToolCallCollector::<u32>::new();
82        let mut stream = Box::pin(stream);
83        let mut last_stop_reason: Option<StopReason> = None;
84
85        while let Some(result) = stream.next().await {
86            match result {
87                Ok(event) => {
88                    for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
89                        yield response;
90                    }
91                }
92                Err(e) => {
93                    yield Err(LlmError::StreamInterrupted(e.to_string()));
94                    break;
95                }
96            }
97        }
98
99        for tc in tool_collector.complete_all() {
100            yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
101        }
102
103        yield Ok(LlmResponse::Done {
104            stop_reason: last_stop_reason,
105        });
106    }
107}
108
109fn process_event(
110    event: CodexStreamEvent,
111    tool_collector: &mut ToolCallCollector<u32>,
112    last_stop_reason: &mut Option<StopReason>,
113) -> Vec<Result<LlmResponse>> {
114    let mut responses = Vec::new();
115
116    match event {
117        CodexStreamEvent::OutputTextDelta(e) if !e.delta.is_empty() => {
118            responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
119        }
120        CodexStreamEvent::OutputItemAdded(e) => {
121            if let OutputItem::FunctionCall(call) = e.item {
122                let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
123                responses.extend(tool_responses.into_iter().map(Ok));
124            }
125        }
126        CodexStreamEvent::FunctionCallArgumentsDelta(e) => {
127            let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
128            responses.extend(tool_responses.into_iter().map(Ok));
129        }
130        CodexStreamEvent::FunctionCallArgumentsDone(e) => {
131            if let Some(tc) = tool_collector.complete_one(e.output_index) {
132                responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
133            }
134        }
135        CodexStreamEvent::ReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
136            responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
137        }
138        CodexStreamEvent::OutputItemDone(e) => {
139            if let OutputItem::Reasoning(reasoning) = e.item
140                && let Some(id) = reasoning.id
141                && let Some(encrypted) = reasoning.encrypted_content
142            {
143                responses.push(Ok(LlmResponse::EncryptedReasoning { id, content: encrypted }));
144            }
145        }
146        CodexStreamEvent::Completed(e) => {
147            if let Some(usage) = e.response.usage {
148                responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
149            }
150            match e.response.status {
151                Some(Status::Completed) => *last_stop_reason = Some(StopReason::EndTurn),
152                Some(Status::Incomplete) => *last_stop_reason = Some(StopReason::Length),
153                _ => {}
154            }
155        }
156        CodexStreamEvent::Error(e) => {
157            responses
158                .push(Err(LlmError::ServerError { status: None, message: format!("Codex API error: {}", e.message) }));
159        }
160        CodexStreamEvent::Ignored
161        | CodexStreamEvent::OutputTextDelta(_)
162        | CodexStreamEvent::ReasoningSummaryTextDelta(_) => {}
163    }
164
165    responses
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::TokenUsage;
172    use async_openai::types::responses::{FunctionToolCall, ReasoningItem};
173    use serde_json::json;
174
175    async fn collect_responses(events: Vec<CodexStreamEvent>) -> Vec<LlmResponse> {
176        let stream = make_stream(events);
177        let mut response_stream = Box::pin(process_response_stream(stream));
178        let mut responses = Vec::new();
179        while let Some(result) = response_stream.next().await {
180            responses.push(result.unwrap());
181        }
182        responses
183    }
184
185    #[tokio::test]
186    async fn test_text_stream() {
187        let responses = collect_responses(vec![
188            text_delta("Hello"),
189            text_delta(" world"),
190            completed(Status::Completed, Some(make_usage(10, 5))),
191        ])
192        .await;
193
194        assert!(matches!(responses[0], LlmResponse::Start { .. }));
195        assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
196        assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
197        assert!(matches!(
198            responses[3],
199            LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
200        ));
201        assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
202    }
203
204    #[tokio::test]
205    async fn test_tool_call_stream() {
206        let responses = collect_responses(vec![
207            CodexStreamEvent::OutputItemAdded(CodexOutputItemEvent {
208                output_index: 0,
209                item: OutputItem::FunctionCall(FunctionToolCall {
210                    id: Some("fc_1".to_string()),
211                    call_id: "call_1".to_string(),
212                    name: "read_file".to_string(),
213                    arguments: String::new(),
214                    status: None,
215                    namespace: None,
216                }),
217            }),
218            function_call_delta(r#"{"path":"#),
219            function_call_delta(r#""foo.rs"}"#),
220            CodexStreamEvent::FunctionCallArgumentsDone(CodexFunctionCallArgumentsDoneEvent { output_index: 0 }),
221            completed(Status::Completed, Some(make_usage(20, 10))),
222        ])
223        .await;
224
225        assert!(matches!(responses[0], LlmResponse::Start { .. }));
226        assert!(
227            matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
228        );
229        assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
230        assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
231
232        let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
233        assert!(tc.is_some());
234        if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
235            assert_eq!(tool_call.id, "fc_1");
236            assert_eq!(tool_call.name, "read_file");
237            assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
238        }
239    }
240
241    #[tokio::test]
242    async fn test_error_event_is_retryable_server_error() {
243        let stream =
244            make_stream(vec![CodexStreamEvent::Error(CodexErrorEvent { message: "Rate limit exceeded".to_string() })]);
245        let mut response_stream = Box::pin(process_response_stream(stream));
246
247        let mut responses = Vec::new();
248        while let Some(result) = response_stream.next().await {
249            responses.push(result);
250        }
251
252        assert!(responses[0].is_ok());
253        let err = responses[1].as_ref().expect_err("expected error event to surface as Err");
254        assert!(matches!(err, LlmError::ServerError { status: None, .. }), "got {err:?}");
255        assert!(err.is_retryable(), "ResponseError must be retryable so the agent can recover");
256    }
257
258    #[tokio::test]
259    async fn test_reasoning_delta() {
260        let responses = collect_responses(vec![
261            reasoning_delta("Thinking about"),
262            reasoning_delta(" the problem"),
263            completed(Status::Completed, None),
264        ])
265        .await;
266
267        assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
268        assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
269    }
270
271    #[tokio::test]
272    async fn test_incomplete_status_gives_length_stop_reason() {
273        let responses = collect_responses(vec![completed(Status::Incomplete, None)]).await;
274
275        assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
276    }
277
278    #[tokio::test]
279    async fn test_stream_error_propagation_is_retryable() {
280        let events: Vec<Result<CodexStreamEvent>> =
281            vec![Err(LlmError::StreamInterrupted("connection lost".to_string()))];
282
283        let stream = tokio_stream::iter(events);
284        let mut response_stream = Box::pin(process_response_stream(stream));
285
286        let mut responses = Vec::new();
287        while let Some(result) = response_stream.next().await {
288            responses.push(result);
289        }
290
291        assert!(responses[0].is_ok());
292        let err = responses[1].as_ref().expect_err("expected upstream Err to surface as Err");
293        assert!(matches!(err, LlmError::StreamInterrupted(_)), "got {err:?}");
294        assert!(err.is_retryable(), "mid-stream interrupts must be retryable");
295    }
296
297    #[test]
298    fn test_encrypted_reasoning_from_output_item_done() {
299        let event = CodexStreamEvent::OutputItemDone(CodexOutputItemEvent {
300            output_index: 0,
301            item: reasoning_item(Some("enc-blob-data")),
302        });
303
304        let mut tool_collector = ToolCallCollector::<u32>::new();
305        let mut stop_reason = None;
306        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
307
308        assert_eq!(responses.len(), 1);
309        assert!(
310            matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
311        );
312    }
313
314    #[tokio::test]
315    async fn test_usage_forwards_reasoning_and_cache_read() {
316        let responses =
317            collect_responses(vec![completed(Status::Completed, Some(make_usage_full(120, 80, 50, 30)))]).await;
318
319        let usage = responses.iter().find_map(|r| match r {
320            LlmResponse::Usage { tokens } => Some(*tokens),
321            _ => None,
322        });
323
324        assert_eq!(
325            usage,
326            Some(TokenUsage {
327                input_tokens: 120,
328                output_tokens: 80,
329                cache_read_tokens: Some(50),
330                reasoning_tokens: Some(30),
331                ..TokenUsage::default()
332            })
333        );
334    }
335
336    #[tokio::test]
337    async fn test_completed_without_output_deserializes_usage_and_stop_reason() {
338        let event: CodexStreamEvent = serde_json::from_value(json!({
339            "type": "response.completed",
340            "sequence_number": 1,
341            "response": {
342                "id": "resp_1",
343                "object": "response",
344                "created_at": 1_000_u64,
345                "status": "completed",
346                "background": false,
347                "completed_at": 2_000_u64,
348                "error": null,
349                "model": "test-model",
350                "usage": make_usage_json(100, 20, 0, 10)
351            }
352        }))
353        .unwrap();
354        let responses = collect_responses(vec![event]).await;
355
356        assert!(matches!(
357            responses.iter().find(|response| matches!(response, LlmResponse::Usage { .. })),
358            Some(LlmResponse::Usage {
359                tokens: TokenUsage { input_tokens: 100, output_tokens: 20, reasoning_tokens: Some(10), .. }
360            })
361        ));
362        assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
363    }
364
365    #[test]
366    fn test_output_item_done_without_encrypted_content_is_ignored() {
367        let event =
368            CodexStreamEvent::OutputItemDone(CodexOutputItemEvent { output_index: 0, item: reasoning_item(None) });
369
370        let mut tool_collector = ToolCallCollector::<u32>::new();
371        let mut stop_reason = None;
372        let responses = process_event(event, &mut tool_collector, &mut stop_reason);
373
374        assert!(responses.is_empty());
375    }
376
377    fn text_delta(delta: &str) -> CodexStreamEvent {
378        CodexStreamEvent::OutputTextDelta(CodexTextDeltaEvent { delta: delta.to_string() })
379    }
380
381    fn reasoning_delta(delta: &str) -> CodexStreamEvent {
382        CodexStreamEvent::ReasoningSummaryTextDelta(CodexTextDeltaEvent { delta: delta.to_string() })
383    }
384
385    fn function_call_delta(delta: &str) -> CodexStreamEvent {
386        CodexStreamEvent::FunctionCallArgumentsDelta(CodexFunctionCallArgumentsDeltaEvent {
387            output_index: 0,
388            delta: delta.to_string(),
389        })
390    }
391
392    fn completed(status: Status, usage: Option<ResponseUsage>) -> CodexStreamEvent {
393        CodexStreamEvent::Completed(CodexResponseCompletedEvent {
394            response: CodexResponseCompleted { usage, status: Some(status) },
395        })
396    }
397
398    fn reasoning_item(encrypted_content: Option<&str>) -> OutputItem {
399        OutputItem::Reasoning(ReasoningItem {
400            id: Some("r_1".to_string()),
401            summary: vec![],
402            encrypted_content: encrypted_content.map(ToString::to_string),
403            content: None,
404            status: None,
405        })
406    }
407
408    fn make_stream(events: Vec<CodexStreamEvent>) -> impl Stream<Item = Result<CodexStreamEvent>> + Send + Unpin {
409        tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
410    }
411
412    fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
413        make_usage_full(input_tokens, output_tokens, 0, 0)
414    }
415
416    fn make_usage_full(
417        input_tokens: u32,
418        output_tokens: u32,
419        cached_tokens: u32,
420        reasoning_tokens: u32,
421    ) -> ResponseUsage {
422        serde_json::from_value(make_usage_json(input_tokens, output_tokens, cached_tokens, reasoning_tokens)).unwrap()
423    }
424
425    fn make_usage_json(
426        input_tokens: u32,
427        output_tokens: u32,
428        cached_tokens: u32,
429        reasoning_tokens: u32,
430    ) -> serde_json::Value {
431        json!({
432            "input_tokens": input_tokens,
433            "input_tokens_details": { "cached_tokens": cached_tokens },
434            "output_tokens": output_tokens,
435            "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
436            "total_tokens": input_tokens + output_tokens
437        })
438    }
439}