1use crate::error::{LiteLLMError, Result};
2use crate::http::MAX_SSE_BUFFER_SIZE;
3use crate::types::Usage;
4use bytes::Bytes;
5use futures_util::stream::{Stream, StreamExt, TryStreamExt};
6use serde_json::Value;
7use std::pin::Pin;
8use tokio::io::{AsyncBufReadExt, BufReader};
9use tokio_util::io::StreamReader;
10
11#[derive(Debug, Clone)]
12pub struct ChatStreamChunk {
13 pub content: String,
14 pub raw: Option<Value>,
15 pub usage: Option<Usage>,
16}
17
18pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatStreamChunk>> + Send>>;
19
20#[derive(Debug, Clone)]
21struct SseEvent {
22 event: Option<String>,
23 data: String,
24}
25
26type SseEventStream = Pin<Box<dyn Stream<Item = Result<SseEvent>> + Send>>;
27
28fn sse_event_stream<S>(stream: S) -> SseEventStream
29where
30 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
31{
32 let s = async_stream::try_stream! {
33 let stream = stream.map_err(std::io::Error::other);
34 let reader = StreamReader::new(stream);
35 let mut lines = BufReader::new(reader).lines();
36
37 let mut event_name: Option<String> = None;
38 let mut data_buf = String::new();
39
40 while let Some(line) = lines.next_line().await.map_err(LiteLLMError::from)? {
41 if line.is_empty() {
42 if !data_buf.is_empty() {
43 let data = std::mem::take(&mut data_buf);
44 let event = event_name.take();
45 yield SseEvent { event, data };
46 } else {
47 event_name = None;
48 }
49 continue;
50 }
51
52 if line.starts_with(':') {
53 continue;
54 }
55
56 let (field, value) = if let Some((field, value)) = line.split_once(':') {
57 (field, value.strip_prefix(' ').unwrap_or(value))
58 } else {
59 (line.as_str(), "")
60 };
61
62 match field {
63 "event" => {
64 event_name = Some(value.to_string());
65 }
66 "data" => {
67 if !data_buf.is_empty() {
68 data_buf.push('\n');
69 }
70 data_buf.push_str(value);
71 if data_buf.len() > MAX_SSE_BUFFER_SIZE {
72 Err(LiteLLMError::http(format!(
73 "SSE data buffer exceeded maximum size of {} bytes",
74 MAX_SSE_BUFFER_SIZE
75 )))?;
76 }
77 }
78 _ => {}
79 }
80 }
81
82 if !data_buf.is_empty() {
83 let data = std::mem::take(&mut data_buf);
84 let event = event_name.take();
85 yield SseEvent { event, data };
86 }
87 };
88 Box::pin(s)
89}
90
91pub fn parse_sse_stream<S>(stream: S) -> ChatStream
96where
97 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
98{
99 let s = async_stream::try_stream! {
100 let mut events = sse_event_stream(stream);
101 while let Some(event) = events.next().await {
102 let event = event?;
103 let data = event.data.trim();
104 if data == "[DONE]" {
105 return;
106 }
107 let value: Value = serde_json::from_str(data)
108 .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
109 let usage = parse_usage(&value);
110 let content = value
111 .pointer("/choices/0/delta/content")
112 .and_then(|v| v.as_str())
113 .unwrap_or("")
114 .to_string();
115 yield ChatStreamChunk {
116 content,
117 raw: Some(value),
118 usage,
119 };
120 }
121 };
122 Box::pin(s)
123}
124
125pub fn parse_anthropic_sse_stream<S>(stream: S) -> ChatStream
130where
131 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
132{
133 let s = async_stream::try_stream! {
134 let mut events = sse_event_stream(stream);
135 while let Some(event) = events.next().await {
136 let event = event?;
137 let data = event.data.trim();
138 if data == "[DONE]" {
139 return;
140 }
141 let value: Value = serde_json::from_str(data)
142 .map_err(|e| LiteLLMError::Parse(e.to_string()))?;
143 let usage = parse_usage(&value);
144 if event.event.as_deref() == Some("content_block_delta") {
145 let content = value
146 .pointer("/delta/text")
147 .and_then(|v| v.as_str())
148 .unwrap_or("")
149 .to_string();
150 if !content.is_empty() {
151 yield ChatStreamChunk {
152 content,
153 raw: Some(value),
154 usage,
155 };
156 }
157 }
158 }
159 };
160 Box::pin(s)
161}
162
163fn parse_usage(value: &Value) -> Option<Usage> {
164 let usage = value.get("usage")?.as_object()?;
165 let prompt_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64());
166 let completion_tokens = usage.get("completion_tokens").and_then(|v| v.as_u64());
167 let total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64());
168 let cost_usd = usage
169 .get("cost")
170 .and_then(|v| v.as_f64())
171 .or_else(|| usage.get("cost").and_then(|v| v.as_str())?.parse().ok())
172 .or_else(|| usage.get("cost_usd").and_then(|v| v.as_f64()))
173 .or_else(|| usage.get("total_cost").and_then(|v| v.as_f64()));
174 Some(Usage {
175 prompt_tokens,
176 completion_tokens,
177 thoughts_tokens: None,
178 total_tokens,
179 cost_usd,
180 })
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use bytes::Bytes;
187 use futures_util::stream;
188
189 #[tokio::test]
190 async fn parse_sse_basic() {
191 let data = "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n\
192 data: {\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n\
193 data: [DONE]\n\n";
194 let bytes_stream = stream::iter(vec![Ok(Bytes::from(data))]);
195 let mut chat_stream = parse_sse_stream(bytes_stream);
196
197 let chunk1 = chat_stream.next().await.unwrap().unwrap();
198 assert_eq!(chunk1.content, "Hello");
199
200 let chunk2 = chat_stream.next().await.unwrap().unwrap();
201 assert_eq!(chunk2.content, " World");
202
203 assert!(chat_stream.next().await.is_none());
204 }
205
206 #[tokio::test]
207 async fn parse_anthropic_sse_basic() {
208 let data = "event: content_block_delta\n\
209 data: {\"delta\":{\"text\":\"Hello\"}}\n\n\
210 event: content_block_delta\n\
211 data: {\"delta\":{\"text\":\" World\"}}\n\n";
212 let bytes_stream = stream::iter(vec![Ok(Bytes::from(data))]);
213 let mut chat_stream = parse_anthropic_sse_stream(bytes_stream);
214
215 let chunk1 = chat_stream.next().await.unwrap().unwrap();
216 assert_eq!(chunk1.content, "Hello");
217
218 let chunk2 = chat_stream.next().await.unwrap().unwrap();
219 assert_eq!(chunk2.content, " World");
220 }
221
222 #[tokio::test]
223 async fn parse_sse_handles_split_chunks() {
224 let chunk1 = "data: {\"choices\":[{\"delta\":{\"con";
226 let chunk2 = "tent\":\"Split\"}}]}\n\ndata: [DONE]\n\n";
227 let bytes_stream = stream::iter(vec![Ok(Bytes::from(chunk1)), Ok(Bytes::from(chunk2))]);
228 let mut chat_stream = parse_sse_stream(bytes_stream);
229
230 let chunk = chat_stream.next().await.unwrap().unwrap();
231 assert_eq!(chunk.content, "Split");
232
233 assert!(chat_stream.next().await.is_none());
234 }
235}