sh_layer1/streaming/
mod.rs1pub mod http;
12pub mod providers;
13pub mod sse;
14pub mod websocket;
15
16pub use http::{HttpAdapter, HttpConfig, HttpRequest, HttpResponseStream, SseStream};
18pub use providers::{
19 ContentBlockType, ContentDelta, StreamEvent, StreamProvider, StreamState, StreamUsage,
20};
21pub use sse::{SseEvent, SseParser};
22pub use websocket::{WebSocketAdapter, WebSocketConfig, WebSocketMessage, WebSocketMessageStream};
23
24use anyhow::Result;
25use futures::Stream;
26use reqwest::Response;
27use std::collections::VecDeque;
28use std::pin::Pin;
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::sync::Arc;
31use std::task::{Context, Poll};
32
33pub use providers::{AnthropicStreamEvent, OllamaStreamChunk, OpenAiStreamChunk};
35
36pub struct StreamHandler;
38
39impl StreamHandler {
40 pub fn create_sse_stream(
42 source: impl Stream<Item = Result<String>> + Send + 'static,
43 ) -> impl Stream<Item = Result<String>> {
44 use futures::StreamExt;
45
46 source.map(|item| match item {
47 Ok(data) => Ok(format!("data: {}\n\n", data)),
48 Err(e) => Err(e),
49 })
50 }
51}
52
53pub struct AbortableStream<S> {
55 inner: S,
56 abort_flag: Arc<AtomicBool>,
57}
58
59impl<S> AbortableStream<S> {
60 pub fn new(inner: S, abort_flag: Arc<AtomicBool>) -> Self {
62 Self { inner, abort_flag }
63 }
64
65 pub fn is_aborted(&self) -> bool {
67 self.abort_flag.load(Ordering::Relaxed)
68 }
69}
70
71impl<S, T> Stream for AbortableStream<S>
72where
73 S: Stream<Item = Result<T>> + Unpin,
74{
75 type Item = Result<T>;
76
77 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 if self.abort_flag.load(Ordering::Relaxed) {
79 return Poll::Ready(None);
80 }
81 Pin::new(&mut self.inner).poll_next(cx)
82 }
83}
84
85pub struct MessageStream {
87 response: Response,
88 parser: SseParser,
89 pending: VecDeque<StreamEvent>,
90 done: bool,
91 state: StreamState,
92 provider: StreamProvider,
93}
94
95impl MessageStream {
96 pub fn new(response: Response, provider: StreamProvider, model: String) -> Self {
98 let parser = SseParser::new().with_context(
99 match provider {
100 StreamProvider::Anthropic | StreamProvider::AnthropicCompatible => "Anthropic",
101 StreamProvider::OpenAI | StreamProvider::OpenAICompatible => "OpenAI",
102 StreamProvider::Gemini => "Gemini",
103 StreamProvider::AzureOpenAI => "AzureOpenAI",
104 StreamProvider::Bedrock => "Bedrock",
105 StreamProvider::Ollama => "Ollama",
106 },
107 &model,
108 );
109 Self {
110 response,
111 parser,
112 pending: VecDeque::new(),
113 done: false,
114 state: StreamState::new(model),
115 provider,
116 }
117 }
118
119 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>> {
121 loop {
122 if let Some(event) = self.pending.pop_front() {
123 return Ok(Some(event));
124 }
125
126 if self.done {
127 let _remaining = self.parser.finish()?;
128 for event in self.state.finish() {
129 self.pending.push_back(event);
130 }
131 if let Some(event) = self.pending.pop_front() {
132 return Ok(Some(event));
133 }
134 return Ok(None);
135 }
136
137 match self.response.chunk().await? {
138 Some(chunk) => {
139 let sse_events = self.parser.push(&chunk)?;
140 for sse_event in sse_events {
141 let events = self.parse_sse_event(&sse_event)?;
142 self.pending.extend(events);
143 }
144 }
145 None => {
146 self.done = true;
147 }
148 }
149 }
150 }
151
152 fn parse_sse_event(
153 &mut self,
154 event: &crate::streaming::sse::SseEvent,
155 ) -> Result<Vec<StreamEvent>> {
156 use crate::streaming::providers::*;
157
158 match self.provider {
159 StreamProvider::Anthropic | StreamProvider::AnthropicCompatible => {
160 let anthropic_event: AnthropicStreamEvent = serde_json::from_str(&event.data)?;
161 Ok(self.state.ingest_anthropic(anthropic_event))
162 }
163 StreamProvider::OpenAI | StreamProvider::OpenAICompatible => {
164 let openai_chunk: OpenAiStreamChunk = serde_json::from_str(&event.data)?;
165 Ok(self.state.ingest_openai(openai_chunk))
166 }
167 StreamProvider::Gemini => {
168 let openai_chunk: OpenAiStreamChunk = serde_json::from_str(&event.data)?;
169 Ok(self.state.ingest_openai(openai_chunk))
170 }
171 StreamProvider::AzureOpenAI => {
172 let openai_chunk: OpenAiStreamChunk = serde_json::from_str(&event.data)?;
173 Ok(self.state.ingest_openai(openai_chunk))
174 }
175 StreamProvider::Bedrock => {
176 let anthropic_event: AnthropicStreamEvent = serde_json::from_str(&event.data)?;
177 Ok(self.state.ingest_anthropic(anthropic_event))
178 }
179 StreamProvider::Ollama => {
180 let ollama_chunk: OllamaStreamChunk = serde_json::from_str(&event.data)?;
181 Ok(self.state.ingest_ollama(ollama_chunk))
182 }
183 }
184 }
185
186 pub async fn collect_text(&mut self) -> Result<String> {
188 let mut text = String::new();
189 while let Some(event) = self.next_event().await? {
190 if let StreamEvent::ContentBlockDelta {
191 delta: ContentDelta::Text(t),
192 ..
193 } = event
194 {
195 text.push_str(&t);
196 }
197 }
198 Ok(text)
199 }
200}
201
202pub type OnChunkCallback = Box<dyn Fn(&str) + Send + Sync>;
204
205pub struct CallbackStream {
207 inner: MessageStream,
208 on_chunk: Option<OnChunkCallback>,
209 abort_flag: Arc<AtomicBool>,
210}
211
212impl CallbackStream {
213 pub fn new(inner: MessageStream, on_chunk: Option<OnChunkCallback>) -> Self {
215 Self {
216 inner,
217 on_chunk,
218 abort_flag: Arc::new(AtomicBool::new(false)),
219 }
220 }
221
222 pub fn abort_flag(&self) -> Arc<AtomicBool> {
224 Arc::clone(&self.abort_flag)
225 }
226
227 pub fn abort(&self) {
229 self.abort_flag.store(true, Ordering::Relaxed);
230 }
231
232 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>> {
234 if self.abort_flag.load(Ordering::Relaxed) {
235 return Ok(None);
236 }
237
238 let event = self.inner.next_event().await?;
239
240 if let Some(ref callback) = self.on_chunk {
242 if let Some(StreamEvent::ContentBlockDelta {
243 delta: ContentDelta::Text(t),
244 ..
245 }) = event.as_ref()
246 {
247 callback(t);
248 }
249 }
250
251 Ok(event)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn abortable_stream_respects_abort_flag() {
261 use futures::stream;
262
263 let abort_flag = Arc::new(AtomicBool::new(true));
264 let inner = stream::iter(vec![Ok("test".to_string())]);
265 let mut stream = AbortableStream::new(inner, abort_flag);
266
267 let result = futures::executor::block_on_stream(&mut stream).next();
268 assert!(
269 result.is_none(),
270 "aborted stream should return None immediately"
271 );
272 }
273}