1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11 T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13 async_stream::stream! {
14 let message_id = uuid::Uuid::new_v4().to_string();
15 yield Ok(LlmResponse::Start { message_id });
16
17 let mut tool_collector = ToolCallCollector::<u32>::new();
18 let mut stream = Box::pin(stream);
19 let mut last_stop_reason: Option<StopReason> = None;
20
21 while let Some(result) = stream.next().await {
22 match result {
23 Ok(event) => {
24 for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25 yield response;
26 }
27 }
28 Err(e) => {
29 yield Err(LlmError::StreamInterrupted(e.to_string()));
30 break;
31 }
32 }
33 }
34
35 for tc in tool_collector.complete_all() {
37 yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38 }
39
40 yield Ok(LlmResponse::Done {
41 stop_reason: last_stop_reason,
42 });
43 }
44}
45
46fn process_event(
47 event: ResponseStreamEvent,
48 tool_collector: &mut ToolCallCollector<u32>,
49 last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51 let mut responses = Vec::new();
52
53 match event {
54 ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
55 responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
56 }
57 ResponseStreamEvent::ResponseOutputItemAdded(e) => {
58 if let OutputItem::FunctionCall(call) = e.item {
59 let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
60 responses.extend(tool_responses.into_iter().map(Ok));
61 }
62 }
63 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
64 let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
65 responses.extend(tool_responses.into_iter().map(Ok));
66 }
67 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
68 if let Some(tc) = tool_collector.complete_one(e.output_index) {
69 responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
70 }
71 }
72 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
73 responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
74 }
75 ResponseStreamEvent::ResponseOutputItemDone(e) => {
76 if let OutputItem::Reasoning(reasoning) = e.item
77 && let Some(id) = reasoning.id
78 && let Some(encrypted) = reasoning.encrypted_content
79 {
80 responses.push(Ok(LlmResponse::EncryptedReasoning { id, content: encrypted }));
81 }
82 }
83 ResponseStreamEvent::ResponseCompleted(e) => {
84 if let Some(usage) = e.response.usage {
85 responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
86 }
87 match e.response.status {
88 Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
89 Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
90 _ => {}
91 }
92 }
93 ResponseStreamEvent::ResponseError(e) => {
94 responses
95 .push(Err(LlmError::ServerError { status: None, message: format!("Codex API error: {}", e.message) }));
96 }
97 _ => {}
99 }
100
101 responses
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::TokenUsage;
108 use async_openai::types::responses::{
109 FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
110 ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
111 ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
112 };
113 fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
115 let status_str = serde_json::to_value(status).unwrap();
116 let mut json = serde_json::json!({
117 "id": "resp_1",
118 "object": "response",
119 "status": status_str,
120 "output": [],
121 "model": "test",
122 "created_at": 0
123 });
124 if let Some(u) = usage {
125 json["usage"] = serde_json::to_value(u).unwrap();
126 }
127 serde_json::from_value(json).unwrap()
128 }
129
130 fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
131 make_usage_full(input_tokens, output_tokens, 0, 0)
132 }
133
134 fn make_usage_full(
135 input_tokens: u32,
136 output_tokens: u32,
137 cached_tokens: u32,
138 reasoning_tokens: u32,
139 ) -> ResponseUsage {
140 serde_json::from_value(serde_json::json!({
141 "input_tokens": input_tokens,
142 "input_tokens_details": { "cached_tokens": cached_tokens },
143 "output_tokens": output_tokens,
144 "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
145 "total_tokens": input_tokens + output_tokens
146 }))
147 .unwrap()
148 }
149
150 fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
151 tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
152 }
153
154 #[tokio::test]
155 async fn test_text_stream() {
156 let events = vec![
157 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
158 output_index: 0,
159 content_index: 0,
160 delta: "Hello".to_string(),
161 sequence_number: 1,
162 item_id: "msg_1".to_string(),
163 logprobs: None,
164 }),
165 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
166 output_index: 0,
167 content_index: 0,
168 delta: " world".to_string(),
169 sequence_number: 2,
170 item_id: "msg_1".to_string(),
171 logprobs: None,
172 }),
173 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
174 sequence_number: 3,
175 response: make_response(&Status::Completed, Some(make_usage(10, 5))),
176 }),
177 ];
178
179 let stream = make_stream(events);
180 let mut response_stream = Box::pin(process_response_stream(stream));
181
182 let mut responses = Vec::new();
183 while let Some(result) = response_stream.next().await {
184 responses.push(result.unwrap());
185 }
186
187 assert!(matches!(responses[0], LlmResponse::Start { .. }));
188 assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
189 assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
190 assert!(matches!(
191 responses[3],
192 LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
193 ));
194 assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
195 }
196
197 #[tokio::test]
198 async fn test_tool_call_stream() {
199 let events = vec![
200 ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
201 sequence_number: 1,
202 output_index: 0,
203 item: OutputItem::FunctionCall(FunctionToolCall {
204 id: Some("fc_1".to_string()),
205 call_id: "call_1".to_string(),
206 name: "read_file".to_string(),
207 arguments: String::new(),
208 status: None,
209 namespace: None,
210 }),
211 }),
212 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
213 sequence_number: 2,
214 item_id: "fc_1".to_string(),
215 output_index: 0,
216 delta: r#"{"path":"#.to_string(),
217 }),
218 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
219 sequence_number: 3,
220 item_id: "fc_1".to_string(),
221 output_index: 0,
222 delta: r#""foo.rs"}"#.to_string(),
223 }),
224 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
225 sequence_number: 4,
226 item_id: "fc_1".to_string(),
227 output_index: 0,
228 arguments: r#"{"path":"foo.rs"}"#.to_string(),
229 name: None,
230 }),
231 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
232 sequence_number: 5,
233 response: make_response(&Status::Completed, Some(make_usage(20, 10))),
234 }),
235 ];
236
237 let stream = make_stream(events);
238 let mut response_stream = Box::pin(process_response_stream(stream));
239
240 let mut responses = Vec::new();
241 while let Some(result) = response_stream.next().await {
242 responses.push(result.unwrap());
243 }
244
245 assert!(matches!(responses[0], LlmResponse::Start { .. }));
246 assert!(
247 matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
248 );
249 assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
250 assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
251
252 let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
253 assert!(tc.is_some());
254 if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
255 assert_eq!(tool_call.id, "fc_1");
256 assert_eq!(tool_call.name, "read_file");
257 assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
258 }
259 }
260
261 #[tokio::test]
262 async fn test_error_event_is_retryable_server_error() {
263 let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
264 sequence_number: 1,
265 code: None,
266 message: "Rate limit exceeded".to_string(),
267 param: None,
268 })];
269
270 let stream = make_stream(events);
271 let mut response_stream = Box::pin(process_response_stream(stream));
272
273 let mut responses = Vec::new();
274 while let Some(result) = response_stream.next().await {
275 responses.push(result);
276 }
277
278 assert!(responses[0].is_ok());
279 let err = responses[1].as_ref().expect_err("expected ResponseError to surface as Err");
280 assert!(matches!(err, LlmError::ServerError { status: None, .. }), "got {err:?}");
281 assert!(err.is_retryable(), "ResponseError must be retryable so the agent can recover");
282 }
283
284 #[tokio::test]
285 async fn test_reasoning_delta() {
286 let events = vec![
287 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
288 sequence_number: 1,
289 item_id: "r_1".to_string(),
290 output_index: 0,
291 summary_index: 0,
292 delta: "Thinking about".to_string(),
293 }),
294 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
295 sequence_number: 2,
296 item_id: "r_1".to_string(),
297 output_index: 0,
298 summary_index: 0,
299 delta: " the problem".to_string(),
300 }),
301 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
302 sequence_number: 3,
303 response: make_response(&Status::Completed, None),
304 }),
305 ];
306
307 let stream = make_stream(events);
308 let mut response_stream = Box::pin(process_response_stream(stream));
309
310 let mut responses = Vec::new();
311 while let Some(result) = response_stream.next().await {
312 responses.push(result.unwrap());
313 }
314
315 assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
316 assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
317 }
318
319 #[tokio::test]
320 async fn test_incomplete_status_gives_length_stop_reason() {
321 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
322 sequence_number: 1,
323 response: make_response(&Status::Incomplete, None),
324 })];
325
326 let stream = make_stream(events);
327 let mut response_stream = Box::pin(process_response_stream(stream));
328
329 let mut responses = Vec::new();
330 while let Some(result) = response_stream.next().await {
331 responses.push(result.unwrap());
332 }
333
334 assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
335 }
336
337 #[tokio::test]
338 async fn test_stream_error_propagation_is_retryable() {
339 let events: Vec<Result<ResponseStreamEvent>> =
340 vec![Err(LlmError::StreamInterrupted("connection lost".to_string()))];
341
342 let stream = tokio_stream::iter(events);
343 let mut response_stream = Box::pin(process_response_stream(stream));
344
345 let mut responses = Vec::new();
346 while let Some(result) = response_stream.next().await {
347 responses.push(result);
348 }
349
350 assert!(responses[0].is_ok());
351 let err = responses[1].as_ref().expect_err("expected upstream Err to surface as Err");
352 assert!(matches!(err, LlmError::StreamInterrupted(_)), "got {err:?}");
353 assert!(err.is_retryable(), "mid-stream interrupts must be retryable");
354 }
355
356 #[test]
357 fn test_encrypted_reasoning_from_output_item_done() {
358 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
359 sequence_number: 1,
360 output_index: 0,
361 item: OutputItem::Reasoning(ReasoningItem {
362 id: Some("r_1".to_string()),
363 summary: vec![],
364 encrypted_content: Some("enc-blob-data".to_string()),
365 content: None,
366 status: None,
367 }),
368 });
369
370 let mut tool_collector = ToolCallCollector::<u32>::new();
371 let mut stop_reason = None;
372 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
373
374 assert_eq!(responses.len(), 1);
375 assert!(
376 matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
377 );
378 }
379
380 #[tokio::test]
381 async fn test_usage_forwards_reasoning_and_cache_read() {
382 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
383 sequence_number: 1,
384 response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
385 })];
386
387 let stream = make_stream(events);
388 let mut response_stream = Box::pin(process_response_stream(stream));
389
390 let mut responses = Vec::new();
391 while let Some(result) = response_stream.next().await {
392 responses.push(result.unwrap());
393 }
394
395 let usage = responses.iter().find_map(|r| match r {
396 LlmResponse::Usage { tokens } => Some(*tokens),
397 _ => None,
398 });
399
400 assert_eq!(
401 usage,
402 Some(TokenUsage {
403 input_tokens: 120,
404 output_tokens: 80,
405 cache_read_tokens: Some(50),
406 reasoning_tokens: Some(30),
407 ..TokenUsage::default()
408 })
409 );
410 }
411
412 #[test]
413 fn test_output_item_done_without_encrypted_content_is_ignored() {
414 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
415 sequence_number: 1,
416 output_index: 0,
417 item: OutputItem::Reasoning(ReasoningItem {
418 id: Some("r_2".to_string()),
419 summary: vec![],
420 encrypted_content: None,
421 content: None,
422 status: None,
423 }),
424 });
425
426 let mut tool_collector = ToolCallCollector::<u32>::new();
427 let mut stop_reason = None;
428 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
429
430 assert!(responses.is_empty());
431 }
432}