Skip to main content

laminar_core/subscription/
stream.rs

1//! Async `Stream` subscriptions — [`ChangeEventStream`] and
2//! [`ChangeEventResultStream`].
3//!
4//! Wraps the broadcast channel from the [`SubscriptionRegistry`] in a
5//! `tokio_stream`-compatible async `Stream`, enabling idiomatic consumption
6//! with combinators like `.filter()`, `.map()`, `.take()`, and
7//! `.buffer_unordered()`.
8//!
9//! # API Styles
10//!
11//! - [`ChangeEventStream`] — `Stream<Item = ChangeEvent>`, silently skips lag.
12//! - [`ChangeEventResultStream`] — `Stream<Item = Result<ChangeEvent, _>>`,
13//!   surfaces lag errors for explicit handling.
14//!
15//! # Usage
16//!
17//! ```rust,ignore
18//! use tokio_stream::StreamExt;
19//!
20//! let mut stream = subscribe_stream(registry, "trades".into(), 0, config);
21//!
22//! while let Some(event) = stream.next().await {
23//!     process(event);
24//! }
25//!
26//! // With combinators
27//! let inserts = subscribe_stream(registry, "trades".into(), 0, config)
28//!     .filter(|e| e.event_type() == EventType::Insert)
29//!     .take(100);
30//! tokio::pin!(inserts);
31//! while let Some(event) = inserts.next().await {
32//!     process(event);
33//! }
34//! ```
35//!
36//! # Implementation Note
37//!
38//! Uses [`BroadcastStream`](tokio_stream::wrappers::BroadcastStream) internally
39//! for correct async wakeup semantics. A naive manual `poll_next` with
40//! `try_recv` + `cx.waker().wake_by_ref()` causes a busy-spin loop at 100%
41//! CPU. `BroadcastStream` integrates with tokio's async machinery — it only
42//! wakes the task when new data is actually available.
43
44use std::pin::Pin;
45use std::sync::Arc;
46use std::task::{Context, Poll};
47
48use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
49use tokio_stream::wrappers::BroadcastStream;
50use tokio_stream::Stream;
51
52use crate::subscription::event::ChangeEvent;
53use crate::subscription::handle::PushSubscriptionError;
54use crate::subscription::registry::{
55    SubscriptionConfig, SubscriptionId, SubscriptionMetrics, SubscriptionRegistry,
56};
57
58// ---------------------------------------------------------------------------
59// ChangeEventStream
60// ---------------------------------------------------------------------------
61
62/// Async stream wrapper for push subscriptions.
63///
64/// Implements `Stream<Item = ChangeEvent>`, silently skipping lagged events
65/// (with a `tracing::debug!` log). The stream terminates when the source is
66/// closed or the subscription is cancelled.
67///
68/// All fields are `Unpin` (including `BroadcastStream`), so the struct is
69/// `Unpin` and works directly with `tokio::select!` without explicit pinning.
70///
71/// Dropping the stream automatically cancels the subscription in the registry.
72pub struct ChangeEventStream {
73    /// Subscription ID for lifecycle management.
74    id: SubscriptionId,
75    /// Registry reference for pause/resume/cancel.
76    registry: Arc<SubscriptionRegistry>,
77    /// Query or source name for diagnostics.
78    query: String,
79    /// Inner `BroadcastStream` that handles proper async wakeup.
80    inner: BroadcastStream<ChangeEvent>,
81    /// Whether the stream has terminated.
82    terminated: bool,
83}
84
85impl ChangeEventStream {
86    /// Returns the subscription ID.
87    #[must_use]
88    pub fn id(&self) -> SubscriptionId {
89        self.id
90    }
91
92    /// Returns the query or source name for this subscription.
93    #[must_use]
94    pub fn query(&self) -> &str {
95        &self.query
96    }
97
98    /// Returns `true` if the stream has terminated.
99    #[must_use]
100    pub fn is_terminated(&self) -> bool {
101        self.terminated
102    }
103
104    /// Pauses the underlying subscription.
105    ///
106    /// While paused, events are buffered or dropped per the backpressure
107    /// configuration. Returns `true` if the subscription was active and is
108    /// now paused.
109    #[must_use]
110    pub fn pause(&self) -> bool {
111        self.registry.pause(self.id)
112    }
113
114    /// Resumes the underlying subscription.
115    ///
116    /// Returns `true` if the subscription was paused and is now active.
117    #[must_use]
118    pub fn resume(&self) -> bool {
119        self.registry.resume(self.id)
120    }
121
122    /// Cancels the subscription and terminates the stream.
123    ///
124    /// Subsequent calls to `poll_next` / `next()` return `None`.
125    pub fn cancel(&mut self) {
126        if !self.terminated {
127            self.terminated = true;
128            self.registry.cancel(self.id);
129        }
130    }
131
132    /// Returns subscription metrics from the registry.
133    #[must_use]
134    pub fn metrics(&self) -> Option<SubscriptionMetrics> {
135        self.registry.metrics(self.id)
136    }
137}
138
139impl Stream for ChangeEventStream {
140    type Item = ChangeEvent;
141
142    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
143        // SAFETY: All fields are Unpin (BroadcastStream stores a ReusableBoxFuture
144        // which is Unpin), so Pin::get_mut is safe.
145        let this = self.get_mut();
146
147        if this.terminated {
148            return Poll::Ready(None);
149        }
150
151        // Delegate to BroadcastStream, looping on lag errors.
152        loop {
153            match Pin::new(&mut this.inner).poll_next(cx) {
154                Poll::Ready(Some(Ok(event))) => return Poll::Ready(Some(event)),
155                Poll::Ready(Some(Err(_lagged))) => {
156                    // BroadcastStreamRecvError — silently skip lagged events.
157                    tracing::debug!("stream subscription lagged, skipping missed events");
158                }
159                Poll::Ready(None) => {
160                    this.terminated = true;
161                    return Poll::Ready(None);
162                }
163                Poll::Pending => return Poll::Pending,
164            }
165        }
166    }
167}
168
169impl Drop for ChangeEventStream {
170    fn drop(&mut self) {
171        if !self.terminated {
172            self.registry.cancel(self.id);
173        }
174    }
175}
176
177// ---------------------------------------------------------------------------
178// ChangeEventResultStream
179// ---------------------------------------------------------------------------
180
181/// Async stream that also yields lag errors.
182///
183/// Implements `Stream<Item = Result<ChangeEvent, PushSubscriptionError>>`,
184/// allowing explicit handling of lag and error conditions. Use this when
185/// you need to react to missed events rather than silently skipping them.
186///
187/// # Usage
188///
189/// ```rust,ignore
190/// use tokio_stream::StreamExt;
191///
192/// let mut stream = subscribe_stream_with_errors(registry, "trades".into(), 0, config);
193///
194/// while let Some(result) = stream.next().await {
195///     match result {
196///         Ok(event) => process(event),
197///         Err(PushSubscriptionError::Lagged(n)) => {
198///             eprintln!("Missed {n} events");
199///         }
200///         Err(e) => break,
201///     }
202/// }
203/// ```
204pub struct ChangeEventResultStream {
205    /// Subscription ID.
206    id: SubscriptionId,
207    /// Registry reference.
208    registry: Arc<SubscriptionRegistry>,
209    /// Inner `BroadcastStream`.
210    inner: BroadcastStream<ChangeEvent>,
211    /// Whether the stream has terminated.
212    terminated: bool,
213}
214
215impl ChangeEventResultStream {
216    /// Returns the subscription ID.
217    #[must_use]
218    pub fn id(&self) -> SubscriptionId {
219        self.id
220    }
221
222    /// Returns `true` if the stream has terminated.
223    #[must_use]
224    pub fn is_terminated(&self) -> bool {
225        self.terminated
226    }
227
228    /// Cancels the subscription and terminates the stream.
229    pub fn cancel(&mut self) {
230        if !self.terminated {
231            self.terminated = true;
232            self.registry.cancel(self.id);
233        }
234    }
235}
236
237impl Stream for ChangeEventResultStream {
238    type Item = Result<ChangeEvent, PushSubscriptionError>;
239
240    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
241        let this = self.get_mut();
242
243        if this.terminated {
244            return Poll::Ready(None);
245        }
246
247        match Pin::new(&mut this.inner).poll_next(cx) {
248            Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(event))),
249            Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) => {
250                Poll::Ready(Some(Err(PushSubscriptionError::Lagged(n))))
251            }
252            Poll::Ready(None) => {
253                this.terminated = true;
254                Poll::Ready(None)
255            }
256            Poll::Pending => Poll::Pending,
257        }
258    }
259}
260
261impl Drop for ChangeEventResultStream {
262    fn drop(&mut self) {
263        if !self.terminated {
264            self.registry.cancel(self.id);
265        }
266    }
267}
268
269// ---------------------------------------------------------------------------
270// Factory Functions
271// ---------------------------------------------------------------------------
272
273/// Creates an async `Stream` subscription.
274///
275/// Returns a [`ChangeEventStream`] that yields [`ChangeEvent`]s. Lagged
276/// events are silently skipped; the stream terminates when the source is
277/// closed or the subscription is cancelled.
278///
279/// # Arguments
280///
281/// * `registry` — Subscription registry.
282/// * `source_name` — Name of the source MV or query.
283/// * `source_id` — Ring 0 source identifier.
284/// * `config` — Subscription configuration.
285pub fn subscribe_stream(
286    registry: Arc<SubscriptionRegistry>,
287    source_name: String,
288    source_id: u32,
289    config: SubscriptionConfig,
290) -> ChangeEventStream {
291    let (id, receiver) = registry.create(source_name.clone(), source_id, config);
292    ChangeEventStream {
293        id,
294        registry,
295        query: source_name,
296        inner: BroadcastStream::new(receiver),
297        terminated: false,
298    }
299}
300
301/// Creates an async `Stream` that also yields errors.
302///
303/// Returns a [`ChangeEventResultStream`] that yields
304/// `Result<ChangeEvent, PushSubscriptionError>`, allowing explicit handling
305/// of lagged events.
306///
307/// # Arguments
308///
309/// * `registry` — Subscription registry.
310/// * `source_name` — Name of the source MV or query.
311/// * `source_id` — Ring 0 source identifier.
312/// * `config` — Subscription configuration.
313pub fn subscribe_stream_with_errors(
314    registry: Arc<SubscriptionRegistry>,
315    source_name: String,
316    source_id: u32,
317    config: SubscriptionConfig,
318) -> ChangeEventResultStream {
319    let (id, receiver) = registry.create(source_name, source_id, config);
320    ChangeEventResultStream {
321        id,
322        registry,
323        inner: BroadcastStream::new(receiver),
324        terminated: false,
325    }
326}
327
328// ===========================================================================
329// Tests
330// ===========================================================================
331
332#[cfg(test)]
333#[allow(clippy::cast_possible_wrap)]
334#[allow(clippy::cast_sign_loss)]
335#[allow(clippy::field_reassign_with_default)]
336#[allow(clippy::ignored_unit_patterns)]
337mod tests {
338    use super::*;
339    use std::sync::Arc;
340
341    use arrow_array::Int64Array;
342    use arrow_schema::{DataType, Field, Schema};
343    use tokio_stream::StreamExt;
344
345    use crate::subscription::event::EventType;
346    use crate::subscription::registry::SubscriptionState;
347
348    fn make_batch(n: usize) -> arrow_array::RecordBatch {
349        let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
350        let values: Vec<i64> = (0..n as i64).collect();
351        let array = Int64Array::from(values);
352        arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
353    }
354
355    /// Helper: create a stream + get senders for pushing events.
356    fn make_stream(
357        name: &str,
358    ) -> (
359        Arc<SubscriptionRegistry>,
360        ChangeEventStream,
361        Vec<tokio::sync::broadcast::Sender<ChangeEvent>>,
362    ) {
363        let registry = Arc::new(SubscriptionRegistry::new());
364        let stream = subscribe_stream(
365            Arc::clone(&registry),
366            name.into(),
367            0,
368            SubscriptionConfig::default(),
369        );
370        let senders = registry.get_senders_for_source(0);
371        (registry, stream, senders)
372    }
373
374    /// Helper: send N insert events to the first sender.
375    fn send_events(senders: &[tokio::sync::broadcast::Sender<ChangeEvent>], count: usize) {
376        for i in 0..count {
377            let batch = Arc::new(make_batch(1));
378            senders[0]
379                .send(ChangeEvent::insert(batch, i as i64 * 1000, i as u64))
380                .unwrap();
381        }
382    }
383
384    // --- Basic stream consumption ---
385
386    #[tokio::test]
387    async fn test_stream_receives_events() {
388        let (_reg, mut stream, senders) = make_stream("trades");
389
390        send_events(&senders, 5);
391
392        for i in 0..5i64 {
393            let event = stream.next().await.unwrap();
394            assert_eq!(event.timestamp(), i * 1000);
395            assert_eq!(event.sequence(), Some(i as u64));
396        }
397    }
398
399    #[tokio::test]
400    async fn test_stream_terminates_on_close() {
401        let (reg, mut stream, senders) = make_stream("trades");
402
403        // Send one event then close
404        send_events(&senders, 1);
405        let event = stream.next().await.unwrap();
406        assert_eq!(event.timestamp(), 0);
407
408        // Close the channel — drop cloned senders AND cancel the registry
409        // entry (which drops the entry's broadcast::Sender).
410        drop(senders);
411        reg.cancel(stream.id());
412
413        let result = stream.next().await;
414        assert!(result.is_none());
415        assert!(stream.is_terminated());
416    }
417
418    #[tokio::test]
419    async fn test_stream_cancel() {
420        let (reg, mut stream, _senders) = make_stream("trades");
421        assert_eq!(reg.subscription_count(), 1);
422
423        stream.cancel();
424
425        assert!(stream.is_terminated());
426        assert_eq!(reg.subscription_count(), 0);
427
428        // Subsequent next() returns None
429        let result = stream.next().await;
430        assert!(result.is_none());
431    }
432
433    #[tokio::test]
434    async fn test_stream_drop_cancels() {
435        let registry = Arc::new(SubscriptionRegistry::new());
436        {
437            let _stream = subscribe_stream(
438                Arc::clone(&registry),
439                "trades".into(),
440                0,
441                SubscriptionConfig::default(),
442            );
443            assert_eq!(registry.subscription_count(), 1);
444        }
445        // Dropped — should be cancelled
446        assert_eq!(registry.subscription_count(), 0);
447    }
448
449    // --- Combinator tests ---
450
451    #[tokio::test]
452    async fn test_stream_filter_combinator() {
453        let (reg, stream, senders) = make_stream("trades");
454        let id = stream.id();
455
456        // Send mixed events: inserts and watermarks
457        let batch = Arc::new(make_batch(1));
458        senders[0]
459            .send(ChangeEvent::insert(Arc::clone(&batch), 1000, 1))
460            .unwrap();
461        senders[0].send(ChangeEvent::watermark(2000)).unwrap();
462        senders[0]
463            .send(ChangeEvent::insert(Arc::clone(&batch), 3000, 3))
464            .unwrap();
465
466        // Close the channel so .collect() terminates
467        drop(senders);
468        reg.cancel(id);
469
470        // Filter to inserts only
471        let inserts: Vec<_> = stream
472            .filter(|e| e.event_type() == EventType::Insert)
473            .collect()
474            .await;
475
476        assert_eq!(inserts.len(), 2);
477        assert_eq!(inserts[0].timestamp(), 1000);
478        assert_eq!(inserts[1].timestamp(), 3000);
479    }
480
481    #[tokio::test]
482    async fn test_stream_map_combinator() {
483        let (reg, stream, senders) = make_stream("trades");
484        let id = stream.id();
485
486        send_events(&senders, 3);
487
488        // Close the channel so .collect() terminates
489        drop(senders);
490        reg.cancel(id);
491
492        // Map to timestamps
493        let timestamps: Vec<i64> = stream.map(|e| e.timestamp()).collect().await;
494        assert_eq!(timestamps, vec![0, 1000, 2000]);
495    }
496
497    #[tokio::test]
498    async fn test_stream_take_combinator() {
499        let (_reg, stream, senders) = make_stream("trades");
500
501        send_events(&senders, 10);
502
503        // Take only 3
504        let events: Vec<_> = stream.take(3).collect().await;
505        assert_eq!(events.len(), 3);
506        assert_eq!(events[0].timestamp(), 0);
507        assert_eq!(events[1].timestamp(), 1000);
508        assert_eq!(events[2].timestamp(), 2000);
509    }
510
511    // --- select! compatibility ---
512
513    #[tokio::test]
514    async fn test_stream_with_select() {
515        let (_reg, mut stream, senders) = make_stream("trades");
516
517        send_events(&senders, 1);
518
519        let result = tokio::select! {
520            event = stream.next() => event,
521            _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
522                panic!("timeout — event should be immediate");
523            }
524        };
525
526        let event = result.unwrap();
527        assert_eq!(event.timestamp(), 0);
528    }
529
530    // --- Result stream ---
531
532    #[tokio::test]
533    async fn test_result_stream_yields_errors() {
534        let registry = Arc::new(SubscriptionRegistry::new());
535        let mut cfg = SubscriptionConfig::default();
536        cfg.buffer_size = 4;
537
538        let mut stream =
539            subscribe_stream_with_errors(Arc::clone(&registry), "trades".into(), 0, cfg);
540
541        let senders = registry.get_senders_for_source(0);
542
543        // Overflow to cause lag
544        for i in 0..20i64 {
545            let batch = Arc::new(make_batch(1));
546            let _ = senders[0].send(ChangeEvent::insert(batch, i * 100, i as u64));
547        }
548
549        // Close the channel so the loop terminates
550        drop(senders);
551        registry.cancel(stream.id());
552
553        // Collect results — should include at least one Lagged error
554        let mut had_error = false;
555        let mut had_ok = false;
556
557        while let Some(result) = stream.next().await {
558            match result {
559                Ok(_) => had_ok = true,
560                Err(PushSubscriptionError::Lagged(n)) => {
561                    assert!(n > 0);
562                    had_error = true;
563                }
564                Err(e) => panic!("unexpected error: {e}"),
565            }
566        }
567
568        assert!(had_error, "expected at least one lag error");
569        assert!(had_ok, "expected at least one successful event");
570    }
571
572    #[tokio::test]
573    async fn test_result_stream_terminates_on_close() {
574        let registry = Arc::new(SubscriptionRegistry::new());
575        let mut stream = subscribe_stream_with_errors(
576            Arc::clone(&registry),
577            "trades".into(),
578            0,
579            SubscriptionConfig::default(),
580        );
581
582        let senders = registry.get_senders_for_source(0);
583        let batch = Arc::new(make_batch(1));
584        senders[0]
585            .send(ChangeEvent::insert(batch, 1000, 1))
586            .unwrap();
587
588        let result = stream.next().await.unwrap().unwrap();
589        assert_eq!(result.timestamp(), 1000);
590
591        // Close the channel — drop clones AND cancel entry
592        drop(senders);
593        registry.cancel(stream.id());
594
595        assert!(stream.next().await.is_none());
596        assert!(stream.is_terminated());
597    }
598
599    // --- Lifecycle ---
600
601    #[tokio::test]
602    async fn test_stream_pause_resume() {
603        let (reg, stream, _senders) = make_stream("trades");
604
605        assert!(stream.pause());
606        assert_eq!(reg.state(stream.id()), Some(SubscriptionState::Paused));
607
608        assert!(!stream.pause()); // already paused
609
610        assert!(stream.resume());
611        assert_eq!(reg.state(stream.id()), Some(SubscriptionState::Active));
612
613        assert!(!stream.resume()); // already active
614    }
615
616    // --- Accessors ---
617
618    #[tokio::test]
619    async fn test_stream_accessors() {
620        let (reg, stream, _senders) = make_stream("trades");
621
622        assert_eq!(stream.query(), "trades");
623        assert!(!stream.is_terminated());
624
625        let m = stream.metrics().unwrap();
626        assert_eq!(m.id, stream.id());
627        assert_eq!(m.source_name, "trades");
628        assert_eq!(m.state, SubscriptionState::Active);
629
630        drop(reg);
631    }
632
633    // --- Multiple consumers ---
634
635    #[tokio::test]
636    async fn test_stream_multiple_consumers() {
637        let registry = Arc::new(SubscriptionRegistry::new());
638
639        let mut s1 = subscribe_stream(
640            Arc::clone(&registry),
641            "trades".into(),
642            0,
643            SubscriptionConfig::default(),
644        );
645        let mut s2 = subscribe_stream(
646            Arc::clone(&registry),
647            "trades".into(),
648            0,
649            SubscriptionConfig::default(),
650        );
651
652        let senders = registry.get_senders_for_source(0);
653        let batch = Arc::new(make_batch(1));
654        let event = ChangeEvent::insert(batch, 5000, 10);
655        for sender in &senders {
656            sender.send(event.clone()).unwrap();
657        }
658
659        let e1 = s1.next().await.unwrap();
660        let e2 = s2.next().await.unwrap();
661        assert_eq!(e1.timestamp(), 5000);
662        assert_eq!(e2.timestamp(), 5000);
663    }
664}