openai_api_rs/v1/responses/
responses_stream.rs1use super::responses::CreateResponseRequest;
2use futures_util::Stream;
3use serde_json::Value;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7pub type CreateResponseStreamRequest = CreateResponseRequest;
8
9#[derive(Debug, Clone)]
10pub struct ResponseStreamEvent {
11 pub event: Option<String>,
12 pub data: Value,
13}
14
15#[derive(Debug, Clone)]
16pub enum ResponseStreamResponse {
17 Event(ResponseStreamEvent),
18 Done,
19}
20
21pub struct ResponseStream<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> {
22 pub response: S,
23 pub buffer: String,
24 pub first_chunk: bool,
25}
26
27impl<S> ResponseStream<S>
28where
29 S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
30{
31 fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> {
32 let carriage_idx = buffer.find("\r\n\r\n");
33 let newline_idx = buffer.find("\n\n");
34
35 match (carriage_idx, newline_idx) {
36 (Some(r_idx), Some(n_idx)) => {
37 if r_idx <= n_idx {
38 Some((r_idx, 4))
39 } else {
40 Some((n_idx, 2))
41 }
42 }
43 (Some(r_idx), None) => Some((r_idx, 4)),
44 (None, Some(n_idx)) => Some((n_idx, 2)),
45 (None, None) => None,
46 }
47 }
48
49 fn next_response_from_buffer(&mut self) -> Option<ResponseStreamResponse> {
50 while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) {
51 let event_block = self.buffer[..idx].to_owned();
52 self.buffer = self.buffer[idx + delimiter_len..].to_owned();
53
54 let mut event_name = None;
55 let mut data_payload = String::new();
56
57 for line in event_block.lines() {
58 let trimmed_line = line.trim_end_matches('\r');
59
60 if let Some(event) = trimmed_line
61 .strip_prefix("event: ")
62 .or_else(|| trimmed_line.strip_prefix("event:"))
63 {
64 let name = event.trim();
65 if !name.is_empty() {
66 event_name = Some(name.to_string());
67 }
68 } else if let Some(content) = trimmed_line
69 .strip_prefix("data: ")
70 .or_else(|| trimmed_line.strip_prefix("data:"))
71 {
72 if !content.is_empty() {
73 if !data_payload.is_empty() {
74 data_payload.push('\n');
75 }
76 data_payload.push_str(content);
77 }
78 }
79 }
80
81 if data_payload.is_empty() {
82 continue;
83 }
84
85 if data_payload.trim() == "[DONE]" {
86 return Some(ResponseStreamResponse::Done);
87 }
88
89 let parsed = serde_json::from_str::<Value>(&data_payload)
90 .unwrap_or_else(|_| Value::String(data_payload.clone()));
91
92 return Some(ResponseStreamResponse::Event(ResponseStreamEvent {
93 event: event_name,
94 data: parsed,
95 }));
96 }
97
98 None
99 }
100}
101
102impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream for ResponseStream<S> {
103 type Item = ResponseStreamResponse;
104
105 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106 loop {
107 if let Some(response) = self.next_response_from_buffer() {
108 return Poll::Ready(Some(response));
109 }
110
111 match Pin::new(&mut self.as_mut().response).poll_next(cx) {
112 Poll::Ready(Some(Ok(chunk))) => {
113 let chunk_str = String::from_utf8_lossy(&chunk).to_string();
114 if self.first_chunk {
115 self.first_chunk = false;
116 }
117 self.buffer.push_str(&chunk_str);
118 }
119 Poll::Ready(Some(Err(error))) => {
120 eprintln!("Error in stream: {:?}", error);
121 return Poll::Ready(None);
122 }
123 Poll::Ready(None) => {
124 return Poll::Ready(None);
125 }
126 Poll::Pending => {
127 return Poll::Pending;
128 }
129 }
130 }
131 }
132}