use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::ops::Range;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt;
use futures::stream::BoxStream;
use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome};
use object_store::path::Path;
use object_store::{
GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
};
use rand::Rng;
use tokio::sync::Mutex;
use tracing::{debug, warn};
pub fn is_throttle_error(err: &object_store::Error) -> bool {
if let object_store::Error::Generic { source, .. } = err {
source.to_string().contains("retries, max_retries")
} else {
false
}
}
#[derive(Debug, Clone)]
pub struct AimdThrottleConfig {
pub read: AimdConfig,
pub write: AimdConfig,
pub delete: AimdConfig,
pub list: AimdConfig,
pub burst_capacity: u32,
pub max_retries: usize,
pub min_backoff_ms: u64,
pub max_backoff_ms: u64,
}
impl Default for AimdThrottleConfig {
fn default() -> Self {
let aimd = AimdConfig::default();
Self {
read: aimd.clone(),
write: aimd.clone(),
delete: aimd.clone(),
list: aimd,
burst_capacity: 100,
max_retries: 3,
min_backoff_ms: 100,
max_backoff_ms: 300,
}
}
}
impl AimdThrottleConfig {
pub fn with_aimd(self, aimd: AimdConfig) -> Self {
Self {
read: aimd.clone(),
write: aimd.clone(),
delete: aimd.clone(),
list: aimd,
..self
}
}
pub fn with_read_aimd(self, aimd: AimdConfig) -> Self {
Self { read: aimd, ..self }
}
pub fn with_write_aimd(self, aimd: AimdConfig) -> Self {
Self {
write: aimd,
..self
}
}
pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self {
Self {
delete: aimd,
..self
}
}
pub fn with_list_aimd(self, aimd: AimdConfig) -> Self {
Self { list: aimd, ..self }
}
pub fn is_disabled(&self) -> bool {
self.max_retries == 0
}
pub fn with_burst_capacity(self, burst_capacity: u32) -> Self {
Self {
burst_capacity,
..self
}
}
pub fn from_storage_options(
storage_options: Option<&HashMap<String, String>>,
) -> lance_core::Result<Self> {
fn resolve_f64(
key: &str,
storage_options: Option<&HashMap<String, String>>,
default: f64,
) -> lance_core::Result<f64> {
let env_key = key.to_ascii_uppercase();
if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
val.parse::<f64>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for storage option '{key}': '{val}'"
))
})
} else if let Ok(val) = std::env::var(&env_key) {
val.parse::<f64>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for env var '{env_key}': '{val}'"
))
})
} else {
Ok(default)
}
}
fn resolve_u32(
key: &str,
storage_options: Option<&HashMap<String, String>>,
default: u32,
) -> lance_core::Result<u32> {
let env_key = key.to_ascii_uppercase();
if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
val.parse::<u32>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for storage option '{key}': '{val}'"
))
})
} else if let Ok(val) = std::env::var(&env_key) {
val.parse::<u32>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for env var '{env_key}': '{val}'"
))
})
} else {
Ok(default)
}
}
fn resolve_usize(
key: &str,
storage_options: Option<&HashMap<String, String>>,
default: usize,
) -> lance_core::Result<usize> {
let env_key = key.to_ascii_uppercase();
if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
val.parse::<usize>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for storage option '{key}': '{val}'"
))
})
} else if let Ok(val) = std::env::var(&env_key) {
val.parse::<usize>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for env var '{env_key}': '{val}'"
))
})
} else {
Ok(default)
}
}
fn resolve_u64(
key: &str,
storage_options: Option<&HashMap<String, String>>,
default: u64,
) -> lance_core::Result<u64> {
let env_key = key.to_ascii_uppercase();
if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
val.parse::<u64>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for storage option '{key}': '{val}'"
))
})
} else if let Ok(val) = std::env::var(&env_key) {
val.parse::<u64>().map_err(|_| {
lance_core::Error::invalid_input(format!(
"Invalid value for env var '{env_key}': '{val}'"
))
})
} else {
Ok(default)
}
}
let initial_rate = resolve_f64("lance_aimd_initial_rate", storage_options, 2000.0)?;
let min_rate = resolve_f64("lance_aimd_min_rate", storage_options, 1.0)?;
let max_rate = resolve_f64("lance_aimd_max_rate", storage_options, 5000.0)?;
let decrease_factor = resolve_f64("lance_aimd_decrease_factor", storage_options, 0.5)?;
let additive_increment =
resolve_f64("lance_aimd_additive_increment", storage_options, 300.0)?;
let burst_capacity = resolve_u32("lance_aimd_burst_capacity", storage_options, 100)?;
let max_retries = resolve_usize("lance_aimd_max_retries", storage_options, 3)?;
let min_backoff_ms = resolve_u64("lance_aimd_min_backoff_ms", storage_options, 100)?;
let max_backoff_ms = resolve_u64("lance_aimd_max_backoff_ms", storage_options, 300)?;
let aimd = AimdConfig::default()
.with_initial_rate(initial_rate)
.with_min_rate(min_rate)
.with_max_rate(max_rate)
.with_decrease_factor(decrease_factor)
.with_additive_increment(additive_increment);
Ok(Self {
max_retries,
min_backoff_ms,
max_backoff_ms,
..Self::default()
.with_aimd(aimd)
.with_burst_capacity(burst_capacity)
})
}
}
struct TokenBucketState {
tokens: f64,
last_refill: std::time::Instant,
rate: f64,
}
struct OperationThrottle {
controller: AimdController,
bucket: Mutex<TokenBucketState>,
burst_capacity: f64,
max_retries: usize,
min_backoff_ms: u64,
max_backoff_ms: u64,
}
impl OperationThrottle {
fn new(
aimd_config: AimdConfig,
burst_capacity: f64,
max_retries: usize,
min_backoff_ms: u64,
max_backoff_ms: u64,
) -> lance_core::Result<Self> {
let initial_rate = aimd_config.initial_rate;
let controller = AimdController::new(aimd_config)?;
Ok(Self {
controller,
bucket: Mutex::new(TokenBucketState {
tokens: burst_capacity,
last_refill: std::time::Instant::now(),
rate: initial_rate,
}),
burst_capacity,
max_retries,
min_backoff_ms,
max_backoff_ms,
})
}
async fn acquire_token(&self) {
let sleep_duration = {
let mut bucket = self.bucket.lock().await;
let now = std::time::Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity);
bucket.last_refill = now;
bucket.tokens -= 1.0;
if bucket.tokens >= 0.0 {
return;
}
std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate)
};
tokio::time::sleep(sleep_duration).await;
}
async fn update_bucket_rate(&self, new_rate: f64) {
let mut bucket = self.bucket.lock().await;
bucket.rate = new_rate;
}
fn observe_outcome<T>(&self, result: &OSResult<T>) {
let outcome = match result {
Ok(_) => RequestOutcome::Success,
Err(err) if is_throttle_error(err) => {
debug!("Throttle error detected in stream");
RequestOutcome::Throttled
}
Err(_) => RequestOutcome::Success,
};
let prev_rate = self.controller.current_rate();
let new_rate = self.controller.record_outcome(outcome);
if new_rate < prev_rate {
warn!(
previous_rate = format!("{prev_rate:.1}"),
new_rate = format!("{new_rate:.1}"),
"AIMD throttle: rate reduced due to throttle errors"
);
}
if let Ok(mut bucket) = self.bucket.try_lock() {
bucket.rate = new_rate;
}
}
async fn throttled<T, F, Fut>(&self, f: F) -> OSResult<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = OSResult<T>>,
{
for attempt in 0..=self.max_retries {
self.acquire_token().await;
let result = f().await;
let outcome = match &result {
Ok(_) => RequestOutcome::Success,
Err(err) if is_throttle_error(err) => {
debug!("Throttle error detected");
RequestOutcome::Throttled
}
Err(_) => RequestOutcome::Success, };
let prev_rate = self.controller.current_rate();
let new_rate = self.controller.record_outcome(outcome);
if new_rate < prev_rate {
warn!(
previous_rate = format!("{prev_rate:.1}"),
new_rate = format!("{new_rate:.1}"),
"AIMD throttle: rate reduced due to throttle errors"
);
}
self.update_bucket_rate(new_rate).await;
match &result {
Err(err) if is_throttle_error(err) && attempt < self.max_retries => {
let backoff_ms =
rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms);
debug!(
attempt = attempt + 1,
max_retries = self.max_retries,
backoff_ms,
"Retrying after throttle error"
);
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
continue;
}
_ => return result,
}
}
unreachable!()
}
}
impl Debug for OperationThrottle {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OperationThrottle")
.field("controller", &self.controller)
.field("burst_capacity", &self.burst_capacity)
.finish()
}
}
struct ThrottledMultipartUpload {
target: Box<dyn MultipartUpload>,
write: Arc<OperationThrottle>,
}
impl Debug for ThrottledMultipartUpload {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThrottledMultipartUpload").finish()
}
}
#[async_trait]
impl MultipartUpload for ThrottledMultipartUpload {
fn put_part(&mut self, data: PutPayload) -> UploadPart {
let write = Arc::clone(&self.write);
let fut = self.target.put_part(data);
Box::pin(async move {
write.acquire_token().await;
let result = fut.await;
write.observe_outcome(&result);
result
})
}
async fn complete(&mut self) -> OSResult<PutResult> {
let target = &mut self.target;
for attempt in 0..=self.write.max_retries {
self.write.acquire_token().await;
let result = target.complete().await;
self.write.observe_outcome(&result);
match &result {
Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
let backoff_ms = rand::rng()
.random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
continue;
}
_ => return result,
}
}
unreachable!()
}
async fn abort(&mut self) -> OSResult<()> {
let target = &mut self.target;
for attempt in 0..=self.write.max_retries {
self.write.acquire_token().await;
let result = target.abort().await;
self.write.observe_outcome(&result);
match &result {
Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
let backoff_ms = rand::rng()
.random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
continue;
}
_ => return result,
}
}
unreachable!()
}
}
pub struct AimdThrottledStore {
target: Arc<dyn ObjectStore>,
read: Arc<OperationThrottle>,
write: Arc<OperationThrottle>,
delete: Arc<OperationThrottle>,
list: Arc<OperationThrottle>,
}
impl Debug for AimdThrottledStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AimdThrottledStore")
.field("target", &self.target)
.field("read", &self.read)
.field("write", &self.write)
.field("delete", &self.delete)
.field("list", &self.list)
.finish()
}
}
impl Display for AimdThrottledStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "AimdThrottledStore({})", self.target)
}
}
impl AimdThrottledStore {
pub fn new(
target: Arc<dyn ObjectStore>,
config: AimdThrottleConfig,
) -> lance_core::Result<Self> {
let burst = config.burst_capacity as f64;
let max_retries = config.max_retries;
let min_backoff_ms = config.min_backoff_ms;
let max_backoff_ms = config.max_backoff_ms;
Ok(Self {
target,
read: Arc::new(OperationThrottle::new(
config.read,
burst,
max_retries,
min_backoff_ms,
max_backoff_ms,
)?),
write: Arc::new(OperationThrottle::new(
config.write,
burst,
max_retries,
min_backoff_ms,
max_backoff_ms,
)?),
delete: Arc::new(OperationThrottle::new(
config.delete,
burst,
max_retries,
min_backoff_ms,
max_backoff_ms,
)?),
list: Arc::new(OperationThrottle::new(
config.list,
burst,
max_retries,
min_backoff_ms,
max_backoff_ms,
)?),
})
}
}
#[async_trait]
#[deny(clippy::missing_trait_methods)]
impl ObjectStore for AimdThrottledStore {
async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
self.write
.throttled(|| self.target.put(location, bytes.clone()))
.await
}
async fn put_opts(
&self,
location: &Path,
bytes: PutPayload,
opts: PutOptions,
) -> OSResult<PutResult> {
self.write
.throttled(|| self.target.put_opts(location, bytes.clone(), opts.clone()))
.await
}
async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
let target = self
.write
.throttled(|| self.target.put_multipart(location))
.await?;
Ok(Box::new(ThrottledMultipartUpload {
target,
write: Arc::clone(&self.write),
}))
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOptions,
) -> OSResult<Box<dyn MultipartUpload>> {
let target = self
.write
.throttled(|| self.target.put_multipart_opts(location, opts.clone()))
.await?;
Ok(Box::new(ThrottledMultipartUpload {
target,
write: Arc::clone(&self.write),
}))
}
async fn get(&self, location: &Path) -> OSResult<GetResult> {
self.read.throttled(|| self.target.get(location)).await
}
async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
self.read
.throttled(|| self.target.get_opts(location, options.clone()))
.await
}
async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
self.read
.throttled(|| self.target.get_range(location, range.clone()))
.await
}
async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
self.read
.throttled(|| self.target.get_ranges(location, ranges))
.await
}
async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
self.read.throttled(|| self.target.head(location)).await
}
async fn delete(&self, location: &Path) -> OSResult<()> {
self.delete.throttled(|| self.target.delete(location)).await
}
fn delete_stream<'a>(
&'a self,
locations: BoxStream<'a, OSResult<Path>>,
) -> BoxStream<'a, OSResult<Path>> {
self.target
.delete_stream(locations)
.map(|item| {
self.delete.observe_outcome(&item);
item
})
.boxed()
}
fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
let throttle = Arc::clone(&self.list);
self.target
.list(prefix)
.map(move |item| {
throttle.observe_outcome(&item);
item
})
.boxed()
}
fn list_with_offset(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> BoxStream<'static, OSResult<ObjectMeta>> {
let throttle = Arc::clone(&self.list);
self.target
.list_with_offset(prefix, offset)
.map(move |item| {
throttle.observe_outcome(&item);
item
})
.boxed()
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
self.list
.throttled(|| self.target.list_with_delimiter(prefix))
.await
}
async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
self.write.throttled(|| self.target.copy(from, to)).await
}
async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
self.write.throttled(|| self.target.rename(from, to)).await
}
async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.write
.throttled(|| self.target.rename_if_not_exists(from, to))
.await
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.write
.throttled(|| self.target.copy_if_not_exists(from, to))
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use object_store::memory::InMemory;
use rstest::rstest;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
fn make_generic_error(msg: &str) -> object_store::Error {
object_store::Error::Generic {
store: "test",
source: msg.into(),
}
}
#[rstest]
#[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)]
#[case::retries_in_message(
"request failed, after 3 retries, max_retries: 5, retry_timeout: 60s",
true
)]
#[case::not_found("Object not found", false)]
#[case::permission_denied("Access denied", false)]
#[case::timeout("Connection timed out", false)]
#[case::http_429_without_retries("HTTP 429 Too Many Requests", false)]
#[case::slowdown_without_retries("SlowDown: Please reduce your request rate", false)]
fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) {
let err = make_generic_error(msg);
assert_eq!(
is_throttle_error(&err),
expected,
"is_throttle_error for '{}' should be {}",
msg,
expected
);
}
#[test]
fn test_non_generic_errors_are_not_throttle() {
let err = object_store::Error::NotFound {
path: "test".to_string(),
source: "not found".into(),
};
assert!(!is_throttle_error(&err));
}
#[tokio::test]
async fn test_basic_put_get_through_wrapper() {
let store = Arc::new(InMemory::new());
let config = AimdThrottleConfig::default();
let throttled = AimdThrottledStore::new(store, config).unwrap();
let path = Path::from("test/file.txt");
let data = PutPayload::from_static(b"hello world");
throttled.put(&path, data).await.unwrap();
let result = throttled.get(&path).await.unwrap();
let bytes = result.bytes().await.unwrap();
assert_eq!(bytes.as_ref(), b"hello world");
}
#[tokio::test]
async fn test_rate_decreases_on_throttle() {
let store = Arc::new(InMemory::new());
let config = AimdThrottleConfig::default().with_aimd(
AimdConfig::default()
.with_initial_rate(100.0)
.with_decrease_factor(0.5)
.with_window_duration(std::time::Duration::from_millis(10)),
);
let throttled = AimdThrottledStore::new(store, config).unwrap();
let initial_rate = throttled.read.controller.current_rate();
assert_eq!(initial_rate, 100.0);
throttled
.read
.controller
.record_outcome(RequestOutcome::Throttled);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
throttled
.read
.controller
.record_outcome(RequestOutcome::Success);
let new_rate = throttled.read.controller.current_rate();
assert!(
new_rate < initial_rate,
"Rate should decrease after throttle: {} < {}",
new_rate,
initial_rate
);
}
#[tokio::test]
async fn test_rate_recovers_on_success() {
let store = Arc::new(InMemory::new());
let config = AimdThrottleConfig::default().with_aimd(
AimdConfig::default()
.with_initial_rate(100.0)
.with_decrease_factor(0.5)
.with_additive_increment(10.0)
.with_window_duration(std::time::Duration::from_millis(10)),
);
let throttled = AimdThrottledStore::new(store, config).unwrap();
throttled
.read
.controller
.record_outcome(RequestOutcome::Throttled);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
throttled
.read
.controller
.record_outcome(RequestOutcome::Success);
let decreased_rate = throttled.read.controller.current_rate();
assert_eq!(decreased_rate, 50.0);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
throttled
.read
.controller
.record_outcome(RequestOutcome::Success);
let recovered_rate = throttled.read.controller.current_rate();
assert_eq!(recovered_rate, 60.0);
}
#[tokio::test]
async fn test_as_dyn_object_store() {
let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
let throttled: Arc<dyn ObjectStore> =
Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap());
let path = Path::from("test/data.bin");
let data = PutPayload::from_static(b"test data");
throttled.put(&path, data).await.unwrap();
let result = throttled.get(&path).await.unwrap();
let bytes = result.bytes().await.unwrap();
assert_eq!(bytes.as_ref(), b"test data");
}
#[tokio::test]
async fn test_token_bucket_delays_when_exhausted() {
let store = Arc::new(InMemory::new());
let config = AimdThrottleConfig::default()
.with_burst_capacity(1)
.with_aimd(AimdConfig::default().with_initial_rate(10.0));
let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap());
let path = Path::from("test/file.txt");
let data = PutPayload::from_static(b"data");
throttled.put(&path, data).await.unwrap();
let start = std::time::Instant::now();
let data2 = PutPayload::from_static(b"data2");
throttled.put(&path, data2).await.unwrap();
let elapsed = start.elapsed();
assert!(
elapsed >= std::time::Duration::from_millis(50),
"Expected delay for token refill, but elapsed was {:?}",
elapsed
);
}
#[tokio::test]
async fn test_list_observes_outcomes() {
let store = Arc::new(InMemory::new());
let config = AimdThrottleConfig::default();
let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
let path = Path::from("prefix/file.txt");
let data = PutPayload::from_static(b"data");
store.put(&path, data).await.unwrap();
let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
assert_eq!(items.len(), 1);
assert!(items[0].is_ok());
}
struct ThrottlingListMockStore {
inner: InMemory,
throttle_count: usize,
}
impl Display for ThrottlingListMockStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ThrottlingListMockStore")
}
}
impl Debug for ThrottlingListMockStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThrottlingListMockStore").finish()
}
}
#[async_trait]
impl ObjectStore for ThrottlingListMockStore {
async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
self.inner.put(location, bytes).await
}
async fn put_opts(
&self,
location: &Path,
bytes: PutPayload,
opts: PutOptions,
) -> OSResult<PutResult> {
self.inner.put_opts(location, bytes, opts).await
}
async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
self.inner.put_multipart(location).await
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOptions,
) -> OSResult<Box<dyn MultipartUpload>> {
self.inner.put_multipart_opts(location, opts).await
}
async fn get(&self, location: &Path) -> OSResult<GetResult> {
self.inner.get(location).await
}
async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
self.inner.get_opts(location, options).await
}
async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
self.inner.get_range(location, range).await
}
async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
self.inner.get_ranges(location, ranges).await
}
async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
self.inner.head(location).await
}
async fn delete(&self, location: &Path) -> OSResult<()> {
self.inner.delete(location).await
}
fn delete_stream<'a>(
&'a self,
locations: BoxStream<'a, OSResult<Path>>,
) -> BoxStream<'a, OSResult<Path>> {
self.inner.delete_stream(locations)
}
fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
let n = self.throttle_count;
let inner_stream = self.inner.list(prefix);
let errors = futures::stream::iter((0..n).map(|_| {
Err(object_store::Error::Generic {
store: "ThrottlingListMock",
source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s"
.into(),
})
}));
errors.chain(inner_stream).boxed()
}
fn list_with_offset(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> BoxStream<'static, OSResult<ObjectMeta>> {
self.inner.list_with_offset(prefix, offset)
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
self.inner.list_with_delimiter(prefix).await
}
async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.copy(from, to).await
}
async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.rename(from, to).await
}
async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.rename_if_not_exists(from, to).await
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.copy_if_not_exists(from, to).await
}
}
#[tokio::test]
async fn test_list_stream_throttle_errors_decrease_rate() {
let mock = Arc::new(ThrottlingListMockStore {
inner: InMemory::new(),
throttle_count: 5,
});
mock.put(
&Path::from("prefix/file.txt"),
PutPayload::from_static(b"data"),
)
.await
.unwrap();
let config = AimdThrottleConfig::default().with_list_aimd(
AimdConfig::default()
.with_initial_rate(100.0)
.with_decrease_factor(0.5)
.with_window_duration(std::time::Duration::from_millis(10)),
);
let throttled = AimdThrottledStore::new(mock as Arc<dyn ObjectStore>, config).unwrap();
let initial_rate = throttled.list.controller.current_rate();
assert_eq!(initial_rate, 100.0);
let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
assert_eq!(items.len(), 6);
assert!(items[0].is_err());
assert!(items[5].is_ok());
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
throttled
.list
.controller
.record_outcome(RequestOutcome::Success);
let new_rate = throttled.list.controller.current_rate();
assert!(
new_rate < initial_rate,
"List rate should decrease after stream throttle errors: {} < {}",
new_rate,
initial_rate
);
}
#[tokio::test]
async fn test_per_category_independence() {
let store = Arc::new(InMemory::new());
let config = AimdThrottleConfig::default().with_aimd(
AimdConfig::default()
.with_initial_rate(100.0)
.with_decrease_factor(0.5)
.with_window_duration(std::time::Duration::from_millis(10)),
);
let throttled = AimdThrottledStore::new(store, config).unwrap();
throttled
.read
.controller
.record_outcome(RequestOutcome::Throttled);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
throttled
.read
.controller
.record_outcome(RequestOutcome::Success);
let read_rate = throttled.read.controller.current_rate();
let write_rate = throttled.write.controller.current_rate();
let delete_rate = throttled.delete.controller.current_rate();
let list_rate = throttled.list.controller.current_rate();
assert_eq!(read_rate, 50.0, "Read rate should have decreased");
assert_eq!(write_rate, 100.0, "Write rate should be unaffected");
assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected");
assert_eq!(list_rate, 100.0, "List rate should be unaffected");
}
#[tokio::test]
async fn test_per_category_config() {
let store = Arc::new(InMemory::new());
let config = AimdThrottleConfig::default()
.with_read_aimd(AimdConfig::default().with_initial_rate(200.0))
.with_write_aimd(AimdConfig::default().with_initial_rate(100.0))
.with_delete_aimd(AimdConfig::default().with_initial_rate(50.0))
.with_list_aimd(AimdConfig::default().with_initial_rate(25.0));
let throttled = AimdThrottledStore::new(store, config).unwrap();
assert_eq!(throttled.read.controller.current_rate(), 200.0);
assert_eq!(throttled.write.controller.current_rate(), 100.0);
assert_eq!(throttled.delete.controller.current_rate(), 50.0);
assert_eq!(throttled.list.controller.current_rate(), 25.0);
}
struct RateLimitingMockStore {
inner: InMemory,
timestamps: std::sync::Mutex<VecDeque<std::time::Instant>>,
max_per_window: usize,
window: std::time::Duration,
success_count: AtomicU64,
throttle_count: AtomicU64,
}
impl RateLimitingMockStore {
fn new(max_per_window: usize, window: std::time::Duration) -> Self {
Self {
inner: InMemory::new(),
timestamps: std::sync::Mutex::new(VecDeque::new()),
max_per_window,
window,
success_count: AtomicU64::new(0),
throttle_count: AtomicU64::new(0),
}
}
fn check_rate(&self) -> bool {
let mut ts = self.timestamps.lock().unwrap();
let now = std::time::Instant::now();
while let Some(&front) = ts.front() {
if now.duration_since(front) > self.window {
ts.pop_front();
} else {
break;
}
}
if ts.len() >= self.max_per_window {
self.throttle_count.fetch_add(1, Ordering::Relaxed);
false
} else {
ts.push_back(now);
self.success_count.fetch_add(1, Ordering::Relaxed);
true
}
}
fn throttle_error() -> object_store::Error {
object_store::Error::Generic {
store: "RateLimitingMock",
source: "request failed, after 10 retries, max_retries: 10, retry_timeout: 180s"
.into(),
}
}
}
impl Display for RateLimitingMockStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "RateLimitingMockStore")
}
}
impl Debug for RateLimitingMockStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimitingMockStore").finish()
}
}
#[async_trait]
impl ObjectStore for RateLimitingMockStore {
async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
self.inner.put(location, bytes).await
}
async fn put_opts(
&self,
location: &Path,
bytes: PutPayload,
opts: PutOptions,
) -> OSResult<PutResult> {
self.inner.put_opts(location, bytes, opts).await
}
async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
self.inner.put_multipart(location).await
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOptions,
) -> OSResult<Box<dyn MultipartUpload>> {
self.inner.put_multipart_opts(location, opts).await
}
async fn get(&self, location: &Path) -> OSResult<GetResult> {
if self.check_rate() {
self.inner.get(location).await
} else {
Err(Self::throttle_error())
}
}
async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
if self.check_rate() {
self.inner.get_opts(location, options).await
} else {
Err(Self::throttle_error())
}
}
async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
if self.check_rate() {
self.inner.get_range(location, range).await
} else {
Err(Self::throttle_error())
}
}
async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
if self.check_rate() {
self.inner.get_ranges(location, ranges).await
} else {
Err(Self::throttle_error())
}
}
async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
if self.check_rate() {
self.inner.head(location).await
} else {
Err(Self::throttle_error())
}
}
async fn delete(&self, location: &Path) -> OSResult<()> {
self.inner.delete(location).await
}
fn delete_stream<'a>(
&'a self,
locations: BoxStream<'a, OSResult<Path>>,
) -> BoxStream<'a, OSResult<Path>> {
self.inner.delete_stream(locations)
}
fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
self.inner.list(prefix)
}
fn list_with_offset(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> BoxStream<'static, OSResult<ObjectMeta>> {
self.inner.list_with_offset(prefix, offset)
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
self.inner.list_with_delimiter(prefix).await
}
async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.copy(from, to).await
}
async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.rename(from, to).await
}
async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.rename_if_not_exists(from, to).await
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.copy_if_not_exists(from, to).await
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_aimd_throttle_under_concurrent_load() {
let mock = Arc::new(RateLimitingMockStore::new(
30,
std::time::Duration::from_millis(100),
));
let path = Path::from("test/data.bin");
mock.put(&path, PutPayload::from_static(b"test data"))
.await
.unwrap();
let aimd = AimdConfig::default()
.with_initial_rate(100.0)
.with_decrease_factor(0.5)
.with_additive_increment(2.0)
.with_window_duration(std::time::Duration::from_millis(100));
let throttle_config = AimdThrottleConfig::default()
.with_aimd(aimd)
.with_burst_capacity(100);
let num_readers = 5;
let test_duration = std::time::Duration::from_secs(2);
let mut handles = Vec::new();
for _ in 0..num_readers {
let store = Arc::new(
AimdThrottledStore::new(
mock.clone() as Arc<dyn ObjectStore>,
throttle_config.clone(),
)
.unwrap(),
);
let p = path.clone();
handles.push(tokio::spawn(async move {
let deadline = std::time::Instant::now() + test_duration;
let mut count = 0u64;
while std::time::Instant::now() < deadline {
let _ = store.head(&p).await;
count += 1;
}
count
}));
}
let mut total_reader_requests = 0u64;
for handle in handles {
total_reader_requests += handle.await.unwrap();
}
let successes = mock.success_count.load(Ordering::Relaxed);
let throttled = mock.throttle_count.load(Ordering::Relaxed);
let total_mock = successes + throttled;
assert!(
total_mock >= total_reader_requests,
"Mock-side count ({total_mock}) should be >= reader-side count ({total_reader_requests})"
);
assert!(
successes >= 300,
"Expected >= 300 successful requests over 2s, got {successes}"
);
assert!(
successes <= 900,
"Expected <= 900 successful requests, got {successes}"
);
assert!(throttled > 0, "Expected some throttled requests but got 0");
assert!(
total_mock <= 5000,
"AIMD should limit total requests, got {total_mock}"
);
}
struct RetryTestMockStore {
inner: InMemory,
errors_remaining: std::sync::Mutex<usize>,
get_call_count: AtomicU64,
}
impl RetryTestMockStore {
fn new(errors_before_success: usize) -> Self {
Self {
inner: InMemory::new(),
errors_remaining: std::sync::Mutex::new(errors_before_success),
get_call_count: AtomicU64::new(0),
}
}
}
impl Display for RetryTestMockStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "RetryTestMockStore")
}
}
impl Debug for RetryTestMockStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryTestMockStore").finish()
}
}
#[async_trait]
impl ObjectStore for RetryTestMockStore {
async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
self.inner.put(location, bytes).await
}
async fn put_opts(
&self,
location: &Path,
bytes: PutPayload,
opts: PutOptions,
) -> OSResult<PutResult> {
self.inner.put_opts(location, bytes, opts).await
}
async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
self.inner.put_multipart(location).await
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOptions,
) -> OSResult<Box<dyn MultipartUpload>> {
self.inner.put_multipart_opts(location, opts).await
}
async fn get(&self, location: &Path) -> OSResult<GetResult> {
self.get_call_count.fetch_add(1, Ordering::Relaxed);
let should_error = {
let mut remaining = self.errors_remaining.lock().unwrap();
if *remaining > 0 {
*remaining -= 1;
true
} else {
false
}
};
if should_error {
Err(object_store::Error::Generic {
store: "RetryTestMock",
source: "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s"
.into(),
})
} else {
self.inner.get(location).await
}
}
async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
self.inner.get_opts(location, options).await
}
async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
self.inner.get_range(location, range).await
}
async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
self.inner.get_ranges(location, ranges).await
}
async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
self.inner.head(location).await
}
async fn delete(&self, location: &Path) -> OSResult<()> {
self.inner.delete(location).await
}
fn delete_stream<'a>(
&'a self,
locations: BoxStream<'a, OSResult<Path>>,
) -> BoxStream<'a, OSResult<Path>> {
self.inner.delete_stream(locations)
}
fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
self.inner.list(prefix)
}
fn list_with_offset(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> BoxStream<'static, OSResult<ObjectMeta>> {
self.inner.list_with_offset(prefix, offset)
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
self.inner.list_with_delimiter(prefix).await
}
async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.copy(from, to).await
}
async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.rename(from, to).await
}
async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.rename_if_not_exists(from, to).await
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.inner.copy_if_not_exists(from, to).await
}
}
#[tokio::test]
async fn test_throttled_retries_on_throttle_error_then_succeeds() {
let mock = Arc::new(RetryTestMockStore::new(2));
let path = Path::from("test/retry.txt");
mock.put(&path, PutPayload::from_static(b"retry data"))
.await
.unwrap();
let config = AimdThrottleConfig::default();
let throttled =
AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
let result = throttled.get(&path).await;
assert!(result.is_ok(), "Expected success after retries");
let bytes = result.unwrap().bytes().await.unwrap();
assert_eq!(bytes.as_ref(), b"retry data");
assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 3);
}
#[tokio::test]
async fn test_throttled_fails_after_max_retries_exceeded() {
let mock = Arc::new(RetryTestMockStore::new(10));
let path = Path::from("test/fail.txt");
mock.put(&path, PutPayload::from_static(b"fail data"))
.await
.unwrap();
let config = AimdThrottleConfig::default();
let throttled =
AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
let result = throttled.get(&path).await;
assert!(result.is_err(), "Expected error after max retries");
assert!(is_throttle_error(&result.unwrap_err()));
assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4);
}
#[tokio::test]
async fn test_throttled_multipart_reorders_parts() {
let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
let config = AimdThrottleConfig::default();
let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
let path = Path::from("test/multipart_ordering.bin");
let mut upload = throttled.put_multipart(&path).await.unwrap();
let fut_a = upload.put_part(PutPayload::from_static(b"AAAA"));
let fut_b = upload.put_part(PutPayload::from_static(b"BBBB"));
fut_b.await.unwrap();
fut_a.await.unwrap();
upload.complete().await.unwrap();
let result = store.get(&path).await.unwrap();
let bytes = result.bytes().await.unwrap();
assert_eq!(
bytes.as_ref(),
b"AAAABBBB",
"Parts were reordered! Got {:?} instead of AAAABBBB.",
std::str::from_utf8(&bytes).unwrap_or("<non-utf8>"),
);
}
}