openai_ergonomic/
streaming.rs

1//! Streaming support for chat completions.
2//!
3//! This module provides streaming functionality for `OpenAI` chat completions using
4//! Server-Sent Events (SSE). The streaming API allows receiving responses incrementally
5//! as they are generated, enabling real-time user experiences.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use openai_ergonomic::Client;
11//! use futures::StreamExt;
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
15//!     let client = Client::from_env()?.build();
16//!
17//!     let builder = client.chat().user("Tell me a story");
18//!     let mut stream = client.send_chat_stream(builder).await?;
19//!
20//!     while let Some(chunk) = stream.next().await {
21//!         let chunk = chunk?;
22//!         if let Some(content) = chunk.content() {
23//!             print!("{}", content);
24//!         }
25//!     }
26//!
27//!     Ok(())
28//! }
29//! ```
30//!
31//! # Interceptor Support
32//!
33//! Streaming responses automatically work with interceptors. When interceptors are
34//! configured, they will receive hooks for:
35//! - `before_request`: Called before streaming starts
36//! - `on_stream_chunk`: Called for each chunk as it arrives
37//! - `on_stream_end`: Called when streaming completes
38//!
39//! See the `langfuse_streaming` example for a complete demonstration.
40
41use crate::interceptor::{StreamChunkContext, StreamEndContext};
42use crate::{Error, Result};
43use bytes::Bytes;
44use futures::stream::Stream;
45use futures::StreamExt;
46use openai_client_base::models::{
47    ChatCompletionStreamResponseDelta, CreateChatCompletionStreamResponse,
48};
49use std::pin::Pin;
50use std::sync::Arc;
51use std::task::{Context, Poll};
52use std::time::Instant;
53
54/// Type alias for a boxed stream of chat completion chunks.
55///
56/// This allows returning either `ChatCompletionStream` or `InterceptedStream`
57/// from the same method based on whether interceptors are enabled.
58pub type BoxedChatStream = Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>;
59
60/// A streaming chunk from a chat completion response.
61///
62/// Each chunk represents a delta update from the model as it generates the response.
63#[derive(Debug, Clone)]
64pub struct ChatCompletionChunk {
65    /// The underlying stream response
66    response: CreateChatCompletionStreamResponse,
67}
68
69impl ChatCompletionChunk {
70    /// Create a new chunk from a stream response.
71    #[must_use]
72    pub fn new(response: CreateChatCompletionStreamResponse) -> Self {
73        Self { response }
74    }
75
76    /// Get the content delta from this chunk, if any.
77    ///
78    /// Returns the text content that was generated in this chunk.
79    #[must_use]
80    pub fn content(&self) -> Option<&str> {
81        self.response
82            .choices
83            .first()
84            .and_then(|choice| choice.delta.content.as_ref().and_then(|c| c.as_deref()))
85    }
86
87    /// Get the role from this chunk, if any.
88    ///
89    /// This is typically only present in the first chunk.
90    #[must_use]
91    pub fn role(&self) -> Option<&str> {
92        self.response
93            .choices
94            .first()
95            .and_then(|choice| choice.delta.role.as_ref())
96            .map(|role| match role {
97                openai_client_base::models::chat_completion_stream_response_delta::Role::System => {
98                    "system"
99                }
100                openai_client_base::models::chat_completion_stream_response_delta::Role::User => {
101                    "user"
102                }
103                openai_client_base::models::chat_completion_stream_response_delta::Role::Assistant => {
104                    "assistant"
105                }
106                openai_client_base::models::chat_completion_stream_response_delta::Role::Tool => {
107                    "tool"
108                }
109                openai_client_base::models::chat_completion_stream_response_delta::Role::Developer => {
110                    "developer"
111                }
112            })
113    }
114
115    /// Get tool calls from this chunk, if any.
116    #[must_use]
117    pub fn tool_calls(
118        &self,
119    ) -> Option<&Vec<openai_client_base::models::ChatCompletionMessageToolCallChunk>> {
120        self.response
121            .choices
122            .first()
123            .and_then(|choice| choice.delta.tool_calls.as_ref())
124    }
125
126    /// Get the finish reason, if any.
127    ///
128    /// This indicates why the generation stopped and is only present in the last chunk.
129    #[must_use]
130    pub fn finish_reason(&self) -> Option<&str> {
131        self.response.choices.first().map(|choice| {
132            match &choice.finish_reason {
133                openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::Stop => "stop",
134                openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::Length => "length",
135                openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::ToolCalls => "tool_calls",
136                openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::ContentFilter => "content_filter",
137                openai_client_base::models::create_chat_completion_stream_response_choices_inner::FinishReason::FunctionCall => "function_call",
138            }
139        })
140    }
141
142    /// Check if this is the last chunk in the stream.
143    #[must_use]
144    pub fn is_final(&self) -> bool {
145        self.finish_reason().is_some()
146    }
147
148    /// Get the underlying raw response.
149    #[must_use]
150    pub fn raw_response(&self) -> &CreateChatCompletionStreamResponse {
151        &self.response
152    }
153
154    /// Get the delta object directly.
155    #[must_use]
156    pub fn delta(&self) -> Option<&ChatCompletionStreamResponseDelta> {
157        self.response
158            .choices
159            .first()
160            .map(|choice| choice.delta.as_ref())
161    }
162}
163
164/// A stream of chat completion chunks.
165///
166/// This stream yields `ChatCompletionChunk` items as the model generates the response.
167/// The stream ends when the model finishes generating or encounters an error.
168pub struct ChatCompletionStream {
169    inner: Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>,
170}
171
172impl ChatCompletionStream {
173    /// Create a new stream from a byte stream response.
174    ///
175    /// Parses Server-Sent Events (SSE) format and yields chat completion chunks.
176    pub fn new(response: reqwest::Response) -> Self {
177        let byte_stream = response.bytes_stream();
178        let stream = parse_sse_stream(byte_stream);
179        Self {
180            inner: Box::pin(stream),
181        }
182    }
183
184    /// Collect all remaining content from the stream into a single string.
185    ///
186    /// This is a convenience method that reads all chunks and concatenates their content.
187    pub async fn collect_content(mut self) -> Result<String> {
188        let mut content = String::new();
189        while let Some(chunk) = self.next().await {
190            let chunk = chunk?;
191            if let Some(text) = chunk.content() {
192                content.push_str(text);
193            }
194        }
195        Ok(content)
196    }
197}
198
199impl Stream for ChatCompletionStream {
200    type Item = Result<ChatCompletionChunk>;
201
202    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
203        self.inner.as_mut().poll_next(cx)
204    }
205}
206
207/// Wrapper for a stream that calls interceptor hooks.
208///
209/// This wrapper intercepts chunks as they flow through the stream and calls
210/// the appropriate interceptor methods for observability and telemetry.
211///
212/// This is the primary stream type returned by streaming methods when
213/// interceptors are enabled. It provides the same interface as `ChatCompletionStream`.
214pub struct InterceptedStream<T = ()> {
215    inner: Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>,
216    interceptors: Arc<crate::interceptor::InterceptorChain<T>>,
217    operation: String,
218    model: String,
219    request_json: String,
220    state: Arc<T>,
221    chunk_index: usize,
222    start_time: Instant,
223    total_input_tokens: Option<i64>,
224    total_output_tokens: Option<i64>,
225}
226
227impl<T: Send + Sync + 'static> InterceptedStream<T> {
228    /// Create a new intercepted stream.
229    pub fn new(
230        inner: ChatCompletionStream,
231        interceptors: Arc<crate::interceptor::InterceptorChain<T>>,
232        operation: String,
233        model: String,
234        request_json: String,
235        state: T,
236    ) -> Self {
237        Self {
238            inner: inner.inner,
239            interceptors,
240            operation,
241            model,
242            request_json,
243            state: Arc::new(state),
244            chunk_index: 0,
245            start_time: Instant::now(),
246            total_input_tokens: None,
247            total_output_tokens: None,
248        }
249    }
250
251    /// Collect all remaining content from the stream into a single string.
252    ///
253    /// This is a convenience method that reads all chunks and concatenates their content.
254    /// Interceptor hooks are still called for each chunk.
255    pub async fn collect_content(mut self) -> Result<String> {
256        let mut content = String::new();
257        while let Some(chunk) = self.next().await {
258            let chunk = chunk?;
259            if let Some(text) = chunk.content() {
260                content.push_str(text);
261            }
262        }
263        Ok(content)
264    }
265}
266
267impl<T: Send + Sync + 'static> Stream for InterceptedStream<T> {
268    type Item = Result<ChatCompletionChunk>;
269
270    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271        let this = &mut *self;
272
273        match this.inner.as_mut().poll_next(cx) {
274            Poll::Ready(Some(Ok(chunk))) => {
275                // Serialize chunk for interceptor
276                let chunk_json = serde_json::to_string(chunk.raw_response())
277                    .unwrap_or_else(|_| "{}".to_string());
278
279                // Update token counts if available
280                if let Some(usage) = &chunk.raw_response().usage {
281                    this.total_input_tokens = Some(i64::from(usage.prompt_tokens));
282                    this.total_output_tokens = Some(i64::from(usage.completion_tokens));
283                }
284
285                // Call on_stream_chunk hook (spawn to avoid blocking)
286                let interceptors = Arc::clone(&this.interceptors);
287                let operation = this.operation.clone();
288                let model = this.model.clone();
289                let request_json = this.request_json.clone();
290                let chunk_index = this.chunk_index;
291                let state = Arc::clone(&this.state);
292
293                tokio::spawn(async move {
294                    let ctx = StreamChunkContext {
295                        operation: &operation,
296                        model: &model,
297                        request_json: &request_json,
298                        chunk_json: &chunk_json,
299                        chunk_index,
300                        state: &*state,
301                    };
302                    let _ = interceptors.on_stream_chunk(&ctx).await;
303                });
304
305                this.chunk_index += 1;
306                Poll::Ready(Some(Ok(chunk)))
307            }
308            Poll::Ready(Some(Err(e))) => {
309                // Note: We skip error interceptor for streaming errors due to
310                // lifetime constraints with Error type. The error is still
311                // propagated to the caller.
312                // TODO: Consider adding error string serialization support
313
314                Poll::Ready(Some(Err(e)))
315            }
316            Poll::Ready(None) => {
317                // Stream ended - call on_stream_end hook
318                let interceptors = Arc::clone(&this.interceptors);
319                let operation = this.operation.clone();
320                let model = this.model.clone();
321                let request_json = this.request_json.clone();
322                let chunk_index = this.chunk_index;
323                let duration = this.start_time.elapsed();
324                let input_tokens = this.total_input_tokens;
325                let output_tokens = this.total_output_tokens;
326                let state = Arc::clone(&this.state);
327
328                tokio::spawn(async move {
329                    let ctx = StreamEndContext {
330                        operation: &operation,
331                        model: &model,
332                        request_json: &request_json,
333                        total_chunks: chunk_index,
334                        duration,
335                        input_tokens,
336                        output_tokens,
337                        state: &*state,
338                    };
339                    let _ = interceptors.on_stream_end(&ctx).await;
340                });
341
342                Poll::Ready(None)
343            }
344            Poll::Pending => Poll::Pending,
345        }
346    }
347}
348
349/// Parse an SSE (Server-Sent Events) stream into chat completion chunks.
350fn parse_sse_stream(
351    byte_stream: impl Stream<Item = reqwest::Result<Bytes>> + Send + 'static,
352) -> impl Stream<Item = Result<ChatCompletionChunk>> + Send {
353    let mut buffer = Vec::new();
354
355    byte_stream
356        .map(move |result| {
357            let bytes = result.map_err(|e| Error::StreamConnection {
358                message: format!("Stream connection error: {e}"),
359            })?;
360
361            buffer.extend_from_slice(&bytes);
362
363            // Process complete lines from buffer
364            let mut chunks = Vec::new();
365            while let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') {
366                let line_bytes = buffer.drain(..=newline_pos).collect::<Vec<u8>>();
367                let line = String::from_utf8_lossy(&line_bytes).trim().to_string();
368
369                if let Some(chunk) = parse_sse_line(&line)? {
370                    chunks.push(chunk);
371                }
372            }
373
374            Ok(chunks)
375        })
376        .flat_map(|result: Result<Vec<ChatCompletionChunk>>| match result {
377            Ok(chunks) => futures::stream::iter(chunks.into_iter().map(Ok)).left_stream(),
378            Err(e) => futures::stream::once(async move { Err(e) }).right_stream(),
379        })
380}
381
382/// Parse a single SSE line into a chat completion chunk.
383fn parse_sse_line(line: &str) -> Result<Option<ChatCompletionChunk>> {
384    // Skip empty lines and comments
385    if line.is_empty() || line.starts_with(':') {
386        return Ok(None);
387    }
388
389    // Handle SSE format: "data: {json}"
390    if let Some(data) = line.strip_prefix("data:").map(str::trim) {
391        // Check for [DONE] marker
392        if data == "[DONE]" {
393            return Ok(None);
394        }
395
396        // Parse JSON data - use Value first to handle null finish_reason
397        let mut value: serde_json::Value =
398            serde_json::from_str(data).map_err(|e| Error::StreamParsing {
399                message: format!("Failed to parse chunk JSON: {e}"),
400                chunk: data.to_string(),
401            })?;
402
403        // Workaround: Remove finish_reason if it's null, since base library
404        // doesn't properly handle Option<FinishReason>
405        if let Some(choices) = value.get_mut("choices").and_then(|c| c.as_array_mut()) {
406            for choice in choices {
407                if let Some(finish_reason) = choice.get("finish_reason") {
408                    if finish_reason.is_null() {
409                        // Set to default value instead of null
410                        choice["finish_reason"] = serde_json::json!("stop");
411                    }
412                }
413            }
414        }
415
416        let response: CreateChatCompletionStreamResponse =
417            serde_json::from_value(value).map_err(|e| Error::StreamParsing {
418                message: format!("Failed to deserialize chunk: {e}"),
419                chunk: data.to_string(),
420            })?;
421
422        return Ok(Some(ChatCompletionChunk::new(response)));
423    }
424
425    // Ignore other SSE fields (event:, id:, retry:)
426    Ok(None)
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_parse_sse_line_with_content() {
435        let line = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#;
436
437        let result = parse_sse_line(line).unwrap();
438        assert!(result.is_some());
439
440        let chunk = result.unwrap();
441        assert_eq!(chunk.content(), Some("Hello"));
442        assert_eq!(chunk.role(), Some("assistant"));
443    }
444
445    #[test]
446    fn test_parse_sse_line_done_marker() {
447        let line = "data: [DONE]";
448        let result = parse_sse_line(line).unwrap();
449        assert!(result.is_none());
450    }
451
452    #[test]
453    fn test_parse_sse_line_empty() {
454        let line = "";
455        let result = parse_sse_line(line).unwrap();
456        assert!(result.is_none());
457    }
458
459    #[test]
460    fn test_parse_sse_line_comment() {
461        let line = ": this is a comment";
462        let result = parse_sse_line(line).unwrap();
463        assert!(result.is_none());
464    }
465}