use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use super::channel::{channel_with_config, Producer};
use super::config::SourceConfig;
use super::error::{StreamingError, TryPushError};
use super::sink::Sink;
pub trait Record: Clone + Send + Sized + 'static {
fn schema() -> SchemaRef;
fn to_record_batch(&self) -> RecordBatch;
fn event_time(&self) -> Option<i64> {
None
}
fn to_record_batch_from_iter<I>(records: I) -> RecordBatch
where
I: IntoIterator<Item = Self>,
{
let batches: Vec<RecordBatch> = records.into_iter().map(|r| r.to_record_batch()).collect();
if batches.is_empty() {
return RecordBatch::new_empty(Self::schema());
}
arrow::compute::concat_batches(&Self::schema(), &batches)
.unwrap_or_else(|_| RecordBatch::new_empty(Self::schema()))
}
}
#[derive(Clone)]
pub(crate) enum SourceMessage<T> {
Record(T),
Batch(RecordBatch),
}
struct SourceWatermark {
current: Arc<AtomicI64>,
}
impl SourceWatermark {
fn new() -> Self {
Self {
current: Arc::new(AtomicI64::new(i64::MIN)),
}
}
fn from_arc(arc: Arc<AtomicI64>) -> Self {
Self { current: arc }
}
fn update(&self, timestamp: i64) {
let mut current = self.current.load(Ordering::Acquire);
while timestamp > current {
match self.current.compare_exchange_weak(
current,
timestamp,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
fn get(&self) -> i64 {
self.current.load(Ordering::Acquire)
}
fn arc(&self) -> Arc<AtomicI64> {
Arc::clone(&self.current)
}
}
struct SourceInner<T: Record> {
producer: Producer<SourceMessage<T>>,
watermark: SourceWatermark,
schema: SchemaRef,
name: Option<String>,
sequence: Arc<AtomicU64>,
event_time_column: OnceLock<String>,
max_out_of_orderness: OnceLock<Duration>,
}
pub struct Source<T: Record> {
inner: Arc<SourceInner<T>>,
}
impl<T: Record> Source<T> {
pub(crate) fn new(config: SourceConfig) -> (Self, Sink<T>) {
let channel_config = config.channel;
let (producer, consumer) = channel_with_config::<SourceMessage<T>>(&channel_config);
let schema = T::schema();
let inner = Arc::new(SourceInner {
producer,
watermark: SourceWatermark::new(),
schema: schema.clone(),
name: config.name,
sequence: Arc::new(AtomicU64::new(0)),
event_time_column: OnceLock::new(),
max_out_of_orderness: OnceLock::new(),
});
let source = Self { inner };
let sink = Sink::new(consumer, schema);
(source, sink)
}
pub fn push(&self, record: T) -> Result<(), StreamingError> {
if let Some(event_time) = record.event_time() {
self.inner.watermark.update(event_time);
}
self.inner
.producer
.push(SourceMessage::Record(record))
.map_err(|_| StreamingError::ChannelFull)?;
self.inner.sequence.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn try_push(&self, record: T) -> Result<(), TryPushError<T>> {
if let Some(event_time) = record.event_time() {
self.inner.watermark.update(event_time);
}
self.inner
.producer
.push(SourceMessage::Record(record))
.map_err(|msg| match msg {
SourceMessage::Record(r) => TryPushError {
value: r,
error: StreamingError::ChannelFull,
},
SourceMessage::Batch(_) => unreachable!("only Record is pushed here"),
})?;
self.inner.sequence.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn push_batch(&self, records: &[T]) -> usize
where
T: Clone,
{
self.push_batch_drain(records.iter().cloned())
}
pub fn push_batch_drain<I>(&self, records: I) -> usize
where
I: IntoIterator<Item = T>,
{
let mut count = 0;
for record in records {
if self.push(record).is_err() {
break;
}
count += 1;
}
count
}
pub fn push_arrow(&self, batch: RecordBatch) -> Result<(), StreamingError> {
if !self.inner.schema.fields().is_empty() && batch.schema() != self.inner.schema {
return Err(StreamingError::SchemaMismatch {
expected: self
.inner
.schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect(),
actual: batch
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect(),
});
}
self.inner
.producer
.push(SourceMessage::Batch(batch))
.map_err(|_| StreamingError::ChannelFull)?;
self.inner.sequence.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn watermark(&self, timestamp: i64) {
self.inner.watermark.update(timestamp);
}
#[must_use]
pub fn current_watermark(&self) -> i64 {
self.inner.watermark.get()
}
#[must_use]
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.inner.schema)
}
#[must_use]
pub fn name(&self) -> Option<&str> {
self.inner.name.as_deref()
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.inner.producer.is_closed()
}
#[must_use]
pub fn pending(&self) -> usize {
self.inner.producer.len()
}
#[must_use]
pub fn capacity(&self) -> usize {
self.inner.producer.capacity()
}
#[must_use]
pub fn sequence(&self) -> u64 {
self.inner.sequence.load(Ordering::Acquire)
}
#[must_use]
pub fn sequence_counter(&self) -> Arc<AtomicU64> {
Arc::clone(&self.inner.sequence)
}
#[must_use]
pub fn watermark_atomic(&self) -> Arc<AtomicI64> {
self.inner.watermark.arc()
}
pub fn set_event_time_column(&self, column: &str) {
let _ = self.inner.event_time_column.set(column.to_owned());
}
#[must_use]
pub fn event_time_column(&self) -> Option<String> {
self.inner.event_time_column.get().cloned()
}
pub fn set_max_out_of_orderness(&self, dur: Duration) {
let _ = self.inner.max_out_of_orderness.set(dur);
}
#[must_use]
pub fn max_out_of_orderness(&self) -> Option<Duration> {
self.inner.max_out_of_orderness.get().copied()
}
}
impl<T: Record> Clone for Source<T> {
fn clone(&self) -> Self {
let producer = self.inner.producer.clone();
let event_time_col = self.inner.event_time_column.get().cloned();
let event_time_column = OnceLock::new();
if let Some(col) = event_time_col {
let _ = event_time_column.set(col);
}
let max_ooo = self.inner.max_out_of_orderness.get().copied();
let max_out_of_orderness = OnceLock::new();
if let Some(dur) = max_ooo {
let _ = max_out_of_orderness.set(dur);
}
Self {
inner: Arc::new(SourceInner {
producer,
watermark: SourceWatermark::from_arc(self.inner.watermark.arc()),
schema: Arc::clone(&self.inner.schema),
name: self.inner.name.clone(),
sequence: Arc::clone(&self.inner.sequence),
event_time_column,
max_out_of_orderness,
}),
}
}
}
impl<T: Record + std::fmt::Debug> std::fmt::Debug for Source<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Source")
.field("name", &self.inner.name)
.field("pending", &self.pending())
.field("capacity", &self.capacity())
.field("watermark", &self.current_watermark())
.finish()
}
}
#[must_use]
pub fn create<T: Record>(buffer_size: usize) -> (Source<T>, Sink<T>) {
Source::new(SourceConfig::with_buffer_size(buffer_size))
}
#[must_use]
pub fn create_with_config<T: Record>(config: SourceConfig) -> (Source<T>, Sink<T>) {
Source::new(config)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
#[derive(Clone, Debug)]
struct TestEvent {
id: i64,
value: f64,
timestamp: i64,
}
impl Record for TestEvent {
fn schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Float64, false),
Field::new("timestamp", DataType::Int64, false),
]))
}
fn to_record_batch(&self) -> RecordBatch {
RecordBatch::try_new(
Self::schema(),
vec![
Arc::new(Int64Array::from(vec![self.id])),
Arc::new(Float64Array::from(vec![self.value])),
Arc::new(Int64Array::from(vec![self.timestamp])),
],
)
.unwrap()
}
fn event_time(&self) -> Option<i64> {
Some(self.timestamp)
}
}
#[tokio::test]
async fn test_create_source_sink() {
let (source, _sink) = create::<TestEvent>(1024);
assert!(!source.is_closed());
assert_eq!(source.pending(), 0);
}
#[tokio::test]
async fn test_push_single() {
let (source, _sink) = create::<TestEvent>(16);
let event = TestEvent {
id: 1,
value: 42.0,
timestamp: 1000,
};
assert!(source.push(event).is_ok());
assert_eq!(source.pending(), 1);
}
#[tokio::test]
async fn test_try_push() {
let (source, _sink) = create::<TestEvent>(16);
let event = TestEvent {
id: 1,
value: 42.0,
timestamp: 1000,
};
assert!(source.try_push(event).is_ok());
}
#[tokio::test]
async fn test_push_batch() {
let (source, _sink) = create::<TestEvent>(16);
let events = vec![
TestEvent {
id: 1,
value: 1.0,
timestamp: 1000,
},
TestEvent {
id: 2,
value: 2.0,
timestamp: 2000,
},
TestEvent {
id: 3,
value: 3.0,
timestamp: 3000,
},
];
let count = source.push_batch(&events);
assert_eq!(count, 3);
assert_eq!(source.pending(), 3);
}
#[tokio::test]
async fn test_push_arrow() {
let (source, _sink) = create::<TestEvent>(16);
let batch = RecordBatch::try_new(
TestEvent::schema(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
Arc::new(Int64Array::from(vec![1000, 2000, 3000])),
],
)
.unwrap();
assert!(source.push_arrow(batch).is_ok());
}
#[tokio::test]
async fn test_push_arrow_schema_mismatch() {
let (source, _sink) = create::<TestEvent>(16);
let wrong_schema = Arc::new(Schema::new(vec![Field::new(
"wrong",
DataType::Utf8,
false,
)]));
let batch = RecordBatch::try_new(
wrong_schema,
vec![Arc::new(StringArray::from(vec!["test"]))],
)
.unwrap();
let result = source.push_arrow(batch);
assert!(matches!(result, Err(StreamingError::SchemaMismatch { .. })));
}
#[tokio::test]
async fn test_watermark() {
let (source, _sink) = create::<TestEvent>(16);
assert_eq!(source.current_watermark(), i64::MIN);
source.watermark(1000);
assert_eq!(source.current_watermark(), 1000);
source.watermark(2000);
assert_eq!(source.current_watermark(), 2000);
source.watermark(1500);
assert_eq!(source.current_watermark(), 2000);
}
#[tokio::test]
async fn test_watermark_from_event_time() {
let (source, _sink) = create::<TestEvent>(16);
let event = TestEvent {
id: 1,
value: 42.0,
timestamp: 5000,
};
source.push(event).unwrap();
assert_eq!(source.current_watermark(), 5000);
}
#[tokio::test]
async fn test_clone_multi_producer() {
let (source, sink) = create::<TestEvent>(16);
let source2 = source.clone();
let mut sub = sink.subscribe();
source
.push(TestEvent {
id: 1,
value: 1.0,
timestamp: 1000,
})
.unwrap();
source2
.push(TestEvent {
id: 2,
value: 2.0,
timestamp: 2000,
})
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(sub.poll().is_some());
assert!(sub.poll().is_some());
}
#[tokio::test]
async fn test_schema() {
let (source, _sink) = create::<TestEvent>(16);
let schema = source.schema();
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.field(0).name(), "id");
assert_eq!(schema.field(1).name(), "value");
assert_eq!(schema.field(2).name(), "timestamp");
}
#[tokio::test]
async fn test_named_source() {
let config = SourceConfig::named("my_source");
let (source, _sink) = create_with_config::<TestEvent>(config);
assert_eq!(source.name(), Some("my_source"));
}
#[tokio::test]
async fn test_debug_format() {
let (source, _sink) = create::<TestEvent>(16);
let debug = format!("{source:?}");
assert!(debug.contains("Source"));
}
#[tokio::test]
async fn test_set_event_time_column() {
let (source, _sink) = create::<TestEvent>(16);
assert!(source.event_time_column().is_none());
source.set_event_time_column("timestamp");
assert_eq!(source.event_time_column(), Some("timestamp".to_string()));
}
#[tokio::test]
async fn test_event_time_column_preserved_on_clone() {
let (source, _sink) = create::<TestEvent>(16);
source.set_event_time_column("ts");
let source2 = source.clone();
assert_eq!(source2.event_time_column(), Some("ts".to_string()));
}
}