use super::{Config, Error, Queue};
use crate::{Context, Persistable};
use commonware_codec::CodecShared;
use commonware_utils::{channel::mpsc, sync::AsyncMutex};
use std::{ops::Range, sync::Arc};
use tracing::debug;
pub struct Writer<E: Context, V: CodecShared> {
queue: Arc<AsyncMutex<Queue<E, V>>>,
notify: mpsc::Sender<()>,
}
impl<E: Context, V: CodecShared> Clone for Writer<E, V> {
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
notify: self.notify.clone(),
}
}
}
impl<E: Context, V: CodecShared> Writer<E, V> {
pub async fn enqueue(&self, item: V) -> Result<u64, Error> {
let pos = self.queue.lock().await.enqueue(item).await?;
let _ = self.notify.try_send(());
debug!(position = pos, "writer: enqueued item");
Ok(pos)
}
pub async fn enqueue_bulk(
&self,
items: impl IntoIterator<Item = V>,
) -> Result<Range<u64>, Error> {
let mut queue = self.queue.lock().await;
let start = queue.size().await;
for item in items {
queue.append(item).await?;
}
let end = queue.size().await;
if end > start {
queue.commit().await?;
}
drop(queue);
if start < end {
let _ = self.notify.try_send(());
}
debug!(start, end, "writer: enqueued bulk");
Ok(start..end)
}
pub async fn append(&self, item: V) -> Result<u64, Error> {
let pos = self.queue.lock().await.append(item).await?;
let _ = self.notify.try_send(());
debug!(position = pos, "writer: appended item");
Ok(pos)
}
pub async fn commit(&self) -> Result<(), Error> {
self.queue.lock().await.commit().await
}
pub async fn sync(&self) -> Result<(), Error> {
self.queue.lock().await.sync().await
}
pub async fn size(&self) -> u64 {
self.queue.lock().await.size().await
}
}
pub struct Reader<E: Context, V: CodecShared> {
queue: Arc<AsyncMutex<Queue<E, V>>>,
notify: mpsc::Receiver<()>,
}
impl<E: Context, V: CodecShared> Reader<E, V> {
pub async fn recv(&mut self) -> Result<Option<(u64, V)>, Error> {
loop {
if let Some(item) = self.queue.lock().await.dequeue().await? {
return Ok(Some(item));
}
if self.notify.recv().await.is_none() {
return self.queue.lock().await.dequeue().await;
}
}
}
pub async fn try_recv(&mut self) -> Result<Option<(u64, V)>, Error> {
let _ = self.notify.try_recv();
self.queue.lock().await.dequeue().await
}
pub async fn ack(&self, position: u64) -> Result<(), Error> {
self.queue.lock().await.ack(position).await
}
pub async fn ack_up_to(&self, up_to: u64) -> Result<(), Error> {
self.queue.lock().await.ack_up_to(up_to).await
}
pub async fn ack_floor(&self) -> u64 {
self.queue.lock().await.ack_floor()
}
pub async fn read_position(&self) -> u64 {
self.queue.lock().await.read_position()
}
pub async fn is_empty(&self) -> bool {
self.queue.lock().await.is_empty().await
}
pub async fn reset(&self) {
self.queue.lock().await.reset();
}
}
pub async fn init<E: Context, V: CodecShared>(
context: E,
cfg: Config<V::Cfg>,
) -> Result<(Writer<E, V>, Reader<E, V>), Error> {
let queue = Arc::new(AsyncMutex::new(Queue::init(context, cfg).await?));
let (notify_tx, notify_rx) = mpsc::channel(1);
let writer = Writer {
queue: queue.clone(),
notify: notify_tx,
};
let reader = Reader {
queue,
notify: notify_rx,
};
Ok((writer, reader))
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_codec::RangeCfg;
use commonware_macros::{select, test_traced};
use commonware_runtime::{
buffer::paged::CacheRef, deterministic, BufferPooler, Clock, Metrics, Runner, Spawner,
};
use commonware_utils::{NZUsize, NZU16, NZU64};
use std::num::{NonZeroU16, NonZeroUsize};
const PAGE_SIZE: NonZeroU16 = NZU16!(1024);
const PAGE_CACHE_SIZE: NonZeroUsize = NZUsize!(10);
fn test_config(partition: &str, pooler: &impl BufferPooler) -> Config<(RangeCfg<usize>, ())> {
Config {
partition: partition.into(),
items_per_section: NZU64!(10),
compression: None,
codec_config: ((0..).into(), ()),
page_cache: CacheRef::from_pooler(pooler, PAGE_SIZE, PAGE_CACHE_SIZE),
write_buffer: NZUsize!(4096),
}
}
#[test_traced]
fn test_shared_basic() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_basic", &context);
let (writer, mut reader) = init(context, cfg).await.unwrap();
let pos = writer.enqueue(b"hello".to_vec()).await.unwrap();
assert_eq!(pos, 0);
let (recv_pos, item) = reader.recv().await.unwrap().unwrap();
assert_eq!(recv_pos, 0);
assert_eq!(item, b"hello".to_vec());
reader.ack(recv_pos).await.unwrap();
assert!(reader.is_empty().await);
});
}
#[test_traced]
fn test_shared_append_commit() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_append_commit", &context);
let (writer, mut reader) = init(context, cfg).await.unwrap();
for i in 0..5u8 {
let pos = writer.append(vec![i]).await.unwrap();
assert_eq!(pos, i as u64);
}
let (pos, item) = reader.recv().await.unwrap().unwrap();
assert_eq!(pos, 0);
assert_eq!(item, vec![0]);
writer.commit().await.unwrap();
for i in 1..5 {
let (pos, item) = reader.recv().await.unwrap().unwrap();
assert_eq!(pos, i);
assert_eq!(item, vec![i as u8]);
reader.ack(pos).await.unwrap();
}
reader.ack(0).await.unwrap();
assert!(reader.is_empty().await);
});
}
#[test_traced]
fn test_shared_enqueue_bulk() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_bulk", &context);
let (writer, mut reader) = init(context, cfg).await.unwrap();
let range = writer
.enqueue_bulk((0..5u8).map(|i| vec![i]))
.await
.unwrap();
assert_eq!(range, 0..5);
for i in 0..5 {
let (pos, item) = reader.recv().await.unwrap().unwrap();
assert_eq!(pos, i);
assert_eq!(item, vec![i as u8]);
reader.ack(pos).await.unwrap();
}
assert!(reader.is_empty().await);
});
}
#[test_traced]
fn test_shared_concurrent() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_concurrent", &context);
let (writer, mut reader) = init(context.clone(), cfg).await.unwrap();
let writer_handle = context.with_label("writer").spawn(|_ctx| async move {
for i in 0..10u8 {
writer.enqueue(vec![i]).await.unwrap();
}
writer
});
let mut received = Vec::new();
for _ in 0..10 {
let (pos, item) = reader.recv().await.unwrap().unwrap();
received.push((pos, item.clone()));
reader.ack(pos).await.unwrap();
}
for (i, (pos, item)) in received.iter().enumerate() {
assert_eq!(*pos, i as u64);
assert_eq!(*item, vec![i as u8]);
}
let _ = writer_handle.await.unwrap();
});
}
#[test_traced]
fn test_shared_select() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_select", &context);
let (writer, mut reader) = init(context.clone(), cfg).await.unwrap();
writer.enqueue(b"test".to_vec()).await.unwrap();
let result = select! {
item = reader.recv() => item,
_ = context.sleep(std::time::Duration::from_secs(1)) => {
panic!("timeout")
},
};
let (pos, item) = result.unwrap().unwrap();
assert_eq!(pos, 0);
assert_eq!(item, b"test".to_vec());
reader.ack(pos).await.unwrap();
});
}
#[test_traced]
fn test_shared_writer_dropped() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_writer_dropped", &context);
let (writer, mut reader) = init(context.clone(), cfg).await.unwrap();
writer.enqueue(b"item1".to_vec()).await.unwrap();
writer.enqueue(b"item2".to_vec()).await.unwrap();
let queue = writer.queue.clone();
drop(writer);
let (pos1, _) = reader.recv().await.unwrap().unwrap();
reader.ack(pos1).await.unwrap();
let (pos2, _) = reader.recv().await.unwrap().unwrap();
reader.ack(pos2).await.unwrap();
let result = reader.recv().await.unwrap();
assert!(result.is_none());
drop(reader);
let _ = Arc::try_unwrap(queue)
.unwrap_or_else(|_| panic!("queue should have a single reference"))
.into_inner();
});
}
#[test_traced]
fn test_shared_try_recv() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_try_recv", &context);
let (writer, mut reader) = init(context, cfg).await.unwrap();
let result = reader.try_recv().await.unwrap();
assert!(result.is_none());
writer.enqueue(b"item".to_vec()).await.unwrap();
let (pos, item) = reader.try_recv().await.unwrap().unwrap();
assert_eq!(pos, 0);
assert_eq!(item, b"item".to_vec());
reader.ack(pos).await.unwrap();
});
}
#[test_traced]
fn test_shared_multiple_writers() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let cfg = test_config("test_shared_multi_writer", &context);
let (writer, mut reader) = init(context.clone(), cfg).await.unwrap();
let writer2 = writer.clone();
let handle1 = context.with_label("writer1").spawn(|_ctx| async move {
for i in 0..5u8 {
writer.enqueue(vec![i]).await.unwrap();
}
writer
});
let handle2 = context.with_label("writer2").spawn(|_ctx| async move {
for i in 5..10u8 {
writer2.enqueue(vec![i]).await.unwrap();
}
});
let mut received = Vec::new();
for _ in 0..10 {
let (pos, item) = reader.recv().await.unwrap().unwrap();
received.push(item[0]);
reader.ack(pos).await.unwrap();
}
received.sort();
assert_eq!(received, (0..10u8).collect::<Vec<_>>());
let _ = handle1.await.unwrap();
handle2.await.unwrap();
});
}
}