laminar_core/subscription/
callback.rs1use std::panic::AssertUnwindSafe;
29use std::sync::Arc;
30
31use tokio::sync::broadcast;
32
33use crate::subscription::event::ChangeEvent;
34use crate::subscription::handle::PushSubscriptionError;
35use crate::subscription::registry::{
36 SubscriptionConfig, SubscriptionId, SubscriptionMetrics, SubscriptionRegistry,
37};
38
39pub trait SubscriptionCallback: Send + Sync + 'static {
64 fn on_change(&self, event: ChangeEvent);
66
67 fn on_error(&self, error: PushSubscriptionError) {
71 tracing::warn!("subscription callback error: {}", error);
72 }
73
74 fn on_complete(&self) {}
78}
79
80struct FnCallback<F>(F);
86
87impl<F: Fn(ChangeEvent) + Send + Sync + 'static> SubscriptionCallback for FnCallback<F> {
88 fn on_change(&self, event: ChangeEvent) {
89 (self.0)(event);
90 }
91}
92
93pub struct CallbackSubscriptionHandle {
107 id: SubscriptionId,
109 registry: Arc<SubscriptionRegistry>,
111 task: Option<tokio::task::JoinHandle<()>>,
113 cancelled: bool,
115}
116
117impl CallbackSubscriptionHandle {
118 #[must_use]
124 pub fn pause(&self) -> bool {
125 self.registry.pause(self.id)
126 }
127
128 #[must_use]
132 pub fn resume(&self) -> bool {
133 self.registry.resume(self.id)
134 }
135
136 pub fn cancel(&mut self) {
141 if !self.cancelled {
142 self.cancelled = true;
143 self.registry.cancel(self.id);
144 if let Some(task) = self.task.take() {
145 task.abort();
146 }
147 }
148 }
149
150 #[must_use]
152 pub fn id(&self) -> SubscriptionId {
153 self.id
154 }
155
156 #[must_use]
158 pub fn metrics(&self) -> Option<SubscriptionMetrics> {
159 self.registry.metrics(self.id)
160 }
161
162 #[must_use]
164 pub fn is_cancelled(&self) -> bool {
165 self.cancelled
166 }
167}
168
169impl Drop for CallbackSubscriptionHandle {
170 fn drop(&mut self) {
171 if !self.cancelled {
172 self.registry.cancel(self.id);
173 if let Some(task) = self.task.take() {
174 task.abort();
175 }
176 }
177 }
178}
179
180pub fn subscribe_callback<C: SubscriptionCallback>(
201 registry: Arc<SubscriptionRegistry>,
202 source_name: String,
203 source_id: u32,
204 config: SubscriptionConfig,
205 callback: C,
206) -> CallbackSubscriptionHandle {
207 let (id, receiver) = registry.create(source_name, source_id, config);
208 let callback = Arc::new(callback);
209
210 let task = tokio::spawn(callback_runner(receiver, callback));
211
212 CallbackSubscriptionHandle {
213 id,
214 registry,
215 task: Some(task),
216 cancelled: false,
217 }
218}
219
220pub fn subscribe_fn<F>(
233 registry: Arc<SubscriptionRegistry>,
234 source_name: String,
235 source_id: u32,
236 config: SubscriptionConfig,
237 f: F,
238) -> CallbackSubscriptionHandle
239where
240 F: Fn(ChangeEvent) + Send + Sync + 'static,
241{
242 subscribe_callback(registry, source_name, source_id, config, FnCallback(f))
243}
244
245async fn callback_runner<C: SubscriptionCallback>(
252 mut receiver: broadcast::Receiver<ChangeEvent>,
253 callback: Arc<C>,
254) {
255 loop {
256 match receiver.recv().await {
257 Ok(event) => {
258 let cb = Arc::clone(&callback);
259 let result = std::panic::catch_unwind(AssertUnwindSafe(|| cb.on_change(event)));
260 if let Err(panic) = result {
261 let msg = if let Some(s) = panic.downcast_ref::<&str>() {
262 format!("callback panicked: {s}")
263 } else if let Some(s) = panic.downcast_ref::<String>() {
264 format!("callback panicked: {s}")
265 } else {
266 "callback panicked".to_string()
267 };
268 callback.on_error(PushSubscriptionError::Internal(msg));
269 }
270 }
271 Err(broadcast::error::RecvError::Lagged(n)) => {
272 callback.on_error(PushSubscriptionError::Lagged(n));
273 }
275 Err(broadcast::error::RecvError::Closed) => {
276 callback.on_complete();
277 break;
278 }
279 }
280 }
281}
282
283#[cfg(test)]
288#[allow(clippy::cast_sign_loss)]
289#[allow(clippy::cast_possible_wrap)]
290#[allow(clippy::field_reassign_with_default)]
291mod tests {
292 use super::*;
293 use std::sync::Mutex;
294
295 use arrow_array::Int64Array;
296 use arrow_schema::{DataType, Field, Schema};
297
298 use crate::subscription::registry::SubscriptionState;
299
300 fn make_batch(n: usize) -> arrow_array::RecordBatch {
301 let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
302 let values: Vec<i64> = (0..n as i64).collect();
303 let array = Int64Array::from(values);
304 arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
305 }
306
307 #[derive(Clone)]
310 struct TestCallback {
311 events: Arc<Mutex<Vec<i64>>>,
312 errors: Arc<Mutex<Vec<String>>>,
313 completed: Arc<Mutex<bool>>,
314 }
315
316 impl TestCallback {
317 fn new() -> Self {
318 Self {
319 events: Arc::new(Mutex::new(Vec::new())),
320 errors: Arc::new(Mutex::new(Vec::new())),
321 completed: Arc::new(Mutex::new(false)),
322 }
323 }
324 }
325
326 impl SubscriptionCallback for TestCallback {
327 fn on_change(&self, event: ChangeEvent) {
328 self.events.lock().unwrap().push(event.timestamp());
329 }
330
331 fn on_error(&self, error: PushSubscriptionError) {
332 self.errors.lock().unwrap().push(format!("{error}"));
333 }
334
335 fn on_complete(&self) {
336 *self.completed.lock().unwrap() = true;
337 }
338 }
339
340 #[tokio::test]
343 async fn test_callback_receives_events() {
344 let registry = Arc::new(SubscriptionRegistry::new());
345 let cb = TestCallback::new();
346 let events = Arc::clone(&cb.events);
347
348 let _handle = subscribe_callback(
349 Arc::clone(®istry),
350 "trades".into(),
351 0,
352 SubscriptionConfig::default(),
353 cb,
354 );
355
356 let senders = registry.get_senders_for_source(0);
357 assert_eq!(senders.len(), 1);
358
359 for i in 0..5i64 {
360 let batch = Arc::new(make_batch(1));
361 senders[0]
362 .send(ChangeEvent::insert(batch, i * 1000, i as u64))
363 .unwrap();
364 }
365
366 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
367
368 let received = events.lock().unwrap();
369 assert_eq!(received.len(), 5);
370 assert_eq!(*received, vec![0, 1000, 2000, 3000, 4000]);
371 }
372
373 #[tokio::test]
374 async fn test_callback_on_error_lagged() {
375 let registry = Arc::new(SubscriptionRegistry::new());
376 let mut cfg = SubscriptionConfig::default();
377 cfg.buffer_size = 4;
378 let cb = TestCallback::new();
379 let errors = Arc::clone(&cb.errors);
380 let events = Arc::clone(&cb.events);
381
382 let _handle = subscribe_callback(Arc::clone(®istry), "trades".into(), 0, cfg, cb);
383
384 let senders = registry.get_senders_for_source(0);
385
386 for i in 0..20i64 {
388 let batch = Arc::new(make_batch(1));
389 let _ = senders[0].send(ChangeEvent::insert(batch, i * 100, i as u64));
390 }
391
392 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
393
394 let errs = errors.lock().unwrap();
395 assert!(!errs.is_empty(), "expected at least one lag error");
396 assert!(errs[0].contains("lagged behind"));
397
398 let evts = events.lock().unwrap();
400 assert!(!evts.is_empty(), "should receive events after lag");
401 }
402
403 #[tokio::test]
404 async fn test_callback_on_complete() {
405 let registry = Arc::new(SubscriptionRegistry::new());
406 let cb = TestCallback::new();
407 let completed = Arc::clone(&cb.completed);
408
409 let handle = subscribe_callback(
410 Arc::clone(®istry),
411 "trades".into(),
412 0,
413 SubscriptionConfig::default(),
414 cb,
415 );
416
417 registry.cancel(handle.id());
419
420 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
421
422 assert!(*completed.lock().unwrap());
423 }
424
425 #[tokio::test]
426 async fn test_callback_panic_caught() {
427 struct PanickingCallback {
428 errors: Arc<Mutex<Vec<String>>>,
429 }
430
431 impl SubscriptionCallback for PanickingCallback {
432 fn on_change(&self, _event: ChangeEvent) {
433 panic!("deliberate test panic");
434 }
435
436 fn on_error(&self, error: PushSubscriptionError) {
437 self.errors.lock().unwrap().push(format!("{error}"));
438 }
439 }
440
441 let registry = Arc::new(SubscriptionRegistry::new());
442 let errors: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
443
444 let _handle = subscribe_callback(
445 Arc::clone(®istry),
446 "trades".into(),
447 0,
448 SubscriptionConfig::default(),
449 PanickingCallback {
450 errors: Arc::clone(&errors),
451 },
452 );
453
454 let senders = registry.get_senders_for_source(0);
455 let batch = Arc::new(make_batch(1));
456 senders[0]
457 .send(ChangeEvent::insert(batch, 1000, 1))
458 .unwrap();
459
460 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
461
462 let errs = errors.lock().unwrap();
463 assert_eq!(errs.len(), 1);
464 assert!(errs[0].contains("callback panicked"));
465 assert!(errs[0].contains("deliberate test panic"));
466 }
467
468 #[tokio::test]
469 async fn test_callback_handle_pause_resume() {
470 let registry = Arc::new(SubscriptionRegistry::new());
471 let cb = TestCallback::new();
472
473 let handle = subscribe_callback(
474 Arc::clone(®istry),
475 "trades".into(),
476 0,
477 SubscriptionConfig::default(),
478 cb,
479 );
480
481 assert!(handle.pause());
482 assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Paused));
483
484 assert!(!handle.pause());
486
487 assert!(handle.resume());
488 assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Active));
489
490 assert!(!handle.resume());
492 }
493
494 #[tokio::test]
495 async fn test_callback_handle_cancel() {
496 let registry = Arc::new(SubscriptionRegistry::new());
497 let cb = TestCallback::new();
498
499 let mut handle = subscribe_callback(
500 Arc::clone(®istry),
501 "trades".into(),
502 0,
503 SubscriptionConfig::default(),
504 cb,
505 );
506
507 assert_eq!(registry.subscription_count(), 1);
508 assert!(!handle.is_cancelled());
509
510 handle.cancel();
511
512 assert!(handle.is_cancelled());
513 assert_eq!(registry.subscription_count(), 0);
514
515 handle.cancel();
517 assert_eq!(registry.subscription_count(), 0);
518 }
519
520 #[tokio::test]
521 async fn test_callback_handle_drop_cancels() {
522 let registry = Arc::new(SubscriptionRegistry::new());
523 let cb = TestCallback::new();
524
525 {
526 let _handle = subscribe_callback(
527 Arc::clone(®istry),
528 "trades".into(),
529 0,
530 SubscriptionConfig::default(),
531 cb,
532 );
533 assert_eq!(registry.subscription_count(), 1);
534 }
535 assert_eq!(registry.subscription_count(), 0);
537 }
538
539 #[tokio::test]
540 async fn test_subscribe_fn() {
541 let registry = Arc::new(SubscriptionRegistry::new());
542 let events: Arc<Mutex<Vec<i64>>> = Arc::new(Mutex::new(Vec::new()));
543 let events_clone = Arc::clone(&events);
544
545 let _handle = subscribe_fn(
546 Arc::clone(®istry),
547 "trades".into(),
548 0,
549 SubscriptionConfig::default(),
550 move |event| {
551 events_clone.lock().unwrap().push(event.timestamp());
552 },
553 );
554
555 let senders = registry.get_senders_for_source(0);
556 let batch = Arc::new(make_batch(1));
557 senders[0]
558 .send(ChangeEvent::insert(batch, 5000, 1))
559 .unwrap();
560
561 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
562
563 let received = events.lock().unwrap();
564 assert_eq!(*received, vec![5000]);
565 }
566
567 #[tokio::test]
568 async fn test_callback_ordering() {
569 let registry = Arc::new(SubscriptionRegistry::new());
570 let cb = TestCallback::new();
571 let events = Arc::clone(&cb.events);
572
573 let _handle = subscribe_callback(
574 Arc::clone(®istry),
575 "trades".into(),
576 0,
577 SubscriptionConfig::default(),
578 cb,
579 );
580
581 let senders = registry.get_senders_for_source(0);
582
583 for i in 0..10i64 {
584 let batch = Arc::new(make_batch(1));
585 senders[0]
586 .send(ChangeEvent::insert(batch, i, i as u64))
587 .unwrap();
588 }
589
590 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
591
592 let received = events.lock().unwrap();
593 assert_eq!(received.len(), 10);
594 let expected: Vec<i64> = (0..10).collect();
595 assert_eq!(*received, expected);
596 }
597
598 #[tokio::test]
599 async fn test_callback_handle_metrics() {
600 let registry = Arc::new(SubscriptionRegistry::new());
601 let cb = TestCallback::new();
602
603 let handle = subscribe_callback(
604 Arc::clone(®istry),
605 "trades".into(),
606 0,
607 SubscriptionConfig::default(),
608 cb,
609 );
610
611 let m = handle.metrics().unwrap();
612 assert_eq!(m.id, handle.id());
613 assert_eq!(m.source_name, "trades");
614 assert_eq!(m.state, SubscriptionState::Active);
615 }
616}