openai_ergonomic/
streaming.rs1use crate::interceptor::{StreamChunkContext, StreamEndContext};
42use crate::{Error, Result};
43use bytes::Bytes;
44use futures::stream::Stream;
45use futures::StreamExt;
46use openai_client_base::models::{
47 ChatCompletionStreamResponseDelta, CreateChatCompletionStreamResponse,
48};
49use std::pin::Pin;
50use std::sync::Arc;
51use std::task::{Context, Poll};
52use std::time::Instant;
53
54pub type BoxedChatStream = Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>;
59
60#[derive(Debug, Clone)]
64pub struct ChatCompletionChunk {
65 response: CreateChatCompletionStreamResponse,
67}
68
69impl ChatCompletionChunk {
70 #[must_use]
72 pub fn new(response: CreateChatCompletionStreamResponse) -> Self {
73 Self { response }
74 }
75
76 #[must_use]
80 pub fn content(&self) -> Option<&str> {
81 self.response
82 .choices
83 .first()
84 .and_then(|choice| choice.delta.content.as_ref().and_then(|c| c.as_deref()))
85 }
86
87 #[must_use]
91 pub fn role(&self) -> Option<&str> {
92 self.response
93 .choices
94 .first()
95 .and_then(|choice| choice.delta.role.as_ref())
96 .map(|role| match role {
97 openai_client_base::models::chat_completion_stream_response_delta::Role::System => {
98 "system"
99 }
100 openai_client_base::models::chat_completion_stream_response_delta::Role::User => {
101 "user"
102 }
103 openai_client_base::models::chat_completion_stream_response_delta::Role::Assistant => {
104 "assistant"
105 }
106 openai_client_base::models::chat_completion_stream_response_delta::Role::Tool => {
107 "tool"
108 }
109 openai_client_base::models::chat_completion_stream_response_delta::Role::Developer => {
110 "developer"
111 }
112 })
113 }
114
115 #[must_use]
117 pub fn tool_calls(
118 &self,
119 ) -> Option<&Vec<openai_client_base::models::ChatCompletionMessageToolCallChunk>> {
120 self.response
121 .choices
122 .first()
123 .and_then(|choice| choice.delta.tool_calls.as_ref())
124 }
125
126 #[must_use]
130 pub fn finish_reason(&self) -> Option<&str> {
131 self.response.choices.first().map(|choice| {
132 match &choice.finish_reason {
133 openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::Stop => "stop",
134 openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::Length => "length",
135 openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::ToolCalls => "tool_calls",
136 openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::ContentFilter => "content_filter",
137 openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::FunctionCall => "function_call",
138 }
139 })
140 }
141
142 #[must_use]
144 pub fn is_final(&self) -> bool {
145 self.finish_reason().is_some()
146 }
147
148 #[must_use]
150 pub fn raw_response(&self) -> &CreateChatCompletionStreamResponse {
151 &self.response
152 }
153
154 #[must_use]
156 pub fn delta(&self) -> Option<&ChatCompletionStreamResponseDelta> {
157 self.response
158 .choices
159 .first()
160 .map(|choice| choice.delta.as_ref())
161 }
162}
163
164pub struct ChatCompletionStream {
169 inner: Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>,
170}
171
172impl ChatCompletionStream {
173 pub fn new(response: reqwest::Response) -> Self {
177 let byte_stream = response.bytes_stream();
178 let stream = parse_sse_stream(byte_stream);
179 Self {
180 inner: Box::pin(stream),
181 }
182 }
183
184 pub async fn collect_content(mut self) -> Result<String> {
188 let mut content = String::new();
189 while let Some(chunk) = self.next().await {
190 let chunk = chunk?;
191 if let Some(text) = chunk.content() {
192 content.push_str(text);
193 }
194 }
195 Ok(content)
196 }
197}
198
199impl Stream for ChatCompletionStream {
200 type Item = Result<ChatCompletionChunk>;
201
202 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
203 self.inner.as_mut().poll_next(cx)
204 }
205}
206
207pub struct InterceptedStream<T = ()> {
215 inner: Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>,
216 interceptors: Arc<crate::interceptor::InterceptorChain<T>>,
217 operation: String,
218 model: String,
219 request_json: String,
220 state: Arc<T>,
221 chunk_index: usize,
222 start_time: Instant,
223 total_input_tokens: Option<i64>,
224 total_output_tokens: Option<i64>,
225}
226
227impl<T: Send + Sync + 'static> InterceptedStream<T> {
228 pub fn new(
230 inner: ChatCompletionStream,
231 interceptors: Arc<crate::interceptor::InterceptorChain<T>>,
232 operation: String,
233 model: String,
234 request_json: String,
235 state: T,
236 ) -> Self {
237 Self {
238 inner: inner.inner,
239 interceptors,
240 operation,
241 model,
242 request_json,
243 state: Arc::new(state),
244 chunk_index: 0,
245 start_time: Instant::now(),
246 total_input_tokens: None,
247 total_output_tokens: None,
248 }
249 }
250
251 pub async fn collect_content(mut self) -> Result<String> {
256 let mut content = String::new();
257 while let Some(chunk) = self.next().await {
258 let chunk = chunk?;
259 if let Some(text) = chunk.content() {
260 content.push_str(text);
261 }
262 }
263 Ok(content)
264 }
265}
266
267impl<T: Send + Sync + 'static> Stream for InterceptedStream<T> {
268 type Item = Result<ChatCompletionChunk>;
269
270 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271 let this = &mut *self;
272
273 match this.inner.as_mut().poll_next(cx) {
274 Poll::Ready(Some(Ok(chunk))) => {
275 let chunk_json = serde_json::to_string(chunk.raw_response())
277 .unwrap_or_else(|_| "{}".to_string());
278
279 if let Some(usage) = &chunk.raw_response().usage {
281 this.total_input_tokens = Some(i64::from(usage.prompt_tokens));
282 this.total_output_tokens = Some(i64::from(usage.completion_tokens));
283 }
284
285 let interceptors = Arc::clone(&this.interceptors);
287 let operation = this.operation.clone();
288 let model = this.model.clone();
289 let request_json = this.request_json.clone();
290 let chunk_index = this.chunk_index;
291 let state = Arc::clone(&this.state);
292
293 tokio::spawn(async move {
294 let ctx = StreamChunkContext {
295 operation: &operation,
296 model: &model,
297 request_json: &request_json,
298 chunk_json: &chunk_json,
299 chunk_index,
300 state: &*state,
301 };
302 let _ = interceptors.on_stream_chunk(&ctx).await;
303 });
304
305 this.chunk_index += 1;
306 Poll::Ready(Some(Ok(chunk)))
307 }
308 Poll::Ready(Some(Err(e))) => {
309 Poll::Ready(Some(Err(e)))
315 }
316 Poll::Ready(None) => {
317 let interceptors = Arc::clone(&this.interceptors);
319 let operation = this.operation.clone();
320 let model = this.model.clone();
321 let request_json = this.request_json.clone();
322 let chunk_index = this.chunk_index;
323 let duration = this.start_time.elapsed();
324 let input_tokens = this.total_input_tokens;
325 let output_tokens = this.total_output_tokens;
326 let state = Arc::clone(&this.state);
327
328 tokio::spawn(async move {
329 let ctx = StreamEndContext {
330 operation: &operation,
331 model: &model,
332 request_json: &request_json,
333 total_chunks: chunk_index,
334 duration,
335 input_tokens,
336 output_tokens,
337 state: &*state,
338 };
339 let _ = interceptors.on_stream_end(&ctx).await;
340 });
341
342 Poll::Ready(None)
343 }
344 Poll::Pending => Poll::Pending,
345 }
346 }
347}
348
349fn parse_sse_stream(
351 byte_stream: impl Stream<Item = reqwest::Result<Bytes>> + Send + 'static,
352) -> impl Stream<Item = Result<ChatCompletionChunk>> + Send {
353 let mut buffer = Vec::new();
354
355 byte_stream
356 .map(move |result| {
357 let bytes = result.map_err(|e| Error::StreamConnection {
358 message: format!("Stream connection error: {e}"),
359 })?;
360
361 buffer.extend_from_slice(&bytes);
362
363 let mut chunks = Vec::new();
365 while let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') {
366 let line_bytes = buffer.drain(..=newline_pos).collect::<Vec<u8>>();
367 let line = String::from_utf8_lossy(&line_bytes).trim().to_string();
368
369 if let Some(chunk) = parse_sse_line(&line)? {
370 chunks.push(chunk);
371 }
372 }
373
374 Ok(chunks)
375 })
376 .flat_map(|result: Result<Vec<ChatCompletionChunk>>| match result {
377 Ok(chunks) => futures::stream::iter(chunks.into_iter().map(Ok)).left_stream(),
378 Err(e) => futures::stream::once(async move { Err(e) }).right_stream(),
379 })
380}
381
382fn parse_sse_line(line: &str) -> Result<Option<ChatCompletionChunk>> {
384 if line.is_empty() || line.starts_with(':') {
386 return Ok(None);
387 }
388
389 if let Some(data) = line.strip_prefix("data:").map(str::trim) {
391 if data == "[DONE]" {
393 return Ok(None);
394 }
395
396 let mut value: serde_json::Value =
398 serde_json::from_str(data).map_err(|e| Error::StreamParsing {
399 message: format!("Failed to parse chunk JSON: {e}"),
400 chunk: data.to_string(),
401 })?;
402
403 if let Some(choices) = value.get_mut("choices").and_then(|c| c.as_array_mut()) {
406 for choice in choices {
407 if let Some(finish_reason) = choice.get("finish_reason") {
408 if finish_reason.is_null() {
409 choice["finish_reason"] = serde_json::json!("stop");
411 }
412 }
413 }
414 }
415
416 let response: CreateChatCompletionStreamResponse =
417 serde_json::from_value(value).map_err(|e| Error::StreamParsing {
418 message: format!("Failed to deserialize chunk: {e}"),
419 chunk: data.to_string(),
420 })?;
421
422 return Ok(Some(ChatCompletionChunk::new(response)));
423 }
424
425 Ok(None)
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_parse_sse_line_with_content() {
435 let line = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#;
436
437 let result = parse_sse_line(line).unwrap();
438 assert!(result.is_some());
439
440 let chunk = result.unwrap();
441 assert_eq!(chunk.content(), Some("Hello"));
442 assert_eq!(chunk.role(), Some("assistant"));
443 }
444
445 #[test]
446 fn test_parse_sse_line_done_marker() {
447 let line = "data: [DONE]";
448 let result = parse_sse_line(line).unwrap();
449 assert!(result.is_none());
450 }
451
452 #[test]
453 fn test_parse_sse_line_empty() {
454 let line = "";
455 let result = parse_sse_line(line).unwrap();
456 assert!(result.is_none());
457 }
458
459 #[test]
460 fn test_parse_sse_line_comment() {
461 let line = ": this is a comment";
462 let result = parse_sse_line(line).unwrap();
463 assert!(result.is_none());
464 }
465}