openai_api_rs/v1/chat_completion/
chat_completion_stream.rs1use crate::v1::chat_completion::{Reasoning, Tool, ToolCall, ToolChoiceType};
2use crate::{
3 impl_builder_methods,
4 v1::chat_completion::{serialize_tool_choice, ChatCompletionMessage},
5};
6
7use futures_util::Stream;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14#[derive(Debug, Serialize, Deserialize, Clone)]
15pub struct ChatCompletionStreamRequest {
16 pub model: String,
17 pub messages: Vec<ChatCompletionMessage>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub temperature: Option<f64>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub top_p: Option<f64>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub n: Option<i64>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub response_format: Option<Value>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stop: Option<Vec<String>>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub max_tokens: Option<i64>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub presence_penalty: Option<f64>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub frequency_penalty: Option<f64>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub logit_bias: Option<HashMap<String, i32>>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub user: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub seed: Option<i64>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub tools: Option<Vec<Tool>>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub parallel_tool_calls: Option<bool>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 #[serde(serialize_with = "serialize_tool_choice")]
46 pub tool_choice: Option<ToolChoiceType>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub reasoning: Option<Reasoning>,
49 #[serde(skip_serializing_if = "Option::is_none")]
55 pub transforms: Option<Vec<String>>,
56}
57
58impl ChatCompletionStreamRequest {
59 pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
60 Self {
61 model,
62 messages,
63 temperature: None,
64 top_p: None,
65 n: None,
66 response_format: None,
67 stop: None,
68 max_tokens: None,
69 presence_penalty: None,
70 frequency_penalty: None,
71 logit_bias: None,
72 user: None,
73 seed: None,
74 tools: None,
75 parallel_tool_calls: None,
76 tool_choice: None,
77 reasoning: None,
78 transforms: None,
79 }
80 }
81}
82
83impl_builder_methods!(
84 ChatCompletionStreamRequest,
85 temperature: f64,
86 top_p: f64,
87 n: i64,
88 response_format: Value,
89 stop: Vec<String>,
90 max_tokens: i64,
91 presence_penalty: f64,
92 frequency_penalty: f64,
93 logit_bias: HashMap<String, i32>,
94 user: String,
95 seed: i64,
96 tools: Vec<Tool>,
97 parallel_tool_calls: bool,
98 tool_choice: ToolChoiceType,
99 reasoning: Reasoning,
100 transforms: Vec<String>
101);
102
103#[derive(Debug, Clone)]
104pub enum ChatCompletionStreamResponse {
105 Content(String),
106 ToolCall(Vec<ToolCall>),
107 Done,
108}
109
110pub struct ChatCompletionStream<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> {
111 pub response: S,
112 pub buffer: String,
113 pub first_chunk: bool,
114}
115
116impl<S> ChatCompletionStream<S>
117where
118 S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
119{
120 fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> {
121 let carriage_idx = buffer.find("\r\n\r\n");
122 let newline_idx = buffer.find("\n\n");
123
124 match (carriage_idx, newline_idx) {
125 (Some(r_idx), Some(n_idx)) => {
126 if r_idx <= n_idx {
127 Some((r_idx, 4))
128 } else {
129 Some((n_idx, 2))
130 }
131 }
132 (Some(r_idx), None) => Some((r_idx, 4)),
133 (None, Some(n_idx)) => Some((n_idx, 2)),
134 (None, None) => None,
135 }
136 }
137
138 fn next_response_from_buffer(&mut self) -> Option<ChatCompletionStreamResponse> {
139 while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) {
140 let event = self.buffer[..idx].to_owned();
141 self.buffer = self.buffer[idx + delimiter_len..].to_owned();
142
143 let mut data_payload = String::new();
144 for line in event.lines() {
145 let trimmed_line = line.trim_end_matches('\r');
146 if let Some(content) = trimmed_line
147 .strip_prefix("data: ")
148 .or_else(|| trimmed_line.strip_prefix("data:"))
149 {
150 if !content.is_empty() {
151 if !data_payload.is_empty() {
152 data_payload.push('\n');
153 }
154 data_payload.push_str(content);
155 }
156 }
157 }
158
159 if data_payload.is_empty() {
160 continue;
161 }
162
163 if data_payload == "[DONE]" {
164 return Some(ChatCompletionStreamResponse::Done);
165 }
166
167 match serde_json::from_str::<Value>(&data_payload) {
168 Ok(json) => {
169 if let Some(delta) = json
170 .get("choices")
171 .and_then(|choices| choices.get(0))
172 .and_then(|choice| choice.get("delta"))
173 {
174 if let Some(tool_call_response) = delta
175 .get("tool_calls")
176 .and_then(|tool_calls| tool_calls.as_array())
177 .map(|tool_calls_array| {
178 tool_calls_array
179 .iter()
180 .filter_map(|v| serde_json::from_value(v.clone()).ok())
181 .collect::<Vec<ToolCall>>()
182 })
183 .filter(|tool_calls_vec| !tool_calls_vec.is_empty())
184 .map(ChatCompletionStreamResponse::ToolCall)
185 {
186 return Some(tool_call_response);
187 }
188
189 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
190 let output = content.replace("\\n", "\n");
191 return Some(ChatCompletionStreamResponse::Content(output));
192 }
193 }
194 }
195 Err(error) => {
196 eprintln!("Failed to parse SSE chunk as JSON: {}", error);
197 }
198 }
199 }
200
201 None
202 }
203}
204
205impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream
206 for ChatCompletionStream<S>
207{
208 type Item = ChatCompletionStreamResponse;
209
210 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
211 loop {
212 if let Some(response) = self.next_response_from_buffer() {
213 return Poll::Ready(Some(response));
214 }
215
216 match Pin::new(&mut self.as_mut().response).poll_next(cx) {
217 Poll::Ready(Some(Ok(chunk))) => {
218 let chunk_str = String::from_utf8_lossy(&chunk).to_string();
219
220 if self.first_chunk {
221 self.first_chunk = false;
222 }
223 self.buffer.push_str(&chunk_str);
224 }
225 Poll::Ready(Some(Err(error))) => {
226 eprintln!("Error in stream: {:?}", error);
227 return Poll::Ready(None);
228 }
229 Poll::Ready(None) => {
230 return Poll::Ready(None);
231 }
232 Poll::Pending => {
233 return Poll::Pending;
234 }
235 }
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use crate::v1::chat_completion::{ReasoningEffort, ReasoningMode};
243
244 use super::*;
245 use serde_json::json;
246
247 #[test]
248 fn test_reasoning_effort_serialization() {
249 let reasoning = Reasoning {
250 mode: Some(ReasoningMode::Effort {
251 effort: ReasoningEffort::High,
252 }),
253 exclude: Some(false),
254 enabled: None,
255 };
256
257 let serialized = serde_json::to_value(&reasoning).unwrap();
258 let expected = json!({
259 "effort": "high",
260 "exclude": false
261 });
262
263 assert_eq!(serialized, expected);
264 }
265
266 #[test]
267 fn test_reasoning_max_tokens_serialization() {
268 let reasoning = Reasoning {
269 mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }),
270 exclude: None,
271 enabled: Some(true),
272 };
273
274 let serialized = serde_json::to_value(&reasoning).unwrap();
275 let expected = json!({
276 "max_tokens": 2000,
277 "enabled": true
278 });
279
280 assert_eq!(serialized, expected);
281 }
282
283 #[test]
284 fn test_reasoning_deserialization() {
285 let json_str = r#"{"effort": "medium", "exclude": true}"#;
286 let reasoning: Reasoning = serde_json::from_str(json_str).unwrap();
287
288 match reasoning.mode {
289 Some(ReasoningMode::Effort { effort }) => {
290 assert_eq!(effort, ReasoningEffort::Medium);
291 }
292 _ => panic!("Expected effort mode"),
293 }
294 assert_eq!(reasoning.exclude, Some(true));
295 }
296
297 #[test]
298 fn test_chat_completion_request_with_reasoning() {
299 let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]);
300
301 req.reasoning = Some(Reasoning {
302 mode: Some(ReasoningMode::Effort {
303 effort: ReasoningEffort::Low,
304 }),
305 exclude: None,
306 enabled: None,
307 });
308
309 let serialized = serde_json::to_value(&req).unwrap();
310 assert_eq!(serialized["reasoning"]["effort"], "low");
311 }
312
313 #[test]
314 fn test_transforms_none_serialization() {
315 let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]);
316 let serialised = serde_json::to_value(&req).unwrap();
317 assert!(!serialised.as_object().unwrap().contains_key("transforms"));
319 }
320
321 #[test]
322 fn test_transforms_some_serialization() {
323 let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]);
324 req.transforms = Some(vec!["transform1".to_string(), "transform2".to_string()]);
325 let serialised = serde_json::to_value(&req).unwrap();
326 assert_eq!(
328 serialised["transforms"],
329 serde_json::json!(["transform1", "transform2"])
330 );
331 }
332
333 #[test]
334 fn test_transforms_some_deserialization() {
335 let json_str =
336 r#"{"model": "gpt-4", "messages": [], "transforms": ["transform1", "transform2"]}"#;
337 let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap();
338 assert_eq!(
340 req.transforms,
341 Some(vec!["transform1".to_string(), "transform2".to_string()])
342 );
343 }
344
345 #[test]
346 fn test_transforms_none_deserialization() {
347 let json_str = r#"{"model": "gpt-4", "messages": []}"#;
348 let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap();
349 assert_eq!(req.transforms, None);
351 }
352
353 #[test]
354 fn test_transforms_builder_method() {
355 let transforms = vec!["transform1".to_string(), "transform2".to_string()];
356 let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![])
357 .transforms(transforms.clone());
358 assert_eq!(req.transforms, Some(transforms));
360 }
361}