1use async_openai::types::responses::{OutputItem, ResponseStreamEvent, Status};
2
3use crate::providers::tool_call_collector::ToolCallCollector;
4use crate::{LlmError, LlmResponse, Result, StopReason};
5use futures::Stream;
6use tokio_stream::StreamExt;
7
8pub fn process_response_stream<T>(stream: T) -> impl Stream<Item = Result<LlmResponse>> + Send
10where
11 T: Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin,
12{
13 async_stream::stream! {
14 let message_id = uuid::Uuid::new_v4().to_string();
15 yield Ok(LlmResponse::Start { message_id });
16
17 let mut tool_collector = ToolCallCollector::<u32>::new();
18 let mut stream = Box::pin(stream);
19 let mut last_stop_reason: Option<StopReason> = None;
20
21 while let Some(result) = stream.next().await {
22 match result {
23 Ok(event) => {
24 for response in process_event(event, &mut tool_collector, &mut last_stop_reason) {
25 yield response;
26 }
27 }
28 Err(e) => {
29 yield Err(LlmError::ApiError(e.to_string()));
30 break;
31 }
32 }
33 }
34
35 for tc in tool_collector.complete_all() {
37 yield Ok(LlmResponse::ToolRequestComplete { tool_call: tc });
38 }
39
40 yield Ok(LlmResponse::Done {
41 stop_reason: last_stop_reason,
42 });
43 }
44}
45
46fn process_event(
47 event: ResponseStreamEvent,
48 tool_collector: &mut ToolCallCollector<u32>,
49 last_stop_reason: &mut Option<StopReason>,
50) -> Vec<Result<LlmResponse>> {
51 let mut responses = Vec::new();
52
53 match event {
54 ResponseStreamEvent::ResponseOutputTextDelta(e) if !e.delta.is_empty() => {
55 responses.push(Ok(LlmResponse::Text { chunk: e.delta }));
56 }
57 ResponseStreamEvent::ResponseOutputItemAdded(e) => {
58 if let OutputItem::FunctionCall(call) = e.item {
59 let tool_responses = tool_collector.handle_delta(e.output_index, call.id, Some(call.name), None);
60 responses.extend(tool_responses.into_iter().map(Ok));
61 }
62 }
63 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(e) => {
64 let tool_responses = tool_collector.handle_delta(e.output_index, None, None, Some(e.delta));
65 responses.extend(tool_responses.into_iter().map(Ok));
66 }
67 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(e) => {
68 if let Some(tc) = tool_collector.complete_one(e.output_index) {
69 responses.push(Ok(LlmResponse::ToolRequestComplete { tool_call: tc }));
70 }
71 }
72 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(e) if !e.delta.is_empty() => {
73 responses.push(Ok(LlmResponse::Reasoning { chunk: e.delta }));
74 }
75 ResponseStreamEvent::ResponseOutputItemDone(e) => {
76 if let OutputItem::Reasoning(reasoning) = e.item
77 && let Some(encrypted) = reasoning.encrypted_content
78 {
79 responses.push(Ok(LlmResponse::EncryptedReasoning { id: reasoning.id, content: encrypted }));
80 }
81 }
82 ResponseStreamEvent::ResponseCompleted(e) => {
83 if let Some(usage) = e.response.usage {
84 responses.push(Ok(LlmResponse::Usage { tokens: usage.into() }));
85 }
86 match e.response.status {
87 Status::Completed => *last_stop_reason = Some(StopReason::EndTurn),
88 Status::Incomplete => *last_stop_reason = Some(StopReason::Length),
89 _ => {}
90 }
91 }
92 ResponseStreamEvent::ResponseError(e) => {
93 responses.push(Err(LlmError::ApiError(format!("Codex API error: {}", e.message))));
94 }
95 _ => {}
97 }
98
99 responses
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use crate::TokenUsage;
106 use async_openai::types::responses::{
107 FunctionToolCall, ReasoningItem, Response, ResponseCompletedEvent, ResponseErrorEvent,
108 ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseOutputItemAddedEvent,
109 ResponseOutputItemDoneEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent, ResponseUsage,
110 };
111 fn make_response(status: &Status, usage: Option<ResponseUsage>) -> Response {
113 let status_str = serde_json::to_value(status).unwrap();
114 let mut json = serde_json::json!({
115 "id": "resp_1",
116 "object": "response",
117 "status": status_str,
118 "output": [],
119 "model": "test",
120 "created_at": 0
121 });
122 if let Some(u) = usage {
123 json["usage"] = serde_json::to_value(u).unwrap();
124 }
125 serde_json::from_value(json).unwrap()
126 }
127
128 fn make_usage(input_tokens: u32, output_tokens: u32) -> ResponseUsage {
129 make_usage_full(input_tokens, output_tokens, 0, 0)
130 }
131
132 fn make_usage_full(
133 input_tokens: u32,
134 output_tokens: u32,
135 cached_tokens: u32,
136 reasoning_tokens: u32,
137 ) -> ResponseUsage {
138 serde_json::from_value(serde_json::json!({
139 "input_tokens": input_tokens,
140 "input_tokens_details": { "cached_tokens": cached_tokens },
141 "output_tokens": output_tokens,
142 "output_tokens_details": { "reasoning_tokens": reasoning_tokens },
143 "total_tokens": input_tokens + output_tokens
144 }))
145 .unwrap()
146 }
147
148 fn make_stream(events: Vec<ResponseStreamEvent>) -> impl Stream<Item = Result<ResponseStreamEvent>> + Send + Unpin {
149 tokio_stream::iter(events.into_iter().map(Ok).collect::<Vec<_>>())
150 }
151
152 #[tokio::test]
153 async fn test_text_stream() {
154 let events = vec![
155 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
156 output_index: 0,
157 content_index: 0,
158 delta: "Hello".to_string(),
159 sequence_number: 1,
160 item_id: "msg_1".to_string(),
161 logprobs: None,
162 }),
163 ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
164 output_index: 0,
165 content_index: 0,
166 delta: " world".to_string(),
167 sequence_number: 2,
168 item_id: "msg_1".to_string(),
169 logprobs: None,
170 }),
171 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
172 sequence_number: 3,
173 response: make_response(&Status::Completed, Some(make_usage(10, 5))),
174 }),
175 ];
176
177 let stream = make_stream(events);
178 let mut response_stream = Box::pin(process_response_stream(stream));
179
180 let mut responses = Vec::new();
181 while let Some(result) = response_stream.next().await {
182 responses.push(result.unwrap());
183 }
184
185 assert!(matches!(responses[0], LlmResponse::Start { .. }));
186 assert!(matches!(responses[1], LlmResponse::Text { ref chunk } if chunk == "Hello"));
187 assert!(matches!(responses[2], LlmResponse::Text { ref chunk } if chunk == " world"));
188 assert!(matches!(
189 responses[3],
190 LlmResponse::Usage { tokens: TokenUsage { input_tokens: 10, output_tokens: 5, .. } }
191 ));
192 assert!(matches!(responses[4], LlmResponse::Done { stop_reason: Some(StopReason::EndTurn) }));
193 }
194
195 #[tokio::test]
196 async fn test_tool_call_stream() {
197 let events = vec![
198 ResponseStreamEvent::ResponseOutputItemAdded(ResponseOutputItemAddedEvent {
199 sequence_number: 1,
200 output_index: 0,
201 item: OutputItem::FunctionCall(FunctionToolCall {
202 id: Some("fc_1".to_string()),
203 call_id: "call_1".to_string(),
204 name: "read_file".to_string(),
205 arguments: String::new(),
206 status: None,
207 namespace: None,
208 }),
209 }),
210 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
211 sequence_number: 2,
212 item_id: "fc_1".to_string(),
213 output_index: 0,
214 delta: r#"{"path":"#.to_string(),
215 }),
216 ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDeltaEvent {
217 sequence_number: 3,
218 item_id: "fc_1".to_string(),
219 output_index: 0,
220 delta: r#""foo.rs"}"#.to_string(),
221 }),
222 ResponseStreamEvent::ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDoneEvent {
223 sequence_number: 4,
224 item_id: "fc_1".to_string(),
225 output_index: 0,
226 arguments: r#"{"path":"foo.rs"}"#.to_string(),
227 name: None,
228 }),
229 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
230 sequence_number: 5,
231 response: make_response(&Status::Completed, Some(make_usage(20, 10))),
232 }),
233 ];
234
235 let stream = make_stream(events);
236 let mut response_stream = Box::pin(process_response_stream(stream));
237
238 let mut responses = Vec::new();
239 while let Some(result) = response_stream.next().await {
240 responses.push(result.unwrap());
241 }
242
243 assert!(matches!(responses[0], LlmResponse::Start { .. }));
244 assert!(
245 matches!(&responses[1], LlmResponse::ToolRequestStart { id, name } if id == "fc_1" && name == "read_file")
246 );
247 assert!(matches!(responses[2], LlmResponse::ToolRequestArg { .. }));
248 assert!(matches!(responses[3], LlmResponse::ToolRequestArg { .. }));
249
250 let tc = responses.iter().find(|r| matches!(r, LlmResponse::ToolRequestComplete { .. }));
251 assert!(tc.is_some());
252 if let LlmResponse::ToolRequestComplete { tool_call } = tc.unwrap() {
253 assert_eq!(tool_call.id, "fc_1");
254 assert_eq!(tool_call.name, "read_file");
255 assert_eq!(tool_call.arguments, r#"{"path":"foo.rs"}"#);
256 }
257 }
258
259 #[tokio::test]
260 async fn test_error_event() {
261 let events = vec![ResponseStreamEvent::ResponseError(ResponseErrorEvent {
262 sequence_number: 1,
263 code: None,
264 message: "Rate limit exceeded".to_string(),
265 param: None,
266 })];
267
268 let stream = make_stream(events);
269 let mut response_stream = Box::pin(process_response_stream(stream));
270
271 let mut responses = Vec::new();
272 while let Some(result) = response_stream.next().await {
273 responses.push(result);
274 }
275
276 assert!(responses[0].is_ok()); assert!(responses[1].is_err()); }
279
280 #[tokio::test]
281 async fn test_reasoning_delta() {
282 let events = vec![
283 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
284 sequence_number: 1,
285 item_id: "r_1".to_string(),
286 output_index: 0,
287 summary_index: 0,
288 delta: "Thinking about".to_string(),
289 }),
290 ResponseStreamEvent::ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDeltaEvent {
291 sequence_number: 2,
292 item_id: "r_1".to_string(),
293 output_index: 0,
294 summary_index: 0,
295 delta: " the problem".to_string(),
296 }),
297 ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
298 sequence_number: 3,
299 response: make_response(&Status::Completed, None),
300 }),
301 ];
302
303 let stream = make_stream(events);
304 let mut response_stream = Box::pin(process_response_stream(stream));
305
306 let mut responses = Vec::new();
307 while let Some(result) = response_stream.next().await {
308 responses.push(result.unwrap());
309 }
310
311 assert!(matches!(responses[1], LlmResponse::Reasoning { ref chunk } if chunk == "Thinking about"));
312 assert!(matches!(responses[2], LlmResponse::Reasoning { ref chunk } if chunk == " the problem"));
313 }
314
315 #[tokio::test]
316 async fn test_incomplete_status_gives_length_stop_reason() {
317 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
318 sequence_number: 1,
319 response: make_response(&Status::Incomplete, None),
320 })];
321
322 let stream = make_stream(events);
323 let mut response_stream = Box::pin(process_response_stream(stream));
324
325 let mut responses = Vec::new();
326 while let Some(result) = response_stream.next().await {
327 responses.push(result.unwrap());
328 }
329
330 assert!(matches!(responses.last().unwrap(), LlmResponse::Done { stop_reason: Some(StopReason::Length) }));
331 }
332
333 #[tokio::test]
334 async fn test_stream_error_propagation() {
335 let events: Vec<Result<ResponseStreamEvent>> = vec![Err(LlmError::ApiError("connection lost".to_string()))];
336
337 let stream = tokio_stream::iter(events);
338 let mut response_stream = Box::pin(process_response_stream(stream));
339
340 let mut responses = Vec::new();
341 while let Some(result) = response_stream.next().await {
342 responses.push(result);
343 }
344
345 assert!(responses[0].is_ok()); assert!(responses[1].is_err()); }
348
349 #[test]
350 fn test_encrypted_reasoning_from_output_item_done() {
351 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
352 sequence_number: 1,
353 output_index: 0,
354 item: OutputItem::Reasoning(ReasoningItem {
355 id: "r_1".to_string(),
356 summary: vec![],
357 encrypted_content: Some("enc-blob-data".to_string()),
358 content: None,
359 status: None,
360 }),
361 });
362
363 let mut tool_collector = ToolCallCollector::<u32>::new();
364 let mut stop_reason = None;
365 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
366
367 assert_eq!(responses.len(), 1);
368 assert!(
369 matches!(&responses[0], Ok(LlmResponse::EncryptedReasoning { content, .. }) if content == "enc-blob-data")
370 );
371 }
372
373 #[tokio::test]
374 async fn test_usage_forwards_reasoning_and_cache_read() {
375 let events = vec![ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
376 sequence_number: 1,
377 response: make_response(&Status::Completed, Some(make_usage_full(120, 80, 50, 30))),
378 })];
379
380 let stream = make_stream(events);
381 let mut response_stream = Box::pin(process_response_stream(stream));
382
383 let mut responses = Vec::new();
384 while let Some(result) = response_stream.next().await {
385 responses.push(result.unwrap());
386 }
387
388 let usage = responses.iter().find_map(|r| match r {
389 LlmResponse::Usage { tokens } => Some(*tokens),
390 _ => None,
391 });
392
393 assert_eq!(
394 usage,
395 Some(TokenUsage {
396 input_tokens: 120,
397 output_tokens: 80,
398 cache_read_tokens: Some(50),
399 reasoning_tokens: Some(30),
400 ..TokenUsage::default()
401 })
402 );
403 }
404
405 #[test]
406 fn test_output_item_done_without_encrypted_content_is_ignored() {
407 let event = ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
408 sequence_number: 1,
409 output_index: 0,
410 item: OutputItem::Reasoning(ReasoningItem {
411 id: "r_2".to_string(),
412 summary: vec![],
413 encrypted_content: None,
414 content: None,
415 status: None,
416 }),
417 });
418
419 let mut tool_collector = ToolCallCollector::<u32>::new();
420 let mut stop_reason = None;
421 let responses = process_event(event, &mut tool_collector, &mut stop_reason);
422
423 assert!(responses.is_empty());
424 }
425}