pub mod config;
pub mod messages;
use std::{cell::Cell, collections::VecDeque, marker::PhantomData, sync::Arc};
pub use config::{Buffer, BufferConfig};
use log::{debug, error, info};
pub use messages::PrefetchHandle;
use parking_lot::Mutex;
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
oneshot,
},
task::JoinHandle,
};
use crate::correlated_randomness::{
generator::CorrelationGenerator,
stream::{
buffered::{config::try_lock_config, messages::Command},
errors::CorrelatedStreamError,
futures::Next,
CorrelatedStream,
NextVec,
},
CorrelatedBatch,
};
pub struct BufferedStream<PB: CorrelatedBatch, E> {
command_sender: UnboundedSender<Command<PB, E>>,
_unsync_marker: PhantomData<Cell<()>>,
config: Arc<Mutex<BufferConfig>>,
dispatcher_handle: JoinHandle<()>,
generator_handle: JoinHandle<()>,
}
impl<PB: CorrelatedBatch, E> Buffer for BufferedStream<PB, E> {
fn config(&self) -> &Arc<Mutex<BufferConfig>> {
&self.config
}
}
impl<
PB: CorrelatedBatch,
E: From<CorrelatedStreamError> + Clone + Send + std::fmt::Debug + 'static,
> BufferedStream<PB, E>
{
pub fn new<G: CorrelationGenerator<PB> + Send + 'static>(
generator: G,
net: G::Net,
config: BufferConfig,
) -> Self
where
E: From<G::Error>,
{
Self::new_with_shared_config(generator, net, Arc::new(Mutex::new(config)))
}
pub fn new_with_shared_config<G: CorrelationGenerator<PB> + Send + 'static>(
generator: G,
net: G::Net,
config: Arc<Mutex<BufferConfig>>,
) -> Self
where
E: From<G::Error>,
{
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<Command<PB, E>>();
let (work_tx, work_rx) = mpsc::unbounded_channel::<usize>();
let (items_tx, items_rx) = mpsc::unbounded_channel::<Result<Vec<PB::Item>, E>>();
let generator_handle = tokio::spawn(generator_loop(generator, net, work_rx, items_tx));
let dispatcher_handle =
tokio::spawn(dispatcher_loop(cmd_rx, work_tx, items_rx, config.clone()));
Self {
command_sender: cmd_tx,
_unsync_marker: PhantomData,
config,
dispatcher_handle,
generator_handle,
}
}
pub fn prefetch_n(&self, n_elements: usize) -> PrefetchHandle<E> {
let (tx, rx) = oneshot::channel();
let max = match self.max_request_size() {
Ok(m) => m,
Err(e) => {
let _ = tx.send(Err(e.into()));
return PrefetchHandle::from(rx);
}
};
if n_elements > max {
let _ = tx.send(Err(CorrelatedStreamError::RequestTooLarge {
requested: n_elements,
max_allowed: max,
}
.into()));
return PrefetchHandle::from(rx);
}
let cmd = Command::Prefetch {
n_elements,
completion: tx,
};
if let Err(e) = self.command_sender.send(cmd) {
if let Command::Prefetch { completion, .. } = e.0 {
let _ = completion.send(Err(CorrelatedStreamError::StreamClosed.into()));
}
}
PrefetchHandle::from(rx)
}
pub async fn stop(self) {
let Self {
command_sender,
dispatcher_handle,
generator_handle,
..
} = self;
drop(command_sender);
let _ = dispatcher_handle.await;
let _ = generator_handle.await;
}
}
async fn generator_loop<
PB: CorrelatedBatch,
G: CorrelationGenerator<PB> + Send,
E: From<G::Error> + Send,
>(
mut generator: G,
mut net: G::Net,
mut work_rx: UnboundedReceiver<usize>,
items_tx: UnboundedSender<Result<Vec<PB::Item>, E>>,
) {
let log_prefix = format!("<Generator<{}>>", std::any::type_name::<PB>());
while let Some(n) = work_rx.recv().await {
debug!("{log_prefix} generating {n} elements");
let result = generator.run_for(n, &mut net).await.map_err(E::from);
let stop = result.is_err();
if items_tx.send(result).is_err() || stop {
break;
}
}
info!("{log_prefix} exiting");
}
type ItemsCollected<PB> = Vec<<PB as IntoIterator>::Item>;
type TotalNeeded = usize;
type BatchSender<PB, E> = oneshot::Sender<Result<Vec<<PB as IntoIterator>::Item>, E>>;
async fn dispatcher_loop<
PB: CorrelatedBatch,
E: From<CorrelatedStreamError> + Clone + Send + std::fmt::Debug,
>(
mut cmd_rx: UnboundedReceiver<Command<PB, E>>,
work_tx: UnboundedSender<usize>,
mut items_rx: UnboundedReceiver<Result<Vec<PB::Item>, E>>,
config: Arc<Mutex<BufferConfig>>,
) {
let log_prefix = format!("<Dispatcher<{}>>", std::any::type_name::<PB>());
let initial_cap = try_lock_config(&config).map(|c| c.capacity()).unwrap_or(0);
let mut buffer: VecDeque<PB::Item> = VecDeque::with_capacity(initial_cap);
let mut prefetch_demand: usize = 0;
let mut prefetch_completions: VecDeque<(usize, oneshot::Sender<Result<(), E>>)> =
VecDeque::new();
let mut gen_in_flight = false;
let mut pending_batches: VecDeque<(ItemsCollected<PB>, TotalNeeded, BatchSender<PB, E>)> =
VecDeque::new();
let mut shutdown_err: Option<E> = None;
macro_rules! lock_cfg {
() => {
match try_lock_config(&config) {
Ok(g) => g,
Err(e) => {
error!("{log_prefix} config lock timeout, shutting down");
shutdown_err = Some(e.into());
break;
}
}
};
}
macro_rules! outstanding_demand {
() => {{
let batch_shortfall: usize = pending_batches
.iter()
.map(|(collected, needed, _)| needed.saturating_sub(collected.len()))
.sum();
(batch_shortfall + prefetch_demand).saturating_sub(buffer.len())
}};
}
macro_rules! maybe_generate {
() => {
if !gen_in_flight {
let cfg = lock_cfg!();
let batch_shortfall: usize = pending_batches
.iter()
.map(|(collected, needed, _)| {
needed.saturating_sub(collected.len() + buffer.len())
})
.sum();
let buf_after = buffer.len().saturating_sub(batch_shortfall);
let need_buffer = cfg
.refill_threshold()
.saturating_sub(buf_after)
.max(prefetch_demand);
drop(cfg);
let need = batch_shortfall + need_buffer;
if need > 0 {
debug!("{log_prefix} requesting generation of {need} items");
gen_in_flight = true;
let _ = work_tx.send(need);
}
}
};
}
loop {
tokio::select! {
cmd = cmd_rx.recv() => {
let Some(cmd) = cmd else {
info!("{log_prefix} command channel closed, shutting down");
break;
};
match cmd {
Command::RequestN { n_elements, completion } => {
debug!("{log_prefix} batch request for {n_elements} items");
let cap = lock_cfg!().capacity();
if outstanding_demand!() + n_elements > cap {
let _ = completion.send(Err(CorrelatedStreamError::RateLimitExceeded.into()));
} else if buffer.len() >= n_elements {
let items: Vec<_> = buffer.drain(..n_elements).collect();
let _ = completion.send(Ok(items));
} else {
let collected: Vec<_> = buffer.drain(..).collect();
pending_batches.push_back((collected, n_elements, completion));
}
}
Command::Prefetch { n_elements, completion } => {
debug!("{log_prefix} prefetch {n_elements} items");
if buffer.len() >= n_elements {
let _ = completion.send(Ok(()));
} else {
let cap = lock_cfg!().capacity();
if outstanding_demand!() + n_elements > cap {
let _ = completion.send(Err(CorrelatedStreamError::RateLimitExceeded.into()));
} else {
let deficit = n_elements - buffer.len();
prefetch_demand += deficit;
prefetch_completions.push_back((deficit, completion));
}
}
}
}
maybe_generate!();
}
result = items_rx.recv() => {
let Some(result) = result else {
info!("{log_prefix} generator channel closed, shutting down");
break;
};
match result {
Ok(items) => {
gen_in_flight = false;
let generated = items.len();
debug!("{log_prefix} received {generated} items (pending_batches: {}, buffer: {})",
pending_batches.iter().map(|(c, n, _)| n - c.len()).sum::<usize>(),
buffer.len());
let mut iter = items.into_iter();
while let Some((ref mut collected, needed, _)) = pending_batches.front_mut() {
let shortfall = *needed - collected.len();
collected.extend(iter.by_ref().take(shortfall));
if collected.len() < *needed { break; } let (collected, _, tx) = pending_batches.pop_front().unwrap();
let _ = tx.send(Ok(collected));
}
buffer.extend(iter);
prefetch_demand = prefetch_demand.saturating_sub(generated);
let mut credit = generated;
while credit > 0 {
let Some((deficit, _)) = prefetch_completions.front_mut() else { break; };
let used = (*deficit).min(credit);
*deficit -= used;
credit -= used;
if *deficit > 0 { break; }
let (_, tx) = prefetch_completions.pop_front().unwrap();
let _ = tx.send(Ok(()));
}
maybe_generate!();
}
Err(e) => {
error!("{log_prefix} generation error, shutting down: {e:?}");
for (_, _, tx) in pending_batches.drain(..) { let _ = tx.send(Err(e.clone())); }
for (_, tx) in prefetch_completions.drain(..) { let _ = tx.send(Err(e.clone())); }
return;
}
}
}
}
}
let final_err: E = shutdown_err.unwrap_or_else(|| CorrelatedStreamError::StreamClosed.into());
for (_, _, tx) in pending_batches.drain(..) {
let _ = tx.send(Err(final_err.clone()));
}
for (_, tx) in prefetch_completions.drain(..) {
let _ = tx.send(Err(final_err.clone()));
}
}
impl<
PB: CorrelatedBatch,
E: From<CorrelatedStreamError> + Clone + Send + std::fmt::Debug + 'static,
> CorrelatedStream<PB::Item> for BufferedStream<PB, E>
{
type Error = E;
fn next_n(&self, n_elements: usize) -> Result<NextVec<PB::Item, E>, CorrelatedStreamError> {
if n_elements == 0 {
return Ok(NextVec::default());
}
let max_allowed = self.max_request_size()?;
if n_elements > max_allowed {
return Err(CorrelatedStreamError::RequestTooLarge {
requested: n_elements,
max_allowed,
});
}
let (tx, rx) = oneshot::channel();
self.command_sender
.send(Command::RequestN {
n_elements,
completion: tx,
})
.map_err(|e| CorrelatedStreamError::SendError(e.to_string()))?;
Ok(NextVec {
future: Next(rx),
size: n_elements,
})
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use rand::{rngs::StdRng, SeedableRng};
use typenum::U2;
use crate::{
algebra::elliptic_curve::{Curve25519Ristretto, ScalarField},
correlated_randomness::{
generator::CorrelationGenerator,
singlets::{Singlet, Singlets},
stream::{
buffered::{Buffer, BufferConfig, BufferedStream},
errors::CorrelatedStreamError,
CorrelatedStream,
},
},
random::Random,
utils::TryFuture,
};
type Fq = ScalarField<Curve25519Ristretto>;
type TestPB = Singlets<Fq, U2>;
type TestItem = Singlet<Fq>;
type TestErr = CorrelatedStreamError;
#[derive(Clone)]
struct MockGenConfig {
delay: Duration,
fail_after: Option<usize>,
}
impl Default for MockGenConfig {
fn default() -> Self {
Self {
delay: Duration::from_millis(0),
fail_after: None,
}
}
}
struct MockGen {
rng: StdRng,
cfg: MockGenConfig,
items_produced: Arc<AtomicUsize>,
batches: Arc<AtomicUsize>,
}
impl MockGen {
fn new(cfg: MockGenConfig) -> (Self, Arc<AtomicUsize>, Arc<AtomicUsize>) {
let items_produced = Arc::new(AtomicUsize::new(0));
let batches = Arc::new(AtomicUsize::new(0));
let gen = Self {
rng: StdRng::from_seed([0u8; 32]),
cfg,
items_produced: items_produced.clone(),
batches: batches.clone(),
};
(gen, items_produced, batches)
}
}
impl CorrelationGenerator<TestPB> for MockGen {
type Net = ();
type Error = TestErr;
fn run(&mut self, _net: &mut ()) -> impl TryFuture<Ok = TestPB, Error = Self::Error> {
async move { Err(CorrelatedStreamError::StreamClosed) }
}
fn run_for(
&mut self,
n: usize,
_net: &mut (),
) -> impl TryFuture<Ok = Vec<TestItem>, Error = Self::Error> {
async move {
self.batches.fetch_add(1, Ordering::SeqCst);
if self.cfg.delay > Duration::ZERO {
tokio::time::sleep(self.cfg.delay).await;
}
if let Some(threshold) = self.cfg.fail_after {
if self.items_produced.load(Ordering::SeqCst) + n > threshold {
return Err(CorrelatedStreamError::StreamClosed);
}
}
let items: Vec<TestItem> = (0..n)
.map(|_| {
Singlet::<Fq>::random_n::<Vec<_>>(&mut self.rng, 1)
.into_iter()
.next()
.unwrap()
})
.collect();
self.items_produced.fetch_add(n, Ordering::SeqCst);
Ok(items)
}
}
}
fn make_stream(
cfg: MockGenConfig,
buf_cfg: BufferConfig,
) -> (
BufferedStream<TestPB, TestErr>,
Arc<AtomicUsize>,
Arc<AtomicUsize>,
) {
let (gen, produced, batches) = MockGen::new(cfg);
let stream = BufferedStream::<TestPB, TestErr>::new(gen, (), buf_cfg);
(stream, produced, batches)
}
#[tokio::test]
async fn next_n_resolves_batch_future() {
let (stream, _, _) = make_stream(MockGenConfig::default(), BufferConfig::eager(16));
let fut = stream.next_n(7).expect("request accepted");
let items = fut.await.expect("batch resolves");
assert_eq!(items.len(), 7);
}
#[tokio::test]
async fn request_too_large_rejected() {
let (stream, _, _) = make_stream(
MockGenConfig::default(),
BufferConfig::eager_with(8, 4), );
match stream.next_n(5) {
Err(CorrelatedStreamError::RequestTooLarge {
requested: 5,
max_allowed: 4,
}) => {}
Ok(_) => panic!("must reject n>max"),
Err(e) => panic!("unexpected error: {e:?}"),
}
}
#[tokio::test]
async fn rate_limit_when_exceeding_capacity() {
let cfg = MockGenConfig {
delay: Duration::from_millis(200),
..Default::default()
};
let (stream, _, _) = make_stream(cfg, BufferConfig::eager(4));
let _f1 = stream.next_n(4).unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
let f2 = stream.next_n(4).unwrap();
let results = futures::future::join_all(f2).await;
assert!(
results
.iter()
.all(|r| matches!(r, Err(CorrelatedStreamError::RateLimitExceeded))),
"expected all four to be rate-limited, got {results:?}"
);
}
#[tokio::test]
async fn prefetch_completes_and_serves_subsequent_requests_quickly() {
let cfg = MockGenConfig {
delay: Duration::from_millis(100),
..Default::default()
};
let (stream, produced, _) = make_stream(cfg, BufferConfig::eager(32));
let handle = stream.prefetch_n(10);
handle.await.expect("prefetch completes");
assert!(produced.load(Ordering::SeqCst) >= 10);
let start = std::time::Instant::now();
let items = stream.next_n(10).unwrap().await.expect("served");
assert_eq!(items.len(), 10);
assert!(
start.elapsed() < Duration::from_millis(80),
"request should be served from prefetched buffer (took {:?})",
start.elapsed()
);
}
#[tokio::test]
async fn sequential_prefetches_all_resolve() {
let (stream, _, _) = make_stream(MockGenConfig::default(), BufferConfig::lazy(16, 0));
for i in 0..2 {
tokio::time::timeout(Duration::from_secs(1), stream.prefetch_n(4))
.await
.unwrap_or_else(|_| panic!("prefetch {i} timed out"))
.expect("prefetch completes");
}
tokio::time::timeout(Duration::from_secs(1), stream.prefetch_n(8))
.await
.expect("partially covered prefetch timed out")
.expect("prefetch completes");
}
#[tokio::test]
async fn generator_error_propagates_to_pending_consumers() {
let cfg = MockGenConfig {
delay: Duration::from_millis(20),
fail_after: Some(0), };
let (stream, _, _) = make_stream(cfg, BufferConfig::eager(16));
let futs = stream.next_n(4).unwrap();
let results = futures::future::join_all(futs).await;
assert!(
results
.iter()
.all(|r| matches!(r, Err(CorrelatedStreamError::StreamClosed))),
"all consumers should receive the generator error"
);
}
#[tokio::test]
async fn fifo_order_across_two_batches() {
let cfg = MockGenConfig {
delay: Duration::from_millis(40),
..Default::default()
};
let (stream, _, batches) = make_stream(cfg, BufferConfig::eager(32));
let f1 = stream.next_n(3).unwrap();
let f2 = stream.next_n(3).unwrap();
let (a, b) = tokio::join!(f1, f2);
let a = a.expect("first batch resolves");
let b = b.expect("second batch resolves");
assert_eq!(a.len(), 3);
assert_eq!(b.len(), 3);
assert!(batches.load(Ordering::SeqCst) >= 1);
}
#[tokio::test]
async fn buffer_config_setters_are_visible() {
let (stream, _, _) = make_stream(MockGenConfig::default(), BufferConfig::eager(16));
assert_eq!(stream.capacity().unwrap(), 16);
stream.set_capacity(32).unwrap();
assert_eq!(stream.capacity().unwrap(), 32);
}
#[tokio::test]
async fn next_n_rejects_too_large() {
let (stream, _, _) = make_stream(
MockGenConfig::default(),
BufferConfig::eager_with(8, 4), );
match stream.next_n(5) {
Err(CorrelatedStreamError::RequestTooLarge {
requested: 5,
max_allowed: 4,
}) => {}
Ok(_) => panic!("must reject n > max"),
Err(e) => panic!("unexpected error: {e:?}"),
}
}
#[tokio::test]
async fn next_n_rate_limit_when_exceeding_capacity() {
let cfg = MockGenConfig {
delay: Duration::from_millis(200),
..Default::default()
};
let (stream, _, _) = make_stream(cfg, BufferConfig::eager(4));
let _f1 = stream.next_n(4).unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
let f2 = stream.next_n(4).unwrap();
assert!(
matches!(f2.await, Err(CorrelatedStreamError::RateLimitExceeded)),
"second next_n should be rate-limited"
);
}
#[tokio::test]
async fn next_n_admitted_when_shortfall_fits_capacity() {
let cfg = MockGenConfig {
delay: Duration::from_millis(50),
..Default::default()
};
let (stream, _, _) = make_stream(cfg, BufferConfig::lazy(8, 0));
let f1 = stream.next_n(4).unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
let f2 = stream.next_n(4).unwrap();
let (a, b) = tokio::join!(f1, f2);
assert_eq!(a.expect("first batch resolves").len(), 4);
assert_eq!(b.expect("second batch resolves").len(), 4);
}
#[tokio::test]
async fn next_n_error_propagates() {
let cfg = MockGenConfig {
delay: Duration::from_millis(20), fail_after: Some(0),
};
let (stream, _, _) = make_stream(cfg, BufferConfig::eager(16));
let result = stream.next_n(4).unwrap().await;
assert!(
matches!(result, Err(CorrelatedStreamError::StreamClosed)),
"expected generator error to propagate through BatchFuture, got {result:?}"
);
}
#[tokio::test]
async fn refill_threshold_drives_proactive_generation() {
let (stream, produced, _) =
make_stream(MockGenConfig::default(), BufferConfig::lazy(16, 8));
let _ = stream.next_n(4).unwrap().await.expect("served");
tokio::time::sleep(Duration::from_millis(50)).await;
let total = produced.load(Ordering::SeqCst);
assert!(total >= 8, "expected >= 8 items generated, got {total}");
}
}