Skip to main content

laminar_core/subscription/
callback.rs

1//! Callback-based subscriptions — [`SubscriptionCallback`] trait and
2//! [`CallbackSubscriptionHandle`].
3//!
4//! Provides a callback-based subscription API where users register a callback
5//! function or trait object that is invoked for every change event. The callback
6//! runs on a dedicated tokio task, wrapping the channel-based broadcast receiver
7//! from the [`SubscriptionRegistry`] internally.
8//!
9//! # API Styles
10//!
11//! - **Trait-based**: Implement [`SubscriptionCallback`] for full control over
12//!   change, error, and completion events.
13//! - **Closure-based**: Use [`subscribe_fn`] for simple cases where only
14//!   `on_change` is needed.
15//!
16//! # Panic Safety
17//!
18//! Panics in the callback's [`on_change`](SubscriptionCallback::on_change) are
19//! caught via [`std::panic::catch_unwind`] and forwarded to
20//! [`on_error`](SubscriptionCallback::on_error) as
21//! [`PushSubscriptionError::Internal`].
22//!
23//! # Lifecycle
24//!
25//! Dropping a [`CallbackSubscriptionHandle`] automatically cancels the
26//! subscription and aborts the callback task.
27
28use std::panic::AssertUnwindSafe;
29use std::sync::Arc;
30
31use tokio::sync::broadcast;
32
33use crate::subscription::event::ChangeEvent;
34use crate::subscription::handle::PushSubscriptionError;
35use crate::subscription::registry::{
36    SubscriptionConfig, SubscriptionId, SubscriptionMetrics, SubscriptionRegistry,
37};
38
39// ---------------------------------------------------------------------------
40// SubscriptionCallback
41// ---------------------------------------------------------------------------
42
43/// Callback trait for push-based subscriptions.
44///
45/// Implement this trait to receive change events via callback. The callback
46/// runs on a dedicated tokio task and is invoked for every event pushed by
47/// the Ring 1 dispatcher.
48///
49/// # Example
50///
51/// ```rust,ignore
52/// struct MyHandler;
53///
54/// impl SubscriptionCallback for MyHandler {
55///     fn on_change(&self, event: ChangeEvent) {
56///         match event {
57///             ChangeEvent::Insert { data, .. } => println!("{} rows", data.num_rows()),
58///             _ => {}
59///         }
60///     }
61/// }
62/// ```
63pub trait SubscriptionCallback: Send + Sync + 'static {
64    /// Called for each change event.
65    fn on_change(&self, event: ChangeEvent);
66
67    /// Called when an error occurs (e.g., lagged behind, internal error).
68    ///
69    /// Default implementation logs the error via `tracing::warn!`.
70    fn on_error(&self, error: PushSubscriptionError) {
71        tracing::warn!("subscription callback error: {}", error);
72    }
73
74    /// Called when the subscription is closed (source dropped or cancelled).
75    ///
76    /// Default implementation is a no-op.
77    fn on_complete(&self) {}
78}
79
80// ---------------------------------------------------------------------------
81// FnCallback (private adapter)
82// ---------------------------------------------------------------------------
83
84/// Adapter that wraps a closure into a [`SubscriptionCallback`].
85struct FnCallback<F>(F);
86
87impl<F: Fn(ChangeEvent) + Send + Sync + 'static> SubscriptionCallback for FnCallback<F> {
88    fn on_change(&self, event: ChangeEvent) {
89        (self.0)(event);
90    }
91}
92
93// ---------------------------------------------------------------------------
94// CallbackSubscriptionHandle
95// ---------------------------------------------------------------------------
96
97/// Handle for a callback-based subscription.
98///
99/// Provides lifecycle management (pause/resume/cancel) for the callback task.
100/// The handle and the callback task share the same `SubscriptionEntry` in
101/// the registry (via [`SubscriptionId`]), so `pause()` / `cancel()` on the
102/// handle directly affects the task's event delivery.
103///
104/// Dropping the handle automatically cancels the subscription and aborts the
105/// callback task.
106pub struct CallbackSubscriptionHandle {
107    /// Subscription ID (shared with the callback task).
108    id: SubscriptionId,
109    /// Registry reference for lifecycle management.
110    registry: Arc<SubscriptionRegistry>,
111    /// Join handle for the callback runner task.
112    task: Option<tokio::task::JoinHandle<()>>,
113    /// Whether the subscription has been explicitly cancelled.
114    cancelled: bool,
115}
116
117impl CallbackSubscriptionHandle {
118    /// Pauses the subscription.
119    ///
120    /// While paused, events are buffered or dropped per the backpressure
121    /// configuration. Returns `true` if the subscription was active and is
122    /// now paused.
123    #[must_use]
124    pub fn pause(&self) -> bool {
125        self.registry.pause(self.id)
126    }
127
128    /// Resumes a paused subscription.
129    ///
130    /// Returns `true` if the subscription was paused and is now active.
131    #[must_use]
132    pub fn resume(&self) -> bool {
133        self.registry.resume(self.id)
134    }
135
136    /// Cancels the subscription and aborts the callback task.
137    ///
138    /// The subscription is removed from the registry (dropping the broadcast
139    /// sender) and the task is aborted as a safety net.
140    pub fn cancel(&mut self) {
141        if !self.cancelled {
142            self.cancelled = true;
143            self.registry.cancel(self.id);
144            if let Some(task) = self.task.take() {
145                task.abort();
146            }
147        }
148    }
149
150    /// Returns the subscription ID.
151    #[must_use]
152    pub fn id(&self) -> SubscriptionId {
153        self.id
154    }
155
156    /// Returns subscription metrics from the registry.
157    #[must_use]
158    pub fn metrics(&self) -> Option<SubscriptionMetrics> {
159        self.registry.metrics(self.id)
160    }
161
162    /// Returns `true` if the subscription has been cancelled.
163    #[must_use]
164    pub fn is_cancelled(&self) -> bool {
165        self.cancelled
166    }
167}
168
169impl Drop for CallbackSubscriptionHandle {
170    fn drop(&mut self) {
171        if !self.cancelled {
172            self.registry.cancel(self.id);
173            if let Some(task) = self.task.take() {
174                task.abort();
175            }
176        }
177    }
178}
179
180// ---------------------------------------------------------------------------
181// Factory Functions
182// ---------------------------------------------------------------------------
183
184/// Creates a callback-based subscription.
185///
186/// Registers a subscription in the registry, then spawns a tokio task that
187/// calls `callback.on_change()` for every event. Panics in the callback are
188/// caught and forwarded to `callback.on_error()`.
189///
190/// When the broadcast sender is dropped (e.g., via cancel or registry
191/// cleanup), the task calls `callback.on_complete()` and exits.
192///
193/// # Arguments
194///
195/// * `registry` — Subscription registry for lifecycle management.
196/// * `source_name` — Name of the source MV or query.
197/// * `source_id` — Ring 0 source identifier.
198/// * `config` — Subscription configuration.
199/// * `callback` — Implementation of [`SubscriptionCallback`].
200pub fn subscribe_callback<C: SubscriptionCallback>(
201    registry: Arc<SubscriptionRegistry>,
202    source_name: String,
203    source_id: u32,
204    config: SubscriptionConfig,
205    callback: C,
206) -> CallbackSubscriptionHandle {
207    let (id, receiver) = registry.create(source_name, source_id, config);
208    let callback = Arc::new(callback);
209
210    let task = tokio::spawn(callback_runner(receiver, callback));
211
212    CallbackSubscriptionHandle {
213        id,
214        registry,
215        task: Some(task),
216        cancelled: false,
217    }
218}
219
220/// Creates a closure-based subscription (convenience wrapper).
221///
222/// Equivalent to [`subscribe_callback`] with a closure wrapped in an internal
223/// adapter that implements [`SubscriptionCallback`].
224///
225/// # Example
226///
227/// ```rust,ignore
228/// let handle = subscribe_fn(registry, "trades".into(), 0, config, |event| {
229///     println!("Got: {:?}", event.event_type());
230/// });
231/// ```
232pub fn subscribe_fn<F>(
233    registry: Arc<SubscriptionRegistry>,
234    source_name: String,
235    source_id: u32,
236    config: SubscriptionConfig,
237    f: F,
238) -> CallbackSubscriptionHandle
239where
240    F: Fn(ChangeEvent) + Send + Sync + 'static,
241{
242    subscribe_callback(registry, source_name, source_id, config, FnCallback(f))
243}
244
245// ---------------------------------------------------------------------------
246// Callback Runner (internal)
247// ---------------------------------------------------------------------------
248
249/// Internal task that receives events from the broadcast channel and calls
250/// the callback. Panics in `on_change` are caught and forwarded to `on_error`.
251async fn callback_runner<C: SubscriptionCallback>(
252    mut receiver: broadcast::Receiver<ChangeEvent>,
253    callback: Arc<C>,
254) {
255    loop {
256        match receiver.recv().await {
257            Ok(event) => {
258                let cb = Arc::clone(&callback);
259                let result = std::panic::catch_unwind(AssertUnwindSafe(|| cb.on_change(event)));
260                if let Err(panic) = result {
261                    let msg = if let Some(s) = panic.downcast_ref::<&str>() {
262                        format!("callback panicked: {s}")
263                    } else if let Some(s) = panic.downcast_ref::<String>() {
264                        format!("callback panicked: {s}")
265                    } else {
266                        "callback panicked".to_string()
267                    };
268                    callback.on_error(PushSubscriptionError::Internal(msg));
269                }
270            }
271            Err(broadcast::error::RecvError::Lagged(n)) => {
272                callback.on_error(PushSubscriptionError::Lagged(n));
273                // Continue receiving after lag
274            }
275            Err(broadcast::error::RecvError::Closed) => {
276                callback.on_complete();
277                break;
278            }
279        }
280    }
281}
282
283// ===========================================================================
284// Tests
285// ===========================================================================
286
287#[cfg(test)]
288#[allow(clippy::cast_sign_loss)]
289#[allow(clippy::cast_possible_wrap)]
290#[allow(clippy::field_reassign_with_default)]
291mod tests {
292    use super::*;
293    use std::sync::Mutex;
294
295    use arrow_array::Int64Array;
296    use arrow_schema::{DataType, Field, Schema};
297
298    use crate::subscription::registry::SubscriptionState;
299
300    fn make_batch(n: usize) -> arrow_array::RecordBatch {
301        let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
302        let values: Vec<i64> = (0..n as i64).collect();
303        let array = Int64Array::from(values);
304        arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
305    }
306
307    // --- Test callback implementation ---
308
309    #[derive(Clone)]
310    struct TestCallback {
311        events: Arc<Mutex<Vec<i64>>>,
312        errors: Arc<Mutex<Vec<String>>>,
313        completed: Arc<Mutex<bool>>,
314    }
315
316    impl TestCallback {
317        fn new() -> Self {
318            Self {
319                events: Arc::new(Mutex::new(Vec::new())),
320                errors: Arc::new(Mutex::new(Vec::new())),
321                completed: Arc::new(Mutex::new(false)),
322            }
323        }
324    }
325
326    impl SubscriptionCallback for TestCallback {
327        fn on_change(&self, event: ChangeEvent) {
328            self.events.lock().unwrap().push(event.timestamp());
329        }
330
331        fn on_error(&self, error: PushSubscriptionError) {
332            self.errors.lock().unwrap().push(format!("{error}"));
333        }
334
335        fn on_complete(&self) {
336            *self.completed.lock().unwrap() = true;
337        }
338    }
339
340    // --- Tests ---
341
342    #[tokio::test]
343    async fn test_callback_receives_events() {
344        let registry = Arc::new(SubscriptionRegistry::new());
345        let cb = TestCallback::new();
346        let events = Arc::clone(&cb.events);
347
348        let _handle = subscribe_callback(
349            Arc::clone(&registry),
350            "trades".into(),
351            0,
352            SubscriptionConfig::default(),
353            cb,
354        );
355
356        let senders = registry.get_senders_for_source(0);
357        assert_eq!(senders.len(), 1);
358
359        for i in 0..5i64 {
360            let batch = Arc::new(make_batch(1));
361            senders[0]
362                .send(ChangeEvent::insert(batch, i * 1000, i as u64))
363                .unwrap();
364        }
365
366        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
367
368        let received = events.lock().unwrap();
369        assert_eq!(received.len(), 5);
370        assert_eq!(*received, vec![0, 1000, 2000, 3000, 4000]);
371    }
372
373    #[tokio::test]
374    async fn test_callback_on_error_lagged() {
375        let registry = Arc::new(SubscriptionRegistry::new());
376        let mut cfg = SubscriptionConfig::default();
377        cfg.buffer_size = 4;
378        let cb = TestCallback::new();
379        let errors = Arc::clone(&cb.errors);
380        let events = Arc::clone(&cb.events);
381
382        let _handle = subscribe_callback(Arc::clone(&registry), "trades".into(), 0, cfg, cb);
383
384        let senders = registry.get_senders_for_source(0);
385
386        // Overflow the buffer to cause lag
387        for i in 0..20i64 {
388            let batch = Arc::new(make_batch(1));
389            let _ = senders[0].send(ChangeEvent::insert(batch, i * 100, i as u64));
390        }
391
392        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
393
394        let errs = errors.lock().unwrap();
395        assert!(!errs.is_empty(), "expected at least one lag error");
396        assert!(errs[0].contains("lagged behind"));
397
398        // Should still receive events after lag recovery
399        let evts = events.lock().unwrap();
400        assert!(!evts.is_empty(), "should receive events after lag");
401    }
402
403    #[tokio::test]
404    async fn test_callback_on_complete() {
405        let registry = Arc::new(SubscriptionRegistry::new());
406        let cb = TestCallback::new();
407        let completed = Arc::clone(&cb.completed);
408
409        let handle = subscribe_callback(
410            Arc::clone(&registry),
411            "trades".into(),
412            0,
413            SubscriptionConfig::default(),
414            cb,
415        );
416
417        // Cancel from registry side — drops sender → task gets Closed → on_complete
418        registry.cancel(handle.id());
419
420        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
421
422        assert!(*completed.lock().unwrap());
423    }
424
425    #[tokio::test]
426    async fn test_callback_panic_caught() {
427        struct PanickingCallback {
428            errors: Arc<Mutex<Vec<String>>>,
429        }
430
431        impl SubscriptionCallback for PanickingCallback {
432            fn on_change(&self, _event: ChangeEvent) {
433                panic!("deliberate test panic");
434            }
435
436            fn on_error(&self, error: PushSubscriptionError) {
437                self.errors.lock().unwrap().push(format!("{error}"));
438            }
439        }
440
441        let registry = Arc::new(SubscriptionRegistry::new());
442        let errors: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
443
444        let _handle = subscribe_callback(
445            Arc::clone(&registry),
446            "trades".into(),
447            0,
448            SubscriptionConfig::default(),
449            PanickingCallback {
450                errors: Arc::clone(&errors),
451            },
452        );
453
454        let senders = registry.get_senders_for_source(0);
455        let batch = Arc::new(make_batch(1));
456        senders[0]
457            .send(ChangeEvent::insert(batch, 1000, 1))
458            .unwrap();
459
460        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
461
462        let errs = errors.lock().unwrap();
463        assert_eq!(errs.len(), 1);
464        assert!(errs[0].contains("callback panicked"));
465        assert!(errs[0].contains("deliberate test panic"));
466    }
467
468    #[tokio::test]
469    async fn test_callback_handle_pause_resume() {
470        let registry = Arc::new(SubscriptionRegistry::new());
471        let cb = TestCallback::new();
472
473        let handle = subscribe_callback(
474            Arc::clone(&registry),
475            "trades".into(),
476            0,
477            SubscriptionConfig::default(),
478            cb,
479        );
480
481        assert!(handle.pause());
482        assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Paused));
483
484        // Already paused
485        assert!(!handle.pause());
486
487        assert!(handle.resume());
488        assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Active));
489
490        // Already active
491        assert!(!handle.resume());
492    }
493
494    #[tokio::test]
495    async fn test_callback_handle_cancel() {
496        let registry = Arc::new(SubscriptionRegistry::new());
497        let cb = TestCallback::new();
498
499        let mut handle = subscribe_callback(
500            Arc::clone(&registry),
501            "trades".into(),
502            0,
503            SubscriptionConfig::default(),
504            cb,
505        );
506
507        assert_eq!(registry.subscription_count(), 1);
508        assert!(!handle.is_cancelled());
509
510        handle.cancel();
511
512        assert!(handle.is_cancelled());
513        assert_eq!(registry.subscription_count(), 0);
514
515        // Idempotent
516        handle.cancel();
517        assert_eq!(registry.subscription_count(), 0);
518    }
519
520    #[tokio::test]
521    async fn test_callback_handle_drop_cancels() {
522        let registry = Arc::new(SubscriptionRegistry::new());
523        let cb = TestCallback::new();
524
525        {
526            let _handle = subscribe_callback(
527                Arc::clone(&registry),
528                "trades".into(),
529                0,
530                SubscriptionConfig::default(),
531                cb,
532            );
533            assert_eq!(registry.subscription_count(), 1);
534        }
535        // Dropped — should be cancelled
536        assert_eq!(registry.subscription_count(), 0);
537    }
538
539    #[tokio::test]
540    async fn test_subscribe_fn() {
541        let registry = Arc::new(SubscriptionRegistry::new());
542        let events: Arc<Mutex<Vec<i64>>> = Arc::new(Mutex::new(Vec::new()));
543        let events_clone = Arc::clone(&events);
544
545        let _handle = subscribe_fn(
546            Arc::clone(&registry),
547            "trades".into(),
548            0,
549            SubscriptionConfig::default(),
550            move |event| {
551                events_clone.lock().unwrap().push(event.timestamp());
552            },
553        );
554
555        let senders = registry.get_senders_for_source(0);
556        let batch = Arc::new(make_batch(1));
557        senders[0]
558            .send(ChangeEvent::insert(batch, 5000, 1))
559            .unwrap();
560
561        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
562
563        let received = events.lock().unwrap();
564        assert_eq!(*received, vec![5000]);
565    }
566
567    #[tokio::test]
568    async fn test_callback_ordering() {
569        let registry = Arc::new(SubscriptionRegistry::new());
570        let cb = TestCallback::new();
571        let events = Arc::clone(&cb.events);
572
573        let _handle = subscribe_callback(
574            Arc::clone(&registry),
575            "trades".into(),
576            0,
577            SubscriptionConfig::default(),
578            cb,
579        );
580
581        let senders = registry.get_senders_for_source(0);
582
583        for i in 0..10i64 {
584            let batch = Arc::new(make_batch(1));
585            senders[0]
586                .send(ChangeEvent::insert(batch, i, i as u64))
587                .unwrap();
588        }
589
590        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
591
592        let received = events.lock().unwrap();
593        assert_eq!(received.len(), 10);
594        let expected: Vec<i64> = (0..10).collect();
595        assert_eq!(*received, expected);
596    }
597
598    #[tokio::test]
599    async fn test_callback_handle_metrics() {
600        let registry = Arc::new(SubscriptionRegistry::new());
601        let cb = TestCallback::new();
602
603        let handle = subscribe_callback(
604            Arc::clone(&registry),
605            "trades".into(),
606            0,
607            SubscriptionConfig::default(),
608            cb,
609        );
610
611        let m = handle.metrics().unwrap();
612        assert_eq!(m.id, handle.id());
613        assert_eq!(m.source_name, "trades");
614        assert_eq!(m.state, SubscriptionState::Active);
615    }
616}