use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use futures::stream::{self, StreamExt, TryStreamExt};
use snapdir_core::store::StoreError;
use tokio::sync::Mutex;
use crate::adaptive::{AdaptiveGate, OpResult};
const DEFAULT_CONCURRENCY_CAP: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum AdaptivePolicy {
#[default]
Off,
On {
fraction: f64,
ceiling: usize,
},
}
#[derive(Debug, Clone)]
pub struct TransferConfig {
pub concurrency: NonZeroUsize,
pub max_bytes_per_sec: Option<u64>,
pub adaptive: AdaptivePolicy,
}
impl TransferConfig {
#[must_use]
pub fn new(concurrency: usize, max_bytes_per_sec: Option<u64>) -> Self {
Self {
concurrency: NonZeroUsize::new(concurrency.max(1)).unwrap_or(NonZeroUsize::MIN),
max_bytes_per_sec,
adaptive: AdaptivePolicy::Off,
}
}
#[must_use]
pub fn with_adaptive(mut self, policy: AdaptivePolicy) -> Self {
self.adaptive = policy;
self
}
}
impl Default for TransferConfig {
fn default() -> Self {
let detected = std::thread::available_parallelism()
.map_or(1, NonZeroUsize::get)
.clamp(1, DEFAULT_CONCURRENCY_CAP);
Self {
concurrency: NonZeroUsize::new(detected).unwrap_or(NonZeroUsize::MIN),
max_bytes_per_sec: None,
adaptive: AdaptivePolicy::Off,
}
}
}
#[must_use]
pub fn classify_error(err: &StoreError) -> OpResult {
match err {
StoreError::Io(io_err) => classify_io_kind(io_err),
StoreError::Backend { message, source } => {
let mut text = message.to_ascii_lowercase();
if let Some(src) = source {
text.push(' ');
text.push_str(&src.to_string().to_ascii_lowercase());
}
if text_is_transient(&text) {
OpResult::Throttle
} else {
OpResult::HardErr
}
}
_ => OpResult::HardErr,
}
}
fn classify_io_kind(err: &std::io::Error) -> OpResult {
use std::io::ErrorKind;
match err.kind() {
ErrorKind::WouldBlock | ErrorKind::StorageFull => OpResult::Throttle,
_ => {
if err
.to_string()
.to_ascii_lowercase()
.contains("too many open files")
{
OpResult::Throttle
} else {
OpResult::HardErr
}
}
}
}
fn text_is_transient(text: &str) -> bool {
const TRANSIENT: &[&str] = &[
"slowdown",
"slow down",
"429",
"too many requests",
"503",
"service unavailable",
"serviceunavailable",
"resource_exhausted",
"resource exhausted",
"throttl",
"request timeout",
"requesttimeout",
"timed out",
"timeout",
"connection reset",
"connection closed",
"connection refused",
"broken pipe",
"too many open files",
];
TRANSIENT.iter().any(|needle| text.contains(needle))
}
pub async fn run_adaptive<I, T, F, Fut>(
items: I,
gate: &AdaptiveGate,
op: F,
) -> Result<Vec<T>, StoreError>
where
I: IntoIterator,
F: Fn(I::Item) -> Fut,
Fut: std::future::Future<Output = Result<T, StoreError>>,
{
let window = gate.ceiling().max(1);
stream::iter(items)
.map(|item| {
let op = &op;
async move {
let _permit = gate.acquire().await;
op(item).await
}
})
.buffer_unordered(window)
.try_collect()
.await
}
#[derive(Debug)]
struct Bucket {
rate: f64,
capacity: f64,
tokens: f64,
last_refill: tokio::time::Instant,
}
#[derive(Debug)]
struct Inner {
bucket: Mutex<Bucket>,
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
inner: Arc<Inner>,
}
impl RateLimiter {
#[must_use]
pub fn new(max_bytes_per_sec: Option<u64>) -> Self {
#[allow(clippy::cast_precision_loss)]
let (rate, capacity, tokens) = match max_bytes_per_sec {
Some(r) if r > 0 => {
let r = r as f64;
(r, r, r)
}
_ => (0.0, 0.0, 0.0),
};
Self {
inner: Arc::new(Inner {
bucket: Mutex::new(Bucket {
rate,
capacity,
tokens,
last_refill: tokio::time::Instant::now(),
}),
}),
}
}
pub async fn set_rate(&self, bytes_per_sec: Option<u64>) {
let mut state = self.inner.bucket.lock().await;
let was_unlimited = state.rate <= 0.0;
#[allow(clippy::cast_precision_loss)]
match bytes_per_sec {
Some(r) if r > 0 => {
let r = r as f64;
state.rate = r;
state.capacity = r;
if was_unlimited {
state.tokens = r;
} else {
state.tokens = state.tokens.min(r);
}
state.last_refill = tokio::time::Instant::now();
}
_ => {
state.rate = 0.0;
state.capacity = 0.0;
state.tokens = 0.0;
}
}
}
pub async fn acquire(&self, n: u64) {
if n == 0 {
return;
}
#[allow(clippy::cast_precision_loss)]
let need = n as f64;
loop {
let wait = {
let mut state = self.inner.bucket.lock().await;
if state.rate <= 0.0 {
return; }
let now = tokio::time::Instant::now();
let elapsed = now.duration_since(state.last_refill).as_secs_f64();
state.tokens = (state.tokens + elapsed * state.rate).min(state.capacity);
state.last_refill = now;
if state.tokens >= need {
state.tokens -= need;
return;
}
let deficit = need - state.tokens;
deficit / state.rate
};
tokio::time::sleep(Duration::from_secs_f64(wait)).await;
}
}
}
#[derive(Debug)]
struct BlockingBucket {
rate: f64,
capacity: f64,
tokens: f64,
last_refill: std::time::Instant,
}
#[derive(Debug)]
struct BlockingInner {
bucket: std::sync::Mutex<BlockingBucket>,
}
#[derive(Debug, Clone)]
pub struct BlockingRateLimiter {
inner: Arc<BlockingInner>,
}
impl BlockingRateLimiter {
#[must_use]
pub fn new(max_bytes_per_sec: Option<u64>) -> Self {
#[allow(clippy::cast_precision_loss)]
let (rate, capacity, tokens) = match max_bytes_per_sec {
Some(r) if r > 0 => {
let r = r as f64;
(r, r, r)
}
_ => (0.0, 0.0, 0.0),
};
Self {
inner: Arc::new(BlockingInner {
bucket: std::sync::Mutex::new(BlockingBucket {
rate,
capacity,
tokens,
last_refill: std::time::Instant::now(),
}),
}),
}
}
pub fn set_rate(&self, bytes_per_sec: Option<u64>) {
let mut state = self
.inner
.bucket
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let was_unlimited = state.rate <= 0.0;
#[allow(clippy::cast_precision_loss)]
match bytes_per_sec {
Some(r) if r > 0 => {
let r = r as f64;
state.rate = r;
state.capacity = r;
if was_unlimited {
state.tokens = r;
} else {
state.tokens = state.tokens.min(r);
}
state.last_refill = std::time::Instant::now();
}
_ => {
state.rate = 0.0;
state.capacity = 0.0;
state.tokens = 0.0;
}
}
}
pub fn acquire_blocking(&self, n: u64) {
if n == 0 {
return;
}
#[allow(clippy::cast_precision_loss)]
let need = n as f64;
loop {
let wait = {
let mut state = self
.inner
.bucket
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if state.rate <= 0.0 {
return; }
let now = std::time::Instant::now();
let elapsed = now.duration_since(state.last_refill).as_secs_f64();
state.tokens = (state.tokens + elapsed * state.rate).min(state.capacity);
state.last_refill = now;
if state.tokens >= need {
state.tokens -= need;
return;
}
let deficit = need - state.tokens;
deficit / state.rate
};
std::thread::sleep(Duration::from_secs_f64(wait));
}
}
}
pub async fn run_concurrent<I, T, F, Fut>(
items: I,
concurrency: NonZeroUsize,
op: F,
) -> Result<Vec<T>, StoreError>
where
I: IntoIterator,
F: Fn(I::Item) -> Fut,
Fut: std::future::Future<Output = Result<T, StoreError>>,
{
stream::iter(items)
.map(op)
.buffer_unordered(concurrency.get())
.try_collect()
.await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
fn runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.expect("build tokio runtime")
}
#[test]
fn transfer_config_default_caps_concurrency() {
let cfg = TransferConfig::default();
assert!(cfg.concurrency.get() >= 1, "concurrency must be >= 1");
assert!(
cfg.concurrency.get() <= DEFAULT_CONCURRENCY_CAP,
"default concurrency must be capped at {DEFAULT_CONCURRENCY_CAP}, got {}",
cfg.concurrency.get()
);
assert_eq!(cfg.max_bytes_per_sec, None);
assert_eq!(TransferConfig::new(0, None).concurrency.get(), 1);
assert_eq!(TransferConfig::new(7, Some(99)).concurrency.get(), 7);
assert_eq!(TransferConfig::new(7, Some(99)).max_bytes_per_sec, Some(99));
}
fn max_in_flight_for(concurrency: usize, items: usize) -> usize {
let in_flight = Arc::new(AtomicUsize::new(0));
let high_water = Arc::new(AtomicUsize::new(0));
let rt = runtime();
let result = rt.block_on(async {
let in_flight = Arc::clone(&in_flight);
let high_water = Arc::clone(&high_water);
run_concurrent(
0..items,
NonZeroUsize::new(concurrency).unwrap(),
move |_item| {
let in_flight = Arc::clone(&in_flight);
let high_water = Arc::clone(&high_water);
async move {
let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
high_water.fetch_max(cur, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(20)).await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok::<_, StoreError>(())
}
},
)
.await
});
assert!(result.is_ok());
high_water.load(Ordering::SeqCst)
}
#[test]
fn transfer_config_run_concurrent_max_in_flight() {
assert_eq!(max_in_flight_for(4, 12), 4);
assert_eq!(max_in_flight_for(1, 5), 1);
assert_eq!(max_in_flight_for(8, 3), 3);
}
#[test]
fn transfer_config_run_concurrent_propagates_error() {
let rt = runtime();
let result: Result<Vec<()>, StoreError> = rt.block_on(async {
run_concurrent(0..10, NonZeroUsize::new(3).unwrap(), |item| async move {
if item == 5 {
Err(StoreError::Backend {
message: "boom".to_owned(),
source: None,
})
} else {
tokio::time::sleep(Duration::from_millis(5)).await;
Ok(())
}
})
.await
});
let err = result.expect_err("must surface the failing op's error");
assert!(
matches!(err, StoreError::Backend { ref message, .. } if message == "boom"),
"unexpected error: {err:?}"
);
}
#[test]
fn sync_snapshot_blocking_rate_limiter() {
use std::time::Instant;
let unlimited = BlockingRateLimiter::new(None);
let start = Instant::now();
unlimited.acquire_blocking(1_000_000);
assert!(
start.elapsed() < Duration::from_millis(200),
"unlimited acquire_blocking should not block"
);
let zero = BlockingRateLimiter::new(Some(0));
let start = Instant::now();
zero.acquire_blocking(1_000_000);
assert!(
start.elapsed() < Duration::from_millis(200),
"Some(0) acquire_blocking should not block"
);
let limiter = BlockingRateLimiter::new(Some(1000));
let start = Instant::now();
limiter.acquire_blocking(1000); limiter.acquire_blocking(1000); let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(900),
"throttled acquire_blocking should take ~1s, took {elapsed:?}"
);
}
#[test]
fn transfer_config_rate_limiter_set_rate_live() {
let rt = runtime();
rt.block_on(async {
let limiter = RateLimiter::new(None);
let start = tokio::time::Instant::now();
limiter.acquire(1_000_000).await;
assert!(
start.elapsed() < Duration::from_millis(200),
"unlimited acquire should not block before set_rate"
);
limiter.set_rate(Some(1000)).await;
let start = tokio::time::Instant::now();
limiter.acquire(1000).await; limiter.acquire(1000).await; let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(900),
"after set_rate(Some(1000)) a 2x-budget acquire should take ~1s, took {elapsed:?}"
);
limiter.set_rate(None).await;
let start = tokio::time::Instant::now();
limiter.acquire(1_000_000).await;
assert!(
start.elapsed() < Duration::from_millis(200),
"after set_rate(None) acquire should no longer block"
);
});
}
#[test]
fn sync_snapshot_blocking_rate_limiter_set_rate_live() {
use std::time::Instant;
let limiter = BlockingRateLimiter::new(None);
let start = Instant::now();
limiter.acquire_blocking(1_000_000);
assert!(
start.elapsed() < Duration::from_millis(200),
"unlimited acquire_blocking should not block before set_rate"
);
limiter.set_rate(Some(1000));
let start = Instant::now();
limiter.acquire_blocking(1000); limiter.acquire_blocking(1000); let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(900),
"after set_rate(Some(1000)) a 2x-budget acquire should take ~1s, took {elapsed:?}"
);
limiter.set_rate(Some(0));
let start = Instant::now();
limiter.acquire_blocking(1_000_000);
assert!(
start.elapsed() < Duration::from_millis(200),
"after set_rate(Some(0)) acquire_blocking should no longer block"
);
}
#[test]
fn classify_error_throttle_vs_hard() {
use crate::adaptive::OpResult;
let transient_msgs = [
"S3 PUT object failed: SlowDown",
"got HTTP 503 Service Unavailable",
"rate limited: 429 Too Many Requests",
"RESOURCE_EXHAUSTED quota",
"request timeout while uploading",
"connection reset by peer",
"os error: too many open files",
];
for msg in transient_msgs {
let err = StoreError::Backend {
message: msg.to_owned(),
source: None,
};
assert_eq!(
classify_error(&err),
OpResult::Throttle,
"expected Throttle for {msg:?}"
);
}
let not_found = StoreError::ObjectNotFound {
checksum: "abc".to_owned(),
};
assert_eq!(classify_error(¬_found), OpResult::HardErr);
let integrity = StoreError::Integrity {
address: "x".to_owned(),
expected: "a".to_owned(),
actual: "b".to_owned(),
};
assert_eq!(classify_error(&integrity), OpResult::HardErr);
let other = StoreError::Backend {
message: "permission denied".to_owned(),
source: None,
};
assert_eq!(classify_error(&other), OpResult::HardErr);
let emfile = StoreError::Io(std::io::Error::other("too many open files (os error 24)"));
assert_eq!(classify_error(&emfile), OpResult::Throttle);
let would_block = StoreError::Io(std::io::Error::from(std::io::ErrorKind::WouldBlock));
assert_eq!(classify_error(&would_block), OpResult::Throttle);
let io_notfound = StoreError::Io(std::io::Error::from(std::io::ErrorKind::NotFound));
assert_eq!(classify_error(&io_notfound), OpResult::HardErr);
}
#[test]
fn run_adaptive_respects_gate_limit() {
use crate::adaptive::AdaptiveGate;
use std::sync::atomic::{AtomicUsize, Ordering};
let rt = runtime();
let gate = AdaptiveGate::new(2, 8);
let in_flight = Arc::new(AtomicUsize::new(0));
let high = Arc::new(AtomicUsize::new(0));
let in_flight2 = Arc::clone(&in_flight);
let high2 = Arc::clone(&high);
let result: Result<Vec<()>, StoreError> = rt.block_on(async move {
run_adaptive(0..20, &gate, move |_item| {
let in_flight = Arc::clone(&in_flight2);
let high = Arc::clone(&high2);
async move {
let cur = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
high.fetch_max(cur, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(15)).await;
in_flight.fetch_sub(1, Ordering::SeqCst);
Ok(())
}
})
.await
});
assert!(result.is_ok());
assert!(
high.load(Ordering::SeqCst) <= 2,
"effective concurrency must be gated to the limit, got {}",
high.load(Ordering::SeqCst)
);
}
#[test]
fn transfer_config_rate_limiter() {
let rt = runtime();
rt.block_on(async {
let unlimited = RateLimiter::new(None);
let start = tokio::time::Instant::now();
unlimited.acquire(1_000_000).await;
assert!(
start.elapsed() < Duration::from_millis(200),
"unlimited acquire should not block"
);
let limiter = RateLimiter::new(Some(1000));
let start = tokio::time::Instant::now();
limiter.acquire(1000).await; limiter.acquire(1000).await; let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(900),
"throttled acquire should take ~1s, took {elapsed:?}"
);
});
}
}