1use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
31use std::sync::Arc;
32
33use arrow::array::RecordBatch;
34use arrow::datatypes::SchemaRef;
35
36use super::channel::{channel_with_config, ChannelMode, Producer};
37use super::config::SourceConfig;
38use super::error::{StreamingError, TryPushError};
39use super::sink::Sink;
40
41pub trait Record: Send + Sized + 'static {
81 fn schema() -> SchemaRef;
83
84 fn to_record_batch(&self) -> RecordBatch;
88
89 fn event_time(&self) -> Option<i64> {
94 None
95 }
96}
97
98pub(crate) enum SourceMessage<T> {
100 Record(T),
102
103 Batch(RecordBatch),
105
106 Watermark(i64),
108}
109
110struct SourceWatermark {
112 current: Arc<AtomicI64>,
116}
117
118impl SourceWatermark {
119 fn new() -> Self {
120 Self {
121 current: Arc::new(AtomicI64::new(i64::MIN)),
122 }
123 }
124
125 fn from_arc(arc: Arc<AtomicI64>) -> Self {
126 Self { current: arc }
127 }
128
129 fn update(&self, timestamp: i64) {
130 let mut current = self.current.load(Ordering::Acquire);
132 while timestamp > current {
133 match self.current.compare_exchange_weak(
134 current,
135 timestamp,
136 Ordering::AcqRel,
137 Ordering::Acquire,
138 ) {
139 Ok(_) => break,
140 Err(actual) => current = actual,
141 }
142 }
143 }
144
145 fn get(&self) -> i64 {
146 self.current.load(Ordering::Acquire)
147 }
148
149 fn arc(&self) -> Arc<AtomicI64> {
150 Arc::clone(&self.current)
151 }
152}
153
154struct SourceInner<T: Record> {
156 producer: Producer<SourceMessage<T>>,
158
159 watermark: SourceWatermark,
161
162 schema: SchemaRef,
164
165 name: Option<String>,
167
168 sequence: Arc<AtomicU64>,
171}
172
173pub struct Source<T: Record> {
199 inner: Arc<SourceInner<T>>,
200}
201
202impl<T: Record> Source<T> {
203 pub(crate) fn new(config: SourceConfig) -> (Self, Sink<T>) {
205 let channel_config = config.channel;
206 let (producer, consumer) = channel_with_config::<SourceMessage<T>>(channel_config.clone());
207
208 let schema = T::schema();
209
210 let inner = Arc::new(SourceInner {
211 producer,
212 watermark: SourceWatermark::new(),
213 schema: schema.clone(),
214 name: config.name,
215 sequence: Arc::new(AtomicU64::new(0)),
216 });
217
218 let source = Self { inner };
219 let sink = Sink::new(consumer, schema, channel_config);
220
221 (source, sink)
222 }
223
224 pub fn push(&self, record: T) -> Result<(), StreamingError> {
231 if let Some(event_time) = record.event_time() {
233 self.inner.watermark.update(event_time);
234 }
235
236 self.inner
237 .producer
238 .push(SourceMessage::Record(record))
239 .map_err(|_| StreamingError::ChannelFull)?;
240
241 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
242 Ok(())
243 }
244
245 pub fn try_push(&self, record: T) -> Result<(), TryPushError<T>> {
251 if let Some(event_time) = record.event_time() {
253 self.inner.watermark.update(event_time);
254 }
255
256 self.inner
257 .producer
258 .try_push(SourceMessage::Record(record))
259 .map_err(|e| match e.into_inner() {
260 SourceMessage::Record(r) => TryPushError {
261 value: r,
262 error: StreamingError::ChannelFull,
263 },
264 _ => unreachable!("pushed a record, got something else back"),
265 })?;
266
267 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
268 Ok(())
269 }
270
271 pub fn push_batch(&self, records: &[T]) -> usize
282 where
283 T: Clone,
284 {
285 let mut count = 0;
286 for record in records {
287 if self.try_push(record.clone()).is_err() {
288 break;
289 }
290 count += 1;
291 }
292 count
293 }
294
295 pub fn push_batch_drain<I>(&self, records: I) -> usize
307 where
308 I: IntoIterator<Item = T>,
309 {
310 let mut count = 0;
311 for record in records {
312 if let Some(event_time) = record.event_time() {
314 self.inner.watermark.update(event_time);
315 }
316
317 if self
318 .inner
319 .producer
320 .try_push(SourceMessage::Record(record))
321 .is_err()
322 {
323 break;
324 }
325 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
326 count += 1;
327 }
328 count
329 }
330
331 pub fn push_arrow(&self, batch: RecordBatch) -> Result<(), StreamingError> {
341 if !self.inner.schema.fields().is_empty() && batch.schema() != self.inner.schema {
343 return Err(StreamingError::SchemaMismatch {
344 expected: self
345 .inner
346 .schema
347 .fields()
348 .iter()
349 .map(|f| f.name().clone())
350 .collect(),
351 actual: batch
352 .schema()
353 .fields()
354 .iter()
355 .map(|f| f.name().clone())
356 .collect(),
357 });
358 }
359
360 self.inner
361 .producer
362 .push(SourceMessage::Batch(batch))
363 .map_err(|_| StreamingError::ChannelFull)?;
364
365 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
366 Ok(())
367 }
368
369 pub fn watermark(&self, timestamp: i64) {
378 self.inner.watermark.update(timestamp);
379
380 let _ = self
383 .inner
384 .producer
385 .try_push(SourceMessage::Watermark(timestamp));
386 }
387
388 #[must_use]
390 pub fn current_watermark(&self) -> i64 {
391 self.inner.watermark.get()
392 }
393
394 #[must_use]
396 pub fn schema(&self) -> SchemaRef {
397 Arc::clone(&self.inner.schema)
398 }
399
400 #[must_use]
402 pub fn name(&self) -> Option<&str> {
403 self.inner.name.as_deref()
404 }
405
406 #[must_use]
408 pub fn is_mpsc(&self) -> bool {
409 self.inner.producer.is_mpsc()
410 }
411
412 #[must_use]
414 pub fn mode(&self) -> ChannelMode {
415 self.inner.producer.mode()
416 }
417
418 #[must_use]
420 pub fn is_closed(&self) -> bool {
421 self.inner.producer.is_closed()
422 }
423
424 #[must_use]
426 pub fn pending(&self) -> usize {
427 self.inner.producer.len()
428 }
429
430 #[must_use]
432 pub fn capacity(&self) -> usize {
433 self.inner.producer.capacity()
434 }
435
436 #[must_use]
438 pub fn sequence(&self) -> u64 {
439 self.inner.sequence.load(Ordering::Acquire)
440 }
441
442 #[must_use]
444 pub fn sequence_counter(&self) -> Arc<AtomicU64> {
445 Arc::clone(&self.inner.sequence)
446 }
447
448 #[must_use]
450 pub fn watermark_atomic(&self) -> Arc<AtomicI64> {
451 self.inner.watermark.arc()
452 }
453}
454
455impl<T: Record> Clone for Source<T> {
456 fn clone(&self) -> Self {
467 let producer = self.inner.producer.clone();
469
470 Self {
474 inner: Arc::new(SourceInner {
475 producer,
476 watermark: SourceWatermark::from_arc(self.inner.watermark.arc()),
477 schema: Arc::clone(&self.inner.schema),
478 name: self.inner.name.clone(),
479 sequence: Arc::clone(&self.inner.sequence),
480 }),
481 }
482 }
483}
484
485impl<T: Record + std::fmt::Debug> std::fmt::Debug for Source<T> {
486 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
487 f.debug_struct("Source")
488 .field("name", &self.inner.name)
489 .field("mode", &self.mode())
490 .field("pending", &self.pending())
491 .field("capacity", &self.capacity())
492 .field("watermark", &self.current_watermark())
493 .finish()
494 }
495}
496
497#[must_use]
518pub fn create<T: Record>(buffer_size: usize) -> (Source<T>, Sink<T>) {
519 Source::new(SourceConfig::with_buffer_size(buffer_size))
520}
521
522#[must_use]
524pub fn create_with_config<T: Record>(config: SourceConfig) -> (Source<T>, Sink<T>) {
525 Source::new(config)
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use arrow::array::{Float64Array, Int64Array, StringArray};
532 use arrow::datatypes::{DataType, Field, Schema};
533 use std::sync::Arc;
534
535 #[derive(Clone, Debug)]
537 struct TestEvent {
538 id: i64,
539 value: f64,
540 timestamp: i64,
541 }
542
543 impl Record for TestEvent {
544 fn schema() -> SchemaRef {
545 Arc::new(Schema::new(vec![
546 Field::new("id", DataType::Int64, false),
547 Field::new("value", DataType::Float64, false),
548 Field::new("timestamp", DataType::Int64, false),
549 ]))
550 }
551
552 fn to_record_batch(&self) -> RecordBatch {
553 RecordBatch::try_new(
554 Self::schema(),
555 vec![
556 Arc::new(Int64Array::from(vec![self.id])),
557 Arc::new(Float64Array::from(vec![self.value])),
558 Arc::new(Int64Array::from(vec![self.timestamp])),
559 ],
560 )
561 .unwrap()
562 }
563
564 fn event_time(&self) -> Option<i64> {
565 Some(self.timestamp)
566 }
567 }
568
569 #[test]
570 fn test_create_source_sink() {
571 let (source, _sink) = create::<TestEvent>(1024);
572
573 assert!(!source.is_mpsc());
574 assert!(!source.is_closed());
575 assert_eq!(source.pending(), 0);
576 }
577
578 #[test]
579 fn test_push_single() {
580 let (source, _sink) = create::<TestEvent>(16);
581
582 let event = TestEvent {
583 id: 1,
584 value: 42.0,
585 timestamp: 1000,
586 };
587
588 assert!(source.push(event).is_ok());
589 assert_eq!(source.pending(), 1);
590 }
591
592 #[test]
593 fn test_try_push() {
594 let (source, _sink) = create::<TestEvent>(16);
595
596 let event = TestEvent {
597 id: 1,
598 value: 42.0,
599 timestamp: 1000,
600 };
601
602 assert!(source.try_push(event).is_ok());
603 }
604
605 #[test]
606 fn test_push_batch() {
607 let (source, _sink) = create::<TestEvent>(16);
608
609 let events = vec![
610 TestEvent {
611 id: 1,
612 value: 1.0,
613 timestamp: 1000,
614 },
615 TestEvent {
616 id: 2,
617 value: 2.0,
618 timestamp: 2000,
619 },
620 TestEvent {
621 id: 3,
622 value: 3.0,
623 timestamp: 3000,
624 },
625 ];
626
627 let count = source.push_batch(&events);
628 assert_eq!(count, 3);
629 assert_eq!(source.pending(), 3);
630 }
631
632 #[test]
633 fn test_push_arrow() {
634 let (source, _sink) = create::<TestEvent>(16);
635
636 let batch = RecordBatch::try_new(
637 TestEvent::schema(),
638 vec![
639 Arc::new(Int64Array::from(vec![1, 2, 3])),
640 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
641 Arc::new(Int64Array::from(vec![1000, 2000, 3000])),
642 ],
643 )
644 .unwrap();
645
646 assert!(source.push_arrow(batch).is_ok());
647 }
648
649 #[test]
650 fn test_push_arrow_schema_mismatch() {
651 let (source, _sink) = create::<TestEvent>(16);
652
653 let wrong_schema = Arc::new(Schema::new(vec![Field::new(
655 "wrong",
656 DataType::Utf8,
657 false,
658 )]));
659
660 let batch = RecordBatch::try_new(
661 wrong_schema,
662 vec![Arc::new(StringArray::from(vec!["test"]))],
663 )
664 .unwrap();
665
666 let result = source.push_arrow(batch);
667 assert!(matches!(result, Err(StreamingError::SchemaMismatch { .. })));
668 }
669
670 #[test]
671 fn test_watermark() {
672 let (source, _sink) = create::<TestEvent>(16);
673
674 assert_eq!(source.current_watermark(), i64::MIN);
675
676 source.watermark(1000);
677 assert_eq!(source.current_watermark(), 1000);
678
679 source.watermark(2000);
680 assert_eq!(source.current_watermark(), 2000);
681
682 source.watermark(1500);
684 assert_eq!(source.current_watermark(), 2000);
685 }
686
687 #[test]
688 fn test_watermark_from_event_time() {
689 let (source, _sink) = create::<TestEvent>(16);
690
691 let event = TestEvent {
692 id: 1,
693 value: 42.0,
694 timestamp: 5000,
695 };
696
697 source.push(event).unwrap();
698
699 assert_eq!(source.current_watermark(), 5000);
701 }
702
703 #[test]
704 fn test_clone_upgrades_to_mpsc() {
705 let (source, _sink) = create::<TestEvent>(16);
706
707 assert!(!source.is_mpsc());
708 assert_eq!(source.mode(), ChannelMode::Spsc);
709
710 let source2 = source.clone();
711
712 assert!(source.is_mpsc());
713 assert!(source2.is_mpsc());
714 }
715
716 #[test]
717 fn test_closed_on_sink_drop() {
718 let (source, sink) = create::<TestEvent>(16);
719
720 assert!(!source.is_closed());
721
722 drop(sink);
723
724 assert!(source.is_closed());
725 }
726
727 #[test]
728 fn test_schema() {
729 let (source, _sink) = create::<TestEvent>(16);
730
731 let schema = source.schema();
732 assert_eq!(schema.fields().len(), 3);
733 assert_eq!(schema.field(0).name(), "id");
734 assert_eq!(schema.field(1).name(), "value");
735 assert_eq!(schema.field(2).name(), "timestamp");
736 }
737
738 #[test]
739 fn test_named_source() {
740 let config = SourceConfig::named("my_source");
741 let (source, _sink) = create_with_config::<TestEvent>(config);
742
743 assert_eq!(source.name(), Some("my_source"));
744 }
745
746 #[test]
747 fn test_debug_format() {
748 let (source, _sink) = create::<TestEvent>(16);
749
750 let debug = format!("{source:?}");
751 assert!(debug.contains("Source"));
752 assert!(debug.contains("Spsc"));
753 }
754}