rustyopenai/chat/response/
stream.rs1use anyhow::Result;
2use tracing::error;
3use futures::{ Stream, StreamExt };
4use std::{ pin::Pin, task::{ Context, Poll } };
5use bytes::Bytes;
6use lazy_static::lazy_static;
7use regex::Regex;
8use super::ChatCompletionChunk;
9
10lazy_static! {
11    static ref STREAM_RESPONSE_CHUNK_RE: Regex = Regex::new(r"^data: \{.*\}\n\n").unwrap();
12    static ref STREAM_RESPONSE_TERMINATION_CHUNK_RE: Regex =
13        Regex::new(r"^data: \[DONE\]\n\n").unwrap();
14}
15
16fn extract_first_chunk(content: &str) -> Result<ExtractedChunkWithRemainingContent> {
17    if let Some(mat) = STREAM_RESPONSE_CHUNK_RE.find(content) {
18        let matched_str = mat.as_str();
20
21        let remaining_content = &content[mat.end()..];
23
24        let json_content = matched_str.strip_prefix("data: ").unwrap();
26
27        let chunk = serde_json::from_str(json_content)?;
29
30        Ok(ExtractedChunkWithRemainingContent {
32            chunk: Some(chunk),
33
34            remaining_content: match remaining_content.len() {
36                0 => None,
37                _ => Some(remaining_content.to_string()),
38            },
39        })
40    } else if let Some(_) = STREAM_RESPONSE_TERMINATION_CHUNK_RE.find(content) {
41        Ok(ExtractedChunkWithRemainingContent { chunk: None, remaining_content: None })
43    } else {
44        Ok(ExtractedChunkWithRemainingContent {
47            chunk: None,
48            remaining_content: match content.len() {
49                0 => None,
50                _ => Some(content.to_string()),
51            },
52        })
53    }
54}
55
56#[derive(Debug)]
57pub struct ExtractedChunkWithRemainingContent {
58    pub chunk: Option<ChatCompletionChunk>,
59    pub remaining_content: Option<String>,
60}
61
62pub struct ChatCompletionStream<S> {
63    response_bytes_stream: S,
64    remaining_content: Option<String>,
65}
66
67impl<S> ChatCompletionStream<S> where S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin {
68    pub fn new(response_bytes_stream: S) -> Self {
69        Self {
70            response_bytes_stream,
71            remaining_content: None,
72        }
73    }
74}
75
76impl<S> Stream
77    for ChatCompletionStream<S>
78    where S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin
79{
80    type Item = ChatCompletionChunk;
81
82    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
83        loop {
84            match self.response_bytes_stream.poll_next_unpin(cx) {
85                Poll::Ready(Some(Ok(bytes))) => {
86                    let mut content = std::str::from_utf8(&bytes).unwrap().to_string();
88
89                    if let Some(remaining_content) = &self.remaining_content {
91                        content.insert_str(0, remaining_content);
93                    }
94
95                    let response_with_remaining_content = extract_first_chunk(&content).unwrap();
97
98                    let chunk = response_with_remaining_content.chunk;
100                    let remaining_content = response_with_remaining_content.remaining_content;
101
102                    if let Some(remaining_content) = &remaining_content {
104                        self.remaining_content = Some(remaining_content.to_owned());
105                    } else {
106                        self.remaining_content = None;
107                    }
108
109                    if let Some(chunk) = chunk {
111                        return Poll::Ready(Some(chunk));
112                    } else if remaining_content.is_none() {
113                        return Poll::Ready(None);
117                    } else {
118                        continue;
119                        }
121                }
122                Poll::Ready(Some(Err(e))) => {
123                    error!("The response is incomplete - {}", e);
124                    return Poll::Ready(None);
125                }
126                Poll::Ready(None) => {
127                    if let Some(remaining_content) = &self.remaining_content {
132                        let response_with_remaining_content = extract_first_chunk(
134                            &remaining_content
135                        ).unwrap();
136
137                        let chunk = response_with_remaining_content.chunk;
139                        let remaining_content = response_with_remaining_content.remaining_content;
140
141                        if let Some(remaining_content) = &remaining_content {
143                            self.remaining_content = Some(remaining_content.to_owned());
144                        } else {
145                            self.remaining_content = None;
146                        }
147
148                        if let Some(chunk) = chunk {
150                            return Poll::Ready(Some(chunk));
151                        } else if remaining_content.is_none() {
152                            return Poll::Ready(None);
156                        } else {
157                            continue;
158                        }
159                    }
160
161                    return Poll::Ready(None);
163                }
164                Poll::Pending => {
165                    return Poll::Pending;
166                }
167            }
168        }
169    }
170}