1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11 T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13 async_stream::stream! {
14 let message_id = uuid::Uuid::new_v4().to_string();
15 yield Ok(LlmResponse::Start { message_id });
16
17 let mut tool_collector = ToolCallCollector::<u32>::new();
18 let mut stream = Box::pin(stream);
19 let mut last_stop_reason: Option<StopReason> = None;
20
21 while let Some(result) = stream.next().await {
22 match result {
23 Ok(event) => {
24 for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25 yield response;
26 }
27 }
28 Err(e) => {
29 yield Err(LlmError::ApiError(e.to_string()));
30 break;
31 }
32 }
33 }
34
35 for tc in tool_collector.complete_all() {
37 yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38 }
39
40 yield Ok(LlmResponse::Done {
41 stop_reason: last_stop_reason,
42 });
43 }
44}
45
46fn process_event(
47 event: ResponseStreamEvent,
48 tool_collector: &mut ToolCallCollector<u32>,
49 last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51 let mut responses = Vec::new();
52
53 match event {
54 ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
55 responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
56 }
57 ResponseStreamEvent::ResponseOutputItemAdded(e) => {
58 if let OutputItem::FunctionCall(call) = e.item {
59 let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
60 responses.extend(tool_responses.into_iter().map(Ok));
61 }
62 }
63 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
64 let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
65 responses.extend(tool_responses.into_iter().map(Ok));
66 }
67 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
68 if let Some(tc) = tool_collector.complete_one(e.output_index) {
69 responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
70 }
71 }
72 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
73 responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
74 }
75 ResponseStreamEvent::ResponseOutputItemDone(e) => {
76 if let OutputItem::Reasoning(reasoning) = e.item
77 && let Some(id) = reasoning.id
78 && let Some(encrypted) = reasoning.encrypted_content
79 {
80 responses.push(Ok(LlmResponse::EncryptedReasoning { id, content: encrypted }));
81 }
82 }
83 ResponseStreamEvent::ResponseCompleted(e) => {
84 if let Some(usage) = e.response.usage {
85 responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
86 }
87 match e.response.status {
88 Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
89 Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
90 _ => {}
91 }
92 }
93 ResponseStreamEvent::ResponseError(e) => {
94 responses.push(Err(LlmError::ApiError(format!("Codex API error: {}", e.message))));
95 }
96 _ => {}
98 }
99
100 responses
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::TokenUsage;
107 use async_openai::types::responses::{
108 FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
109 ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
110 ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
111 };
112 fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
114 let status_str = serde_json::to_value(status).unwrap();
115 let mut json = serde_json::json!({
116 "id": "resp_1",
117 "object": "response",
118 "status": status_str,
119 "output": [],
120 "model": "test",
121 "created_at": 0
122 });
123 if let Some(u) = usage {
124 json["usage"] = serde_json::to_value(u).unwrap();
125 }
126 serde_json::from_value(json).unwrap()
127 }
128
129 fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
130 make_usage_full(input_tokens, output_tokens, 0, 0)
131 }
132
133 fn make_usage_full(
134 input_tokens: u32,
135 output_tokens: u32,
136 cached_tokens: u32,
137 reasoning_tokens: u32,
138 ) -> ResponseUsage {
139 serde_json::from_value(serde_json::json!({
140 "input_tokens": input_tokens,
141 "input_tokens_details": { "cached_tokens": cached_tokens },
142 "output_tokens": output_tokens,
143 "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
144 "total_tokens": input_tokens + output_tokens
145 }))
146 .unwrap()
147 }
148
149 fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
150 tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
151 }
152
153 #[tokio::test]
154 async fn test_text_stream() {
155 let events = vec![
156 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
157 output_index: 0,
158 content_index: 0,
159 delta: "Hello".to_string(),
160 sequence_number: 1,
161 item_id: "msg_1".to_string(),
162 logprobs: None,
163 }),
164 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
165 output_index: 0,
166 content_index: 0,
167 delta: " world".to_string(),
168 sequence_number: 2,
169 item_id: "msg_1".to_string(),
170 logprobs: None,
171 }),
172 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
173 sequence_number: 3,
174 response: make_response(&Status::Completed, Some(make_usage(10, 5))),
175 }),
176 ];
177
178 let stream = make_stream(events);
179 let mut response_stream = Box::pin(process_response_stream(stream));
180
181 let mut responses = Vec::new();
182 while let Some(result) = response_stream.next().await {
183 responses.push(result.unwrap());
184 }
185
186 assert!(matches!(responses[0], LlmResponse::Start { .. }));
187 assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
188 assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
189 assert!(matches!(
190 responses[3],
191 LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
192 ));
193 assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
194 }
195
196 #[tokio::test]
197 async fn test_tool_call_stream() {
198 let events = vec![
199 ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
200 sequence_number: 1,
201 output_index: 0,
202 item: OutputItem::FunctionCall(FunctionToolCall {
203 id: Some("fc_1".to_string()),
204 call_id: "call_1".to_string(),
205 name: "read_file".to_string(),
206 arguments: String::new(),
207 status: None,
208 namespace: None,
209 }),
210 }),
211 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
212 sequence_number: 2,
213 item_id: "fc_1".to_string(),
214 output_index: 0,
215 delta: r#"{"path":"#.to_string(),
216 }),
217 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
218 sequence_number: 3,
219 item_id: "fc_1".to_string(),
220 output_index: 0,
221 delta: r#""foo.rs"}"#.to_string(),
222 }),
223 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
224 sequence_number: 4,
225 item_id: "fc_1".to_string(),
226 output_index: 0,
227 arguments: r#"{"path":"foo.rs"}"#.to_string(),
228 name: None,
229 }),
230 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
231 sequence_number: 5,
232 response: make_response(&Status::Completed, Some(make_usage(20, 10))),
233 }),
234 ];
235
236 let stream = make_stream(events);
237 let mut response_stream = Box::pin(process_response_stream(stream));
238
239 let mut responses = Vec::new();
240 while let Some(result) = response_stream.next().await {
241 responses.push(result.unwrap());
242 }
243
244 assert!(matches!(responses[0], LlmResponse::Start { .. }));
245 assert!(
246 matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
247 );
248 assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
249 assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
250
251 let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
252 assert!(tc.is_some());
253 if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
254 assert_eq!(tool_call.id, "fc_1");
255 assert_eq!(tool_call.name, "read_file");
256 assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
257 }
258 }
259
260 #[tokio::test]
261 async fn test_error_event() {
262 let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
263 sequence_number: 1,
264 code: None,
265 message: "Rate limit exceeded".to_string(),
266 param: None,
267 })];
268
269 let stream = make_stream(events);
270 let mut response_stream = Box::pin(process_response_stream(stream));
271
272 let mut responses = Vec::new();
273 while let Some(result) = response_stream.next().await {
274 responses.push(result);
275 }
276
277 assert!(responses[0].is_ok()); assert!(responses[1].is_err()); }
280
281 #[tokio::test]
282 async fn test_reasoning_delta() {
283 let events = vec![
284 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
285 sequence_number: 1,
286 item_id: "r_1".to_string(),
287 output_index: 0,
288 summary_index: 0,
289 delta: "Thinking about".to_string(),
290 }),
291 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
292 sequence_number: 2,
293 item_id: "r_1".to_string(),
294 output_index: 0,
295 summary_index: 0,
296 delta: " the problem".to_string(),
297 }),
298 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
299 sequence_number: 3,
300 response: make_response(&Status::Completed, None),
301 }),
302 ];
303
304 let stream = make_stream(events);
305 let mut response_stream = Box::pin(process_response_stream(stream));
306
307 let mut responses = Vec::new();
308 while let Some(result) = response_stream.next().await {
309 responses.push(result.unwrap());
310 }
311
312 assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
313 assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
314 }
315
316 #[tokio::test]
317 async fn test_incomplete_status_gives_length_stop_reason() {
318 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
319 sequence_number: 1,
320 response: make_response(&Status::Incomplete, None),
321 })];
322
323 let stream = make_stream(events);
324 let mut response_stream = Box::pin(process_response_stream(stream));
325
326 let mut responses = Vec::new();
327 while let Some(result) = response_stream.next().await {
328 responses.push(result.unwrap());
329 }
330
331 assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
332 }
333
334 #[tokio::test]
335 async fn test_stream_error_propagation() {
336 let events: Vec<Result<ResponseStreamEvent>> = vec![Err(LlmError::ApiError("connection lost".to_string()))];
337
338 let stream = tokio_stream::iter(events);
339 let mut response_stream = Box::pin(process_response_stream(stream));
340
341 let mut responses = Vec::new();
342 while let Some(result) = response_stream.next().await {
343 responses.push(result);
344 }
345
346 assert!(responses[0].is_ok()); assert!(responses[1].is_err()); }
349
350 #[test]
351 fn test_encrypted_reasoning_from_output_item_done() {
352 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
353 sequence_number: 1,
354 output_index: 0,
355 item: OutputItem::Reasoning(ReasoningItem {
356 id: Some("r_1".to_string()),
357 summary: vec![],
358 encrypted_content: Some("enc-blob-data".to_string()),
359 content: None,
360 status: None,
361 }),
362 });
363
364 let mut tool_collector = ToolCallCollector::<u32>::new();
365 let mut stop_reason = None;
366 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
367
368 assert_eq!(responses.len(), 1);
369 assert!(
370 matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
371 );
372 }
373
374 #[tokio::test]
375 async fn test_usage_forwards_reasoning_and_cache_read() {
376 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
377 sequence_number: 1,
378 response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
379 })];
380
381 let stream = make_stream(events);
382 let mut response_stream = Box::pin(process_response_stream(stream));
383
384 let mut responses = Vec::new();
385 while let Some(result) = response_stream.next().await {
386 responses.push(result.unwrap());
387 }
388
389 let usage = responses.iter().find_map(|r| match r {
390 LlmResponse::Usage { tokens } => Some(*tokens),
391 _ => None,
392 });
393
394 assert_eq!(
395 usage,
396 Some(TokenUsage {
397 input_tokens: 120,
398 output_tokens: 80,
399 cache_read_tokens: Some(50),
400 reasoning_tokens: Some(30),
401 ..TokenUsage::default()
402 })
403 );
404 }
405
406 #[test]
407 fn test_output_item_done_without_encrypted_content_is_ignored() {
408 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
409 sequence_number: 1,
410 output_index: 0,
411 item: OutputItem::Reasoning(ReasoningItem {
412 id: Some("r_2".to_string()),
413 summary: vec![],
414 encrypted_content: None,
415 content: None,
416 status: None,
417 }),
418 });
419
420 let mut tool_collector = ToolCallCollector::<u32>::new();
421 let mut stop_reason = None;
422 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
423
424 assert!(responses.is_empty());
425 }
426}