rustyopenai/chat/response/
stream.rs

1use anyhow::Result;
2use tracing::error;
3use futures::{ Stream, StreamExt };
4use std::{ pin::Pin, task::{ Context, Poll } };
5use bytes::Bytes;
6use lazy_static::lazy_static;
7use regex::Regex;
8use super::ChatCompletionChunk;
9
10lazy_static! {
11    static ref STREAM_RESPONSE_CHUNK_RE: Regex = Regex::new(r"^data: \{.*\}\n\n").unwrap();
12    static ref STREAM_RESPONSE_TERMINATION_CHUNK_RE: Regex =
13        Regex::new(r"^data: \[DONE\]\n\n").unwrap();
14}
15
16fn extract_first_chunk(content: &str) -> Result<ExtractedChunkWithRemainingContent> {
17    if let Some(mat) = STREAM_RESPONSE_CHUNK_RE.find(content) {
18        // Matched string
19        let matched_str = mat.as_str();
20
21        // Remaining content
22        let remaining_content = &content[mat.end()..];
23
24        // Extract the JSON content
25        let json_content = matched_str.strip_prefix("data: ").unwrap();
26
27        // Parse the JSON content to ChatCompletionChunk
28        let chunk = serde_json::from_str(json_content)?;
29
30        // Return the extracted response and remaining content
31        Ok(ExtractedChunkWithRemainingContent {
32            chunk: Some(chunk),
33
34            // If there is no remaining content, return None
35            remaining_content: match remaining_content.len() {
36                0 => None,
37                _ => Some(remaining_content.to_string()),
38            },
39        })
40    } else if let Some(_) = STREAM_RESPONSE_TERMINATION_CHUNK_RE.find(content) {
41        // The current chunk is the termination chunk, "data: [DONE]\n\n", from OpenAI
42        Ok(ExtractedChunkWithRemainingContent { chunk: None, remaining_content: None })
43    } else {
44        // There is no response in the json_content,
45        // so return None for the response and the json_content as the remaining content
46        Ok(ExtractedChunkWithRemainingContent {
47            chunk: None,
48            remaining_content: match content.len() {
49                0 => None,
50                _ => Some(content.to_string()),
51            },
52        })
53    }
54}
55
56#[derive(Debug)]
57pub struct ExtractedChunkWithRemainingContent {
58    pub chunk: Option<ChatCompletionChunk>,
59    pub remaining_content: Option<String>,
60}
61
62pub struct ChatCompletionStream<S> {
63    response_bytes_stream: S,
64    remaining_content: Option<String>,
65}
66
67impl<S> ChatCompletionStream<S> where S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin {
68    pub fn new(response_bytes_stream: S) -> Self {
69        Self {
70            response_bytes_stream,
71            remaining_content: None,
72        }
73    }
74}
75
76impl<S> Stream
77    for ChatCompletionStream<S>
78    where S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin
79{
80    type Item = ChatCompletionChunk;
81
82    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
83        loop {
84            match self.response_bytes_stream.poll_next_unpin(cx) {
85                Poll::Ready(Some(Ok(bytes))) => {
86                    // Convert bytes to string
87                    let mut content = std::str::from_utf8(&bytes).unwrap().to_string();
88
89                    // Concatenate the remaining content if there is any
90                    if let Some(remaining_content) = &self.remaining_content {
91                        // Put the remaining content in front of the currently received content
92                        content.insert_str(0, remaining_content);
93                    }
94
95                    // Extract the first response and remaining content
96                    let response_with_remaining_content = extract_first_chunk(&content).unwrap();
97
98                    // Get the response and remaining content
99                    let chunk = response_with_remaining_content.chunk;
100                    let remaining_content = response_with_remaining_content.remaining_content;
101
102                    // Collect the remaining content if there is any
103                    if let Some(remaining_content) = &remaining_content {
104                        self.remaining_content = Some(remaining_content.to_owned());
105                    } else {
106                        self.remaining_content = None;
107                    }
108
109                    // Return Ready if there is one response
110                    if let Some(chunk) = chunk {
111                        return Poll::Ready(Some(chunk));
112                    } else if remaining_content.is_none() {
113                        // If both the chunk and the remaining content are None,
114                        // then it means that the current chunk
115                        // is the termination chunk, "data: [DONE]\n\n", from OpenAI
116                        return Poll::Ready(None);
117                    } else {
118                        continue;
119                        // return Poll::Pending;
120                    }
121                }
122                Poll::Ready(Some(Err(e))) => {
123                    error!("The response is incomplete - {}", e);
124                    return Poll::Ready(None);
125                }
126                Poll::Ready(None) => {
127                    // All the chunks from OpenAI have been received
128                    // But chances are that there are still some remaining content
129                    // that has not been parsed yet
130                    // So we handle the remaining content here
131                    if let Some(remaining_content) = &self.remaining_content {
132                        // Extract the first response and remaining content
133                        let response_with_remaining_content = extract_first_chunk(
134                            &remaining_content
135                        ).unwrap();
136
137                        // Get the response and remaining content
138                        let chunk = response_with_remaining_content.chunk;
139                        let remaining_content = response_with_remaining_content.remaining_content;
140
141                        // Collect the remaining content if there is any
142                        if let Some(remaining_content) = &remaining_content {
143                            self.remaining_content = Some(remaining_content.to_owned());
144                        } else {
145                            self.remaining_content = None;
146                        }
147
148                        // Return Ready if there is one response
149                        if let Some(chunk) = chunk {
150                            return Poll::Ready(Some(chunk));
151                        } else if remaining_content.is_none() {
152                            // If both the chunk and the remaining content are None,
153                            // then it means that the current chunk
154                            // is the termination chunk, "data: [DONE]\n\n", from OpenAI
155                            return Poll::Ready(None);
156                        } else {
157                            continue;
158                        }
159                    }
160
161                    // If there is no remaining content, then return None
162                    return Poll::Ready(None);
163                }
164                Poll::Pending => {
165                    return Poll::Pending;
166                }
167            }
168        }
169    }
170}