use crate::{BufferError, DynamicCircularBuffer};
use futures_core::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::Notify;
pub struct BufferStream<T> {
buffer: Arc<DynamicCircularBuffer<T>>,
notify: Arc<Notify>,
poll_interval: Option<Duration>,
closed: bool,
}
impl<T> BufferStream<T> {
pub fn new(buffer: Arc<DynamicCircularBuffer<T>>) -> Self {
Self {
buffer,
notify: Arc::new(Notify::new()),
poll_interval: Some(Duration::from_millis(10)),
closed: false,
}
}
pub fn with_notify(buffer: Arc<DynamicCircularBuffer<T>>, notify: Arc<Notify>) -> Self {
Self {
buffer,
notify,
poll_interval: None,
closed: false,
}
}
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = Some(interval);
self
}
pub fn notify(&self) -> Arc<Notify> {
self.notify.clone()
}
pub fn close(&mut self) {
self.closed = true;
}
}
impl<T: Send + Sync + 'static> Stream for BufferStream<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.closed {
return Poll::Ready(None);
}
match this.buffer.pop() {
Ok(item) => Poll::Ready(Some(item)),
Err(BufferError::Empty) => {
let waker = cx.waker().clone();
let notify = this.notify.clone();
if let Some(interval) = this.poll_interval {
tokio::spawn(async move {
tokio::select! {
_ = notify.notified() => {}
_ = tokio::time::sleep(interval) => {}
}
waker.wake();
});
} else {
tokio::spawn(async move {
notify.notified().await;
waker.wake();
});
}
Poll::Pending
}
Err(_) => Poll::Ready(None), }
}
}
pub struct BufferSink<T> {
buffer: Arc<DynamicCircularBuffer<T>>,
notify: Option<Arc<Notify>>,
}
impl<T> BufferSink<T> {
pub fn new(buffer: Arc<DynamicCircularBuffer<T>>) -> Self {
Self {
buffer,
notify: None,
}
}
pub fn with_notify(buffer: Arc<DynamicCircularBuffer<T>>, notify: Arc<Notify>) -> Self {
Self {
buffer,
notify: Some(notify),
}
}
}
impl<T: Send + Sync + 'static> BufferSink<T> {
pub async fn send(&self, item: T) -> Result<(), BufferError> {
match self.buffer.push(item) {
Ok(()) => {
if let Some(ref notify) = self.notify {
notify.notify_one();
}
Ok(())
}
Err(e) => Err(e),
}
}
pub async fn send_batch(&self, items: Vec<T>) -> Result<(), BufferError> {
self.buffer.push_batch(items)?;
if let Some(ref notify) = self.notify {
notify.notify_one();
}
Ok(())
}
}
pub trait BufferStreamExt<T> {
fn into_stream(self: Arc<Self>) -> BufferStream<T>;
fn into_sink(self: Arc<Self>) -> BufferSink<T>;
fn stream_sink_pair(self: Arc<Self>) -> (BufferStream<T>, BufferSink<T>);
}
impl<T: Send + Sync + 'static> BufferStreamExt<T> for DynamicCircularBuffer<T> {
fn into_stream(self: Arc<Self>) -> BufferStream<T> {
BufferStream::new(self)
}
fn into_sink(self: Arc<Self>) -> BufferSink<T> {
BufferSink::new(self)
}
fn stream_sink_pair(self: Arc<Self>) -> (BufferStream<T>, BufferSink<T>) {
let notify = Arc::new(Notify::new());
let stream = BufferStream::with_notify(self.clone(), notify.clone());
let sink = BufferSink::with_notify(self, notify);
(stream, sink)
}
}
pub struct BufferChannel<T> {
buffer: Arc<DynamicCircularBuffer<T>>,
notify: Arc<Notify>,
}
impl<T: Send + Sync + 'static> BufferChannel<T> {
pub fn new(buffer: Arc<DynamicCircularBuffer<T>>) -> Self {
Self {
buffer,
notify: Arc::new(Notify::new()),
}
}
pub async fn send(&self, item: T) -> Result<(), BufferError> {
self.buffer.push(item)?;
self.notify.notify_one();
Ok(())
}
pub async fn recv(&self) -> Result<T, BufferError> {
loop {
match self.buffer.pop() {
Ok(item) => return Ok(item),
Err(BufferError::Empty) => {
self.notify.notified().await;
}
Err(e) => return Err(e),
}
}
}
pub fn try_recv(&self) -> Result<T, BufferError> {
self.buffer.pop()
}
pub async fn recv_timeout(&self, timeout: Duration) -> Result<T, BufferError> {
tokio::time::timeout(timeout, self.recv())
.await
.map_err(|_| BufferError::Timeout(timeout))?
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn stream(&self) -> BufferStream<T> {
BufferStream::with_notify(self.buffer.clone(), self.notify.clone())
}
pub fn sink(&self) -> BufferSink<T> {
BufferSink::with_notify(self.buffer.clone(), self.notify.clone())
}
}
impl<T: Send + Sync + 'static> Clone for BufferChannel<T> {
fn clone(&self) -> Self {
Self {
buffer: self.buffer.clone(),
notify: self.notify.clone(),
}
}
}
#[cfg(all(test, not(feature = "async")))]
mod tests {
use super::*;
use crate::Config;
use tokio_stream::StreamExt;
#[tokio::test]
async fn test_stream_basic() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
buffer.push(1).unwrap();
buffer.push(2).unwrap();
buffer.push(3).unwrap();
let mut stream = BufferStream::new(buffer);
stream.poll_interval = Some(Duration::from_millis(1));
let mut items = Vec::new();
let timeout = tokio::time::timeout(Duration::from_millis(100), async {
while let Some(item) = stream.next().await {
items.push(item);
if items.len() >= 3 {
break;
}
}
});
let _ = timeout.await;
assert_eq!(items, vec![1, 2, 3]);
}
#[tokio::test]
async fn test_sink_basic() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
let sink = BufferSink::new(buffer.clone());
sink.send(1).await.unwrap();
sink.send(2).await.unwrap();
sink.send(3).await.unwrap();
assert_eq!(buffer.len(), 3);
assert_eq!(buffer.pop().unwrap(), 1);
assert_eq!(buffer.pop().unwrap(), 2);
assert_eq!(buffer.pop().unwrap(), 3);
}
#[tokio::test]
async fn test_stream_sink_pair() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
let (mut stream, sink) = buffer.stream_sink_pair();
stream.poll_interval = Some(Duration::from_millis(1));
let sink_task = tokio::spawn(async move {
for i in 0..5 {
sink.send(i).await.unwrap();
tokio::time::sleep(Duration::from_millis(5)).await;
}
});
let mut items = Vec::new();
let timeout = tokio::time::timeout(Duration::from_millis(200), async {
while let Some(item) = stream.next().await {
items.push(item);
if items.len() >= 5 {
break;
}
}
});
let _ = timeout.await;
let _ = sink_task.await;
assert_eq!(items.len(), 5);
}
#[tokio::test]
async fn test_channel_basic() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
let channel = BufferChannel::new(buffer);
channel.send(1).await.unwrap();
channel.send(2).await.unwrap();
assert_eq!(channel.recv().await.unwrap(), 1);
assert_eq!(channel.recv().await.unwrap(), 2);
}
#[tokio::test]
async fn test_channel_try_recv() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
let channel = BufferChannel::new(buffer);
assert!(matches!(channel.try_recv(), Err(BufferError::Empty)));
channel.send(42).await.unwrap();
assert_eq!(channel.try_recv().unwrap(), 42);
}
#[tokio::test]
async fn test_channel_timeout() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
let channel = BufferChannel::new(buffer);
let result = channel.recv_timeout(Duration::from_millis(10)).await;
assert!(matches!(result, Err(BufferError::Timeout(_))));
}
#[tokio::test]
async fn test_channel_concurrent() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
let channel = BufferChannel::new(buffer);
let sender = channel.clone();
let receiver = channel.clone();
let send_task = tokio::spawn(async move {
for i in 0..100 {
sender.send(i).await.unwrap();
}
});
let recv_task = tokio::spawn(async move {
let mut sum = 0;
for _ in 0..100 {
let item = receiver.recv().await.unwrap();
sum += item;
}
sum
});
send_task.await.unwrap();
let sum = recv_task.await.unwrap();
assert_eq!(sum, 4950);
}
#[tokio::test]
async fn test_extension_trait() {
let buffer = Arc::new(DynamicCircularBuffer::<i32>::new(Config::default()).unwrap());
let stream = buffer.clone().into_stream();
assert!(stream.buffer.is_empty());
let sink = buffer.clone().into_sink();
sink.send(42).await.unwrap();
assert_eq!(buffer.len(), 1);
}
}