1use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use arrow::array::RecordBatch;
36use arrow::datatypes::SchemaRef;
37
38use super::channel::Consumer;
39use super::error::RecvError;
40use super::sink::SinkInner;
41use super::source::{Record, SourceMessage};
42
43pub struct Subscription<T: Record> {
53 inner: SubscriptionInner<T>,
54 schema: SchemaRef,
55}
56
57enum SubscriptionInner<T: Record> {
58 Direct(Arc<SinkInner<T>>),
60 Broadcast(Consumer<SourceMessage<T>>),
62}
63
64impl<T: Record> Subscription<T> {
65 pub(crate) fn new_direct(sink_inner: Arc<SinkInner<T>>) -> Self {
67 let schema = sink_inner.schema();
68 Self {
69 inner: SubscriptionInner::Direct(sink_inner),
70 schema,
71 }
72 }
73
74 pub(crate) fn new_broadcast(consumer: Consumer<SourceMessage<T>>, schema: SchemaRef) -> Self {
76 Self {
77 inner: SubscriptionInner::Broadcast(consumer),
78 schema,
79 }
80 }
81
82 #[must_use]
88 pub fn poll(&self) -> Option<RecordBatch> {
89 let msg = match &self.inner {
90 SubscriptionInner::Direct(sink) => sink.consumer().poll(),
91 SubscriptionInner::Broadcast(consumer) => consumer.poll(),
92 }?;
93
94 Self::message_to_batch(msg)
95 }
96
97 #[must_use]
101 pub fn poll_message(&self) -> Option<SubscriptionMessage<T>> {
102 let msg = match &self.inner {
103 SubscriptionInner::Direct(sink) => sink.consumer().poll(),
104 SubscriptionInner::Broadcast(consumer) => consumer.poll(),
105 }?;
106
107 Some(Self::convert_message(msg))
108 }
109
110 pub fn recv(&self) -> Result<RecordBatch, RecvError> {
117 loop {
118 if let Some(batch) = self.poll() {
119 return Ok(batch);
120 }
121
122 if self.is_disconnected() {
123 return Err(RecvError::Disconnected);
124 }
125
126 std::hint::spin_loop();
128 }
129 }
130
131 pub fn recv_timeout(&self, timeout: Duration) -> Result<RecordBatch, RecvError> {
138 let deadline = Instant::now() + timeout;
139
140 loop {
141 if let Some(batch) = self.poll() {
142 return Ok(batch);
143 }
144
145 if self.is_disconnected() {
146 return Err(RecvError::Disconnected);
147 }
148
149 if Instant::now() >= deadline {
150 return Err(RecvError::Timeout);
151 }
152
153 std::hint::spin_loop();
154 }
155 }
156
157 #[cold]
167 #[must_use]
168 pub fn poll_batch(&self, max_count: usize) -> Vec<RecordBatch> {
169 let mut batches = Vec::with_capacity(max_count);
170
171 for _ in 0..max_count {
172 if let Some(batch) = self.poll() {
173 batches.push(batch);
174 } else {
175 break;
176 }
177 }
178
179 batches
180 }
181
182 pub fn poll_batch_into(&self, buffer: &mut Vec<RecordBatch>, max_count: usize) -> usize {
201 let mut count = 0;
202
203 for _ in 0..max_count {
204 if let Some(batch) = self.poll() {
205 buffer.push(batch);
206 count += 1;
207 } else {
208 break;
209 }
210 }
211
212 count
213 }
214
215 pub fn poll_each<F>(&self, max_count: usize, mut f: F) -> usize
224 where
225 F: FnMut(RecordBatch) -> bool,
226 {
227 let mut count = 0;
228
229 for _ in 0..max_count {
230 if let Some(batch) = self.poll() {
231 count += 1;
232 if !f(batch) {
233 break;
234 }
235 } else {
236 break;
237 }
238 }
239
240 count
241 }
242
243 #[must_use]
245 pub fn is_disconnected(&self) -> bool {
246 match &self.inner {
247 SubscriptionInner::Direct(sink) => sink.is_disconnected(),
248 SubscriptionInner::Broadcast(consumer) => consumer.is_disconnected(),
249 }
250 }
251
252 #[must_use]
254 pub fn pending(&self) -> usize {
255 match &self.inner {
256 SubscriptionInner::Direct(sink) => sink.consumer().len(),
257 SubscriptionInner::Broadcast(consumer) => consumer.len(),
258 }
259 }
260
261 #[must_use]
263 pub fn schema(&self) -> SchemaRef {
264 Arc::clone(&self.schema)
265 }
266
267 fn message_to_batch(msg: SourceMessage<T>) -> Option<RecordBatch> {
268 match msg {
269 SourceMessage::Record(record) => Some(record.to_record_batch()),
270 SourceMessage::Batch(batch) => Some(batch),
271 SourceMessage::Watermark(_) => {
272 None
274 }
275 }
276 }
277
278 fn convert_message(msg: SourceMessage<T>) -> SubscriptionMessage<T> {
279 match msg {
280 SourceMessage::Record(record) => SubscriptionMessage::Record(record),
281 SourceMessage::Batch(batch) => SubscriptionMessage::Batch(batch),
282 SourceMessage::Watermark(ts) => SubscriptionMessage::Watermark(ts),
283 }
284 }
285}
286
287#[derive(Debug)]
289pub enum SubscriptionMessage<T> {
290 Record(T),
292 Batch(RecordBatch),
294 Watermark(i64),
296}
297
298impl<T: Record> SubscriptionMessage<T> {
299 #[must_use]
301 pub fn is_record(&self) -> bool {
302 matches!(self, Self::Record(_))
303 }
304
305 #[must_use]
307 pub fn is_batch(&self) -> bool {
308 matches!(self, Self::Batch(_))
309 }
310
311 #[must_use]
313 pub fn is_watermark(&self) -> bool {
314 matches!(self, Self::Watermark(_))
315 }
316
317 #[must_use]
319 pub fn to_batch(self) -> Option<RecordBatch> {
320 match self {
321 Self::Record(r) => Some(r.to_record_batch()),
322 Self::Batch(b) => Some(b),
323 Self::Watermark(_) => None,
324 }
325 }
326
327 #[must_use]
329 pub fn watermark(&self) -> Option<i64> {
330 match self {
331 Self::Watermark(ts) => Some(*ts),
332 _ => None,
333 }
334 }
335}
336
337impl<T: Record> Iterator for Subscription<T> {
342 type Item = RecordBatch;
343
344 fn next(&mut self) -> Option<Self::Item> {
345 self.recv().ok()
346 }
347}
348
349impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 let mode = match &self.inner {
352 SubscriptionInner::Direct(_) => "Direct",
353 SubscriptionInner::Broadcast(_) => "Broadcast",
354 };
355
356 f.debug_struct("Subscription")
357 .field("mode", &mode)
358 .field("pending", &self.pending())
359 .field("is_disconnected", &self.is_disconnected())
360 .field("schema", &self.schema)
361 .finish()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::streaming::source::create;
369 use arrow::array::{Float64Array, Int64Array};
370 use arrow::datatypes::{DataType, Field, Schema};
371 use std::sync::Arc;
372
373 #[derive(Clone, Debug)]
374 struct TestEvent {
375 id: i64,
376 value: f64,
377 }
378
379 impl Record for TestEvent {
380 fn schema() -> SchemaRef {
381 Arc::new(Schema::new(vec![
382 Field::new("id", DataType::Int64, false),
383 Field::new("value", DataType::Float64, false),
384 ]))
385 }
386
387 fn to_record_batch(&self) -> RecordBatch {
388 RecordBatch::try_new(
389 Self::schema(),
390 vec![
391 Arc::new(Int64Array::from(vec![self.id])),
392 Arc::new(Float64Array::from(vec![self.value])),
393 ],
394 )
395 .unwrap()
396 }
397 }
398
399 #[test]
400 fn test_poll_empty() {
401 let (_source, sink) = create::<TestEvent>(16);
402 let sub = sink.subscribe();
403
404 assert!(sub.poll().is_none());
405 }
406
407 #[test]
408 fn test_poll_records() {
409 let (source, sink) = create::<TestEvent>(16);
410 let sub = sink.subscribe();
411
412 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
413 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
414
415 let batch1 = sub.poll().unwrap();
416 assert_eq!(batch1.num_rows(), 1);
417
418 let batch2 = sub.poll().unwrap();
419 assert_eq!(batch2.num_rows(), 1);
420
421 assert!(sub.poll().is_none());
422 }
423
424 #[test]
425 fn test_poll_message() {
426 let (source, sink) = create::<TestEvent>(16);
427 let sub = sink.subscribe();
428
429 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
430
431 let msg = sub.poll_message().unwrap();
432 assert!(msg.is_record());
433 }
434
435 #[test]
436 fn test_recv_timeout() {
437 let (_source, sink) = create::<TestEvent>(16);
438 let sub = sink.subscribe();
439
440 let result = sub.recv_timeout(Duration::from_millis(10));
442 assert!(matches!(result, Err(RecvError::Timeout)));
443 }
444
445 #[test]
446 fn test_recv_timeout_success() {
447 let (source, sink) = create::<TestEvent>(16);
448 let sub = sink.subscribe();
449
450 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
451
452 let result = sub.recv_timeout(Duration::from_secs(1));
453 assert!(result.is_ok());
454 }
455
456 #[test]
457 fn test_poll_batch() {
458 let (source, sink) = create::<TestEvent>(16);
459 let sub = sink.subscribe();
460
461 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
462 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
463 source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
464
465 let batches = sub.poll_batch(10);
466 assert_eq!(batches.len(), 3);
467 }
468
469 #[test]
470 fn test_poll_each() {
471 let (source, sink) = create::<TestEvent>(16);
472 let sub = sink.subscribe();
473
474 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
475 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
476
477 let mut total_rows = 0;
478 let count = sub.poll_each(10, |batch| {
479 total_rows += batch.num_rows();
480 true
481 });
482
483 assert_eq!(count, 2);
484 assert_eq!(total_rows, 2);
485 }
486
487 #[test]
488 fn test_poll_each_early_stop() {
489 let (source, sink) = create::<TestEvent>(16);
490 let sub = sink.subscribe();
491
492 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
493 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
494 source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
495
496 let mut seen = 0;
497 let count = sub.poll_each(10, |_| {
498 seen += 1;
499 seen < 2 });
501
502 assert_eq!(count, 2);
503 assert_eq!(seen, 2);
504 assert_eq!(sub.pending(), 1); }
506
507 #[test]
508 fn test_disconnected() {
509 let (source, sink) = create::<TestEvent>(16);
510 let sub = sink.subscribe();
511
512 assert!(!sub.is_disconnected());
513
514 drop(source);
515
516 assert!(sub.is_disconnected());
517 }
518
519 #[test]
520 fn test_pending() {
521 let (source, sink) = create::<TestEvent>(16);
522 let sub = sink.subscribe();
523
524 assert_eq!(sub.pending(), 0);
525
526 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
527 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
528
529 assert_eq!(sub.pending(), 2);
530 }
531
532 #[test]
533 fn test_schema() {
534 let (_source, sink) = create::<TestEvent>(16);
535 let sub = sink.subscribe();
536
537 let schema = sub.schema();
538 assert_eq!(schema.fields().len(), 2);
539 }
540
541 #[test]
542 fn test_subscription_message() {
543 let msg = SubscriptionMessage::Record(TestEvent { id: 1, value: 1.0 });
544 assert!(msg.is_record());
545 assert!(!msg.is_batch());
546 assert!(!msg.is_watermark());
547
548 let batch = msg.to_batch().unwrap();
549 assert_eq!(batch.num_rows(), 1);
550
551 let wm = SubscriptionMessage::<TestEvent>::Watermark(1000);
552 assert!(wm.is_watermark());
553 assert_eq!(wm.watermark(), Some(1000));
554 }
555
556 #[test]
557 fn test_iterator() {
558 let (source, sink) = create::<TestEvent>(16);
559 let mut sub = sink.subscribe();
560
561 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
562 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
563
564 drop(source);
565
566 let batches: Vec<_> = sub.by_ref().collect();
567 assert_eq!(batches.len(), 2);
568 }
569
570 #[test]
571 fn test_debug_format() {
572 let (_source, sink) = create::<TestEvent>(16);
573 let sub = sink.subscribe();
574
575 let debug = format!("{sub:?}");
576 assert!(debug.contains("Subscription"));
577 assert!(debug.contains("Direct"));
578 }
579}