Skip to main content

swink_agent/
stream_middleware.rs

1//! Middleware wrapper for [`StreamFn`] that intercepts the output stream.
2//!
3//! Mirrors the [`ToolMiddleware`](crate::ToolMiddleware) pattern but for the
4//! streaming boundary. Wraps an `Arc<dyn StreamFn>` and transforms the output
5//! stream of [`AssistantMessageEvent`] values.
6//!
7//! # Example
8//!
9//! ```
10//! use std::sync::Arc;
11//! use swink_agent::{StreamMiddleware, AssistantMessageEvent};
12//! # use swink_agent::StreamFn;
13//! # fn example(stream_fn: Arc<dyn StreamFn>) {
14//! let logged = StreamMiddleware::with_logging(stream_fn, |event| {
15//!     println!("event: {event:?}");
16//! });
17//! # }
18//! ```
19
20use std::pin::Pin;
21use std::sync::Arc;
22
23use futures::stream::{Stream, StreamExt};
24use tokio_util::sync::CancellationToken;
25
26use crate::stream::{AssistantMessageEvent, StreamFn, StreamOptions};
27use crate::types::{AgentContext, ModelSpec};
28
29// ─── Type alias for the stream transformation closure ───────────────────────
30
31type MapStreamFn = Arc<
32    dyn for<'a> Fn(
33            Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>,
34        ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>
35        + Send
36        + Sync,
37>;
38
39// ─── StreamMiddleware ───────────────────────────────────────────────────────
40
41/// Intercepts the output stream from a wrapped [`StreamFn`].
42///
43/// The inner `StreamFn` is called normally, then `map_stream` transforms
44/// the resulting event stream before it reaches the consumer.
45pub struct StreamMiddleware {
46    inner: Arc<dyn StreamFn>,
47    map_stream: MapStreamFn,
48}
49
50impl StreamMiddleware {
51    /// Create a new middleware with a full stream transformation.
52    ///
53    /// The closure receives the inner stream and returns a transformed stream.
54    pub fn new<F>(inner: Arc<dyn StreamFn>, f: F) -> Self
55    where
56        F: for<'a> Fn(
57                Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>,
58            )
59                -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>>
60            + Send
61            + Sync
62            + 'static,
63    {
64        Self {
65            inner,
66            map_stream: Arc::new(f),
67        }
68    }
69
70    /// Create a middleware that inspects each event via a logging callback.
71    ///
72    /// Events pass through unmodified; the callback is called for each event.
73    pub fn with_logging<F>(inner: Arc<dyn StreamFn>, callback: F) -> Self
74    where
75        F: Fn(&AssistantMessageEvent) + Send + Sync + 'static,
76    {
77        let callback = Arc::new(callback);
78        Self::new(inner, move |stream| {
79            let cb = callback.clone();
80            Box::pin(stream.inspect(move |event| cb(event)))
81        })
82    }
83
84    /// Create a middleware that maps each event through a transformation.
85    pub fn with_map<F>(inner: Arc<dyn StreamFn>, f: F) -> Self
86    where
87        F: Fn(AssistantMessageEvent) -> AssistantMessageEvent + Send + Sync + 'static,
88    {
89        let f = Arc::new(f);
90        Self::new(inner, move |stream| {
91            let f = f.clone();
92            Box::pin(stream.map(move |event| f(event)))
93        })
94    }
95
96    /// Create a middleware that filters events based on a predicate.
97    ///
98    /// Events for which the predicate returns `false` are dropped from the stream.
99    pub fn with_filter<F>(inner: Arc<dyn StreamFn>, f: F) -> Self
100    where
101        F: Fn(&AssistantMessageEvent) -> bool + Send + Sync + 'static,
102    {
103        let f = Arc::new(f);
104        Self::new(inner, move |stream| {
105            let f = f.clone();
106            Box::pin(stream.filter(move |event| {
107                let keep = f(event);
108                async move { keep }
109            }))
110        })
111    }
112}
113
114impl StreamFn for StreamMiddleware {
115    fn stream<'a>(
116        &'a self,
117        model: &'a ModelSpec,
118        context: &'a AgentContext,
119        options: &'a StreamOptions,
120        cancellation_token: CancellationToken,
121    ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
122        let inner_stream = self
123            .inner
124            .stream(model, context, options, cancellation_token);
125        (self.map_stream)(inner_stream)
126    }
127}
128
129impl std::fmt::Debug for StreamMiddleware {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("StreamMiddleware").finish_non_exhaustive()
132    }
133}
134
135// ─── Compile-time Send + Sync assertion ─────────────────────────────────────
136
137const _: () = {
138    const fn assert_send_sync<T: Send + Sync>() {}
139    assert_send_sync::<StreamMiddleware>();
140};
141
142#[cfg(test)]
143mod tests {
144    use std::sync::atomic::{AtomicU32, Ordering};
145
146    use futures::StreamExt;
147
148    use super::*;
149    use crate::stream::AssistantMessageEvent;
150    use crate::types::{AgentContext, Cost, ModelSpec, StopReason, Usage};
151
152    /// Minimal `StreamFn` for testing.
153    struct TestStreamFn {
154        events: std::sync::Mutex<Vec<AssistantMessageEvent>>,
155    }
156
157    impl TestStreamFn {
158        fn new(events: Vec<AssistantMessageEvent>) -> Self {
159            Self {
160                events: std::sync::Mutex::new(events),
161            }
162        }
163    }
164
165    impl StreamFn for TestStreamFn {
166        fn stream<'a>(
167            &'a self,
168            _model: &'a ModelSpec,
169            _context: &'a AgentContext,
170            _options: &'a StreamOptions,
171            _ct: CancellationToken,
172        ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
173            let events = self.events.lock().unwrap().clone();
174            Box::pin(futures::stream::iter(events))
175        }
176    }
177
178    fn test_events() -> Vec<AssistantMessageEvent> {
179        vec![
180            AssistantMessageEvent::Start,
181            AssistantMessageEvent::TextStart { content_index: 0 },
182            AssistantMessageEvent::TextDelta {
183                content_index: 0,
184                delta: "hello".into(),
185            },
186            AssistantMessageEvent::TextEnd { content_index: 0 },
187            AssistantMessageEvent::Done {
188                stop_reason: StopReason::Stop,
189                usage: Usage::default(),
190                cost: Cost::default(),
191            },
192        ]
193    }
194
195    fn test_model() -> ModelSpec {
196        ModelSpec::new("test", "test-model")
197    }
198
199    fn test_context() -> AgentContext {
200        AgentContext {
201            system_prompt: String::new(),
202            messages: vec![],
203            tools: vec![],
204        }
205    }
206
207    #[tokio::test]
208    async fn logging_middleware_receives_all_events() {
209        let inner = Arc::new(TestStreamFn::new(test_events()));
210        let count = Arc::new(AtomicU32::new(0));
211        let count_clone = count.clone();
212
213        let mw = StreamMiddleware::with_logging(inner, move |_event| {
214            count_clone.fetch_add(1, Ordering::SeqCst);
215        });
216
217        let model = test_model();
218        let ctx = test_context();
219        let opts = StreamOptions::default();
220        let ct = CancellationToken::new();
221        let stream = mw.stream(&model, &ctx, &opts, ct);
222        let collected: Vec<_> = stream.collect().await;
223
224        assert_eq!(collected.len(), 5);
225        assert_eq!(count.load(Ordering::SeqCst), 5);
226    }
227
228    #[tokio::test]
229    async fn map_middleware_transforms_events() {
230        let inner = Arc::new(TestStreamFn::new(test_events()));
231        let mw = StreamMiddleware::with_map(inner, |event| match event {
232            AssistantMessageEvent::TextDelta {
233                content_index,
234                delta,
235            } => AssistantMessageEvent::TextDelta {
236                content_index,
237                delta: delta.to_uppercase(),
238            },
239            other => other,
240        });
241
242        let model = test_model();
243        let ctx = test_context();
244        let opts = StreamOptions::default();
245        let ct = CancellationToken::new();
246        let stream = mw.stream(&model, &ctx, &opts, ct);
247        let collected: Vec<_> = stream.collect().await;
248
249        let text_delta = &collected[2];
250        if let AssistantMessageEvent::TextDelta { delta, .. } = text_delta {
251            assert_eq!(delta, "HELLO");
252        } else {
253            panic!("expected TextDelta");
254        }
255    }
256
257    #[tokio::test]
258    async fn filter_middleware_drops_thinking_events() {
259        let events = vec![
260            AssistantMessageEvent::Start,
261            AssistantMessageEvent::ThinkingStart { content_index: 0 },
262            AssistantMessageEvent::ThinkingDelta {
263                content_index: 0,
264                delta: "reasoning...".into(),
265            },
266            AssistantMessageEvent::ThinkingEnd {
267                content_index: 0,
268                signature: None,
269            },
270            AssistantMessageEvent::TextStart { content_index: 1 },
271            AssistantMessageEvent::TextDelta {
272                content_index: 1,
273                delta: "result".into(),
274            },
275            AssistantMessageEvent::TextEnd { content_index: 1 },
276            AssistantMessageEvent::Done {
277                stop_reason: StopReason::Stop,
278                usage: Usage::default(),
279                cost: Cost::default(),
280            },
281        ];
282        let inner = Arc::new(TestStreamFn::new(events));
283        let mw = StreamMiddleware::with_filter(inner, |event| {
284            !matches!(
285                event,
286                AssistantMessageEvent::ThinkingStart { .. }
287                    | AssistantMessageEvent::ThinkingDelta { .. }
288                    | AssistantMessageEvent::ThinkingEnd { .. }
289            )
290        });
291
292        let model = test_model();
293        let ctx = test_context();
294        let opts = StreamOptions::default();
295        let ct = CancellationToken::new();
296        let stream = mw.stream(&model, &ctx, &opts, ct);
297        let collected: Vec<_> = stream.collect().await;
298
299        // Start + TextStart + TextDelta + TextEnd + Done = 5
300        assert_eq!(collected.len(), 5);
301        // No thinking events
302        for event in &collected {
303            assert!(!matches!(
304                event,
305                AssistantMessageEvent::ThinkingStart { .. }
306                    | AssistantMessageEvent::ThinkingDelta { .. }
307                    | AssistantMessageEvent::ThinkingEnd { .. }
308            ));
309        }
310    }
311
312    #[tokio::test]
313    async fn middleware_chains_compose() {
314        let inner = Arc::new(TestStreamFn::new(test_events()));
315        let count = Arc::new(AtomicU32::new(0));
316        let count_clone = count.clone();
317
318        // First layer: log
319        let logged: Arc<dyn StreamFn> =
320            Arc::new(StreamMiddleware::with_logging(inner, move |_| {
321                count_clone.fetch_add(1, Ordering::SeqCst);
322            }));
323
324        // Second layer: map
325        let mapped = StreamMiddleware::with_map(logged, |event| match event {
326            AssistantMessageEvent::TextDelta {
327                content_index,
328                delta,
329            } => AssistantMessageEvent::TextDelta {
330                content_index,
331                delta: format!("[{delta}]"),
332            },
333            other => other,
334        });
335
336        let model = test_model();
337        let ctx = test_context();
338        let opts = StreamOptions::default();
339        let ct = CancellationToken::new();
340        let stream = mapped.stream(&model, &ctx, &opts, ct);
341        let collected: Vec<_> = stream.collect().await;
342
343        // Logging saw all 5 events
344        assert_eq!(count.load(Ordering::SeqCst), 5);
345        // Map transformed the delta
346        if let AssistantMessageEvent::TextDelta { delta, .. } = &collected[2] {
347            assert_eq!(delta, "[hello]");
348        } else {
349            panic!("expected TextDelta");
350        }
351    }
352}