1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11 T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13 async_stream::stream! {
14 let message_id = uuid::Uuid::new_v4().to_string();
15 yield Ok(LlmResponse::Start { message_id });
16
17 let mut tool_collector = ToolCallCollector::<u32>::new();
18 let mut stream = Box::pin(stream);
19 let mut last_stop_reason: Option<StopReason> = None;
20
21 while let Some(result) = stream.next().await {
22 match result {
23 Ok(event) => {
24 for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25 yield response;
26 }
27 }
28 Err(e) => {
29 yield Err(LlmError::ApiError(e.to_string()));
30 break;
31 }
32 }
33 }
34
35 for tc in tool_collector.complete_all() {
37 yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38 }
39
40 yield Ok(LlmResponse::Done {
41 stop_reason: last_stop_reason,
42 });
43 }
44}
45
46fn process_event(
47 event: ResponseStreamEvent,
48 tool_collector: &mut ToolCallCollector<u32>,
49 last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51 let mut responses = Vec::new();
52
53 match event {
54 ResponseStreamEvent::ResponseOutputTextDelta(e) => {
55 if !e.delta.is_empty() {
56 responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
57 }
58 }
59 ResponseStreamEvent::ResponseOutputItemAdded(e) => {
60 if let OutputItem::FunctionCall(call) = e.item {
61 let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
62 responses.extend(tool_responses.into_iter().map(Ok));
63 }
64 }
65 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
66 let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
67 responses.extend(tool_responses.into_iter().map(Ok));
68 }
69 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
70 if let Some(tc) = tool_collector.complete_one(e.output_index) {
71 responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
72 }
73 }
74 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) => {
75 if !e.delta.is_empty() {
76 responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
77 }
78 }
79 ResponseStreamEvent::ResponseOutputItemDone(e) => {
80 if let OutputItem::Reasoning(reasoning) = e.item
81 && let Some(encrypted) = reasoning.encrypted_content
82 {
83 responses.push(Ok(LlmResponse::EncryptedReasoning { id: reasoning.id, content: encrypted }));
84 }
85 }
86 ResponseStreamEvent::ResponseCompleted(e) => {
87 if let Some(usage) = e.response.usage {
88 responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
89 }
90 match e.response.status {
91 Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
92 Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
93 _ => {}
94 }
95 }
96 ResponseStreamEvent::ResponseError(e) => {
97 responses.push(Err(LlmError::ApiError(format!("Codex API error: {}", e.message))));
98 }
99 _ => {}
101 }
102
103 responses
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::TokenUsage;
110 use async_openai::types::responses::{
111 FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
112 ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
113 ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
114 };
115 fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
117 let status_str = serde_json::to_value(status).unwrap();
118 let mut json = serde_json::json!({
119 "id": "resp_1",
120 "object": "response",
121 "status": status_str,
122 "output": [],
123 "model": "test",
124 "created_at": 0
125 });
126 if let Some(u) = usage {
127 json["usage"] = serde_json::to_value(u).unwrap();
128 }
129 serde_json::from_value(json).unwrap()
130 }
131
132 fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
133 make_usage_full(input_tokens, output_tokens, 0, 0)
134 }
135
136 fn make_usage_full(
137 input_tokens: u32,
138 output_tokens: u32,
139 cached_tokens: u32,
140 reasoning_tokens: u32,
141 ) -> ResponseUsage {
142 serde_json::from_value(serde_json::json!({
143 "input_tokens": input_tokens,
144 "input_tokens_details": { "cached_tokens": cached_tokens },
145 "output_tokens": output_tokens,
146 "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
147 "total_tokens": input_tokens + output_tokens
148 }))
149 .unwrap()
150 }
151
152 fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
153 tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
154 }
155
156 #[tokio::test]
157 async fn test_text_stream() {
158 let events = vec![
159 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
160 output_index: 0,
161 content_index: 0,
162 delta: "Hello".to_string(),
163 sequence_number: 1,
164 item_id: "msg_1".to_string(),
165 logprobs: None,
166 }),
167 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
168 output_index: 0,
169 content_index: 0,
170 delta: " world".to_string(),
171 sequence_number: 2,
172 item_id: "msg_1".to_string(),
173 logprobs: None,
174 }),
175 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
176 sequence_number: 3,
177 response: make_response(&Status::Completed, Some(make_usage(10, 5))),
178 }),
179 ];
180
181 let stream = make_stream(events);
182 let mut response_stream = Box::pin(process_response_stream(stream));
183
184 let mut responses = Vec::new();
185 while let Some(result) = response_stream.next().await {
186 responses.push(result.unwrap());
187 }
188
189 assert!(matches!(responses[0], LlmResponse::Start { .. }));
190 assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
191 assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
192 assert!(matches!(
193 responses[3],
194 LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
195 ));
196 assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
197 }
198
199 #[tokio::test]
200 async fn test_tool_call_stream() {
201 let events = vec![
202 ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
203 sequence_number: 1,
204 output_index: 0,
205 item: OutputItem::FunctionCall(FunctionToolCall {
206 id: Some("fc_1".to_string()),
207 call_id: "call_1".to_string(),
208 name: "read_file".to_string(),
209 arguments: String::new(),
210 status: None,
211 namespace: None,
212 }),
213 }),
214 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
215 sequence_number: 2,
216 item_id: "fc_1".to_string(),
217 output_index: 0,
218 delta: r#"{"path":"#.to_string(),
219 }),
220 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
221 sequence_number: 3,
222 item_id: "fc_1".to_string(),
223 output_index: 0,
224 delta: r#""foo.rs"}"#.to_string(),
225 }),
226 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
227 sequence_number: 4,
228 item_id: "fc_1".to_string(),
229 output_index: 0,
230 arguments: r#"{"path":"foo.rs"}"#.to_string(),
231 name: None,
232 }),
233 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
234 sequence_number: 5,
235 response: make_response(&Status::Completed, Some(make_usage(20, 10))),
236 }),
237 ];
238
239 let stream = make_stream(events);
240 let mut response_stream = Box::pin(process_response_stream(stream));
241
242 let mut responses = Vec::new();
243 while let Some(result) = response_stream.next().await {
244 responses.push(result.unwrap());
245 }
246
247 assert!(matches!(responses[0], LlmResponse::Start { .. }));
248 assert!(
249 matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
250 );
251 assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
252 assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
253
254 let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
255 assert!(tc.is_some());
256 if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
257 assert_eq!(tool_call.id, "fc_1");
258 assert_eq!(tool_call.name, "read_file");
259 assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
260 }
261 }
262
263 #[tokio::test]
264 async fn test_error_event() {
265 let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
266 sequence_number: 1,
267 code: None,
268 message: "Rate limit exceeded".to_string(),
269 param: None,
270 })];
271
272 let stream = make_stream(events);
273 let mut response_stream = Box::pin(process_response_stream(stream));
274
275 let mut responses = Vec::new();
276 while let Some(result) = response_stream.next().await {
277 responses.push(result);
278 }
279
280 assert!(responses[0].is_ok()); assert!(responses[1].is_err()); }
283
284 #[tokio::test]
285 async fn test_reasoning_delta() {
286 let events = vec![
287 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
288 sequence_number: 1,
289 item_id: "r_1".to_string(),
290 output_index: 0,
291 summary_index: 0,
292 delta: "Thinking about".to_string(),
293 }),
294 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
295 sequence_number: 2,
296 item_id: "r_1".to_string(),
297 output_index: 0,
298 summary_index: 0,
299 delta: " the problem".to_string(),
300 }),
301 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
302 sequence_number: 3,
303 response: make_response(&Status::Completed, None),
304 }),
305 ];
306
307 let stream = make_stream(events);
308 let mut response_stream = Box::pin(process_response_stream(stream));
309
310 let mut responses = Vec::new();
311 while let Some(result) = response_stream.next().await {
312 responses.push(result.unwrap());
313 }
314
315 assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
316 assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
317 }
318
319 #[tokio::test]
320 async fn test_incomplete_status_gives_length_stop_reason() {
321 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
322 sequence_number: 1,
323 response: make_response(&Status::Incomplete, None),
324 })];
325
326 let stream = make_stream(events);
327 let mut response_stream = Box::pin(process_response_stream(stream));
328
329 let mut responses = Vec::new();
330 while let Some(result) = response_stream.next().await {
331 responses.push(result.unwrap());
332 }
333
334 assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
335 }
336
337 #[tokio::test]
338 async fn test_stream_error_propagation() {
339 let events: Vec<Result<ResponseStreamEvent>> = vec![Err(LlmError::ApiError("connection lost".to_string()))];
340
341 let stream = tokio_stream::iter(events);
342 let mut response_stream = Box::pin(process_response_stream(stream));
343
344 let mut responses = Vec::new();
345 while let Some(result) = response_stream.next().await {
346 responses.push(result);
347 }
348
349 assert!(responses[0].is_ok()); assert!(responses[1].is_err()); }
352
353 #[test]
354 fn test_encrypted_reasoning_from_output_item_done() {
355 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
356 sequence_number: 1,
357 output_index: 0,
358 item: OutputItem::Reasoning(ReasoningItem {
359 id: "r_1".to_string(),
360 summary: vec![],
361 encrypted_content: Some("enc-blob-data".to_string()),
362 content: None,
363 status: None,
364 }),
365 });
366
367 let mut tool_collector = ToolCallCollector::<u32>::new();
368 let mut stop_reason = None;
369 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
370
371 assert_eq!(responses.len(), 1);
372 assert!(
373 matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
374 );
375 }
376
377 #[tokio::test]
378 async fn test_usage_forwards_reasoning_and_cache_read() {
379 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
380 sequence_number: 1,
381 response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
382 })];
383
384 let stream = make_stream(events);
385 let mut response_stream = Box::pin(process_response_stream(stream));
386
387 let mut responses = Vec::new();
388 while let Some(result) = response_stream.next().await {
389 responses.push(result.unwrap());
390 }
391
392 let usage = responses.iter().find_map(|r| match r {
393 LlmResponse::Usage { tokens } => Some(*tokens),
394 _ => None,
395 });
396
397 assert_eq!(
398 usage,
399 Some(TokenUsage {
400 input_tokens: 120,
401 output_tokens: 80,
402 cache_read_tokens: Some(50),
403 reasoning_tokens: Some(30),
404 ..TokenUsage::default()
405 })
406 );
407 }
408
409 #[test]
410 fn test_output_item_done_without_encrypted_content_is_ignored() {
411 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
412 sequence_number: 1,
413 output_index: 0,
414 item: OutputItem::Reasoning(ReasoningItem {
415 id: "r_2".to_string(),
416 summary: vec![],
417 encrypted_content: None,
418 content: None,
419 status: None,
420 }),
421 });
422
423 let mut tool_collector = ToolCallCollector::<u32>::new();
424 let mut stop_reason = None;
425 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
426
427 assert!(responses.is_empty());
428 }
429}