1use std::sync::Arc;
30
31use tokio::sync::broadcast;
32
33use crate::subscription::event::ChangeEvent;
34use crate::subscription::registry::{SubscriptionId, SubscriptionMetrics, SubscriptionRegistry};
35
36#[derive(Debug, thiserror::Error)]
42pub enum PushSubscriptionError {
43 #[error("subscription closed")]
45 Closed,
46 #[error("lagged behind by {0} events")]
48 Lagged(u64),
49 #[error("subscription cancelled")]
51 Cancelled,
52 #[error("internal error: {0}")]
54 Internal(String),
55}
56
57pub struct PushSubscription {
76 id: SubscriptionId,
78 receiver: broadcast::Receiver<ChangeEvent>,
80 registry: Arc<SubscriptionRegistry>,
82 query: String,
84 cancelled: bool,
86}
87
88impl PushSubscription {
89 pub fn new(
95 id: SubscriptionId,
96 receiver: broadcast::Receiver<ChangeEvent>,
97 registry: Arc<SubscriptionRegistry>,
98 query: String,
99 ) -> Self {
100 Self {
101 id,
102 receiver,
103 registry,
104 query,
105 cancelled: false,
106 }
107 }
108
109 pub async fn recv(&mut self) -> Result<ChangeEvent, PushSubscriptionError> {
121 if self.cancelled {
122 return Err(PushSubscriptionError::Cancelled);
123 }
124
125 match self.receiver.recv().await {
126 Ok(event) => Ok(event),
127 Err(broadcast::error::RecvError::Lagged(n)) => Err(PushSubscriptionError::Lagged(n)),
128 Err(broadcast::error::RecvError::Closed) => Err(PushSubscriptionError::Closed),
129 }
130 }
131
132 pub fn try_recv(&mut self) -> Option<Result<ChangeEvent, PushSubscriptionError>> {
137 if self.cancelled {
138 return Some(Err(PushSubscriptionError::Cancelled));
139 }
140
141 match self.receiver.try_recv() {
142 Ok(event) => Some(Ok(event)),
143 Err(broadcast::error::TryRecvError::Lagged(n)) => {
144 Some(Err(PushSubscriptionError::Lagged(n)))
145 }
146 Err(broadcast::error::TryRecvError::Closed) => Some(Err(PushSubscriptionError::Closed)),
147 Err(broadcast::error::TryRecvError::Empty) => None,
148 }
149 }
150
151 #[must_use]
157 pub fn pause(&self) -> bool {
158 self.registry.pause(self.id)
159 }
160
161 #[must_use]
165 pub fn resume(&self) -> bool {
166 self.registry.resume(self.id)
167 }
168
169 pub fn cancel(&mut self) {
174 if !self.cancelled {
175 self.cancelled = true;
176 self.registry.cancel(self.id);
177 }
178 }
179
180 #[must_use]
182 pub fn id(&self) -> SubscriptionId {
183 self.id
184 }
185
186 #[must_use]
188 pub fn query(&self) -> &str {
189 &self.query
190 }
191
192 #[must_use]
194 pub fn is_cancelled(&self) -> bool {
195 self.cancelled
196 }
197
198 #[must_use]
200 pub fn metrics(&self) -> Option<SubscriptionMetrics> {
201 self.registry.metrics(self.id)
202 }
203}
204
205impl Drop for PushSubscription {
206 fn drop(&mut self) {
207 if !self.cancelled {
208 self.registry.cancel(self.id);
209 }
210 }
211}
212
213#[cfg(test)]
218#[allow(clippy::cast_possible_wrap)]
219#[allow(clippy::field_reassign_with_default)]
220#[allow(clippy::ignored_unit_patterns)]
221mod tests {
222 use super::*;
223 use std::sync::Arc;
224
225 use arrow_array::Int64Array;
226 use arrow_schema::{DataType, Field, Schema};
227
228 use crate::subscription::registry::{SubscriptionConfig, SubscriptionState};
229
230 fn make_batch(n: usize) -> arrow_array::RecordBatch {
231 let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
232 let values: Vec<i64> = (0..n as i64).collect();
233 let array = Int64Array::from(values);
234 arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
235 }
236
237 fn make_sub(query: &str) -> (Arc<SubscriptionRegistry>, PushSubscription) {
239 let registry = Arc::new(SubscriptionRegistry::new());
240 let (id, rx) = registry.create(query.into(), 0, SubscriptionConfig::default());
241 let sub = PushSubscription::new(id, rx, Arc::clone(®istry), query.into());
242 (registry, sub)
243 }
244
245 fn make_sub_with_sender(
247 query: &str,
248 ) -> (
249 Arc<SubscriptionRegistry>,
250 PushSubscription,
251 broadcast::Sender<ChangeEvent>,
252 ) {
253 let registry = Arc::new(SubscriptionRegistry::new());
254 let (id, rx) = registry.create(query.into(), 0, SubscriptionConfig::default());
255 let senders = registry.get_senders_for_source(0);
257 let sender = senders.into_iter().next().unwrap();
258 let sub = PushSubscription::new(id, rx, Arc::clone(®istry), query.into());
259 (registry, sub, sender)
260 }
261
262 #[tokio::test]
265 async fn test_push_subscription_recv() {
266 let (_reg, mut sub, sender) = make_sub_with_sender("SELECT * FROM trades");
267
268 let batch = Arc::new(make_batch(5));
269 let event = ChangeEvent::insert(batch, 1000, 1);
270 sender.send(event).unwrap();
271
272 let received = sub.recv().await.unwrap();
273 assert_eq!(received.timestamp(), 1000);
274 assert_eq!(received.sequence(), Some(1));
275 assert_eq!(received.row_count(), 5);
276 }
277
278 #[test]
279 fn test_push_subscription_try_recv() {
280 let (_reg, mut sub, sender) = make_sub_with_sender("trades");
281
282 assert!(sub.try_recv().is_none());
284
285 let batch = Arc::new(make_batch(3));
287 sender.send(ChangeEvent::insert(batch, 2000, 2)).unwrap();
288
289 let result = sub.try_recv().unwrap().unwrap();
290 assert_eq!(result.timestamp(), 2000);
291 assert_eq!(result.sequence(), Some(2));
292 }
293
294 #[test]
295 fn test_push_subscription_try_recv_empty() {
296 let (_reg, mut sub) = make_sub("trades");
297 assert!(sub.try_recv().is_none());
298 }
299
300 #[test]
303 fn test_push_subscription_pause_resume() {
304 let (reg, sub) = make_sub("trades");
305
306 assert_eq!(reg.state(sub.id()), Some(SubscriptionState::Active));
308
309 assert!(sub.pause());
311 assert_eq!(reg.state(sub.id()), Some(SubscriptionState::Paused));
312
313 assert!(!sub.pause());
315
316 assert!(sub.resume());
318 assert_eq!(reg.state(sub.id()), Some(SubscriptionState::Active));
319
320 assert!(!sub.resume());
322 }
323
324 #[tokio::test]
325 async fn test_push_subscription_cancel() {
326 let (reg, mut sub, sender) = make_sub_with_sender("trades");
327
328 sub.cancel();
330 assert!(sub.is_cancelled());
331
332 let err = sub.recv().await.unwrap_err();
334 assert!(matches!(err, PushSubscriptionError::Cancelled));
335
336 let err = sub.try_recv().unwrap().unwrap_err();
338 assert!(matches!(err, PushSubscriptionError::Cancelled));
339
340 assert_eq!(reg.subscription_count(), 0);
342
343 sub.cancel();
345
346 drop(sender);
348 }
349
350 #[test]
351 fn test_push_subscription_drop_cancels() {
352 let registry = Arc::new(SubscriptionRegistry::new());
353 let (id, rx) = registry.create("trades".into(), 0, SubscriptionConfig::default());
354 assert_eq!(registry.subscription_count(), 1);
355
356 {
357 let _sub = PushSubscription::new(id, rx, Arc::clone(®istry), "trades".into());
358 assert_eq!(registry.subscription_count(), 1);
360 }
361 assert_eq!(registry.subscription_count(), 0);
363 }
364
365 #[tokio::test]
368 async fn test_push_subscription_lagged() {
369 let registry = Arc::new(SubscriptionRegistry::new());
370 let mut cfg = SubscriptionConfig::default();
371 cfg.buffer_size = 4; let (id, rx) = registry.create("trades".into(), 0, cfg);
373
374 let senders = registry.get_senders_for_source(0);
375 let sender = senders.into_iter().next().unwrap();
376
377 let mut sub = PushSubscription::new(id, rx, Arc::clone(®istry), "trades".into());
378
379 for i in 0..10u64 {
381 let batch = Arc::new(make_batch(1));
382 sender
383 .send(ChangeEvent::insert(batch, i as i64, i))
384 .unwrap();
385 }
386
387 let err = sub.recv().await.unwrap_err();
389 match err {
390 PushSubscriptionError::Lagged(n) => {
391 assert!(n > 0, "expected non-zero lag count, got {n}");
392 }
393 other => panic!("expected Lagged, got {other:?}"),
394 }
395
396 let event = sub.recv().await.unwrap();
398 assert!(event.sequence().unwrap() > 0);
399 }
400
401 #[tokio::test]
402 async fn test_push_subscription_closed() {
403 let registry = Arc::new(SubscriptionRegistry::new());
404 let (id, rx) = registry.create("trades".into(), 0, SubscriptionConfig::default());
405
406 let mut sub = PushSubscription::new(id, rx, Arc::clone(®istry), "trades".into());
407
408 registry.cancel(id);
411
412 let err = sub.recv().await.unwrap_err();
413 assert!(matches!(err, PushSubscriptionError::Closed));
414 }
415
416 #[test]
419 fn test_push_subscription_id_and_query() {
420 let (_reg, sub) = make_sub("SELECT * FROM trades");
421 assert_eq!(sub.id(), SubscriptionId(1));
422 assert_eq!(sub.query(), "SELECT * FROM trades");
423 assert!(!sub.is_cancelled());
424 }
425
426 #[test]
427 fn test_push_subscription_metrics() {
428 let (reg, sub) = make_sub("trades");
429 let m = sub.metrics().unwrap();
430 assert_eq!(m.id, sub.id());
431 assert_eq!(m.source_name, "trades");
432 assert_eq!(m.state, SubscriptionState::Active);
433 assert_eq!(m.events_delivered, 0);
434
435 reg.record_delivery(sub.id(), 5);
437 let m = sub.metrics().unwrap();
438 assert_eq!(m.events_delivered, 5);
439 }
440
441 #[tokio::test]
444 async fn test_push_subscription_with_select() {
445 let (_reg, mut sub, sender) = make_sub_with_sender("trades");
446
447 let batch = Arc::new(make_batch(1));
448 sender.send(ChangeEvent::insert(batch, 9000, 42)).unwrap();
449
450 let result = tokio::select! {
451 event = sub.recv() => event,
452 _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
453 panic!("timeout — event should be immediate");
454 }
455 };
456
457 let event = result.unwrap();
458 assert_eq!(event.timestamp(), 9000);
459 assert_eq!(event.sequence(), Some(42));
460 }
461
462 #[tokio::test]
465 async fn test_push_subscription_multiple_subscribers() {
466 let registry = Arc::new(SubscriptionRegistry::new());
467
468 let (id1, rx1) = registry.create("trades".into(), 0, SubscriptionConfig::default());
469 let (id2, rx2) = registry.create("trades".into(), 0, SubscriptionConfig::default());
470
471 let mut sub1 = PushSubscription::new(id1, rx1, Arc::clone(®istry), "trades".into());
472 let mut sub2 = PushSubscription::new(id2, rx2, Arc::clone(®istry), "trades".into());
473
474 let senders = registry.get_senders_for_source(0);
476 let batch = Arc::new(make_batch(3));
477 let event = ChangeEvent::insert(batch, 5000, 10);
478 for sender in &senders {
479 sender.send(event.clone()).unwrap();
480 }
481
482 let e1 = sub1.recv().await.unwrap();
484 let e2 = sub2.recv().await.unwrap();
485 assert_eq!(e1.timestamp(), 5000);
486 assert_eq!(e2.timestamp(), 5000);
487 }
488
489 #[test]
492 fn test_end_to_end_push_subscribe() {
493 use crate::subscription::dispatcher::{
494 DispatcherConfig, NotificationDataSource, SubscriptionDispatcher,
495 };
496 use crate::subscription::event::EventType;
497 use crate::subscription::notification::{NotificationHub, NotificationRing};
498
499 struct TestDataSource;
500
501 impl NotificationDataSource for TestDataSource {
502 fn resolve(
503 &self,
504 notif: &crate::subscription::event::NotificationRef,
505 ) -> Option<ChangeEvent> {
506 let batch = Arc::new(make_batch(notif.row_count as usize));
507 Some(ChangeEvent::insert(batch, notif.timestamp, notif.sequence))
508 }
509 }
510
511 let mut hub = NotificationHub::new(4, 64);
513 let source_id = hub.register_source().unwrap();
514
515 let registry = Arc::new(SubscriptionRegistry::new());
516
517 let (sub_id, rx) =
519 registry.create("mv_trades".into(), source_id, SubscriptionConfig::default());
520 let mut sub = PushSubscription::new(sub_id, rx, Arc::clone(®istry), "mv_trades".into());
521
522 hub.notify_source(source_id, EventType::Insert, 10, 7000, 0);
524
525 let ring = Arc::new(NotificationRing::new(64));
527 hub.drain_notifications(|n| {
528 ring.push(n);
529 });
530
531 let ds = Arc::new(TestDataSource) as Arc<dyn NotificationDataSource>;
532 let (_tx, shutdown_rx) = tokio::sync::watch::channel(false);
533 let dispatcher = SubscriptionDispatcher::new(
534 vec![ring],
535 Arc::clone(®istry),
536 ds,
537 DispatcherConfig::default(),
538 shutdown_rx,
539 );
540
541 let mut buf = Vec::new();
542 dispatcher.drain_and_dispatch(&mut buf);
543
544 let event = sub.try_recv().unwrap().unwrap();
546 assert_eq!(event.timestamp(), 7000);
547 assert_eq!(event.row_count(), 10);
548 }
549
550 #[test]
553 fn test_push_subscription_error_display() {
554 let e = PushSubscriptionError::Closed;
555 assert_eq!(format!("{e}"), "subscription closed");
556
557 let e = PushSubscriptionError::Lagged(42);
558 assert_eq!(format!("{e}"), "lagged behind by 42 events");
559
560 let e = PushSubscriptionError::Cancelled;
561 assert_eq!(format!("{e}"), "subscription cancelled");
562
563 let e = PushSubscriptionError::Internal("oops".into());
564 assert_eq!(format!("{e}"), "internal error: oops");
565 }
566}