1use std::pin::Pin;
21use std::sync::Arc;
22
23use futures::stream::{Stream, StreamExt};
24use tokio_util::sync::CancellationToken;
25
26use crate::stream::{AssistantMessageEvent, StreamFn, StreamOptions};
27use crate::types::{AgentContext, ModelSpec};
28
29type MapStreamFn = Arc<
32 dyn for<'a> Fn(
33 Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>,
34 ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>
35 + Send
36 + Sync,
37>;
38
39pub struct StreamMiddleware {
46 inner: Arc<dyn StreamFn>,
47 map_stream: MapStreamFn,
48}
49
50impl StreamMiddleware {
51 pub fn new<F>(inner: Arc<dyn StreamFn>, f: F) -> Self
55 where
56 F: for<'a> Fn(
57 Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>,
58 )
59 -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>
60 + Send
61 + Sync
62 + 'static,
63 {
64 Self {
65 inner,
66 map_stream: Arc::new(f),
67 }
68 }
69
70 pub fn with_logging<F>(inner: Arc<dyn StreamFn>, callback: F) -> Self
74 where
75 F: Fn(&AssistantMessageEvent) + Send + Sync + 'static,
76 {
77 let callback = Arc::new(callback);
78 Self::new(inner, move |stream| {
79 let cb = callback.clone();
80 Box::pin(stream.inspect(move |event| cb(event)))
81 })
82 }
83
84 pub fn with_map<F>(inner: Arc<dyn StreamFn>, f: F) -> Self
86 where
87 F: Fn(AssistantMessageEvent) -> AssistantMessageEvent + Send + Sync + 'static,
88 {
89 let f = Arc::new(f);
90 Self::new(inner, move |stream| {
91 let f = f.clone();
92 Box::pin(stream.map(move |event| f(event)))
93 })
94 }
95
96 pub fn with_filter<F>(inner: Arc<dyn StreamFn>, f: F) -> Self
100 where
101 F: Fn(&AssistantMessageEvent) -> bool + Send + Sync + 'static,
102 {
103 let f = Arc::new(f);
104 Self::new(inner, move |stream| {
105 let f = f.clone();
106 Box::pin(stream.filter(move |event| {
107 let keep = f(event);
108 async move { keep }
109 }))
110 })
111 }
112}
113
114impl StreamFn for StreamMiddleware {
115 fn stream<'a>(
116 &'a self,
117 model: &'a ModelSpec,
118 context: &'a AgentContext,
119 options: &'a StreamOptions,
120 cancellation_token: CancellationToken,
121 ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
122 let inner_stream = self
123 .inner
124 .stream(model, context, options, cancellation_token);
125 (self.map_stream)(inner_stream)
126 }
127}
128
129impl std::fmt::Debug for StreamMiddleware {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("StreamMiddleware").finish_non_exhaustive()
132 }
133}
134
135const _: () = {
138 const fn assert_send_sync<T: Send + Sync>() {}
139 assert_send_sync::<StreamMiddleware>();
140};
141
142#[cfg(test)]
143mod tests {
144 use std::sync::atomic::{AtomicU32, Ordering};
145
146 use futures::StreamExt;
147
148 use super::*;
149 use crate::stream::AssistantMessageEvent;
150 use crate::types::{AgentContext, Cost, ModelSpec, StopReason, Usage};
151
152 struct TestStreamFn {
154 events: std::sync::Mutex<Vec<AssistantMessageEvent>>,
155 }
156
157 impl TestStreamFn {
158 fn new(events: Vec<AssistantMessageEvent>) -> Self {
159 Self {
160 events: std::sync::Mutex::new(events),
161 }
162 }
163 }
164
165 impl StreamFn for TestStreamFn {
166 fn stream<'a>(
167 &'a self,
168 _model: &'a ModelSpec,
169 _context: &'a AgentContext,
170 _options: &'a StreamOptions,
171 _ct: CancellationToken,
172 ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
173 let events = self.events.lock().unwrap().clone();
174 Box::pin(futures::stream::iter(events))
175 }
176 }
177
178 fn test_events() -> Vec<AssistantMessageEvent> {
179 vec![
180 AssistantMessageEvent::Start,
181 AssistantMessageEvent::TextStart { content_index: 0 },
182 AssistantMessageEvent::TextDelta {
183 content_index: 0,
184 delta: "hello".into(),
185 },
186 AssistantMessageEvent::TextEnd { content_index: 0 },
187 AssistantMessageEvent::Done {
188 stop_reason: StopReason::Stop,
189 usage: Usage::default(),
190 cost: Cost::default(),
191 },
192 ]
193 }
194
195 fn test_model() -> ModelSpec {
196 ModelSpec::new("test", "test-model")
197 }
198
199 fn test_context() -> AgentContext {
200 AgentContext {
201 system_prompt: String::new(),
202 messages: vec![],
203 tools: vec![],
204 }
205 }
206
207 #[tokio::test]
208 async fn logging_middleware_receives_all_events() {
209 let inner = Arc::new(TestStreamFn::new(test_events()));
210 let count = Arc::new(AtomicU32::new(0));
211 let count_clone = count.clone();
212
213 let mw = StreamMiddleware::with_logging(inner, move |_event| {
214 count_clone.fetch_add(1, Ordering::SeqCst);
215 });
216
217 let model = test_model();
218 let ctx = test_context();
219 let opts = StreamOptions::default();
220 let ct = CancellationToken::new();
221 let stream = mw.stream(&model, &ctx, &opts, ct);
222 let collected: Vec<_> = stream.collect().await;
223
224 assert_eq!(collected.len(), 5);
225 assert_eq!(count.load(Ordering::SeqCst), 5);
226 }
227
228 #[tokio::test]
229 async fn map_middleware_transforms_events() {
230 let inner = Arc::new(TestStreamFn::new(test_events()));
231 let mw = StreamMiddleware::with_map(inner, |event| match event {
232 AssistantMessageEvent::TextDelta {
233 content_index,
234 delta,
235 } => AssistantMessageEvent::TextDelta {
236 content_index,
237 delta: delta.to_uppercase(),
238 },
239 other => other,
240 });
241
242 let model = test_model();
243 let ctx = test_context();
244 let opts = StreamOptions::default();
245 let ct = CancellationToken::new();
246 let stream = mw.stream(&model, &ctx, &opts, ct);
247 let collected: Vec<_> = stream.collect().await;
248
249 let text_delta = &collected[2];
250 if let AssistantMessageEvent::TextDelta { delta, .. } = text_delta {
251 assert_eq!(delta, "HELLO");
252 } else {
253 panic!("expected TextDelta");
254 }
255 }
256
257 #[tokio::test]
258 async fn filter_middleware_drops_thinking_events() {
259 let events = vec![
260 AssistantMessageEvent::Start,
261 AssistantMessageEvent::ThinkingStart { content_index: 0 },
262 AssistantMessageEvent::ThinkingDelta {
263 content_index: 0,
264 delta: "reasoning...".into(),
265 },
266 AssistantMessageEvent::ThinkingEnd {
267 content_index: 0,
268 signature: None,
269 },
270 AssistantMessageEvent::TextStart { content_index: 1 },
271 AssistantMessageEvent::TextDelta {
272 content_index: 1,
273 delta: "result".into(),
274 },
275 AssistantMessageEvent::TextEnd { content_index: 1 },
276 AssistantMessageEvent::Done {
277 stop_reason: StopReason::Stop,
278 usage: Usage::default(),
279 cost: Cost::default(),
280 },
281 ];
282 let inner = Arc::new(TestStreamFn::new(events));
283 let mw = StreamMiddleware::with_filter(inner, |event| {
284 !matches!(
285 event,
286 AssistantMessageEvent::ThinkingStart { .. }
287 | AssistantMessageEvent::ThinkingDelta { .. }
288 | AssistantMessageEvent::ThinkingEnd { .. }
289 )
290 });
291
292 let model = test_model();
293 let ctx = test_context();
294 let opts = StreamOptions::default();
295 let ct = CancellationToken::new();
296 let stream = mw.stream(&model, &ctx, &opts, ct);
297 let collected: Vec<_> = stream.collect().await;
298
299 assert_eq!(collected.len(), 5);
301 for event in &collected {
303 assert!(!matches!(
304 event,
305 AssistantMessageEvent::ThinkingStart { .. }
306 | AssistantMessageEvent::ThinkingDelta { .. }
307 | AssistantMessageEvent::ThinkingEnd { .. }
308 ));
309 }
310 }
311
312 #[tokio::test]
313 async fn middleware_chains_compose() {
314 let inner = Arc::new(TestStreamFn::new(test_events()));
315 let count = Arc::new(AtomicU32::new(0));
316 let count_clone = count.clone();
317
318 let logged: Arc<dyn StreamFn> =
320 Arc::new(StreamMiddleware::with_logging(inner, move |_| {
321 count_clone.fetch_add(1, Ordering::SeqCst);
322 }));
323
324 let mapped = StreamMiddleware::with_map(logged, |event| match event {
326 AssistantMessageEvent::TextDelta {
327 content_index,
328 delta,
329 } => AssistantMessageEvent::TextDelta {
330 content_index,
331 delta: format!("[{delta}]"),
332 },
333 other => other,
334 });
335
336 let model = test_model();
337 let ctx = test_context();
338 let opts = StreamOptions::default();
339 let ct = CancellationToken::new();
340 let stream = mapped.stream(&model, &ctx, &opts, ct);
341 let collected: Vec<_> = stream.collect().await;
342
343 assert_eq!(count.load(Ordering::SeqCst), 5);
345 if let AssistantMessageEvent::TextDelta { delta, .. } = &collected[2] {
347 assert_eq!(delta, "[hello]");
348 } else {
349 panic!("expected TextDelta");
350 }
351 }
352}