use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use futures::stream::{self, StreamExt, TryStreamExt};
use snapdir_core::store::StoreError;
use tokio::sync::Mutex;
const DEFAULT_CONCURRENCY_CAP: usize = 16;
#[derive(Debug, Clone)]
pub struct TransferConfig {
pub concurrency: NonZeroUsize,
pub max_bytes_per_sec: Option<u64>,
}
impl TransferConfig {
#[must_use]
pub fn new(concurrency: usize, max_bytes_per_sec: Option<u64>) -> Self {
Self {
concurrency: NonZeroUsize::new(concurrency.max(1)).unwrap_or(NonZeroUsize::MIN),
max_bytes_per_sec,
}
}
}
impl Default for TransferConfig {
fn default() -> Self {
let detected = std::thread::available_parallelism()
.map_or(1, NonZeroUsize::get)
.clamp(1, DEFAULT_CONCURRENCY_CAP);
Self {
concurrency: NonZeroUsize::new(detected).unwrap_or(NonZeroUsize::MIN),
max_bytes_per_sec: None,
}
}
}
#[derive(Debug)]
struct Bucket {
tokens: f64,
last_refill: tokio::time::Instant,
}
#[derive(Debug)]
struct Inner {
rate: f64,
capacity: f64,
bucket: Option<Mutex<Bucket>>,
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
inner: Arc<Inner>,
}
impl RateLimiter {
#[must_use]
pub fn new(max_bytes_per_sec: Option<u64>) -> Self {
let inner = match max_bytes_per_sec {
Some(rate) if rate > 0 => {
#[allow(clippy::cast_precision_loss)]
let rate = rate as f64;
Inner {
rate,
capacity: rate,
bucket: Some(Mutex::new(Bucket {
tokens: rate,
last_refill: tokio::time::Instant::now(),
})),
}
}
_ => Inner {
rate: 0.0,
capacity: 0.0,
bucket: None,
},
};
Self {
inner: Arc::new(inner),
}
}
pub async fn acquire(&self, n: u64) {
let Some(bucket) = self.inner.bucket.as_ref() else {
return; };
if n == 0 {
return;
}
#[allow(clippy::cast_precision_loss)]
let need = n as f64;
loop {
let wait = {
let mut state = bucket.lock().await;
let now = tokio::time::Instant::now();
let elapsed = now.duration_since(state.last_refill).as_secs_f64();
state.tokens = (state.tokens + elapsed * self.inner.rate).min(self.inner.capacity);
state.last_refill = now;
if state.tokens >= need {
state.tokens -= need;
return;
}
let deficit = need - state.tokens;
deficit / self.inner.rate
};
tokio::time::sleep(Duration::from_secs_f64(wait)).await;
}
}
}
pub async fn run_concurrent<I, T, F, Fut>(
items: I,
concurrency: NonZeroUsize,
op: F,
) -> Result<Vec<T>, StoreError>
where
I: IntoIterator,
F: Fn(I::Item) -> Fut,
Fut: std::future::Future<Output = Result<T, StoreError>>,
{
stream::iter(items)
.map(op)
.buffer_unordered(concurrency.get())
.try_collect()
.await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
fn runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.expect("build tokio runtime")
}
#[test]
fn transfer_config_default_caps_concurrency() {
let cfg = TransferConfig::default();
assert!(cfg.concurrency.get() >= 1, "concurrency must be >= 1");
assert!(
cfg.concurrency.get() <= DEFAULT_CONCURRENCY_CAP,
"default concurrency must be capped at {DEFAULT_CONCURRENCY_CAP}, got {}",
cfg.concurrency.get()
);
assert_eq!(cfg.max_bytes_per_sec, None);
assert_eq!(TransferConfig::new(0, None).concurrency.get(), 1);
assert_eq!(TransferConfig::new(7, Some(99)).concurrency.get(), 7);
assert_eq!(TransferConfig::new(7, Some(99)).max_bytes_per_sec, Some(99));
}
fn max_in_flight_for(concurrency: usize, items: usize) -> usize {
let in_flight = Arc::new(AtomicUsize::new(0));
let high_water = Arc::new(AtomicUsize::new(0));
let rt = runtime();
let result = rt.block_on(async {
let in_flight = Arc::clone(&in_flight);
let high_water = Arc::clone(&high_water);
run_concurrent(
0..items,
NonZeroUsize::new(concurrency).unwrap(),
move |_item| {
let in_flight = Arc::clone(&in_flight);
let high_water = Arc::clone(&high_water);
async move {
let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
high_water.fetch_max(cur, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(20)).await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok::<_, StoreError>(())
}
},
)
.await
});
assert!(result.is_ok());
high_water.load(Ordering::SeqCst)
}
#[test]
fn transfer_config_run_concurrent_max_in_flight() {
assert_eq!(max_in_flight_for(4, 12), 4);
assert_eq!(max_in_flight_for(1, 5), 1);
assert_eq!(max_in_flight_for(8, 3), 3);
}
#[test]
fn transfer_config_run_concurrent_propagates_error() {
let rt = runtime();
let result: Result<Vec<()>, StoreError> = rt.block_on(async {
run_concurrent(0..10, NonZeroUsize::new(3).unwrap(), |item| async move {
if item == 5 {
Err(StoreError::Backend {
message: "boom".to_owned(),
source: None,
})
} else {
tokio::time::sleep(Duration::from_millis(5)).await;
Ok(())
}
})
.await
});
let err = result.expect_err("must surface the failing op's error");
assert!(
matches!(err, StoreError::Backend { ref message, .. } if message == "boom"),
"unexpected error: {err:?}"
);
}
#[test]
fn transfer_config_rate_limiter() {
let rt = runtime();
rt.block_on(async {
let unlimited = RateLimiter::new(None);
let start = tokio::time::Instant::now();
unlimited.acquire(1_000_000).await;
assert!(
start.elapsed() < Duration::from_millis(200),
"unlimited acquire should not block"
);
let limiter = RateLimiter::new(Some(1000));
let start = tokio::time::Instant::now();
limiter.acquire(1000).await; limiter.acquire(1000).await; let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(900),
"throttled acquire should take ~1s, took {elapsed:?}"
);
});
}
}