Skip to main content

laminar_core/streaming/
subscription.rs

1//! Streaming Subscription API.
2//!
3//! A Subscription provides access to records from a Sink. It supports:
4//!
5//! - Non-blocking poll
6//! - Blocking receive with optional timeout
7//! - Iterator interface
8//! - Zero-allocation batch operations
9//!
10//! ## Usage
11//!
12//! ```rust,ignore
13//! let subscription = sink.subscribe();
14//!
15//! // Non-blocking poll
16//! while let Some(batch) = subscription.poll() {
17//!     process(batch);
18//! }
19//!
20//! // Blocking receive
21//! let batch = subscription.recv()?;
22//!
23//! // With timeout
24//! let batch = subscription.recv_timeout(Duration::from_secs(1))?;
25//!
26//! // As iterator
27//! for batch in subscription {
28//!     process(batch);
29//! }
30//! ```
31
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use arrow::array::RecordBatch;
36use arrow::datatypes::SchemaRef;
37
38use super::channel::Consumer;
39use super::error::RecvError;
40use super::sink::SinkInner;
41use super::source::{Record, SourceMessage};
42
43/// A subscription to a streaming sink.
44///
45/// Subscriptions receive records from a Sink and provide them via
46/// polling, blocking receive, or iterator interfaces.
47///
48/// ## Modes
49///
50/// - **Direct**: First subscriber, reads directly from source channel
51/// - **Broadcast**: Additional subscribers, reads from dedicated channel
52pub struct Subscription<T: Record> {
53    inner: SubscriptionInner<T>,
54    schema: SchemaRef,
55}
56
57enum SubscriptionInner<T: Record> {
58    /// Direct subscription to sink's consumer.
59    Direct(Arc<SinkInner<T>>),
60    /// Broadcast subscription with dedicated channel.
61    Broadcast(Consumer<SourceMessage<T>>),
62}
63
64impl<T: Record> Subscription<T> {
65    /// Creates a direct subscription (first subscriber).
66    pub(crate) fn new_direct(sink_inner: Arc<SinkInner<T>>) -> Self {
67        let schema = sink_inner.schema();
68        Self {
69            inner: SubscriptionInner::Direct(sink_inner),
70            schema,
71        }
72    }
73
74    /// Creates a broadcast subscription (additional subscribers).
75    pub(crate) fn new_broadcast(consumer: Consumer<SourceMessage<T>>, schema: SchemaRef) -> Self {
76        Self {
77            inner: SubscriptionInner::Broadcast(consumer),
78            schema,
79        }
80    }
81
82    /// Polls for the next record batch without blocking.
83    ///
84    /// Returns `Some(RecordBatch)` if data is available, `None` if empty.
85    ///
86    /// Records are automatically converted to Arrow `RecordBatch` format.
87    #[must_use]
88    pub fn poll(&self) -> Option<RecordBatch> {
89        let msg = match &self.inner {
90            SubscriptionInner::Direct(sink) => sink.consumer().poll(),
91            SubscriptionInner::Broadcast(consumer) => consumer.poll(),
92        }?;
93
94        Self::message_to_batch(msg)
95    }
96
97    /// Polls for raw messages (without conversion to `RecordBatch`).
98    ///
99    /// This is useful when you need to handle watermarks separately.
100    #[must_use]
101    pub fn poll_message(&self) -> Option<SubscriptionMessage<T>> {
102        let msg = match &self.inner {
103            SubscriptionInner::Direct(sink) => sink.consumer().poll(),
104            SubscriptionInner::Broadcast(consumer) => consumer.poll(),
105        }?;
106
107        Some(Self::convert_message(msg))
108    }
109
110    /// Receives the next record batch, blocking until available.
111    ///
112    /// # Errors
113    ///
114    /// Returns `RecvError::Disconnected` if the source has been dropped
115    /// and there are no more buffered records.
116    pub fn recv(&self) -> Result<RecordBatch, RecvError> {
117        loop {
118            if let Some(batch) = self.poll() {
119                return Ok(batch);
120            }
121
122            if self.is_disconnected() {
123                return Err(RecvError::Disconnected);
124            }
125
126            // Brief yield before retrying
127            std::hint::spin_loop();
128        }
129    }
130
131    /// Receives the next record batch with a timeout.
132    ///
133    /// # Errors
134    ///
135    /// Returns `RecvError::Timeout` if no record becomes available within the timeout.
136    /// Returns `RecvError::Disconnected` if the source has been dropped.
137    pub fn recv_timeout(&self, timeout: Duration) -> Result<RecordBatch, RecvError> {
138        let deadline = Instant::now() + timeout;
139
140        loop {
141            if let Some(batch) = self.poll() {
142                return Ok(batch);
143            }
144
145            if self.is_disconnected() {
146                return Err(RecvError::Disconnected);
147            }
148
149            if Instant::now() >= deadline {
150                return Err(RecvError::Timeout);
151            }
152
153            std::hint::spin_loop();
154        }
155    }
156
157    /// Polls multiple record batches into a vector.
158    ///
159    /// Returns up to `max_count` batches.
160    ///
161    /// # Performance Warning
162    ///
163    /// **This method allocates a `Vec` on every call.** Do not use on hot paths
164    /// where allocation overhead matters. For zero-allocation consumption, use
165    /// [`poll_each`](Self::poll_each) or [`poll_batch_into`](Self::poll_batch_into).
166    #[cold]
167    #[must_use]
168    pub fn poll_batch(&self, max_count: usize) -> Vec<RecordBatch> {
169        let mut batches = Vec::with_capacity(max_count);
170
171        for _ in 0..max_count {
172            if let Some(batch) = self.poll() {
173                batches.push(batch);
174            } else {
175                break;
176            }
177        }
178
179        batches
180    }
181
182    /// Polls multiple record batches into a pre-allocated vector (zero-allocation).
183    ///
184    /// Appends up to `max_count` batches to the provided vector.
185    /// Returns the number of batches added.
186    ///
187    /// # Example
188    ///
189    /// ```rust,ignore
190    /// let mut buffer = Vec::with_capacity(100);
191    /// loop {
192    ///     buffer.clear();
193    ///     let count = subscription.poll_batch_into(&mut buffer, 100);
194    ///     if count == 0 { break; }
195    ///     for batch in &buffer {
196    ///         process(batch);
197    ///     }
198    /// }
199    /// ```
200    pub fn poll_batch_into(&self, buffer: &mut Vec<RecordBatch>, max_count: usize) -> usize {
201        let mut count = 0;
202
203        for _ in 0..max_count {
204            if let Some(batch) = self.poll() {
205                buffer.push(batch);
206                count += 1;
207            } else {
208                break;
209            }
210        }
211
212        count
213    }
214
215    /// Processes records with a callback (zero-allocation).
216    ///
217    /// The callback receives each `RecordBatch`. Processing stops when:
218    /// - `max_count` batches have been processed
219    /// - No more batches are available
220    /// - The callback returns `false`
221    ///
222    /// Returns the number of batches processed.
223    pub fn poll_each<F>(&self, max_count: usize, mut f: F) -> usize
224    where
225        F: FnMut(RecordBatch) -> bool,
226    {
227        let mut count = 0;
228
229        for _ in 0..max_count {
230            if let Some(batch) = self.poll() {
231                count += 1;
232                if !f(batch) {
233                    break;
234                }
235            } else {
236                break;
237            }
238        }
239
240        count
241    }
242
243    /// Returns true if the source has been dropped and buffer is empty.
244    #[must_use]
245    pub fn is_disconnected(&self) -> bool {
246        match &self.inner {
247            SubscriptionInner::Direct(sink) => sink.is_disconnected(),
248            SubscriptionInner::Broadcast(consumer) => consumer.is_disconnected(),
249        }
250    }
251
252    /// Returns the number of pending items.
253    #[must_use]
254    pub fn pending(&self) -> usize {
255        match &self.inner {
256            SubscriptionInner::Direct(sink) => sink.consumer().len(),
257            SubscriptionInner::Broadcast(consumer) => consumer.len(),
258        }
259    }
260
261    /// Returns the schema for records in this subscription.
262    #[must_use]
263    pub fn schema(&self) -> SchemaRef {
264        Arc::clone(&self.schema)
265    }
266
267    fn message_to_batch(msg: SourceMessage<T>) -> Option<RecordBatch> {
268        match msg {
269            SourceMessage::Record(record) => Some(record.to_record_batch()),
270            SourceMessage::Batch(batch) => Some(batch),
271            SourceMessage::Watermark(_) => {
272                // Skip watermarks in poll(), they're handled separately
273                None
274            }
275        }
276    }
277
278    fn convert_message(msg: SourceMessage<T>) -> SubscriptionMessage<T> {
279        match msg {
280            SourceMessage::Record(record) => SubscriptionMessage::Record(record),
281            SourceMessage::Batch(batch) => SubscriptionMessage::Batch(batch),
282            SourceMessage::Watermark(ts) => SubscriptionMessage::Watermark(ts),
283        }
284    }
285}
286
287/// Message types that can be received from a subscription.
288#[derive(Debug)]
289pub enum SubscriptionMessage<T> {
290    /// A single record.
291    Record(T),
292    /// A batch of records.
293    Batch(RecordBatch),
294    /// A watermark timestamp.
295    Watermark(i64),
296}
297
298impl<T: Record> SubscriptionMessage<T> {
299    /// Returns true if this is a record message.
300    #[must_use]
301    pub fn is_record(&self) -> bool {
302        matches!(self, Self::Record(_))
303    }
304
305    /// Returns true if this is a batch message.
306    #[must_use]
307    pub fn is_batch(&self) -> bool {
308        matches!(self, Self::Batch(_))
309    }
310
311    /// Returns true if this is a watermark message.
312    #[must_use]
313    pub fn is_watermark(&self) -> bool {
314        matches!(self, Self::Watermark(_))
315    }
316
317    /// Converts to a `RecordBatch` if this is a data message.
318    #[must_use]
319    pub fn to_batch(self) -> Option<RecordBatch> {
320        match self {
321            Self::Record(r) => Some(r.to_record_batch()),
322            Self::Batch(b) => Some(b),
323            Self::Watermark(_) => None,
324        }
325    }
326
327    /// Returns the watermark timestamp if this is a watermark message.
328    #[must_use]
329    pub fn watermark(&self) -> Option<i64> {
330        match self {
331            Self::Watermark(ts) => Some(*ts),
332            _ => None,
333        }
334    }
335}
336
337/// Iterator implementation for Subscription.
338///
339/// Iterates over record batches, blocking on each call to `next()`.
340/// Iteration stops when the source is disconnected.
341impl<T: Record> Iterator for Subscription<T> {
342    type Item = RecordBatch;
343
344    fn next(&mut self) -> Option<Self::Item> {
345        self.recv().ok()
346    }
347}
348
349impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351        let mode = match &self.inner {
352            SubscriptionInner::Direct(_) => "Direct",
353            SubscriptionInner::Broadcast(_) => "Broadcast",
354        };
355
356        f.debug_struct("Subscription")
357            .field("mode", &mode)
358            .field("pending", &self.pending())
359            .field("is_disconnected", &self.is_disconnected())
360            .field("schema", &self.schema)
361            .finish()
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use crate::streaming::source::create;
369    use arrow::array::{Float64Array, Int64Array};
370    use arrow::datatypes::{DataType, Field, Schema};
371    use std::sync::Arc;
372
373    #[derive(Clone, Debug)]
374    struct TestEvent {
375        id: i64,
376        value: f64,
377    }
378
379    impl Record for TestEvent {
380        fn schema() -> SchemaRef {
381            Arc::new(Schema::new(vec![
382                Field::new("id", DataType::Int64, false),
383                Field::new("value", DataType::Float64, false),
384            ]))
385        }
386
387        fn to_record_batch(&self) -> RecordBatch {
388            RecordBatch::try_new(
389                Self::schema(),
390                vec![
391                    Arc::new(Int64Array::from(vec![self.id])),
392                    Arc::new(Float64Array::from(vec![self.value])),
393                ],
394            )
395            .unwrap()
396        }
397    }
398
399    #[test]
400    fn test_poll_empty() {
401        let (_source, sink) = create::<TestEvent>(16);
402        let sub = sink.subscribe();
403
404        assert!(sub.poll().is_none());
405    }
406
407    #[test]
408    fn test_poll_records() {
409        let (source, sink) = create::<TestEvent>(16);
410        let sub = sink.subscribe();
411
412        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
413        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
414
415        let batch1 = sub.poll().unwrap();
416        assert_eq!(batch1.num_rows(), 1);
417
418        let batch2 = sub.poll().unwrap();
419        assert_eq!(batch2.num_rows(), 1);
420
421        assert!(sub.poll().is_none());
422    }
423
424    #[test]
425    fn test_poll_message() {
426        let (source, sink) = create::<TestEvent>(16);
427        let sub = sink.subscribe();
428
429        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
430
431        let msg = sub.poll_message().unwrap();
432        assert!(msg.is_record());
433    }
434
435    #[test]
436    fn test_recv_timeout() {
437        let (_source, sink) = create::<TestEvent>(16);
438        let sub = sink.subscribe();
439
440        // Should timeout on empty subscription
441        let result = sub.recv_timeout(Duration::from_millis(10));
442        assert!(matches!(result, Err(RecvError::Timeout)));
443    }
444
445    #[test]
446    fn test_recv_timeout_success() {
447        let (source, sink) = create::<TestEvent>(16);
448        let sub = sink.subscribe();
449
450        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
451
452        let result = sub.recv_timeout(Duration::from_secs(1));
453        assert!(result.is_ok());
454    }
455
456    #[test]
457    fn test_poll_batch() {
458        let (source, sink) = create::<TestEvent>(16);
459        let sub = sink.subscribe();
460
461        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
462        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
463        source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
464
465        let batches = sub.poll_batch(10);
466        assert_eq!(batches.len(), 3);
467    }
468
469    #[test]
470    fn test_poll_each() {
471        let (source, sink) = create::<TestEvent>(16);
472        let sub = sink.subscribe();
473
474        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
475        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
476
477        let mut total_rows = 0;
478        let count = sub.poll_each(10, |batch| {
479            total_rows += batch.num_rows();
480            true
481        });
482
483        assert_eq!(count, 2);
484        assert_eq!(total_rows, 2);
485    }
486
487    #[test]
488    fn test_poll_each_early_stop() {
489        let (source, sink) = create::<TestEvent>(16);
490        let sub = sink.subscribe();
491
492        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
493        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
494        source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
495
496        let mut seen = 0;
497        let count = sub.poll_each(10, |_| {
498            seen += 1;
499            seen < 2 // Stop after 2
500        });
501
502        assert_eq!(count, 2);
503        assert_eq!(seen, 2);
504        assert_eq!(sub.pending(), 1); // One left
505    }
506
507    #[test]
508    fn test_disconnected() {
509        let (source, sink) = create::<TestEvent>(16);
510        let sub = sink.subscribe();
511
512        assert!(!sub.is_disconnected());
513
514        drop(source);
515
516        assert!(sub.is_disconnected());
517    }
518
519    #[test]
520    fn test_pending() {
521        let (source, sink) = create::<TestEvent>(16);
522        let sub = sink.subscribe();
523
524        assert_eq!(sub.pending(), 0);
525
526        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
527        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
528
529        assert_eq!(sub.pending(), 2);
530    }
531
532    #[test]
533    fn test_schema() {
534        let (_source, sink) = create::<TestEvent>(16);
535        let sub = sink.subscribe();
536
537        let schema = sub.schema();
538        assert_eq!(schema.fields().len(), 2);
539    }
540
541    #[test]
542    fn test_subscription_message() {
543        let msg = SubscriptionMessage::Record(TestEvent { id: 1, value: 1.0 });
544        assert!(msg.is_record());
545        assert!(!msg.is_batch());
546        assert!(!msg.is_watermark());
547
548        let batch = msg.to_batch().unwrap();
549        assert_eq!(batch.num_rows(), 1);
550
551        let wm = SubscriptionMessage::<TestEvent>::Watermark(1000);
552        assert!(wm.is_watermark());
553        assert_eq!(wm.watermark(), Some(1000));
554    }
555
556    #[test]
557    fn test_iterator() {
558        let (source, sink) = create::<TestEvent>(16);
559        let mut sub = sink.subscribe();
560
561        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
562        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
563
564        drop(source);
565
566        let batches: Vec<_> = sub.by_ref().collect();
567        assert_eq!(batches.len(), 2);
568    }
569
570    #[test]
571    fn test_debug_format() {
572        let (_source, sink) = create::<TestEvent>(16);
573        let sub = sink.subscribe();
574
575        let debug = format!("{sub:?}");
576        assert!(debug.contains("Subscription"));
577        assert!(debug.contains("Direct"));
578    }
579}