openai_api_rs/v1/responses/
responses_stream.rs

1use super::responses::CreateResponseRequest;
2use futures_util::Stream;
3use serde_json::Value;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7pub type CreateResponseStreamRequest = CreateResponseRequest;
8
9#[derive(Debug, Clone)]
10pub struct ResponseStreamEvent {
11    pub event: Option<String>,
12    pub data: Value,
13}
14
15#[derive(Debug, Clone)]
16pub enum ResponseStreamResponse {
17    Event(ResponseStreamEvent),
18    Done,
19}
20
21pub struct ResponseStream<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> {
22    pub response: S,
23    pub buffer: String,
24    pub first_chunk: bool,
25}
26
27impl<S> ResponseStream<S>
28where
29    S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
30{
31    fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> {
32        let carriage_idx = buffer.find("\r\n\r\n");
33        let newline_idx = buffer.find("\n\n");
34
35        match (carriage_idx, newline_idx) {
36            (Some(r_idx), Some(n_idx)) => {
37                if r_idx <= n_idx {
38                    Some((r_idx, 4))
39                } else {
40                    Some((n_idx, 2))
41                }
42            }
43            (Some(r_idx), None) => Some((r_idx, 4)),
44            (None, Some(n_idx)) => Some((n_idx, 2)),
45            (None, None) => None,
46        }
47    }
48
49    fn next_response_from_buffer(&mut self) -> Option<ResponseStreamResponse> {
50        while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) {
51            let event_block = self.buffer[..idx].to_owned();
52            self.buffer = self.buffer[idx + delimiter_len..].to_owned();
53
54            let mut event_name = None;
55            let mut data_payload = String::new();
56
57            for line in event_block.lines() {
58                let trimmed_line = line.trim_end_matches('\r');
59
60                if let Some(event) = trimmed_line
61                    .strip_prefix("event: ")
62                    .or_else(|| trimmed_line.strip_prefix("event:"))
63                {
64                    let name = event.trim();
65                    if !name.is_empty() {
66                        event_name = Some(name.to_string());
67                    }
68                } else if let Some(content) = trimmed_line
69                    .strip_prefix("data: ")
70                    .or_else(|| trimmed_line.strip_prefix("data:"))
71                {
72                    if !content.is_empty() {
73                        if !data_payload.is_empty() {
74                            data_payload.push('\n');
75                        }
76                        data_payload.push_str(content);
77                    }
78                }
79            }
80
81            if data_payload.is_empty() {
82                continue;
83            }
84
85            if data_payload.trim() == "[DONE]" {
86                return Some(ResponseStreamResponse::Done);
87            }
88
89            let parsed = serde_json::from_str::<Value>(&data_payload)
90                .unwrap_or_else(|_| Value::String(data_payload.clone()));
91
92            return Some(ResponseStreamResponse::Event(ResponseStreamEvent {
93                event: event_name,
94                data: parsed,
95            }));
96        }
97
98        None
99    }
100}
101
102impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream for ResponseStream<S> {
103    type Item = ResponseStreamResponse;
104
105    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        loop {
107            if let Some(response) = self.next_response_from_buffer() {
108                return Poll::Ready(Some(response));
109            }
110
111            match Pin::new(&mut self.as_mut().response).poll_next(cx) {
112                Poll::Ready(Some(Ok(chunk))) => {
113                    let chunk_str = String::from_utf8_lossy(&chunk).to_string();
114                    if self.first_chunk {
115                        self.first_chunk = false;
116                    }
117                    self.buffer.push_str(&chunk_str);
118                }
119                Poll::Ready(Some(Err(error))) => {
120                    eprintln!("Error in stream: {:?}", error);
121                    return Poll::Ready(None);
122                }
123                Poll::Ready(None) => {
124                    return Poll::Ready(None);
125                }
126                Poll::Pending => {
127                    return Poll::Pending;
128                }
129            }
130        }
131    }
132}