use crate::error::SdkError;
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub buffer_size: usize,
pub timeout_secs: Option<u64>,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
buffer_size: 64,
timeout_secs: None,
}
}
}
impl StreamConfig {
pub fn new(buffer_size: usize) -> Self {
Self {
buffer_size,
timeout_secs: None,
}
}
#[must_use]
pub fn with_timeout(mut self, secs: u64) -> Self {
self.timeout_secs = Some(secs);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Row {
pub key: Vec<u8>,
pub value: Vec<u8>,
}
impl Row {
pub fn new(key: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) -> Self {
Self {
key: key.into(),
value: value.into(),
}
}
}
pub struct RowSender {
tx: mpsc::Sender<Result<Row, SdkError>>,
cancel: CancellationToken,
pub sent: Arc<AtomicUsize>,
}
impl RowSender {
pub async fn send_row(&self, row: Row) -> bool {
if self.cancel.is_cancelled() {
return false;
}
self.sent.fetch_add(1, Ordering::Relaxed);
self.tx.send(Ok(row)).await.is_ok()
}
pub fn is_cancelled(&self) -> bool {
self.cancel.is_cancelled()
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel.clone()
}
pub async fn send_error(&self, err: SdkError) -> bool {
if self.cancel.is_cancelled() {
return false;
}
self.tx.send(Err(err)).await.is_ok()
}
}
pub struct QueryStream {
rx: mpsc::Receiver<Result<Row, SdkError>>,
cancel: CancellationToken,
}
impl QueryStream {
pub fn new(config: &StreamConfig) -> (Self, RowSender) {
let (tx, rx) = mpsc::channel(config.buffer_size);
let cancel = CancellationToken::new();
let sent = Arc::new(AtomicUsize::new(0));
let sender = RowSender {
tx,
cancel: cancel.clone(),
sent,
};
let stream = Self { rx, cancel };
(stream, sender)
}
pub fn cancel(&self) {
self.cancel.cancel();
}
}
impl Drop for QueryStream {
fn drop(&mut self) {
self.cancel.cancel();
}
}
impl Stream for QueryStream {
type Item = Result<Row, SdkError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
#[doc(hidden)]
pub fn spawn_stub_producer(
query_collection: String,
total_rows: usize,
sender: RowSender,
timeout_secs: Option<u64>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let deadline = timeout_secs.map(|s| tokio::time::Instant::now() + Duration::from_secs(s));
for i in 0..total_rows {
if sender.is_cancelled() {
break;
}
if let Some(dl) = deadline {
if tokio::time::Instant::now() >= dl {
break;
}
}
let key =
format!("{collection}:row:{i}", collection = query_collection, i = i).into_bytes();
let value = (i as u64).to_le_bytes().to_vec();
let row = Row::new(key, value);
if !sender.send_row(row).await {
break;
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[tokio::test]
async fn test_stream_config_defaults() {
let cfg = StreamConfig::default();
assert_eq!(cfg.buffer_size, 64);
assert!(cfg.timeout_secs.is_none());
}
#[tokio::test]
async fn test_row_construction() {
let row = Row::new(b"key".to_vec(), b"value".to_vec());
assert_eq!(row.key, b"key");
assert_eq!(row.value, b"value");
}
#[tokio::test]
async fn test_stream_collects_rows() {
let config = StreamConfig::new(16);
let (stream, sender) = QueryStream::new(&config);
let _handle = spawn_stub_producer("test".to_string(), 5, sender, None);
let rows: Vec<_> = stream.collect().await;
assert_eq!(rows.len(), 5);
for r in &rows {
assert!(r.is_ok());
}
}
#[tokio::test]
async fn test_stream_cancellation_stops_producer() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::time::{Duration, sleep};
let config = StreamConfig::new(4);
let (stream, sender) = QueryStream::new(&config);
let finished = Arc::new(AtomicBool::new(false));
let finished_clone = Arc::clone(&finished);
let _handle = tokio::spawn(async move {
spawn_stub_producer("cancel_test".to_string(), 1_000, sender, None)
.await
.ok();
finished_clone.store(true, Ordering::Release);
});
let mut s = stream;
let _ = s.next().await;
let _ = s.next().await;
drop(s);
let deadline = tokio::time::Instant::now() + Duration::from_secs(1);
while !finished.load(Ordering::Acquire) {
if tokio::time::Instant::now() >= deadline {
panic!("producer task did not stop within 1 second after stream was dropped");
}
sleep(Duration::from_millis(10)).await;
}
}
}