Skip to main content

litellm_rs/
stream.rs

1use crate::error::{LiteLLMError, Result};
2use crate::http::MAX_SSE_BUFFER_SIZE;
3use crate::types::Usage;
4use bytes::Bytes;
5use futures_util::stream::{Stream, StreamExt, TryStreamExt};
6use serde_json::Value;
7use std::pin::Pin;
8use tokio::io::{AsyncBufReadExt, BufReader};
9use tokio_util::io::StreamReader;
10
11#[derive(Debug, Clone)]
12pub struct ChatStreamChunk {
13    pub content: String,
14    pub raw: Option<Value>,
15    pub usage: Option<Usage>,
16}
17
18pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatStreamChunk>> + Send>>;
19
20#[derive(Debug, Clone)]
21struct SseEvent {
22    event: Option<String>,
23    data: String,
24}
25
26type SseEventStream = Pin<Box<dyn Stream<Item = Result<SseEvent>> + Send>>;
27
28fn sse_event_stream<S>(stream: S) -> SseEventStream
29where
30    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
31{
32    let s = async_stream::try_stream! {
33        let stream = stream.map_err(std::io::Error::other);
34        let reader = StreamReader::new(stream);
35        let mut lines = BufReader::new(reader).lines();
36
37        let mut event_name: Option<String> = None;
38        let mut data_buf = String::new();
39
40        while let Some(line) = lines.next_line().await.map_err(LiteLLMError::from)? {
41            if line.is_empty() {
42                if !data_buf.is_empty() {
43                    let data = std::mem::take(&mut data_buf);
44                    let event = event_name.take();
45                    yield SseEvent { event, data };
46                } else {
47                    event_name = None;
48                }
49                continue;
50            }
51
52            if line.starts_with(':') {
53                continue;
54            }
55
56            let (field, value) = if let Some((field, value)) = line.split_once(':') {
57                (field, value.strip_prefix(' ').unwrap_or(value))
58            } else {
59                (line.as_str(), "")
60            };
61
62            match field {
63                "event" => {
64                    event_name = Some(value.to_string());
65                }
66                "data" => {
67                    if !data_buf.is_empty() {
68                        data_buf.push('\n');
69                    }
70                    data_buf.push_str(value);
71                    if data_buf.len() > MAX_SSE_BUFFER_SIZE {
72                        Err(LiteLLMError::http(format!(
73                            "SSE data buffer exceeded maximum size of {} bytes",
74                            MAX_SSE_BUFFER_SIZE
75                        )))?;
76                    }
77                }
78                _ => {}
79            }
80        }
81
82        if !data_buf.is_empty() {
83            let data = std::mem::take(&mut data_buf);
84            let event = event_name.take();
85            yield SseEvent { event, data };
86        }
87    };
88    Box::pin(s)
89}
90
91/// Parse an OpenAI-compatible SSE stream into chat chunks.
92///
93/// This function includes protection against unbounded memory growth by limiting
94/// the internal buffer size to `MAX_SSE_BUFFER_SIZE`.
95pub fn parse_sse_stream<S>(stream: S) -> ChatStream
96where
97    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
98{
99    let s = async_stream::try_stream! {
100        let mut events = sse_event_stream(stream);
101        while let Some(event) = events.next().await {
102            let event = event?;
103            let data = event.data.trim();
104            if data == "[DONE]" {
105                return;
106            }
107            let value: Value = serde_json::from_str(data)
108                .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
109            let usage = parse_usage(&value);
110            let content = value
111                .pointer("/choices/0/delta/content")
112                .and_then(|v| v.as_str())
113                .unwrap_or("")
114                .to_string();
115            yield ChatStreamChunk {
116                content,
117                raw: Some(value),
118                usage,
119            };
120        }
121    };
122    Box::pin(s)
123}
124
125/// Parse an Anthropic SSE stream into chat chunks.
126///
127/// This function includes protection against unbounded memory growth by limiting
128/// the internal buffer size to `MAX_SSE_BUFFER_SIZE`.
129pub fn parse_anthropic_sse_stream<S>(stream: S) -> ChatStream
130where
131    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
132{
133    let s = async_stream::try_stream! {
134        let mut events = sse_event_stream(stream);
135        while let Some(event) = events.next().await {
136            let event = event?;
137            let data = event.data.trim();
138            if data == "[DONE]" {
139                return;
140            }
141            let value: Value = serde_json::from_str(data)
142                .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
143            let usage = parse_usage(&value);
144            if event.event.as_deref() == Some("content_block_delta") {
145                let content = value
146                    .pointer("/delta/text")
147                    .and_then(|v| v.as_str())
148                    .unwrap_or("")
149                    .to_string();
150                if !content.is_empty() {
151                    yield ChatStreamChunk {
152                        content,
153                        raw: Some(value),
154                        usage,
155                    };
156                }
157            }
158        }
159    };
160    Box::pin(s)
161}
162
163fn parse_usage(value: &Value) -> Option<Usage> {
164    let usage = value.get("usage")?.as_object()?;
165    let prompt_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64());
166    let completion_tokens = usage.get("completion_tokens").and_then(|v| v.as_u64());
167    let total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64());
168    let cost_usd = usage
169        .get("cost")
170        .and_then(|v| v.as_f64())
171        .or_else(|| usage.get("cost").and_then(|v| v.as_str())?.parse().ok())
172        .or_else(|| usage.get("cost_usd").and_then(|v| v.as_f64()))
173        .or_else(|| usage.get("total_cost").and_then(|v| v.as_f64()));
174    Some(Usage {
175        prompt_tokens,
176        completion_tokens,
177        thoughts_tokens: None,
178        total_tokens,
179        cost_usd,
180    })
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use bytes::Bytes;
187    use futures_util::stream;
188
189    #[tokio::test]
190    async fn parse_sse_basic() {
191        let data = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n\
192                    data: {\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n\
193                    data: [DONE]\n\n";
194        let bytes_stream = stream::iter(vec![Ok(Bytes::from(data))]);
195        let mut chat_stream = parse_sse_stream(bytes_stream);
196
197        let chunk1 = chat_stream.next().await.unwrap().unwrap();
198        assert_eq!(chunk1.content, "Hello");
199
200        let chunk2 = chat_stream.next().await.unwrap().unwrap();
201        assert_eq!(chunk2.content, " World");
202
203        assert!(chat_stream.next().await.is_none());
204    }
205
206    #[tokio::test]
207    async fn parse_anthropic_sse_basic() {
208        let data = "event: content_block_delta\n\
209                    data: {\"delta\":{\"text\":\"Hello\"}}\n\n\
210                    event: content_block_delta\n\
211                    data: {\"delta\":{\"text\":\" World\"}}\n\n";
212        let bytes_stream = stream::iter(vec![Ok(Bytes::from(data))]);
213        let mut chat_stream = parse_anthropic_sse_stream(bytes_stream);
214
215        let chunk1 = chat_stream.next().await.unwrap().unwrap();
216        assert_eq!(chunk1.content, "Hello");
217
218        let chunk2 = chat_stream.next().await.unwrap().unwrap();
219        assert_eq!(chunk2.content, " World");
220    }
221
222    #[tokio::test]
223    async fn parse_sse_handles_split_chunks() {
224        // Simulate data coming in multiple network chunks
225        let chunk1 = "data: {\"choices\":[{\"delta\":{\"con";
226        let chunk2 = "tent\":\"Split\"}}]}\n\ndata: [DONE]\n\n";
227        let bytes_stream = stream::iter(vec![Ok(Bytes::from(chunk1)), Ok(Bytes::from(chunk2))]);
228        let mut chat_stream = parse_sse_stream(bytes_stream);
229
230        let chunk = chat_stream.next().await.unwrap().unwrap();
231        assert_eq!(chunk.content, "Split");
232
233        assert!(chat_stream.next().await.is_none());
234    }
235}