#![allow(missing_docs)]
use crate::error::{IoError, Result};
use std::collections::VecDeque;
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
struct BoundedState<T> {
queue: VecDeque<T>,
capacity: usize,
closed: bool,
}
impl<T> BoundedState<T> {
fn new(capacity: usize) -> Self {
Self {
queue: VecDeque::with_capacity(capacity),
capacity,
closed: false,
}
}
}
pub struct BoundedBuffer<T: Send> {
state: Mutex<BoundedState<T>>,
not_full: Condvar,
not_empty: Condvar,
}
impl<T: Send> BoundedBuffer<T> {
pub fn new(capacity: usize) -> Arc<Self> {
Arc::new(Self {
state: Mutex::new(BoundedState::new(capacity.max(1))),
not_full: Condvar::new(),
not_empty: Condvar::new(),
})
}
pub fn push(&self, item: T) -> Result<()> {
let mut guard = self
.state
.lock()
.map_err(|_| IoError::Other("BoundedBuffer: mutex poisoned".to_string()))?;
loop {
if guard.closed {
return Err(IoError::Other("BoundedBuffer: closed".to_string()));
}
if guard.queue.len() < guard.capacity {
guard.queue.push_back(item);
self.not_empty.notify_one();
return Ok(());
}
guard = self
.not_full
.wait(guard)
.map_err(|_| IoError::Other("BoundedBuffer: condvar poisoned".to_string()))?;
}
}
pub fn push_timeout(&self, item: T, timeout: Duration) -> Result<bool> {
let deadline = Instant::now() + timeout;
let mut guard = self
.state
.lock()
.map_err(|_| IoError::Other("BoundedBuffer: mutex poisoned".to_string()))?;
loop {
if guard.closed {
return Err(IoError::Other("BoundedBuffer: closed".to_string()));
}
if guard.queue.len() < guard.capacity {
guard.queue.push_back(item);
self.not_empty.notify_one();
return Ok(true);
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Ok(false);
}
let (new_guard, timed_out) = self
.not_full
.wait_timeout(guard, remaining)
.map_err(|_| IoError::Other("BoundedBuffer: condvar poisoned".to_string()))?;
guard = new_guard;
if timed_out.timed_out() {
return Ok(false);
}
}
}
pub fn pop(&self) -> Result<Option<T>> {
let mut guard = self
.state
.lock()
.map_err(|_| IoError::Other("BoundedBuffer: mutex poisoned".to_string()))?;
loop {
if let Some(item) = guard.queue.pop_front() {
self.not_full.notify_one();
return Ok(Some(item));
}
if guard.closed {
return Ok(None);
}
guard = self
.not_empty
.wait(guard)
.map_err(|_| IoError::Other("BoundedBuffer: condvar poisoned".to_string()))?;
}
}
pub fn pop_timeout(&self, timeout: Duration) -> Result<Option<T>> {
let deadline = Instant::now() + timeout;
let mut guard = self
.state
.lock()
.map_err(|_| IoError::Other("BoundedBuffer: mutex poisoned".to_string()))?;
loop {
if let Some(item) = guard.queue.pop_front() {
self.not_full.notify_one();
return Ok(Some(item));
}
if guard.closed {
return Ok(None);
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Ok(None);
}
let (new_guard, timed_out) = self
.not_empty
.wait_timeout(guard, remaining)
.map_err(|_| IoError::Other("BoundedBuffer: condvar poisoned".to_string()))?;
guard = new_guard;
if timed_out.timed_out() {
return Ok(None);
}
}
}
pub fn try_pop(&self) -> Result<Option<T>> {
let mut guard = self
.state
.lock()
.map_err(|_| IoError::Other("BoundedBuffer: mutex poisoned".to_string()))?;
if let Some(item) = guard.queue.pop_front() {
self.not_full.notify_one();
Ok(Some(item))
} else {
Ok(None)
}
}
pub fn len(&self) -> usize {
self.state.lock().map(|g| g.queue.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn close(&self) -> Result<()> {
let mut guard = self
.state
.lock()
.map_err(|_| IoError::Other("BoundedBuffer: mutex poisoned".to_string()))?;
guard.closed = true;
self.not_empty.notify_all();
self.not_full.notify_all();
Ok(())
}
pub fn is_closed(&self) -> bool {
self.state.lock().map(|g| g.closed).unwrap_or(true)
}
pub fn capacity(&self) -> usize {
self.state.lock().map(|g| g.capacity).unwrap_or(0)
}
}
pub struct ThrottleTransform {
rate: f64,
state: Mutex<ThrottleState>,
label: String,
}
struct ThrottleState {
tokens: f64,
last_refill: Instant,
}
impl ThrottleTransform {
pub fn new(rate: f64) -> Self {
assert!(rate > 0.0, "ThrottleTransform: rate must be > 0");
Self {
rate,
state: Mutex::new(ThrottleState {
tokens: rate, last_refill: Instant::now(),
}),
label: "throttle".to_string(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.label = name.into();
self
}
pub fn acquire(&self) -> Result<()> {
let mut state = self
.state
.lock()
.map_err(|_| IoError::Other("ThrottleTransform: mutex poisoned".to_string()))?;
let now = Instant::now();
let elapsed = now.duration_since(state.last_refill).as_secs_f64();
state.tokens = (state.tokens + elapsed * self.rate).min(self.rate);
state.last_refill = now;
if state.tokens >= 1.0 {
state.tokens -= 1.0;
return Ok(());
}
let deficit = 1.0 - state.tokens;
let sleep_secs = deficit / self.rate;
let sleep_dur = Duration::from_secs_f64(sleep_secs);
drop(state); std::thread::sleep(sleep_dur);
let mut state = self
.state
.lock()
.map_err(|_| IoError::Other("ThrottleTransform: mutex poisoned".to_string()))?;
let now2 = Instant::now();
let elapsed2 = now2.duration_since(state.last_refill).as_secs_f64();
state.tokens = (state.tokens + elapsed2 * self.rate).min(self.rate);
state.last_refill = now2;
state.tokens = (state.tokens - 1.0).max(0.0);
Ok(())
}
pub fn throttle_batch<T>(&self, items: Vec<T>) -> Result<Vec<T>> {
let mut out = Vec::with_capacity(items.len());
for item in items {
self.acquire()?;
out.push(item);
}
Ok(out)
}
}
pub struct BatchCollector<T: Send> {
buffer: Arc<BoundedBuffer<T>>,
batch_size: usize,
timeout: Duration,
label: String,
}
impl<T: Send + 'static> BatchCollector<T> {
pub fn new(batch_size: usize, timeout: Duration, internal_capacity: usize) -> Self {
Self {
buffer: BoundedBuffer::new(internal_capacity.max(batch_size)),
batch_size,
timeout,
label: "batch_collector".to_string(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.label = name.into();
self
}
pub fn push(&self, item: T) -> Result<()> {
self.buffer.push(item)
}
pub fn push_timeout(&self, item: T, timeout: Duration) -> Result<bool> {
self.buffer.push_timeout(item, timeout)
}
pub fn collect_batch(&self) -> Result<Vec<T>> {
let deadline = Instant::now() + self.timeout;
let mut batch = Vec::with_capacity(self.batch_size);
while batch.len() < self.batch_size {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
break;
}
match self.buffer.pop_timeout(remaining)? {
Some(item) => batch.push(item),
None => break, }
}
Ok(batch)
}
pub fn close(&self) -> Result<()> {
self.buffer.close()
}
pub fn pending(&self) -> usize {
self.buffer.len()
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn timeout(&self) -> Duration {
self.timeout
}
pub fn name(&self) -> &str {
&self.label
}
}
#[derive(Debug, Clone, Default)]
pub struct FlowStats {
pub produced: usize,
pub consumed: usize,
pub producer_stalls: usize,
pub elapsed: Duration,
}
pub struct FlowController<T: Send + Clone + 'static> {
buffer: Arc<BoundedBuffer<T>>,
}
impl<T: Send + Clone + 'static> FlowController<T> {
pub fn new(capacity: usize) -> Self {
Self {
buffer: BoundedBuffer::new(capacity),
}
}
pub fn buffer(&self) -> Arc<BoundedBuffer<T>> {
Arc::clone(&self.buffer)
}
pub fn run<P, C>(&self, mut producer: P, mut consumer: C) -> Result<FlowStats>
where
P: FnMut() -> Option<T>,
C: FnMut(T) -> Result<()>,
{
let start = Instant::now();
let mut stats = FlowStats::default();
while let Some(item) = producer() {
let pushed = self.buffer.push_timeout(item.clone(), Duration::from_millis(100))?;
if !pushed {
stats.producer_stalls += 1;
let pushed2 = self.buffer.push_timeout(item, Duration::from_millis(200))?;
if !pushed2 {
return Err(IoError::Other(
"FlowController::run: buffer full in single-threaded mode; increase capacity or use multi-threaded operation".to_string(),
));
}
}
stats.produced += 1;
}
self.buffer.close()?;
loop {
match self.buffer.pop()? {
Some(item) => {
consumer(item)?;
stats.consumed += 1;
}
None => break,
}
}
stats.elapsed = start.elapsed();
Ok(stats)
}
pub fn run_with_timeout<P, C>(
&self,
deadline: Duration,
mut producer: P,
mut consumer: C,
) -> Result<FlowStats>
where
P: FnMut() -> Option<T>,
C: FnMut(T) -> Result<()>,
{
let start = Instant::now();
let wall_deadline = start + deadline;
let mut stats = FlowStats::default();
while let Some(item) = producer() {
let remaining = wall_deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
let _ = self.buffer.close();
return Err(IoError::Other(
"FlowController::run_with_timeout: deadline exceeded during produce phase"
.to_string(),
));
}
let push_timeout = remaining.min(Duration::from_millis(200));
let pushed = self.buffer.push_timeout(item.clone(), push_timeout)?;
if !pushed {
stats.producer_stalls += 1;
let remaining2 = wall_deadline.saturating_duration_since(Instant::now());
if remaining2.is_zero() {
let _ = self.buffer.close();
return Err(IoError::Other(
"FlowController::run_with_timeout: deadline exceeded; buffer full"
.to_string(),
));
}
let pushed2 = self
.buffer
.push_timeout(item, remaining2.min(Duration::from_millis(400)))?;
if !pushed2 {
let _ = self.buffer.close();
return Err(IoError::Other(
"FlowController::run_with_timeout: buffer full; increase capacity"
.to_string(),
));
}
}
stats.produced += 1;
}
self.buffer.close()?;
loop {
let remaining = wall_deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(IoError::Other(
"FlowController::run_with_timeout: deadline exceeded during consume phase"
.to_string(),
));
}
let pop_timeout = remaining.min(Duration::from_secs(1));
match self.buffer.pop_timeout(pop_timeout)? {
Some(item) => {
consumer(item)?;
stats.consumed += 1;
}
None => break,
}
}
stats.elapsed = start.elapsed();
Ok(stats)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc as StdArc;
use std::thread;
#[test]
fn test_bounded_buffer_basic() {
let buf = BoundedBuffer::<i32>::new(4);
buf.push(1).unwrap();
buf.push(2).unwrap();
assert_eq!(buf.len(), 2);
assert_eq!(buf.pop().unwrap(), Some(1));
assert_eq!(buf.pop().unwrap(), Some(2));
buf.close().unwrap();
assert_eq!(buf.pop().unwrap(), None);
}
#[test]
fn test_bounded_buffer_capacity_blocking() {
let buf = BoundedBuffer::<i32>::new(2);
buf.push(10).unwrap();
buf.push(20).unwrap();
let result = buf.push_timeout(30, Duration::from_millis(20)).unwrap();
assert!(!result, "should time out when buffer is full");
}
#[test]
fn test_bounded_buffer_producer_consumer() {
let buf = BoundedBuffer::<i32>::new(8);
let buf_prod = StdArc::clone(&buf);
let buf_cons = StdArc::clone(&buf);
let produced = StdArc::new(AtomicUsize::new(0));
let consumed = StdArc::new(AtomicUsize::new(0));
let prod_count = StdArc::clone(&produced);
let cons_count = StdArc::clone(&consumed);
let producer = thread::spawn(move || {
for i in 0..50i32 {
buf_prod.push(i).unwrap();
prod_count.fetch_add(1, Ordering::Relaxed);
}
buf_prod.close().unwrap();
});
let consumer = thread::spawn(move || {
while let Ok(Some(_)) = buf_cons.pop() {
cons_count.fetch_add(1, Ordering::Relaxed);
}
});
producer.join().unwrap();
consumer.join().unwrap();
assert_eq!(produced.load(Ordering::Relaxed), 50);
assert_eq!(consumed.load(Ordering::Relaxed), 50);
}
#[test]
fn test_bounded_buffer_try_pop() {
let buf = BoundedBuffer::<i32>::new(4);
assert_eq!(buf.try_pop().unwrap(), None);
buf.push(42).unwrap();
assert_eq!(buf.try_pop().unwrap(), Some(42));
assert_eq!(buf.try_pop().unwrap(), None);
}
#[test]
fn test_throttle_high_rate() {
let throttle = ThrottleTransform::new(1_000.0);
let items = vec![1i32, 2, 3, 4, 5];
let result = throttle.throttle_batch(items).unwrap();
assert_eq!(result, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_throttle_acquire_does_not_error() {
let throttle = ThrottleTransform::new(100.0);
for _ in 0..5 {
throttle.acquire().unwrap();
}
}
#[test]
fn test_batch_collector_collects_full_batch() {
let collector = BatchCollector::<i32>::new(4, Duration::from_millis(200), 16);
for i in 0..4 {
collector.push(i).unwrap();
}
let batch = collector.collect_batch().unwrap();
assert_eq!(batch.len(), 4);
}
#[test]
fn test_batch_collector_timeout_flush() {
let collector = BatchCollector::<i32>::new(10, Duration::from_millis(30), 20);
collector.push(1).unwrap();
collector.push(2).unwrap();
let batch = collector.collect_batch().unwrap();
assert_eq!(batch.len(), 2);
}
#[test]
fn test_batch_collector_close_drains() {
let collector = BatchCollector::<i32>::new(5, Duration::from_millis(200), 20);
collector.push(1).unwrap();
collector.push(2).unwrap();
collector.close().unwrap();
let batch = collector.collect_batch().unwrap();
let _rest = collector.collect_batch().unwrap();
let empty = collector.collect_batch().unwrap();
assert_eq!(empty.len(), 0);
}
#[test]
fn test_flow_controller_basic() {
let fc = FlowController::<i32>::new(16);
let mut counter = 0i32;
let mut results = Vec::new();
let t0 = Instant::now();
let stats = fc
.run(
|| {
if counter < 10 {
counter += 1;
Some(counter)
} else {
None
}
},
|item| {
results.push(item);
Ok(())
},
)
.expect("FlowController::run failed");
let elapsed = t0.elapsed();
assert!(
elapsed < Duration::from_secs(5),
"test_flow_controller_basic took {elapsed:?}; expected < 5 s (possible deadlock)"
);
assert_eq!(stats.produced, 10);
assert_eq!(stats.consumed, 10);
assert_eq!(results.len(), 10);
}
#[test]
fn test_flow_controller_with_timeout() {
let fc = FlowController::<i32>::new(16);
let mut counter = 0i32;
let mut results = Vec::new();
let stats = fc
.run_with_timeout(
Duration::from_secs(5),
|| {
if counter < 10 {
counter += 1;
Some(counter)
} else {
None
}
},
|item| {
results.push(item);
Ok(())
},
)
.expect("run_with_timeout failed");
assert_eq!(stats.produced, 10);
assert_eq!(stats.consumed, 10);
assert_eq!(results.len(), 10);
assert!(
stats.elapsed < Duration::from_secs(5),
"run_with_timeout elapsed {:?}; expected < 5 s",
stats.elapsed
);
}
#[test]
fn test_flow_controller_timeout_triggers() {
let fc = FlowController::<i32>::new(1);
let mut counter = 0i32;
let result = fc.run_with_timeout(
Duration::from_millis(1),
|| {
if counter < 3 {
counter += 1;
Some(counter)
} else {
None
}
},
|_item| Ok(()),
);
match result {
Err(_) => { }
Ok(stats) => {
assert!(
stats.elapsed < Duration::from_secs(5),
"run completed but took too long: {:?}",
stats.elapsed
);
}
}
}
}