openai_api_rs/v1/chat_completion/
chat_completion_stream.rs

1use crate::v1::chat_completion::{Reasoning, Tool, ToolCall, ToolChoiceType};
2use crate::{
3    impl_builder_methods,
4    v1::chat_completion::{serialize_tool_choice, ChatCompletionMessage},
5};
6
7use futures_util::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14#[derive(Debug, Serialize, Deserialize, Clone)]
15pub struct ChatCompletionStreamRequest {
16    pub model: String,
17    pub messages: Vec<ChatCompletionMessage>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub temperature: Option<f64>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub top_p: Option<f64>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub n: Option<i64>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub response_format: Option<Value>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stop: Option<Vec<String>>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub max_tokens: Option<i64>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub presence_penalty: Option<f64>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub frequency_penalty: Option<f64>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub logit_bias: Option<HashMap<String, i32>>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub user: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub seed: Option<i64>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub tools: Option<Vec<Tool>>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub parallel_tool_calls: Option<bool>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    #[serde(serialize_with = "serialize_tool_choice")]
46    pub tool_choice: Option<ToolChoiceType>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub reasoning: Option<Reasoning>,
49    /// Optional list of transforms to apply to the chat completion request.
50    ///
51    /// Transforms allow modifying the request before it's sent to the API,
52    /// enabling features like prompt rewriting, content filtering, or other
53    /// preprocessing steps. When None, no transforms are applied.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub transforms: Option<Vec<String>>,
56}
57
58impl ChatCompletionStreamRequest {
59    pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
60        Self {
61            model,
62            messages,
63            temperature: None,
64            top_p: None,
65            n: None,
66            response_format: None,
67            stop: None,
68            max_tokens: None,
69            presence_penalty: None,
70            frequency_penalty: None,
71            logit_bias: None,
72            user: None,
73            seed: None,
74            tools: None,
75            parallel_tool_calls: None,
76            tool_choice: None,
77            reasoning: None,
78            transforms: None,
79        }
80    }
81}
82
83impl_builder_methods!(
84    ChatCompletionStreamRequest,
85    temperature: f64,
86    top_p: f64,
87    n: i64,
88    response_format: Value,
89    stop: Vec<String>,
90    max_tokens: i64,
91    presence_penalty: f64,
92    frequency_penalty: f64,
93    logit_bias: HashMap<String, i32>,
94    user: String,
95    seed: i64,
96    tools: Vec<Tool>,
97    parallel_tool_calls: bool,
98    tool_choice: ToolChoiceType,
99    reasoning: Reasoning,
100    transforms: Vec<String>
101);
102
103#[derive(Debug, Clone)]
104pub enum ChatCompletionStreamResponse {
105    Content(String),
106    ToolCall(Vec<ToolCall>),
107    Done,
108}
109
110pub struct ChatCompletionStream<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> {
111    pub response: S,
112    pub buffer: String,
113    pub first_chunk: bool,
114}
115
116impl<S> ChatCompletionStream<S>
117where
118    S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
119{
120    fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> {
121        let carriage_idx = buffer.find("\r\n\r\n");
122        let newline_idx = buffer.find("\n\n");
123
124        match (carriage_idx, newline_idx) {
125            (Some(r_idx), Some(n_idx)) => {
126                if r_idx <= n_idx {
127                    Some((r_idx, 4))
128                } else {
129                    Some((n_idx, 2))
130                }
131            }
132            (Some(r_idx), None) => Some((r_idx, 4)),
133            (None, Some(n_idx)) => Some((n_idx, 2)),
134            (None, None) => None,
135        }
136    }
137
138    fn next_response_from_buffer(&mut self) -> Option<ChatCompletionStreamResponse> {
139        while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) {
140            let event = self.buffer[..idx].to_owned();
141            self.buffer = self.buffer[idx + delimiter_len..].to_owned();
142
143            let mut data_payload = String::new();
144            for line in event.lines() {
145                let trimmed_line = line.trim_end_matches('\r');
146                if let Some(content) = trimmed_line
147                    .strip_prefix("data: ")
148                    .or_else(|| trimmed_line.strip_prefix("data:"))
149                {
150                    if !content.is_empty() {
151                        if !data_payload.is_empty() {
152                            data_payload.push('\n');
153                        }
154                        data_payload.push_str(content);
155                    }
156                }
157            }
158
159            if data_payload.is_empty() {
160                continue;
161            }
162
163            if data_payload == "[DONE]" {
164                return Some(ChatCompletionStreamResponse::Done);
165            }
166
167            match serde_json::from_str::<Value>(&data_payload) {
168                Ok(json) => {
169                    if let Some(delta) = json
170                        .get("choices")
171                        .and_then(|choices| choices.get(0))
172                        .and_then(|choice| choice.get("delta"))
173                    {
174                        if let Some(tool_call_response) = delta
175                            .get("tool_calls")
176                            .and_then(|tool_calls| tool_calls.as_array())
177                            .map(|tool_calls_array| {
178                                tool_calls_array
179                                    .iter()
180                                    .filter_map(|v| serde_json::from_value(v.clone()).ok())
181                                    .collect::<Vec<ToolCall>>()
182                            })
183                            .filter(|tool_calls_vec| !tool_calls_vec.is_empty())
184                            .map(ChatCompletionStreamResponse::ToolCall)
185                        {
186                            return Some(tool_call_response);
187                        }
188
189                        if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
190                            let output = content.replace("\\n", "\n");
191                            return Some(ChatCompletionStreamResponse::Content(output));
192                        }
193                    }
194                }
195                Err(error) => {
196                    eprintln!("Failed to parse SSE chunk as JSON: {}", error);
197                }
198            }
199        }
200
201        None
202    }
203}
204
205impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream
206    for ChatCompletionStream<S>
207{
208    type Item = ChatCompletionStreamResponse;
209
210    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
211        loop {
212            if let Some(response) = self.next_response_from_buffer() {
213                return Poll::Ready(Some(response));
214            }
215
216            match Pin::new(&mut self.as_mut().response).poll_next(cx) {
217                Poll::Ready(Some(Ok(chunk))) => {
218                    let chunk_str = String::from_utf8_lossy(&chunk).to_string();
219
220                    if self.first_chunk {
221                        self.first_chunk = false;
222                    }
223                    self.buffer.push_str(&chunk_str);
224                }
225                Poll::Ready(Some(Err(error))) => {
226                    eprintln!("Error in stream: {:?}", error);
227                    return Poll::Ready(None);
228                }
229                Poll::Ready(None) => {
230                    return Poll::Ready(None);
231                }
232                Poll::Pending => {
233                    return Poll::Pending;
234                }
235            }
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use crate::v1::chat_completion::{ReasoningEffort, ReasoningMode};
243
244    use super::*;
245    use serde_json::json;
246
247    #[test]
248    fn test_reasoning_effort_serialization() {
249        let reasoning = Reasoning {
250            mode: Some(ReasoningMode::Effort {
251                effort: ReasoningEffort::High,
252            }),
253            exclude: Some(false),
254            enabled: None,
255        };
256
257        let serialized = serde_json::to_value(&reasoning).unwrap();
258        let expected = json!({
259            "effort": "high",
260            "exclude": false
261        });
262
263        assert_eq!(serialized, expected);
264    }
265
266    #[test]
267    fn test_reasoning_max_tokens_serialization() {
268        let reasoning = Reasoning {
269            mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }),
270            exclude: None,
271            enabled: Some(true),
272        };
273
274        let serialized = serde_json::to_value(&reasoning).unwrap();
275        let expected = json!({
276            "max_tokens": 2000,
277            "enabled": true
278        });
279
280        assert_eq!(serialized, expected);
281    }
282
283    #[test]
284    fn test_reasoning_deserialization() {
285        let json_str = r#"{"effort": "medium", "exclude": true}"#;
286        let reasoning: Reasoning = serde_json::from_str(json_str).unwrap();
287
288        match reasoning.mode {
289            Some(ReasoningMode::Effort { effort }) => {
290                assert_eq!(effort, ReasoningEffort::Medium);
291            }
292            _ => panic!("Expected effort mode"),
293        }
294        assert_eq!(reasoning.exclude, Some(true));
295    }
296
297    #[test]
298    fn test_chat_completion_request_with_reasoning() {
299        let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]);
300
301        req.reasoning = Some(Reasoning {
302            mode: Some(ReasoningMode::Effort {
303                effort: ReasoningEffort::Low,
304            }),
305            exclude: None,
306            enabled: None,
307        });
308
309        let serialized = serde_json::to_value(&req).unwrap();
310        assert_eq!(serialized["reasoning"]["effort"], "low");
311    }
312
313    #[test]
314    fn test_transforms_none_serialization() {
315        let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]);
316        let serialised = serde_json::to_value(&req).unwrap();
317        // Verify that the transforms field is completely omitted from JSON output
318        assert!(!serialised.as_object().unwrap().contains_key("transforms"));
319    }
320
321    #[test]
322    fn test_transforms_some_serialization() {
323        let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]);
324        req.transforms = Some(vec!["transform1".to_string(), "transform2".to_string()]);
325        let serialised = serde_json::to_value(&req).unwrap();
326        // Verify that the transforms field is included as a proper JSON array
327        assert_eq!(
328            serialised["transforms"],
329            serde_json::json!(["transform1", "transform2"])
330        );
331    }
332
333    #[test]
334    fn test_transforms_some_deserialization() {
335        let json_str =
336            r#"{"model": "gpt-4", "messages": [], "transforms": ["transform1", "transform2"]}"#;
337        let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap();
338        // Verify that the transforms field is properly populated with Some(vec)
339        assert_eq!(
340            req.transforms,
341            Some(vec!["transform1".to_string(), "transform2".to_string()])
342        );
343    }
344
345    #[test]
346    fn test_transforms_none_deserialization() {
347        let json_str = r#"{"model": "gpt-4", "messages": []}"#;
348        let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap();
349        // Verify that the transforms field is properly set to None when absent
350        assert_eq!(req.transforms, None);
351    }
352
353    #[test]
354    fn test_transforms_builder_method() {
355        let transforms = vec!["transform1".to_string(), "transform2".to_string()];
356        let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![])
357            .transforms(transforms.clone());
358        // Verify that the transforms field is properly set through the builder method
359        assert_eq!(req.transforms, Some(transforms));
360    }
361}