Skip to main content

lance_io/object_store/
throttle.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! AIMD-controlled token bucket rate limiter for ObjectStore operations.
5//!
6//! Wraps any [`object_store::ObjectStore`] with per-category token buckets
7//! whose fill rates are dynamically adjusted by AIMD controllers. When cloud
8//! stores return HTTP 429/503, the fill rate decreases multiplicatively. During
9//! sustained success windows, it increases additively.
10//!
11//! Operations are split into four independent categories — **read**, **write**,
12//! **delete**, **list** — each with its own AIMD controller and token bucket.
13//! This prevents a burst of reads from starving writes, and vice versa.
14//!
15//! # Example
16//!
17//! ```ignore
18//! use lance_io::object_store::throttle::{AimdThrottleConfig, AimdThrottledStore};
19//!
20//! let throttled = AimdThrottledStore::new(target, AimdThrottleConfig::default()).unwrap();
21//! ```
22
23use std::collections::HashMap;
24use std::fmt::{Debug, Display, Formatter};
25use std::ops::Range;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use bytes::Bytes;
30use futures::StreamExt;
31use futures::stream::BoxStream;
32use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome};
33use object_store::path::Path;
34use object_store::{
35    GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
36    PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
37};
38use rand::Rng;
39use tokio::sync::Mutex;
40use tracing::{debug, warn};
41
42/// Check whether an `object_store::Error` represents a throttle response
43/// (HTTP 429 / 503) from a cloud object store.
44///
45/// Regrettably, this information is not fully exposed by the `object_store` crate.
46/// There is no generic mechanism for a custom object store to return a throttle error.
47///
48/// However, the builtin object stores all use RetryError when retries are configured and
49/// throttle errors are returned.  Sadly, RetryError is not a public type, so we have to
50/// infer it from the error message.  This is potentially dangerous because these errors
51/// often include the URI itself and that URI could have any characters in it (e.g. if we
52/// look for 429 then we might match a 429 in a UUID).These error messages currently look like:
53///
54/// ", after ... retries, max_retries: ..., retry_timeout: ..."
55///
56/// So, as a crude heuristic, which should work for the builtin object stores, but won't
57/// work for custom object stores, we simply look for the string "retries, max_retries"
58/// in the error message.
59pub fn is_throttle_error(err: &object_store::Error) -> bool {
60    // Only Generic errors can carry throttle responses
61    if let object_store::Error::Generic { source, .. } = err {
62        source.to_string().contains("retries, max_retries")
63    } else {
64        false
65    }
66}
67
68/// Configuration for the AIMD-throttled ObjectStore wrapper.
69///
70/// Each operation category (read, write, delete, list) has its own AIMD config.
71/// Use [`with_aimd`](AimdThrottleConfig::with_aimd) to set all categories at
72/// once, or per-category methods like [`with_read_aimd`](AimdThrottleConfig::with_read_aimd)
73/// for fine-grained control.
74#[derive(Debug, Clone)]
75pub struct AimdThrottleConfig {
76    /// AIMD configuration for read operations (get, get_opts, get_range, get_ranges, head).
77    pub read: AimdConfig,
78    /// AIMD configuration for write operations (put, put_opts, put_multipart, copy, rename, etc.).
79    pub write: AimdConfig,
80    /// AIMD configuration for delete operations.
81    pub delete: AimdConfig,
82    /// AIMD configuration for list operations (list_with_delimiter).
83    pub list: AimdConfig,
84    /// Maximum tokens that can accumulate for bursts (shared across all categories).
85    pub burst_capacity: u32,
86    /// Maximum number of retries for throttle errors within the AIMD layer.
87    pub max_retries: usize,
88    /// Minimum backoff in milliseconds between retry attempts.
89    pub min_backoff_ms: u64,
90    /// Maximum backoff in milliseconds between retry attempts.
91    pub max_backoff_ms: u64,
92}
93
94impl Default for AimdThrottleConfig {
95    fn default() -> Self {
96        let aimd = AimdConfig::default();
97        Self {
98            read: aimd.clone(),
99            write: aimd.clone(),
100            delete: aimd.clone(),
101            list: aimd,
102            burst_capacity: 100,
103            max_retries: 3,
104            min_backoff_ms: 100,
105            max_backoff_ms: 300,
106        }
107    }
108}
109
110impl AimdThrottleConfig {
111    /// Set the AIMD configuration for all four operation categories at once.
112    pub fn with_aimd(self, aimd: AimdConfig) -> Self {
113        Self {
114            read: aimd.clone(),
115            write: aimd.clone(),
116            delete: aimd.clone(),
117            list: aimd,
118            ..self
119        }
120    }
121
122    /// Set the AIMD configuration for read operations.
123    pub fn with_read_aimd(self, aimd: AimdConfig) -> Self {
124        Self { read: aimd, ..self }
125    }
126
127    /// Set the AIMD configuration for write operations.
128    pub fn with_write_aimd(self, aimd: AimdConfig) -> Self {
129        Self {
130            write: aimd,
131            ..self
132        }
133    }
134
135    /// Set the AIMD configuration for delete operations.
136    pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self {
137        Self {
138            delete: aimd,
139            ..self
140        }
141    }
142
143    /// Set the AIMD configuration for list operations.
144    pub fn with_list_aimd(self, aimd: AimdConfig) -> Self {
145        Self { list: aimd, ..self }
146    }
147
148    /// Returns `true` when the AIMD throttle layer should be bypassed entirely.
149    pub fn is_disabled(&self) -> bool {
150        self.max_retries == 0
151    }
152
153    pub fn with_burst_capacity(self, burst_capacity: u32) -> Self {
154        Self {
155            burst_capacity,
156            ..self
157        }
158    }
159
160    /// Build an `AimdThrottleConfig` from storage options and environment variables.
161    ///
162    /// Storage options take precedence over environment variables, which take
163    /// precedence over defaults. A single AIMD config is applied to all four
164    /// operation categories (read/write/delete/list).
165    ///
166    /// | Setting              | Storage Option Key               | Env Var                          | Default |
167    /// |----------------------|----------------------------------|----------------------------------|---------|
168    /// | Initial rate         | `lance_aimd_initial_rate`        | `LANCE_AIMD_INITIAL_RATE`        | 2000    |
169    /// | Min rate             | `lance_aimd_min_rate`            | `LANCE_AIMD_MIN_RATE`            | 1       |
170    /// | Max rate             | `lance_aimd_max_rate`            | `LANCE_AIMD_MAX_RATE`            | 5000    |
171    /// | Decrease factor      | `lance_aimd_decrease_factor`     | `LANCE_AIMD_DECREASE_FACTOR`     | 0.5     |
172    /// | Additive increment   | `lance_aimd_additive_increment`  | `LANCE_AIMD_ADDITIVE_INCREMENT`  | 300     |
173    /// | Burst capacity       | `lance_aimd_burst_capacity`      | `LANCE_AIMD_BURST_CAPACITY`      | 100     |
174    /// | Max retries          | `lance_aimd_max_retries`         | `LANCE_AIMD_MAX_RETRIES`         | 3       |
175    /// | Min backoff ms       | `lance_aimd_min_backoff_ms`      | `LANCE_AIMD_MIN_BACKOFF_MS`      | 100     |
176    /// | Max backoff ms       | `lance_aimd_max_backoff_ms`      | `LANCE_AIMD_MAX_BACKOFF_MS`      | 300     |
177    pub fn from_storage_options(
178        storage_options: Option<&HashMap<String, String>>,
179    ) -> lance_core::Result<Self> {
180        fn resolve_f64(
181            key: &str,
182            storage_options: Option<&HashMap<String, String>>,
183            default: f64,
184        ) -> lance_core::Result<f64> {
185            let env_key = key.to_ascii_uppercase();
186            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
187                val.parse::<f64>().map_err(|_| {
188                    lance_core::Error::invalid_input(format!(
189                        "Invalid value for storage option '{key}': '{val}'"
190                    ))
191                })
192            } else if let Ok(val) = std::env::var(&env_key) {
193                val.parse::<f64>().map_err(|_| {
194                    lance_core::Error::invalid_input(format!(
195                        "Invalid value for env var '{env_key}': '{val}'"
196                    ))
197                })
198            } else {
199                Ok(default)
200            }
201        }
202
203        fn resolve_u32(
204            key: &str,
205            storage_options: Option<&HashMap<String, String>>,
206            default: u32,
207        ) -> lance_core::Result<u32> {
208            let env_key = key.to_ascii_uppercase();
209            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
210                val.parse::<u32>().map_err(|_| {
211                    lance_core::Error::invalid_input(format!(
212                        "Invalid value for storage option '{key}': '{val}'"
213                    ))
214                })
215            } else if let Ok(val) = std::env::var(&env_key) {
216                val.parse::<u32>().map_err(|_| {
217                    lance_core::Error::invalid_input(format!(
218                        "Invalid value for env var '{env_key}': '{val}'"
219                    ))
220                })
221            } else {
222                Ok(default)
223            }
224        }
225
226        fn resolve_usize(
227            key: &str,
228            storage_options: Option<&HashMap<String, String>>,
229            default: usize,
230        ) -> lance_core::Result<usize> {
231            let env_key = key.to_ascii_uppercase();
232            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
233                val.parse::<usize>().map_err(|_| {
234                    lance_core::Error::invalid_input(format!(
235                        "Invalid value for storage option '{key}': '{val}'"
236                    ))
237                })
238            } else if let Ok(val) = std::env::var(&env_key) {
239                val.parse::<usize>().map_err(|_| {
240                    lance_core::Error::invalid_input(format!(
241                        "Invalid value for env var '{env_key}': '{val}'"
242                    ))
243                })
244            } else {
245                Ok(default)
246            }
247        }
248
249        fn resolve_u64(
250            key: &str,
251            storage_options: Option<&HashMap<String, String>>,
252            default: u64,
253        ) -> lance_core::Result<u64> {
254            let env_key = key.to_ascii_uppercase();
255            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
256                val.parse::<u64>().map_err(|_| {
257                    lance_core::Error::invalid_input(format!(
258                        "Invalid value for storage option '{key}': '{val}'"
259                    ))
260                })
261            } else if let Ok(val) = std::env::var(&env_key) {
262                val.parse::<u64>().map_err(|_| {
263                    lance_core::Error::invalid_input(format!(
264                        "Invalid value for env var '{env_key}': '{val}'"
265                    ))
266                })
267            } else {
268                Ok(default)
269            }
270        }
271
272        let initial_rate = resolve_f64("lance_aimd_initial_rate", storage_options, 2000.0)?;
273        let min_rate = resolve_f64("lance_aimd_min_rate", storage_options, 1.0)?;
274        let max_rate = resolve_f64("lance_aimd_max_rate", storage_options, 5000.0)?;
275        let decrease_factor = resolve_f64("lance_aimd_decrease_factor", storage_options, 0.5)?;
276        let additive_increment =
277            resolve_f64("lance_aimd_additive_increment", storage_options, 300.0)?;
278        let burst_capacity = resolve_u32("lance_aimd_burst_capacity", storage_options, 100)?;
279        let max_retries = resolve_usize("lance_aimd_max_retries", storage_options, 3)?;
280        let min_backoff_ms = resolve_u64("lance_aimd_min_backoff_ms", storage_options, 100)?;
281        let max_backoff_ms = resolve_u64("lance_aimd_max_backoff_ms", storage_options, 300)?;
282
283        let aimd = AimdConfig::default()
284            .with_initial_rate(initial_rate)
285            .with_min_rate(min_rate)
286            .with_max_rate(max_rate)
287            .with_decrease_factor(decrease_factor)
288            .with_additive_increment(additive_increment);
289
290        Ok(Self {
291            max_retries,
292            min_backoff_ms,
293            max_backoff_ms,
294            ..Self::default()
295                .with_aimd(aimd)
296                .with_burst_capacity(burst_capacity)
297        })
298    }
299}
300
301struct TokenBucketState {
302    tokens: f64,
303    last_refill: std::time::Instant,
304    rate: f64,
305}
306
307/// Per-category throttle state: an AIMD controller paired with a token bucket.
308struct OperationThrottle {
309    controller: AimdController,
310    bucket: Mutex<TokenBucketState>,
311    burst_capacity: f64,
312    max_retries: usize,
313    min_backoff_ms: u64,
314    max_backoff_ms: u64,
315}
316
317impl OperationThrottle {
318    fn new(
319        aimd_config: AimdConfig,
320        burst_capacity: f64,
321        max_retries: usize,
322        min_backoff_ms: u64,
323        max_backoff_ms: u64,
324    ) -> lance_core::Result<Self> {
325        let initial_rate = aimd_config.initial_rate;
326        let controller = AimdController::new(aimd_config)?;
327        Ok(Self {
328            controller,
329            bucket: Mutex::new(TokenBucketState {
330                tokens: burst_capacity,
331                last_refill: std::time::Instant::now(),
332                rate: initial_rate,
333            }),
334            burst_capacity,
335            max_retries,
336            min_backoff_ms,
337            max_backoff_ms,
338        })
339    }
340
341    /// Acquire a token from the bucket, sleeping if none are available.
342    ///
343    /// Each caller reserves a token immediately (allowing `tokens` to go
344    /// negative) so that concurrent waiters queue behind each other instead
345    /// of all waking at the same instant (thundering herd).
346    async fn acquire_token(&self) {
347        let sleep_duration = {
348            let mut bucket = self.bucket.lock().await;
349            let now = std::time::Instant::now();
350            let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
351            bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity);
352            bucket.last_refill = now;
353
354            // Reserve a token (may go negative to queue behind other waiters)
355            bucket.tokens -= 1.0;
356
357            if bucket.tokens >= 0.0 {
358                // Had a token available, no need to sleep
359                return;
360            }
361
362            // Sleep proportional to our position in the queue
363            std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate)
364        };
365
366        tokio::time::sleep(sleep_duration).await;
367    }
368
369    /// Update the bucket's fill rate from the controller.
370    async fn update_bucket_rate(&self, new_rate: f64) {
371        let mut bucket = self.bucket.lock().await;
372        bucket.rate = new_rate;
373    }
374
375    /// Classify a result and feed it back to the AIMD controller without
376    /// acquiring a token. Uses `try_lock` for the bucket update so that if the
377    /// bucket lock is contended the rate update is deferred to the next
378    /// `throttled()` call.
379    fn observe_outcome<T>(&self, result: &OSResult<T>) {
380        let outcome = match result {
381            Ok(_) => RequestOutcome::Success,
382            Err(err) if is_throttle_error(err) => {
383                debug!("Throttle error detected in stream");
384                RequestOutcome::Throttled
385            }
386            Err(_) => RequestOutcome::Success,
387        };
388        let prev_rate = self.controller.current_rate();
389        let new_rate = self.controller.record_outcome(outcome);
390        if new_rate < prev_rate {
391            warn!(
392                previous_rate = format!("{prev_rate:.1}"),
393                new_rate = format!("{new_rate:.1}"),
394                "AIMD throttle: rate reduced due to throttle errors"
395            );
396        }
397        if let Ok(mut bucket) = self.bucket.try_lock() {
398            bucket.rate = new_rate;
399        }
400    }
401
402    /// Execute an operation with throttling: acquire token, run, classify result.
403    /// On throttle errors, retries up to `max_retries` times with a random
404    /// backoff between `min_backoff_ms` and `max_backoff_ms` between attempts.
405    async fn throttled<T, F, Fut>(&self, f: F) -> OSResult<T>
406    where
407        F: Fn() -> Fut,
408        Fut: std::future::Future<Output = OSResult<T>>,
409    {
410        for attempt in 0..=self.max_retries {
411            self.acquire_token().await;
412            let result = f().await;
413            let outcome = match &result {
414                Ok(_) => RequestOutcome::Success,
415                Err(err) if is_throttle_error(err) => {
416                    debug!("Throttle error detected");
417                    RequestOutcome::Throttled
418                }
419                Err(_) => RequestOutcome::Success, // Non-throttle errors don't indicate capacity problems
420            };
421            let prev_rate = self.controller.current_rate();
422            let new_rate = self.controller.record_outcome(outcome);
423            if new_rate < prev_rate {
424                warn!(
425                    previous_rate = format!("{prev_rate:.1}"),
426                    new_rate = format!("{new_rate:.1}"),
427                    "AIMD throttle: rate reduced due to throttle errors"
428                );
429            }
430            self.update_bucket_rate(new_rate).await;
431
432            match &result {
433                Err(err) if is_throttle_error(err) && attempt < self.max_retries => {
434                    let backoff_ms =
435                        rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms);
436                    debug!(
437                        attempt = attempt + 1,
438                        max_retries = self.max_retries,
439                        backoff_ms,
440                        "Retrying after throttle error"
441                    );
442                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
443                    continue;
444                }
445                _ => return result,
446            }
447        }
448        unreachable!()
449    }
450}
451
452impl Debug for OperationThrottle {
453    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
454        f.debug_struct("OperationThrottle")
455            .field("controller", &self.controller)
456            .field("burst_capacity", &self.burst_capacity)
457            .finish()
458    }
459}
460
461/// A [`MultipartUpload`] wrapper that throttles and retries `put_part`,
462/// `complete`, and `abort`, feeding outcomes back to the write AIMD
463/// controller.
464struct ThrottledMultipartUpload {
465    target: Box<dyn MultipartUpload>,
466    write: Arc<OperationThrottle>,
467}
468
469impl Debug for ThrottledMultipartUpload {
470    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
471        f.debug_struct("ThrottledMultipartUpload").finish()
472    }
473}
474
475#[async_trait]
476impl MultipartUpload for ThrottledMultipartUpload {
477    fn put_part(&mut self, data: PutPayload) -> UploadPart {
478        let write = Arc::clone(&self.write);
479        // Call put_part synchronously to preserve part ordering regardless
480        // of which futures are awaited first.
481        let fut = self.target.put_part(data);
482        Box::pin(async move {
483            write.acquire_token().await;
484            let result = fut.await;
485            write.observe_outcome(&result);
486            result
487        })
488    }
489
490    async fn complete(&mut self) -> OSResult<PutResult> {
491        let target = &mut self.target;
492        for attempt in 0..=self.write.max_retries {
493            self.write.acquire_token().await;
494            let result = target.complete().await;
495            self.write.observe_outcome(&result);
496
497            match &result {
498                Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
499                    let backoff_ms = rand::rng()
500                        .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
501                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
502                    continue;
503                }
504                _ => return result,
505            }
506        }
507        unreachable!()
508    }
509
510    async fn abort(&mut self) -> OSResult<()> {
511        let target = &mut self.target;
512        for attempt in 0..=self.write.max_retries {
513            self.write.acquire_token().await;
514            let result = target.abort().await;
515            self.write.observe_outcome(&result);
516
517            match &result {
518                Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
519                    let backoff_ms = rand::rng()
520                        .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
521                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
522                    continue;
523                }
524                _ => return result,
525            }
526        }
527        unreachable!()
528    }
529}
530
531/// An ObjectStore wrapper that rate-limits operations using per-category token
532/// buckets whose fill rates are controlled by AIMD algorithms.
533///
534/// Operations are split into four independent categories:
535/// - **read**: `get`, `get_opts`, `get_range`, `get_ranges`, `head`
536/// - **write**: `put`, `put_opts`, `put_multipart`, `put_multipart_opts`, `copy`, `copy_if_not_exists`, `rename`, `rename_if_not_exists`
537/// - **delete**: `delete`
538/// - **list**: `list_with_delimiter`
539///
540/// Streaming operations (`list`, `list_with_offset`, `delete_stream`) do not acquire tokens,
541/// but observe each yielded item and feed the result back to the AIMD controller so it can
542/// adjust the rate for other operations in the same category.
543///
544/// This is not perfect but probably as close as we can get without moving the throttle into
545/// the object_store crate itself.
546pub struct AimdThrottledStore {
547    target: Arc<dyn ObjectStore>,
548    read: Arc<OperationThrottle>,
549    write: Arc<OperationThrottle>,
550    delete: Arc<OperationThrottle>,
551    list: Arc<OperationThrottle>,
552}
553
554impl Debug for AimdThrottledStore {
555    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
556        f.debug_struct("AimdThrottledStore")
557            .field("target", &self.target)
558            .field("read", &self.read)
559            .field("write", &self.write)
560            .field("delete", &self.delete)
561            .field("list", &self.list)
562            .finish()
563    }
564}
565
566impl Display for AimdThrottledStore {
567    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
568        write!(f, "AimdThrottledStore({})", self.target)
569    }
570}
571
572impl AimdThrottledStore {
573    pub fn new(
574        target: Arc<dyn ObjectStore>,
575        config: AimdThrottleConfig,
576    ) -> lance_core::Result<Self> {
577        let burst = config.burst_capacity as f64;
578        let max_retries = config.max_retries;
579        let min_backoff_ms = config.min_backoff_ms;
580        let max_backoff_ms = config.max_backoff_ms;
581        Ok(Self {
582            target,
583            read: Arc::new(OperationThrottle::new(
584                config.read,
585                burst,
586                max_retries,
587                min_backoff_ms,
588                max_backoff_ms,
589            )?),
590            write: Arc::new(OperationThrottle::new(
591                config.write,
592                burst,
593                max_retries,
594                min_backoff_ms,
595                max_backoff_ms,
596            )?),
597            delete: Arc::new(OperationThrottle::new(
598                config.delete,
599                burst,
600                max_retries,
601                min_backoff_ms,
602                max_backoff_ms,
603            )?),
604            list: Arc::new(OperationThrottle::new(
605                config.list,
606                burst,
607                max_retries,
608                min_backoff_ms,
609                max_backoff_ms,
610            )?),
611        })
612    }
613}
614
615#[async_trait]
616#[deny(clippy::missing_trait_methods)]
617impl ObjectStore for AimdThrottledStore {
618    async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
619        self.write
620            .throttled(|| self.target.put(location, bytes.clone()))
621            .await
622    }
623
624    async fn put_opts(
625        &self,
626        location: &Path,
627        bytes: PutPayload,
628        opts: PutOptions,
629    ) -> OSResult<PutResult> {
630        self.write
631            .throttled(|| self.target.put_opts(location, bytes.clone(), opts.clone()))
632            .await
633    }
634
635    async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
636        let target = self
637            .write
638            .throttled(|| self.target.put_multipart(location))
639            .await?;
640        Ok(Box::new(ThrottledMultipartUpload {
641            target,
642            write: Arc::clone(&self.write),
643        }))
644    }
645
646    async fn put_multipart_opts(
647        &self,
648        location: &Path,
649        opts: PutMultipartOptions,
650    ) -> OSResult<Box<dyn MultipartUpload>> {
651        let target = self
652            .write
653            .throttled(|| self.target.put_multipart_opts(location, opts.clone()))
654            .await?;
655        Ok(Box::new(ThrottledMultipartUpload {
656            target,
657            write: Arc::clone(&self.write),
658        }))
659    }
660
661    async fn get(&self, location: &Path) -> OSResult<GetResult> {
662        self.read.throttled(|| self.target.get(location)).await
663    }
664
665    async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
666        self.read
667            .throttled(|| self.target.get_opts(location, options.clone()))
668            .await
669    }
670
671    async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
672        self.read
673            .throttled(|| self.target.get_range(location, range.clone()))
674            .await
675    }
676
677    async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
678        self.read
679            .throttled(|| self.target.get_ranges(location, ranges))
680            .await
681    }
682
683    async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
684        self.read.throttled(|| self.target.head(location)).await
685    }
686
687    async fn delete(&self, location: &Path) -> OSResult<()> {
688        self.delete.throttled(|| self.target.delete(location)).await
689    }
690
691    fn delete_stream<'a>(
692        &'a self,
693        locations: BoxStream<'a, OSResult<Path>>,
694    ) -> BoxStream<'a, OSResult<Path>> {
695        self.target
696            .delete_stream(locations)
697            .map(|item| {
698                self.delete.observe_outcome(&item);
699                item
700            })
701            .boxed()
702    }
703
704    fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
705        let throttle = Arc::clone(&self.list);
706        self.target
707            .list(prefix)
708            .map(move |item| {
709                throttle.observe_outcome(&item);
710                item
711            })
712            .boxed()
713    }
714
715    fn list_with_offset(
716        &self,
717        prefix: Option<&Path>,
718        offset: &Path,
719    ) -> BoxStream<'static, OSResult<ObjectMeta>> {
720        let throttle = Arc::clone(&self.list);
721        self.target
722            .list_with_offset(prefix, offset)
723            .map(move |item| {
724                throttle.observe_outcome(&item);
725                item
726            })
727            .boxed()
728    }
729
730    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
731        self.list
732            .throttled(|| self.target.list_with_delimiter(prefix))
733            .await
734    }
735
736    async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
737        self.write.throttled(|| self.target.copy(from, to)).await
738    }
739
740    async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
741        self.write.throttled(|| self.target.rename(from, to)).await
742    }
743
744    async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
745        self.write
746            .throttled(|| self.target.rename_if_not_exists(from, to))
747            .await
748    }
749
750    async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
751        self.write
752            .throttled(|| self.target.copy_if_not_exists(from, to))
753            .await
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760    use object_store::memory::InMemory;
761    use rstest::rstest;
762    use std::collections::VecDeque;
763    use std::sync::atomic::{AtomicU64, Ordering};
764
765    fn make_generic_error(msg: &str) -> object_store::Error {
766        object_store::Error::Generic {
767            store: "test",
768            source: msg.into(),
769        }
770    }
771
772    #[rstest]
773    #[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)]
774    #[case::retries_in_message(
775        "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s",
776        true
777    )]
778    #[case::not_found("Object not found", false)]
779    #[case::permission_denied("Access denied", false)]
780    #[case::timeout("Connection timed out", false)]
781    #[case::http_429_without_retries("HTTP 429 Too Many Requests", false)]
782    #[case::slowdown_without_retries("SlowDown: Please reduce your request rate", false)]
783    fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) {
784        let err = make_generic_error(msg);
785        assert_eq!(
786            is_throttle_error(&err),
787            expected,
788            "is_throttle_error for '{}' should be {}",
789            msg,
790            expected
791        );
792    }
793
794    #[test]
795    fn test_non_generic_errors_are_not_throttle() {
796        let err = object_store::Error::NotFound {
797            path: "test".to_string(),
798            source: "not found".into(),
799        };
800        assert!(!is_throttle_error(&err));
801    }
802
803    #[tokio::test]
804    async fn test_basic_put_get_through_wrapper() {
805        let store = Arc::new(InMemory::new());
806        let config = AimdThrottleConfig::default();
807        let throttled = AimdThrottledStore::new(store, config).unwrap();
808
809        let path = Path::from("test/file.txt");
810        let data = PutPayload::from_static(b"hello world");
811        throttled.put(&path, data).await.unwrap();
812
813        let result = throttled.get(&path).await.unwrap();
814        let bytes = result.bytes().await.unwrap();
815        assert_eq!(bytes.as_ref(), b"hello world");
816    }
817
818    #[tokio::test]
819    async fn test_rate_decreases_on_throttle() {
820        let store = Arc::new(InMemory::new());
821        let config = AimdThrottleConfig::default().with_aimd(
822            AimdConfig::default()
823                .with_initial_rate(100.0)
824                .with_decrease_factor(0.5)
825                .with_window_duration(std::time::Duration::from_millis(10)),
826        );
827        let throttled = AimdThrottledStore::new(store, config).unwrap();
828
829        let initial_rate = throttled.read.controller.current_rate();
830        assert_eq!(initial_rate, 100.0);
831
832        // Simulate a throttle outcome directly
833        throttled
834            .read
835            .controller
836            .record_outcome(RequestOutcome::Throttled);
837
838        // Wait for window to expire and trigger evaluation
839        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
840        throttled
841            .read
842            .controller
843            .record_outcome(RequestOutcome::Success);
844
845        let new_rate = throttled.read.controller.current_rate();
846        assert!(
847            new_rate < initial_rate,
848            "Rate should decrease after throttle: {} < {}",
849            new_rate,
850            initial_rate
851        );
852    }
853
854    #[tokio::test]
855    async fn test_rate_recovers_on_success() {
856        let store = Arc::new(InMemory::new());
857        let config = AimdThrottleConfig::default().with_aimd(
858            AimdConfig::default()
859                .with_initial_rate(100.0)
860                .with_decrease_factor(0.5)
861                .with_additive_increment(10.0)
862                .with_window_duration(std::time::Duration::from_millis(10)),
863        );
864        let throttled = AimdThrottledStore::new(store, config).unwrap();
865
866        // First decrease via throttle
867        throttled
868            .read
869            .controller
870            .record_outcome(RequestOutcome::Throttled);
871        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
872        throttled
873            .read
874            .controller
875            .record_outcome(RequestOutcome::Success);
876        let decreased_rate = throttled.read.controller.current_rate();
877        assert_eq!(decreased_rate, 50.0);
878
879        // Now recover via success
880        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
881        throttled
882            .read
883            .controller
884            .record_outcome(RequestOutcome::Success);
885        let recovered_rate = throttled.read.controller.current_rate();
886        assert_eq!(recovered_rate, 60.0);
887    }
888
889    #[tokio::test]
890    async fn test_as_dyn_object_store() {
891        let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
892        let throttled: Arc<dyn ObjectStore> =
893            Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap());
894
895        let path = Path::from("test/data.bin");
896        let data = PutPayload::from_static(b"test data");
897        throttled.put(&path, data).await.unwrap();
898
899        let result = throttled.get(&path).await.unwrap();
900        let bytes = result.bytes().await.unwrap();
901        assert_eq!(bytes.as_ref(), b"test data");
902    }
903
904    #[tokio::test]
905    async fn test_token_bucket_delays_when_exhausted() {
906        let store = Arc::new(InMemory::new());
907        // Very low rate and burst capacity to force waiting
908        let config = AimdThrottleConfig::default()
909            .with_burst_capacity(1)
910            .with_aimd(AimdConfig::default().with_initial_rate(10.0));
911        let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap());
912
913        let path = Path::from("test/file.txt");
914        let data = PutPayload::from_static(b"data");
915        throttled.put(&path, data).await.unwrap();
916
917        // After consuming the burst token, the next request should take ~100ms
918        // (1 token / 10 tokens-per-sec). We verify it takes at least 50ms.
919        let start = std::time::Instant::now();
920        let data2 = PutPayload::from_static(b"data2");
921        throttled.put(&path, data2).await.unwrap();
922        let elapsed = start.elapsed();
923
924        assert!(
925            elapsed >= std::time::Duration::from_millis(50),
926            "Expected delay for token refill, but elapsed was {:?}",
927            elapsed
928        );
929    }
930
931    #[tokio::test]
932    async fn test_list_observes_outcomes() {
933        let store = Arc::new(InMemory::new());
934        let config = AimdThrottleConfig::default();
935        let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
936
937        let path = Path::from("prefix/file.txt");
938        let data = PutPayload::from_static(b"data");
939        store.put(&path, data).await.unwrap();
940
941        let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
942        assert_eq!(items.len(), 1);
943        assert!(items[0].is_ok());
944    }
945
946    /// A mock store whose `list` stream yields a configurable sequence of
947    /// Ok / throttle-error items. Used to verify that the AIMD wrapper
948    /// observes errors surfaced inside list streams.
949    struct ThrottlingListMockStore {
950        inner: InMemory,
951        /// Number of throttle errors to inject at the start of each list call.
952        throttle_count: usize,
953    }
954
955    impl Display for ThrottlingListMockStore {
956        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
957            write!(f, "ThrottlingListMockStore")
958        }
959    }
960
961    impl Debug for ThrottlingListMockStore {
962        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
963            f.debug_struct("ThrottlingListMockStore").finish()
964        }
965    }
966
967    #[async_trait]
968    impl ObjectStore for ThrottlingListMockStore {
969        async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
970            self.inner.put(location, bytes).await
971        }
972        async fn put_opts(
973            &self,
974            location: &Path,
975            bytes: PutPayload,
976            opts: PutOptions,
977        ) -> OSResult<PutResult> {
978            self.inner.put_opts(location, bytes, opts).await
979        }
980        async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
981            self.inner.put_multipart(location).await
982        }
983        async fn put_multipart_opts(
984            &self,
985            location: &Path,
986            opts: PutMultipartOptions,
987        ) -> OSResult<Box<dyn MultipartUpload>> {
988            self.inner.put_multipart_opts(location, opts).await
989        }
990        async fn get(&self, location: &Path) -> OSResult<GetResult> {
991            self.inner.get(location).await
992        }
993        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
994            self.inner.get_opts(location, options).await
995        }
996        async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
997            self.inner.get_range(location, range).await
998        }
999        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1000            self.inner.get_ranges(location, ranges).await
1001        }
1002        async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
1003            self.inner.head(location).await
1004        }
1005        async fn delete(&self, location: &Path) -> OSResult<()> {
1006            self.inner.delete(location).await
1007        }
1008        fn delete_stream<'a>(
1009            &'a self,
1010            locations: BoxStream<'a, OSResult<Path>>,
1011        ) -> BoxStream<'a, OSResult<Path>> {
1012            self.inner.delete_stream(locations)
1013        }
1014        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1015            let n = self.throttle_count;
1016            let inner_stream = self.inner.list(prefix);
1017            let errors = futures::stream::iter((0..n).map(|_| {
1018                Err(object_store::Error::Generic {
1019                    store: "ThrottlingListMock",
1020                    source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s"
1021                        .into(),
1022                })
1023            }));
1024            errors.chain(inner_stream).boxed()
1025        }
1026        fn list_with_offset(
1027            &self,
1028            prefix: Option<&Path>,
1029            offset: &Path,
1030        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1031            self.inner.list_with_offset(prefix, offset)
1032        }
1033        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1034            self.inner.list_with_delimiter(prefix).await
1035        }
1036        async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
1037            self.inner.copy(from, to).await
1038        }
1039        async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
1040            self.inner.rename(from, to).await
1041        }
1042        async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1043            self.inner.rename_if_not_exists(from, to).await
1044        }
1045        async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1046            self.inner.copy_if_not_exists(from, to).await
1047        }
1048    }
1049
1050    #[tokio::test]
1051    async fn test_list_stream_throttle_errors_decrease_rate() {
1052        let mock = Arc::new(ThrottlingListMockStore {
1053            inner: InMemory::new(),
1054            throttle_count: 5,
1055        });
1056
1057        // Seed a file so the real items come through after the errors.
1058        mock.put(
1059            &Path::from("prefix/file.txt"),
1060            PutPayload::from_static(b"data"),
1061        )
1062        .await
1063        .unwrap();
1064
1065        let config = AimdThrottleConfig::default().with_list_aimd(
1066            AimdConfig::default()
1067                .with_initial_rate(100.0)
1068                .with_decrease_factor(0.5)
1069                .with_window_duration(std::time::Duration::from_millis(10)),
1070        );
1071        let throttled = AimdThrottledStore::new(mock as Arc<dyn ObjectStore>, config).unwrap();
1072
1073        let initial_rate = throttled.list.controller.current_rate();
1074        assert_eq!(initial_rate, 100.0);
1075
1076        let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
1077
1078        // 5 errors + 1 real item
1079        assert_eq!(items.len(), 6);
1080        assert!(items[0].is_err());
1081        assert!(items[5].is_ok());
1082
1083        // Wait for the AIMD window to expire and trigger evaluation.
1084        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1085        throttled
1086            .list
1087            .controller
1088            .record_outcome(RequestOutcome::Success);
1089
1090        let new_rate = throttled.list.controller.current_rate();
1091        assert!(
1092            new_rate < initial_rate,
1093            "List rate should decrease after stream throttle errors: {} < {}",
1094            new_rate,
1095            initial_rate
1096        );
1097    }
1098
1099    #[tokio::test]
1100    async fn test_per_category_independence() {
1101        let store = Arc::new(InMemory::new());
1102        let config = AimdThrottleConfig::default().with_aimd(
1103            AimdConfig::default()
1104                .with_initial_rate(100.0)
1105                .with_decrease_factor(0.5)
1106                .with_window_duration(std::time::Duration::from_millis(10)),
1107        );
1108        let throttled = AimdThrottledStore::new(store, config).unwrap();
1109
1110        // Push the read controller into a throttled state
1111        throttled
1112            .read
1113            .controller
1114            .record_outcome(RequestOutcome::Throttled);
1115        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1116        throttled
1117            .read
1118            .controller
1119            .record_outcome(RequestOutcome::Success);
1120
1121        let read_rate = throttled.read.controller.current_rate();
1122        let write_rate = throttled.write.controller.current_rate();
1123        let delete_rate = throttled.delete.controller.current_rate();
1124        let list_rate = throttled.list.controller.current_rate();
1125
1126        assert_eq!(read_rate, 50.0, "Read rate should have decreased");
1127        assert_eq!(write_rate, 100.0, "Write rate should be unaffected");
1128        assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected");
1129        assert_eq!(list_rate, 100.0, "List rate should be unaffected");
1130    }
1131
1132    #[tokio::test]
1133    async fn test_per_category_config() {
1134        let store = Arc::new(InMemory::new());
1135        let config = AimdThrottleConfig::default()
1136            .with_read_aimd(AimdConfig::default().with_initial_rate(200.0))
1137            .with_write_aimd(AimdConfig::default().with_initial_rate(100.0))
1138            .with_delete_aimd(AimdConfig::default().with_initial_rate(50.0))
1139            .with_list_aimd(AimdConfig::default().with_initial_rate(25.0));
1140        let throttled = AimdThrottledStore::new(store, config).unwrap();
1141
1142        assert_eq!(throttled.read.controller.current_rate(), 200.0);
1143        assert_eq!(throttled.write.controller.current_rate(), 100.0);
1144        assert_eq!(throttled.delete.controller.current_rate(), 50.0);
1145        assert_eq!(throttled.list.controller.current_rate(), 25.0);
1146    }
1147
1148    /// A mock [`ObjectStore`] that measures request rate over a sliding window
1149    /// and returns 503 errors when the rate exceeds a configurable threshold.
1150    /// Write and metadata-only operations are not rate-limited.
1151    struct RateLimitingMockStore {
1152        inner: InMemory,
1153        /// Timestamps of recent successful (admitted) requests.
1154        timestamps: std::sync::Mutex<VecDeque<std::time::Instant>>,
1155        /// Maximum requests allowed within `window`.
1156        max_per_window: usize,
1157        /// Sliding window duration.
1158        window: std::time::Duration,
1159        success_count: AtomicU64,
1160        throttle_count: AtomicU64,
1161    }
1162
1163    impl RateLimitingMockStore {
1164        fn new(max_per_window: usize, window: std::time::Duration) -> Self {
1165            Self {
1166                inner: InMemory::new(),
1167                timestamps: std::sync::Mutex::new(VecDeque::new()),
1168                max_per_window,
1169                window,
1170                success_count: AtomicU64::new(0),
1171                throttle_count: AtomicU64::new(0),
1172            }
1173        }
1174
1175        /// Returns `true` if the request is admitted, `false` if throttled.
1176        fn check_rate(&self) -> bool {
1177            let mut ts = self.timestamps.lock().unwrap();
1178            let now = std::time::Instant::now();
1179            while let Some(&front) = ts.front() {
1180                if now.duration_since(front) > self.window {
1181                    ts.pop_front();
1182                } else {
1183                    break;
1184                }
1185            }
1186            if ts.len() >= self.max_per_window {
1187                self.throttle_count.fetch_add(1, Ordering::Relaxed);
1188                false
1189            } else {
1190                ts.push_back(now);
1191                self.success_count.fetch_add(1, Ordering::Relaxed);
1192                true
1193            }
1194        }
1195
1196        fn throttle_error() -> object_store::Error {
1197            object_store::Error::Generic {
1198                store: "RateLimitingMock",
1199                source: "request failed, after 10 retries, max_retries: 10, retry_timeout: 180s"
1200                    .into(),
1201            }
1202        }
1203    }
1204
1205    impl Display for RateLimitingMockStore {
1206        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1207            write!(f, "RateLimitingMockStore")
1208        }
1209    }
1210
1211    impl Debug for RateLimitingMockStore {
1212        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1213            f.debug_struct("RateLimitingMockStore").finish()
1214        }
1215    }
1216
1217    #[async_trait]
1218    impl ObjectStore for RateLimitingMockStore {
1219        async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
1220            self.inner.put(location, bytes).await
1221        }
1222
1223        async fn put_opts(
1224            &self,
1225            location: &Path,
1226            bytes: PutPayload,
1227            opts: PutOptions,
1228        ) -> OSResult<PutResult> {
1229            self.inner.put_opts(location, bytes, opts).await
1230        }
1231
1232        async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
1233            self.inner.put_multipart(location).await
1234        }
1235
1236        async fn put_multipart_opts(
1237            &self,
1238            location: &Path,
1239            opts: PutMultipartOptions,
1240        ) -> OSResult<Box<dyn MultipartUpload>> {
1241            self.inner.put_multipart_opts(location, opts).await
1242        }
1243
1244        async fn get(&self, location: &Path) -> OSResult<GetResult> {
1245            if self.check_rate() {
1246                self.inner.get(location).await
1247            } else {
1248                Err(Self::throttle_error())
1249            }
1250        }
1251
1252        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1253            if self.check_rate() {
1254                self.inner.get_opts(location, options).await
1255            } else {
1256                Err(Self::throttle_error())
1257            }
1258        }
1259
1260        async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
1261            if self.check_rate() {
1262                self.inner.get_range(location, range).await
1263            } else {
1264                Err(Self::throttle_error())
1265            }
1266        }
1267
1268        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1269            if self.check_rate() {
1270                self.inner.get_ranges(location, ranges).await
1271            } else {
1272                Err(Self::throttle_error())
1273            }
1274        }
1275
1276        async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
1277            if self.check_rate() {
1278                self.inner.head(location).await
1279            } else {
1280                Err(Self::throttle_error())
1281            }
1282        }
1283
1284        async fn delete(&self, location: &Path) -> OSResult<()> {
1285            self.inner.delete(location).await
1286        }
1287
1288        fn delete_stream<'a>(
1289            &'a self,
1290            locations: BoxStream<'a, OSResult<Path>>,
1291        ) -> BoxStream<'a, OSResult<Path>> {
1292            self.inner.delete_stream(locations)
1293        }
1294
1295        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1296            self.inner.list(prefix)
1297        }
1298
1299        fn list_with_offset(
1300            &self,
1301            prefix: Option<&Path>,
1302            offset: &Path,
1303        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1304            self.inner.list_with_offset(prefix, offset)
1305        }
1306
1307        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1308            self.inner.list_with_delimiter(prefix).await
1309        }
1310
1311        async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
1312            self.inner.copy(from, to).await
1313        }
1314
1315        async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
1316            self.inner.rename(from, to).await
1317        }
1318
1319        async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1320            self.inner.rename_if_not_exists(from, to).await
1321        }
1322
1323        async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1324            self.inner.copy_if_not_exists(from, to).await
1325        }
1326    }
1327
1328    /// Verify that multiple concurrent readers sharing an AIMD-throttled store
1329    /// converge to the backend's actual capacity.
1330    ///
1331    /// Setup:
1332    /// - Mock backend allows 30 requests per 100ms (= 300 req/s).
1333    /// - 5 reader tasks, each with their own [`AimdThrottledStore`] wrapping
1334    ///   the shared mock.
1335    /// - AIMD: 100ms window, initial rate 100 req/s, decrease 0.5, increase 2.
1336    /// - Readers issue `head()` requests as fast as the throttle allows for 2s.
1337    ///
1338    /// Expected behaviour:
1339    /// - Initial burst (100 burst tokens × 5 readers) overshoots the mock
1340    ///   capacity, causing many 503s. Each reader's AIMD halves its rate.
1341    /// - After the transient, each reader converges to ~60 req/s (300/5).
1342    /// - Over 2 seconds, total successful requests should be in the range
1343    ///   [300, 900] (theoretical max ≈ 600).
1344    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1345    async fn test_aimd_throttle_under_concurrent_load() {
1346        let mock = Arc::new(RateLimitingMockStore::new(
1347            30,
1348            std::time::Duration::from_millis(100),
1349        ));
1350
1351        // Seed a test file so head() succeeds when admitted.
1352        let path = Path::from("test/data.bin");
1353        mock.put(&path, PutPayload::from_static(b"test data"))
1354            .await
1355            .unwrap();
1356
1357        let aimd = AimdConfig::default()
1358            .with_initial_rate(100.0)
1359            .with_decrease_factor(0.5)
1360            .with_additive_increment(2.0)
1361            .with_window_duration(std::time::Duration::from_millis(100));
1362        let throttle_config = AimdThrottleConfig::default()
1363            .with_aimd(aimd)
1364            .with_burst_capacity(100);
1365
1366        let num_readers = 5;
1367        let test_duration = std::time::Duration::from_secs(2);
1368        let mut handles = Vec::new();
1369
1370        for _ in 0..num_readers {
1371            let store = Arc::new(
1372                AimdThrottledStore::new(
1373                    mock.clone() as Arc<dyn ObjectStore>,
1374                    throttle_config.clone(),
1375                )
1376                .unwrap(),
1377            );
1378            let p = path.clone();
1379            handles.push(tokio::spawn(async move {
1380                let deadline = std::time::Instant::now() + test_duration;
1381                let mut count = 0u64;
1382                while std::time::Instant::now() < deadline {
1383                    let _ = store.head(&p).await;
1384                    count += 1;
1385                }
1386                count
1387            }));
1388        }
1389
1390        let mut total_reader_requests = 0u64;
1391        for handle in handles {
1392            total_reader_requests += handle.await.unwrap();
1393        }
1394
1395        let successes = mock.success_count.load(Ordering::Relaxed);
1396        let throttled = mock.throttle_count.load(Ordering::Relaxed);
1397        let total_mock = successes + throttled;
1398
1399        // Mock-side count >= reader-side count because the AIMD layer retries
1400        // throttle errors internally, causing multiple mock calls per reader call.
1401        assert!(
1402            total_mock >= total_reader_requests,
1403            "Mock-side count ({total_mock}) should be >= reader-side count ({total_reader_requests})"
1404        );
1405
1406        // Mock capacity is 30/100ms = 300 req/s. Over 2s the theoretical max is
1407        // ~600 successful requests. With AIMD ramp-up, expect somewhat fewer.
1408        assert!(
1409            successes >= 300,
1410            "Expected >= 300 successful requests over 2s, got {successes}"
1411        );
1412        assert!(
1413            successes <= 900,
1414            "Expected <= 900 successful requests, got {successes}"
1415        );
1416
1417        // The initial burst exceeds mock capacity, so throttling must occur.
1418        assert!(throttled > 0, "Expected some throttled requests but got 0");
1419
1420        // Without AIMD, raw tokio tasks against InMemory would fire 100k+ req/s.
1421        // AIMD should keep the total well under 5000 over 2s.
1422        assert!(
1423            total_mock <= 5000,
1424            "AIMD should limit total requests, got {total_mock}"
1425        );
1426    }
1427
1428    /// A mock store that returns a configurable number of throttle errors
1429    /// before succeeding on `get` operations. Used to test the retry logic
1430    /// inside `OperationThrottle::throttled()`.
1431    struct RetryTestMockStore {
1432        inner: InMemory,
1433        /// Number of throttle errors remaining before success.
1434        errors_remaining: std::sync::Mutex<usize>,
1435        /// Total number of `get` calls observed.
1436        get_call_count: AtomicU64,
1437    }
1438
1439    impl RetryTestMockStore {
1440        fn new(errors_before_success: usize) -> Self {
1441            Self {
1442                inner: InMemory::new(),
1443                errors_remaining: std::sync::Mutex::new(errors_before_success),
1444                get_call_count: AtomicU64::new(0),
1445            }
1446        }
1447    }
1448
1449    impl Display for RetryTestMockStore {
1450        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1451            write!(f, "RetryTestMockStore")
1452        }
1453    }
1454
1455    impl Debug for RetryTestMockStore {
1456        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1457            f.debug_struct("RetryTestMockStore").finish()
1458        }
1459    }
1460
1461    #[async_trait]
1462    impl ObjectStore for RetryTestMockStore {
1463        async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
1464            self.inner.put(location, bytes).await
1465        }
1466        async fn put_opts(
1467            &self,
1468            location: &Path,
1469            bytes: PutPayload,
1470            opts: PutOptions,
1471        ) -> OSResult<PutResult> {
1472            self.inner.put_opts(location, bytes, opts).await
1473        }
1474        async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
1475            self.inner.put_multipart(location).await
1476        }
1477        async fn put_multipart_opts(
1478            &self,
1479            location: &Path,
1480            opts: PutMultipartOptions,
1481        ) -> OSResult<Box<dyn MultipartUpload>> {
1482            self.inner.put_multipart_opts(location, opts).await
1483        }
1484        async fn get(&self, location: &Path) -> OSResult<GetResult> {
1485            self.get_call_count.fetch_add(1, Ordering::Relaxed);
1486            let should_error = {
1487                let mut remaining = self.errors_remaining.lock().unwrap();
1488                if *remaining > 0 {
1489                    *remaining -= 1;
1490                    true
1491                } else {
1492                    false
1493                }
1494            };
1495            if should_error {
1496                Err(object_store::Error::Generic {
1497                    store: "RetryTestMock",
1498                    source: "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s"
1499                        .into(),
1500                })
1501            } else {
1502                self.inner.get(location).await
1503            }
1504        }
1505        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1506            self.inner.get_opts(location, options).await
1507        }
1508        async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
1509            self.inner.get_range(location, range).await
1510        }
1511        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1512            self.inner.get_ranges(location, ranges).await
1513        }
1514        async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
1515            self.inner.head(location).await
1516        }
1517        async fn delete(&self, location: &Path) -> OSResult<()> {
1518            self.inner.delete(location).await
1519        }
1520        fn delete_stream<'a>(
1521            &'a self,
1522            locations: BoxStream<'a, OSResult<Path>>,
1523        ) -> BoxStream<'a, OSResult<Path>> {
1524            self.inner.delete_stream(locations)
1525        }
1526        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1527            self.inner.list(prefix)
1528        }
1529        fn list_with_offset(
1530            &self,
1531            prefix: Option<&Path>,
1532            offset: &Path,
1533        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1534            self.inner.list_with_offset(prefix, offset)
1535        }
1536        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1537            self.inner.list_with_delimiter(prefix).await
1538        }
1539        async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
1540            self.inner.copy(from, to).await
1541        }
1542        async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
1543            self.inner.rename(from, to).await
1544        }
1545        async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1546            self.inner.rename_if_not_exists(from, to).await
1547        }
1548        async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1549            self.inner.copy_if_not_exists(from, to).await
1550        }
1551    }
1552
1553    #[tokio::test]
1554    async fn test_throttled_retries_on_throttle_error_then_succeeds() {
1555        // Mock returns 2 throttle errors then succeeds (within MAX_RETRIES=3)
1556        let mock = Arc::new(RetryTestMockStore::new(2));
1557        let path = Path::from("test/retry.txt");
1558        mock.put(&path, PutPayload::from_static(b"retry data"))
1559            .await
1560            .unwrap();
1561
1562        let config = AimdThrottleConfig::default();
1563        let throttled =
1564            AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1565
1566        let result = throttled.get(&path).await;
1567        assert!(result.is_ok(), "Expected success after retries");
1568
1569        let bytes = result.unwrap().bytes().await.unwrap();
1570        assert_eq!(bytes.as_ref(), b"retry data");
1571
1572        // Should have called get 3 times total: 2 failures + 1 success
1573        assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 3);
1574    }
1575
1576    #[tokio::test]
1577    async fn test_throttled_fails_after_max_retries_exceeded() {
1578        // Mock returns 4 throttle errors (more than MAX_RETRIES=3),
1579        // so all 4 attempts (initial + 3 retries) will fail.
1580        let mock = Arc::new(RetryTestMockStore::new(10));
1581        let path = Path::from("test/fail.txt");
1582        mock.put(&path, PutPayload::from_static(b"fail data"))
1583            .await
1584            .unwrap();
1585
1586        let config = AimdThrottleConfig::default();
1587        let throttled =
1588            AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1589
1590        let result = throttled.get(&path).await;
1591        assert!(result.is_err(), "Expected error after max retries");
1592        assert!(is_throttle_error(&result.unwrap_err()));
1593
1594        // Should have called get 4 times: initial attempt + 3 retries
1595        assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4);
1596    }
1597
1598    #[tokio::test]
1599    async fn test_throttled_multipart_reorders_parts() {
1600        let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
1601        let config = AimdThrottleConfig::default();
1602        let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
1603
1604        let path = Path::from("test/multipart_ordering.bin");
1605        let mut upload = throttled.put_multipart(&path).await.unwrap();
1606
1607        // Create futures for two parts in order: A then B.
1608        let fut_a = upload.put_part(PutPayload::from_static(b"AAAA"));
1609        let fut_b = upload.put_part(PutPayload::from_static(b"BBBB"));
1610
1611        // Await in REVERSE order. Part ordering should be determined by
1612        // creation order (put_part call order), not by await order.
1613        fut_b.await.unwrap();
1614        fut_a.await.unwrap();
1615
1616        upload.complete().await.unwrap();
1617
1618        let result = store.get(&path).await.unwrap();
1619        let bytes = result.bytes().await.unwrap();
1620
1621        assert_eq!(
1622            bytes.as_ref(),
1623            b"AAAABBBB",
1624            "Parts were reordered! Got {:?} instead of AAAABBBB.",
1625            std::str::from_utf8(&bytes).unwrap_or("<non-utf8>"),
1626        );
1627    }
1628}