1use bytes::Bytes;
42use futures_util::Stream;
43use pin_project_lite::pin_project;
44use std::pin::Pin;
45use std::task::{Context, Poll};
46
47use crate::error::{Error, Result};
48use crate::models::response::{Response, StreamChunk};
49use crate::models::tool::ToolCall;
50
51pin_project! {
52 pub struct ResponseStream {
54 #[pin]
55 inner: Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>,
56 buffer: String,
57 raw_buffer: Vec<u8>,
58 done: bool,
59 }
60}
61
62impl ResponseStream {
63 pub fn new<S>(stream: S) -> Self
65 where
66 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + 'static,
67 {
68 use futures_util::StreamExt;
69
70 let mapped = stream.map(|result| result.map_err(Error::from));
71 Self {
72 inner: Box::pin(mapped),
73 buffer: String::new(),
74 raw_buffer: Vec::new(),
75 done: false,
76 }
77 }
78
79 fn parse_event(data: &str) -> Result<Option<StreamChunk>> {
81 let data = data.trim();
82
83 if data.is_empty() || data.starts_with(':') {
85 return Ok(None);
86 }
87
88 if data == "[DONE]" {
90 return Ok(Some(StreamChunk {
91 delta: None,
92 reasoning_delta: None,
93 tool_calls: vec![],
94 done: true,
95 response: None,
96 }));
97 }
98
99 let event: StreamEvent = serde_json::from_str(data)?;
101
102 match event {
103 StreamEvent::ResponseDelta(delta) => {
104 let text_delta = delta.delta.and_then(|d| {
105 d.content.and_then(|parts| {
106 let text = parts
107 .into_iter()
108 .filter_map(|part| {
109 if let DeltaContentPart::Text { text } = part {
110 Some(text)
111 } else {
112 None
113 }
114 })
115 .collect::<String>();
116 if text.is_empty() {
117 None
118 } else {
119 Some(text)
120 }
121 })
122 });
123
124 Ok(Some(StreamChunk {
125 delta: text_delta,
126 reasoning_delta: None,
127 tool_calls: vec![],
128 done: false,
129 response: None,
130 }))
131 }
132 StreamEvent::ResponseDone(done) => Ok(Some(StreamChunk {
133 delta: None,
134 reasoning_delta: None,
135 tool_calls: vec![],
136 done: true,
137 response: Some(done.response),
138 })),
139 StreamEvent::ResponseToolCallDelta(delta) => {
140 let tool_call = delta.delta.map(|d| ToolCall {
141 id: delta.tool_call_id.unwrap_or_default(),
142 call_type: Some("function".to_string()),
143 function: d.function,
144 });
145
146 Ok(Some(StreamChunk {
147 delta: None,
148 reasoning_delta: None,
149 tool_calls: tool_call.into_iter().collect(),
150 done: false,
151 response: None,
152 }))
153 }
154 _ => Ok(None),
155 }
156 }
157
158 fn parse_sse_event(event_block: &str) -> Result<Option<StreamChunk>> {
160 let mut first_data_line: Option<&str> = None;
161 let mut merged_payload: Option<String> = None;
162
163 for line in event_block.lines() {
164 if line.is_empty() || line.starts_with(':') {
165 continue;
166 }
167
168 if line == "data" {
169 if let Some(payload) = merged_payload.as_mut() {
170 payload.push('\n');
171 } else if let Some(first) = first_data_line {
172 let mut payload = String::with_capacity(first.len() + 1);
173 payload.push_str(first);
174 payload.push('\n');
175 merged_payload = Some(payload);
176 } else {
177 first_data_line = Some("");
178 }
179 } else if let Some(data) = line.strip_prefix("data:") {
180 let payload_line = data.strip_prefix(' ').unwrap_or(data);
182 if let Some(payload) = merged_payload.as_mut() {
183 payload.push('\n');
184 payload.push_str(payload_line);
185 } else if let Some(first) = first_data_line {
186 let mut payload = String::with_capacity(first.len() + 1 + payload_line.len());
187 payload.push_str(first);
188 payload.push('\n');
189 payload.push_str(payload_line);
190 merged_payload = Some(payload);
191 } else {
192 first_data_line = Some(payload_line);
193 }
194 }
195 }
196
197 if let Some(payload) = merged_payload {
198 return Self::parse_event(&payload);
199 }
200
201 if let Some(payload) = first_data_line {
202 return Self::parse_event(payload);
203 }
204
205 Ok(None)
206 }
207
208 fn find_event_separator(buffer: &str) -> Option<(usize, usize)> {
209 let bytes = buffer.as_bytes();
210 let mut i = 0usize;
211 while i + 1 < bytes.len() {
212 if bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
213 return Some((i, 2));
214 }
215
216 if bytes[i] == b'\r' && bytes[i + 1] == b'\r' {
217 return Some((i, 2));
218 }
219
220 if i + 3 < bytes.len()
221 && bytes[i] == b'\r'
222 && bytes[i + 1] == b'\n'
223 && bytes[i + 2] == b'\r'
224 && bytes[i + 3] == b'\n'
225 {
226 return Some((i, 4));
227 }
228
229 if i + 2 < bytes.len() {
230 if bytes[i] == b'\r' && bytes[i + 1] == b'\n' && bytes[i + 2] == b'\n' {
232 return Some((i, 3));
233 }
234 if bytes[i] == b'\n' && bytes[i + 1] == b'\r' && bytes[i + 2] == b'\n' {
235 return Some((i, 3));
236 }
237 }
238
239 i += 1;
240 }
241 None
242 }
243}
244
245impl Stream for ResponseStream {
246 type Item = Result<StreamChunk>;
247
248 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
249 let mut this = self.project();
250
251 if *this.done {
252 return Poll::Ready(None);
253 }
254
255 loop {
256 if let Some((pos, sep_len)) = Self::find_event_separator(this.buffer) {
258 let parsed = {
259 let event_str = &this.buffer[..pos];
260 Self::parse_sse_event(event_str)
261 };
262 this.buffer.drain(..pos + sep_len);
263
264 match parsed {
265 Ok(Some(chunk)) => {
266 if chunk.done {
267 *this.done = true;
268 }
269 return Poll::Ready(Some(Ok(chunk)));
270 }
271 Ok(None) => continue,
272 Err(e) => return Poll::Ready(Some(Err(e))),
273 }
274 }
275
276 match this.inner.as_mut().poll_next(cx) {
278 Poll::Ready(Some(Ok(bytes))) => {
279 this.raw_buffer.extend_from_slice(&bytes);
280 match std::str::from_utf8(this.raw_buffer) {
282 Ok(text) => {
283 this.buffer.push_str(text);
284 this.raw_buffer.clear();
285 }
286 Err(e) => {
287 let valid_up_to = e.valid_up_to();
288 if valid_up_to > 0 {
289 let valid = std::str::from_utf8(&this.raw_buffer[..valid_up_to])
291 .expect("valid_up_to guarantees valid UTF-8");
292 this.buffer.push_str(valid);
293 this.raw_buffer.drain(..valid_up_to);
294 }
295 }
297 }
298 }
299 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
300 Poll::Ready(None) => {
301 *this.done = true;
302 return Poll::Ready(None);
303 }
304 Poll::Pending => return Poll::Pending,
305 }
306 }
307 }
308}
309
310#[derive(Debug, serde::Deserialize)]
312#[serde(tag = "type", rename_all = "snake_case")]
313enum StreamEvent {
314 #[serde(rename = "response.output_item.delta")]
316 ResponseDelta(ResponseDeltaEvent),
317 #[serde(rename = "response.done")]
319 ResponseDone(ResponseDoneEvent),
320 #[serde(rename = "response.function_call_arguments.delta")]
322 ResponseToolCallDelta(ToolCallDeltaEvent),
323 #[serde(rename = "response.created")]
325 ResponseCreated {},
326 #[serde(rename = "response.output_item.added")]
328 OutputItemAdded {},
329 #[serde(rename = "response.output_item.done")]
331 OutputItemDone {},
332 #[serde(rename = "response.content_part.added")]
334 ContentPartAdded {},
335 #[serde(rename = "response.content_part.done")]
337 ContentPartDone {},
338 #[serde(other)]
340 Unknown,
341}
342
343#[derive(Debug, serde::Deserialize)]
344struct ResponseDeltaEvent {
345 delta: Option<DeltaContent>,
346}
347
348#[derive(Debug, serde::Deserialize)]
349struct DeltaContent {
350 content: Option<Vec<DeltaContentPart>>,
351}
352
353#[derive(Debug, serde::Deserialize)]
354#[serde(tag = "type", rename_all = "snake_case")]
355enum DeltaContentPart {
356 Text {
357 text: String,
358 },
359 #[serde(other)]
360 Other,
361}
362
363#[derive(Debug, serde::Deserialize)]
364struct ResponseDoneEvent {
365 response: Response,
366}
367
368#[derive(Debug, serde::Deserialize)]
369struct ToolCallDeltaEvent {
370 tool_call_id: Option<String>,
371 delta: Option<ToolCallDelta>,
372}
373
374#[derive(Debug, serde::Deserialize)]
375struct ToolCallDelta {
376 function: Option<crate::models::tool::FunctionCall>,
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use bytes::Bytes;
383 use futures_util::{stream, StreamExt};
384
385 #[test]
388 fn parse_event_response_delta_text() {
389 let data = r#"{"type":"response.output_item.delta","delta":{"content":[{"type":"text","text":"Hello"}]}}"#;
390 let result = ResponseStream::parse_event(data).unwrap();
391 let chunk = result.unwrap();
392 assert!(!chunk.done);
393 assert_eq!(chunk.delta.as_deref(), Some("Hello"));
394 assert!(chunk.response.is_none());
395 }
396
397 #[test]
398 fn parse_event_response_delta_multiple_text_parts_concatenates() {
399 let data = r#"{"type":"response.output_item.delta","delta":{"content":[{"type":"text","text":"Hel"},{"type":"text","text":"lo"}]}}"#;
400 let result = ResponseStream::parse_event(data).unwrap();
401 let chunk = result.unwrap();
402 assert!(!chunk.done);
403 assert_eq!(chunk.delta.as_deref(), Some("Hello"));
404 }
405
406 #[test]
407 fn parse_event_response_delta_no_content() {
408 let data = r#"{"type":"response.output_item.delta","delta":{}}"#;
410 let result = ResponseStream::parse_event(data).unwrap();
411 let chunk = result.unwrap();
412 assert!(!chunk.done);
413 assert!(chunk.delta.is_none());
414 }
415
416 #[test]
417 fn parse_event_response_done() {
418 let data = r#"{"type":"response.done","response":{"id":"resp_1","model":"grok-4","output":[{"type":"message","role":"assistant","content":[{"type":"text","text":"Done"}]}],"usage":{"prompt_tokens":5,"completion_tokens":10,"total_tokens":15}}}"#;
419 let result = ResponseStream::parse_event(data).unwrap();
420 let chunk = result.unwrap();
421 assert!(chunk.done);
422 assert!(chunk.response.is_some());
423 let resp = chunk.response.unwrap();
424 assert_eq!(resp.id, "resp_1");
425 assert_eq!(resp.output_text().unwrap(), "Done");
426 }
427
428 #[test]
429 fn parse_event_tool_call_delta() {
430 let data = r#"{"type":"response.function_call_arguments.delta","tool_call_id":"call_1","delta":{"function":{"name":"get_weather","arguments":"{\"city\":"}}}"#;
431 let result = ResponseStream::parse_event(data).unwrap();
432 let chunk = result.unwrap();
433 assert!(!chunk.done);
434 assert_eq!(chunk.tool_calls.len(), 1);
435 assert_eq!(chunk.tool_calls[0].id, "call_1");
436 assert_eq!(chunk.tool_calls[0].call_type.as_deref(), Some("function"));
437 assert_eq!(
438 chunk.tool_calls[0].function.as_ref().unwrap().name,
439 "get_weather"
440 );
441 }
442
443 #[test]
446 fn parse_event_done_marker() {
447 let result = ResponseStream::parse_event("[DONE]").unwrap();
448 let chunk = result.unwrap();
449 assert!(chunk.done);
450 assert!(chunk.delta.is_none());
451 assert!(chunk.response.is_none());
452 assert!(chunk.tool_calls.is_empty());
453 }
454
455 #[test]
456 fn parse_event_done_marker_trims_whitespace() {
457 let result = ResponseStream::parse_event(" [DONE]\n").unwrap();
458 let chunk = result.unwrap();
459 assert!(chunk.done);
460 }
461
462 #[test]
465 fn parse_event_empty_string() {
466 let result = ResponseStream::parse_event("").unwrap();
467 assert!(result.is_none());
468 }
469
470 #[test]
471 fn parse_event_comment_line() {
472 let result = ResponseStream::parse_event(": this is a comment").unwrap();
473 assert!(result.is_none());
474 }
475
476 #[test]
477 fn parse_event_comment_colon_only() {
478 let result = ResponseStream::parse_event(":").unwrap();
479 assert!(result.is_none());
480 }
481
482 #[test]
485 fn parse_event_unknown_type_returns_none() {
486 let data = r#"{"type":"response.created"}"#;
487 let result = ResponseStream::parse_event(data).unwrap();
488 assert!(result.is_none());
489 }
490
491 #[test]
492 fn parse_event_content_part_added_returns_none() {
493 let data = r#"{"type":"response.content_part.added"}"#;
494 let result = ResponseStream::parse_event(data).unwrap();
495 assert!(result.is_none());
496 }
497
498 #[test]
501 fn parse_event_invalid_json() {
502 let result = ResponseStream::parse_event("{not valid json}");
503 assert!(result.is_err());
504 }
505
506 #[test]
507 fn parse_event_completely_broken() {
508 let result = ResponseStream::parse_event("just random text");
509 assert!(result.is_err());
510 }
511
512 #[test]
515 fn parse_event_tool_call_delta_no_id() {
516 let data = r#"{"type":"response.function_call_arguments.delta","delta":{"function":{"name":"fn1","arguments":"{}"}}}"#;
517 let result = ResponseStream::parse_event(data).unwrap();
518 let chunk = result.unwrap();
519 assert_eq!(chunk.tool_calls.len(), 1);
520 assert_eq!(chunk.tool_calls[0].id, "");
522 }
523
524 #[test]
525 fn parse_event_tool_call_delta_no_delta() {
526 let data = r#"{"type":"response.function_call_arguments.delta","tool_call_id":"call_2"}"#;
527 let result = ResponseStream::parse_event(data).unwrap();
528 let chunk = result.unwrap();
529 assert!(chunk.tool_calls.is_empty());
531 }
532
533 #[test]
534 fn parse_sse_event_data_without_space() {
535 let event = r#"data:{"type":"response.output_item.delta","delta":{"content":[{"type":"text","text":"Hello"}]}}"#;
536 let result = ResponseStream::parse_sse_event(event).unwrap();
537 let chunk = result.unwrap();
538 assert_eq!(chunk.delta.as_deref(), Some("Hello"));
539 }
540
541 #[test]
542 fn parse_sse_event_multiline_data_concatenates_with_newline() {
543 let event = "data: {\"type\":\"response.output_item.delta\",\n\
544data: \"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}";
545 let result = ResponseStream::parse_sse_event(event).unwrap();
546 let chunk = result.unwrap();
547 assert_eq!(chunk.delta.as_deref(), Some("Hello"));
548 }
549
550 #[test]
551 fn parse_sse_event_accepts_bare_data_line_before_done() {
552 let event = "data\ndata: [DONE]";
553 let result = ResponseStream::parse_sse_event(event).unwrap();
554 let chunk = result.unwrap();
555 assert!(chunk.done);
556 }
557
558 #[tokio::test]
559 async fn stream_handles_crlf_event_separators() {
560 let payload = concat!(
561 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}\r\n\r\n",
562 "data: [DONE]\r\n\r\n"
563 );
564
565 let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
566 vec![Ok(Bytes::from(payload.to_string()))];
567 let raw_stream = stream::iter(chunks);
568 let mut response_stream = ResponseStream::new(raw_stream);
569
570 let first = response_stream.next().await.unwrap().unwrap();
571 assert_eq!(first.delta.as_deref(), Some("Hello"));
572 assert!(!first.done);
573
574 let done = response_stream.next().await.unwrap().unwrap();
575 assert!(done.done);
576
577 assert!(response_stream.next().await.is_none());
578 }
579
580 #[tokio::test]
581 async fn stream_handles_cr_only_event_separators() {
582 let payload = concat!(
583 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}\r\r",
584 "data: [DONE]\r\r"
585 );
586
587 let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
588 vec![Ok(Bytes::from(payload.to_string()))];
589 let raw_stream = stream::iter(chunks);
590 let mut response_stream = ResponseStream::new(raw_stream);
591
592 let first = response_stream.next().await.unwrap().unwrap();
593 assert_eq!(first.delta.as_deref(), Some("Hello"));
594 assert!(!first.done);
595
596 let done = response_stream.next().await.unwrap().unwrap();
597 assert!(done.done);
598
599 assert!(response_stream.next().await.is_none());
600 }
601
602 #[tokio::test]
603 async fn stream_handles_mixed_event_separators() {
604 let payload = concat!(
605 "data: {\"type\":\"response.output_item.delta\",\"delta\":{\"content\":[{\"type\":\"text\",\"text\":\"Hello\"}]}}\r\n\n",
606 "data: [DONE]\n\r\n"
607 );
608
609 let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
610 vec![Ok(Bytes::from(payload.to_string()))];
611 let raw_stream = stream::iter(chunks);
612 let mut response_stream = ResponseStream::new(raw_stream);
613
614 let first = response_stream.next().await.unwrap().unwrap();
615 assert_eq!(first.delta.as_deref(), Some("Hello"));
616 assert!(!first.done);
617
618 let done = response_stream.next().await.unwrap().unwrap();
619 assert!(done.done);
620
621 assert!(response_stream.next().await.is_none());
622 }
623}