use std::sync::Arc;
use std::time::Duration;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use tokio::sync::broadcast;
use super::error::RecvError;
use super::source::{Record, SourceMessage};
pub struct Subscription<T: Record> {
rx: broadcast::Receiver<SourceMessage<T>>,
schema: SchemaRef,
closed: bool,
}
impl<T: Record> Subscription<T> {
pub(crate) fn new(rx: broadcast::Receiver<SourceMessage<T>>, schema: SchemaRef) -> Self {
Self {
rx,
schema,
closed: false,
}
}
pub fn poll(&mut self) -> Option<RecordBatch> {
loop {
match self.rx.try_recv() {
Ok(msg) => return Some(to_batch(msg)),
Err(broadcast::error::TryRecvError::Empty) => return None,
Err(broadcast::error::TryRecvError::Closed) => {
self.closed = true;
return None;
}
Err(broadcast::error::TryRecvError::Lagged(_)) => {}
}
}
}
pub async fn recv_async(&mut self) -> Result<RecordBatch, RecvError> {
loop {
match self.rx.recv().await {
Ok(msg) => return Ok(to_batch(msg)),
Err(broadcast::error::RecvError::Closed) => {
self.closed = true;
return Err(RecvError::Disconnected);
}
Err(broadcast::error::RecvError::Lagged(_)) => {}
}
}
}
pub fn recv(&mut self) -> Result<RecordBatch, RecvError> {
loop {
match self.rx.blocking_recv() {
Ok(msg) => return Ok(to_batch(msg)),
Err(broadcast::error::RecvError::Closed) => {
self.closed = true;
return Err(RecvError::Disconnected);
}
Err(broadcast::error::RecvError::Lagged(_)) => {}
}
}
}
pub fn recv_timeout(&mut self, timeout: Duration) -> Result<RecordBatch, RecvError> {
let handle = tokio::runtime::Handle::current();
match handle.block_on(tokio::time::timeout(timeout, self.recv_async())) {
Ok(Ok(batch)) => Ok(batch),
Ok(Err(e)) => Err(e),
Err(_) => Err(RecvError::Timeout),
}
}
#[must_use]
pub fn is_disconnected(&self) -> bool {
self.closed
}
#[must_use]
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
fn to_batch<T: Record>(msg: SourceMessage<T>) -> RecordBatch {
match msg {
SourceMessage::Record(r) => r.to_record_batch(),
SourceMessage::Batch(b) => b,
}
}
impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Subscription")
.field("closed", &self.closed)
.field("schema", &self.schema)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::streaming::source::create;
use arrow::array::{Float64Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
#[derive(Clone, Debug)]
struct TestEvent {
id: i64,
value: f64,
}
impl Record for TestEvent {
fn schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Float64, false),
]))
}
fn to_record_batch(&self) -> RecordBatch {
RecordBatch::try_new(
Self::schema(),
vec![
Arc::new(Int64Array::from(vec![self.id])),
Arc::new(Float64Array::from(vec![self.value])),
],
)
.unwrap()
}
}
#[tokio::test]
async fn test_poll_empty() {
let (_source, sink) = create::<TestEvent>(16);
let mut sub = sink.subscribe();
assert!(sub.poll().is_none());
}
#[tokio::test]
async fn test_single_subscriber_async() {
let (source, sink) = create::<TestEvent>(16);
let mut sub = sink.subscribe();
source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
let batch = sub.recv_async().await.unwrap();
assert_eq!(batch.num_rows(), 1);
}
#[tokio::test]
async fn test_multiple_subscribers_all_receive() {
let (source, sink) = create::<TestEvent>(16);
let mut sub1 = sink.subscribe();
let mut sub2 = sink.subscribe();
source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
let b1 = sub1.recv_async().await.unwrap();
let b2 = sub2.recv_async().await.unwrap();
assert_eq!(b1.num_rows(), 1);
assert_eq!(b2.num_rows(), 1);
}
#[tokio::test]
async fn test_disconnected_after_source_and_sink_drop() {
let (source, sink) = create::<TestEvent>(16);
let mut sub = sink.subscribe();
drop(source);
drop(sink);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(sub.recv_async().await.is_err());
assert!(sub.is_disconnected());
}
#[tokio::test]
async fn test_schema() {
let (_source, sink) = create::<TestEvent>(16);
let sub = sink.subscribe();
assert_eq!(sub.schema().fields().len(), 2);
}
}