Skip to main content

lance_io/object_store/
throttle.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! AIMD-controlled token bucket rate limiter for ObjectStore operations.
5//!
6//! Wraps any [`object_store::ObjectStore`] with per-category token buckets
7//! whose fill rates are dynamically adjusted by AIMD controllers. When cloud
8//! stores return HTTP 429/503, the fill rate decreases multiplicatively. During
9//! sustained success windows, it increases additively.
10//!
11//! Operations are split into four independent categories — **read**, **write**,
12//! **delete**, **list** — each with its own AIMD controller and token bucket.
13//! This prevents a burst of reads from starving writes, and vice versa.
14//!
15//! # Example
16//!
17//! ```ignore
18//! use lance_io::object_store::throttle::{AimdThrottleConfig, AimdThrottledStore};
19//!
20//! let throttled = AimdThrottledStore::new(target, AimdThrottleConfig::default()).unwrap();
21//! ```
22
23use std::collections::HashMap;
24use std::fmt::{Debug, Display, Formatter};
25use std::ops::Range;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use bytes::Bytes;
30use futures::StreamExt;
31use futures::stream::BoxStream;
32use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome};
33use lance_core::utils::tracing::TRACE_OBJECT_STORE_THROTTLE;
34#[cfg(test)]
35use object_store::ObjectStoreExt;
36use object_store::path::Path;
37use object_store::{
38    CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
39    PutMultipartOptions, PutOptions, PutPayload, PutResult, RenameOptions, Result as OSResult,
40    UploadPart,
41};
42use rand::Rng;
43use tokio::sync::Mutex;
44use tracing::{debug, warn};
45
46/// Check whether an `object_store::Error` represents a throttle response
47/// (HTTP 429 / 503) from a cloud object store.
48///
49/// Regrettably, this information is not fully exposed by the `object_store` crate.
50/// There is no generic mechanism for a custom object store to return a throttle error.
51///
52/// However, the builtin object stores all use RetryError when retries are configured and
53/// throttle errors are returned.  Sadly, RetryError is not a public type, so we have to
54/// infer it from the error message.  This is potentially dangerous because these errors
55/// often include the URI itself and that URI could have any characters in it (e.g. if we
56/// look for 429 then we might match a 429 in a UUID).These error messages currently look like:
57///
58/// ", after ... retries, max_retries: ..., retry_timeout: ..."
59///
60/// So, as a crude heuristic, which should work for the builtin object stores, but won't
61/// work for custom object stores, we simply look for the string "retries, max_retries"
62/// in the error message.
63pub fn is_throttle_error(err: &object_store::Error) -> bool {
64    // Only Generic errors can carry throttle responses
65    if let object_store::Error::Generic { source, .. } = err {
66        let message = source.to_string();
67        let lowercase = message.to_ascii_lowercase();
68        lowercase.contains("retries, max_retries")
69            || lowercase.contains("serverbusy")
70            || lowercase.contains("server busy")
71            || lowercase.contains("egress is over the account limit")
72            || lowercase.contains("http 429")
73            || lowercase.contains("status code: 429")
74            || lowercase.contains("429 too many requests")
75            || lowercase.contains("too many requests")
76            || lowercase.contains("slowdown")
77            || lowercase.contains("please reduce your request rate")
78            || lowercase.contains("rate limit")
79            || lowercase.contains("throttling")
80            || lowercase.contains("throttled")
81    } else {
82        false
83    }
84}
85
86/// Configuration for the AIMD-throttled ObjectStore wrapper.
87///
88/// Each operation category (read, write, delete, list) has its own AIMD config.
89/// Use [`with_aimd`](AimdThrottleConfig::with_aimd) to set all categories at
90/// once, or per-category methods like [`with_read_aimd`](AimdThrottleConfig::with_read_aimd)
91/// for fine-grained control.
92#[derive(Debug, Clone)]
93pub struct AimdThrottleConfig {
94    /// AIMD configuration for read operations (get, get_opts, get_range, get_ranges, head).
95    pub read: AimdConfig,
96    /// AIMD configuration for write operations (put, put_opts, put_multipart, copy, rename, etc.).
97    pub write: AimdConfig,
98    /// AIMD configuration for delete operations.
99    pub delete: AimdConfig,
100    /// AIMD configuration for list operations.
101    pub list: AimdConfig,
102    /// Maximum tokens that can accumulate for bursts (shared across all categories).
103    pub burst_capacity: u32,
104    /// Maximum number of retries for throttle errors within the AIMD layer.
105    pub max_retries: usize,
106    /// Minimum backoff in milliseconds between retry attempts.
107    pub min_backoff_ms: u64,
108    /// Maximum backoff in milliseconds between retry attempts.
109    pub max_backoff_ms: u64,
110}
111
112impl Default for AimdThrottleConfig {
113    fn default() -> Self {
114        let aimd = AimdConfig::default();
115        Self {
116            read: aimd.clone(),
117            write: aimd.clone(),
118            delete: aimd.clone(),
119            list: aimd,
120            burst_capacity: 100,
121            max_retries: 3,
122            min_backoff_ms: 100,
123            max_backoff_ms: 300,
124        }
125    }
126}
127
128impl AimdThrottleConfig {
129    /// Set the AIMD configuration for all four operation categories at once.
130    pub fn with_aimd(self, aimd: AimdConfig) -> Self {
131        Self {
132            read: aimd.clone(),
133            write: aimd.clone(),
134            delete: aimd.clone(),
135            list: aimd,
136            ..self
137        }
138    }
139
140    /// Set the AIMD configuration for read operations.
141    pub fn with_read_aimd(self, aimd: AimdConfig) -> Self {
142        Self { read: aimd, ..self }
143    }
144
145    /// Set the AIMD configuration for write operations.
146    pub fn with_write_aimd(self, aimd: AimdConfig) -> Self {
147        Self {
148            write: aimd,
149            ..self
150        }
151    }
152
153    /// Set the AIMD configuration for delete operations.
154    pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self {
155        Self {
156            delete: aimd,
157            ..self
158        }
159    }
160
161    /// Set the AIMD configuration for list operations.
162    pub fn with_list_aimd(self, aimd: AimdConfig) -> Self {
163        Self { list: aimd, ..self }
164    }
165
166    /// Returns `true` when the AIMD throttle layer should be bypassed entirely.
167    pub fn is_disabled(&self) -> bool {
168        self.max_retries == 0
169    }
170
171    pub fn with_burst_capacity(self, burst_capacity: u32) -> Self {
172        Self {
173            burst_capacity,
174            ..self
175        }
176    }
177
178    /// Build an `AimdThrottleConfig` from storage options and environment variables.
179    ///
180    /// Storage options take precedence over environment variables, which take
181    /// precedence over defaults. A single AIMD config is applied to all four
182    /// operation categories (read/write/delete/list).
183    ///
184    /// | Setting              | Storage Option Key               | Env Var                          | Default |
185    /// |----------------------|----------------------------------|----------------------------------|---------|
186    /// | Initial rate         | `lance_aimd_initial_rate`        | `LANCE_AIMD_INITIAL_RATE`        | 2000    |
187    /// | Min rate             | `lance_aimd_min_rate`            | `LANCE_AIMD_MIN_RATE`            | 1       |
188    /// | Max rate             | `lance_aimd_max_rate`            | `LANCE_AIMD_MAX_RATE`            | 5000    |
189    /// | Decrease factor      | `lance_aimd_decrease_factor`     | `LANCE_AIMD_DECREASE_FACTOR`     | 0.5     |
190    /// | Additive increment   | `lance_aimd_additive_increment`  | `LANCE_AIMD_ADDITIVE_INCREMENT`  | 300     |
191    /// | Burst capacity       | `lance_aimd_burst_capacity`      | `LANCE_AIMD_BURST_CAPACITY`      | 100     |
192    /// | Max retries          | `lance_aimd_max_retries`         | `LANCE_AIMD_MAX_RETRIES`         | 3       |
193    /// | Min backoff ms       | `lance_aimd_min_backoff_ms`      | `LANCE_AIMD_MIN_BACKOFF_MS`      | 100     |
194    /// | Max backoff ms       | `lance_aimd_max_backoff_ms`      | `LANCE_AIMD_MAX_BACKOFF_MS`      | 300     |
195    pub fn from_storage_options(
196        storage_options: Option<&HashMap<String, String>>,
197    ) -> lance_core::Result<Self> {
198        fn resolve_f64(
199            key: &str,
200            storage_options: Option<&HashMap<String, String>>,
201            default: f64,
202        ) -> lance_core::Result<f64> {
203            let env_key = key.to_ascii_uppercase();
204            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
205                val.parse::<f64>().map_err(|_| {
206                    lance_core::Error::invalid_input(format!(
207                        "Invalid value for storage option '{key}': '{val}'"
208                    ))
209                })
210            } else if let Ok(val) = std::env::var(&env_key) {
211                val.parse::<f64>().map_err(|_| {
212                    lance_core::Error::invalid_input(format!(
213                        "Invalid value for env var '{env_key}': '{val}'"
214                    ))
215                })
216            } else {
217                Ok(default)
218            }
219        }
220
221        fn resolve_u32(
222            key: &str,
223            storage_options: Option<&HashMap<String, String>>,
224            default: u32,
225        ) -> lance_core::Result<u32> {
226            let env_key = key.to_ascii_uppercase();
227            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
228                val.parse::<u32>().map_err(|_| {
229                    lance_core::Error::invalid_input(format!(
230                        "Invalid value for storage option '{key}': '{val}'"
231                    ))
232                })
233            } else if let Ok(val) = std::env::var(&env_key) {
234                val.parse::<u32>().map_err(|_| {
235                    lance_core::Error::invalid_input(format!(
236                        "Invalid value for env var '{env_key}': '{val}'"
237                    ))
238                })
239            } else {
240                Ok(default)
241            }
242        }
243
244        fn resolve_usize(
245            key: &str,
246            storage_options: Option<&HashMap<String, String>>,
247            default: usize,
248        ) -> lance_core::Result<usize> {
249            let env_key = key.to_ascii_uppercase();
250            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
251                val.parse::<usize>().map_err(|_| {
252                    lance_core::Error::invalid_input(format!(
253                        "Invalid value for storage option '{key}': '{val}'"
254                    ))
255                })
256            } else if let Ok(val) = std::env::var(&env_key) {
257                val.parse::<usize>().map_err(|_| {
258                    lance_core::Error::invalid_input(format!(
259                        "Invalid value for env var '{env_key}': '{val}'"
260                    ))
261                })
262            } else {
263                Ok(default)
264            }
265        }
266
267        fn resolve_u64(
268            key: &str,
269            storage_options: Option<&HashMap<String, String>>,
270            default: u64,
271        ) -> lance_core::Result<u64> {
272            let env_key = key.to_ascii_uppercase();
273            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
274                val.parse::<u64>().map_err(|_| {
275                    lance_core::Error::invalid_input(format!(
276                        "Invalid value for storage option '{key}': '{val}'"
277                    ))
278                })
279            } else if let Ok(val) = std::env::var(&env_key) {
280                val.parse::<u64>().map_err(|_| {
281                    lance_core::Error::invalid_input(format!(
282                        "Invalid value for env var '{env_key}': '{val}'"
283                    ))
284                })
285            } else {
286                Ok(default)
287            }
288        }
289
290        let initial_rate = resolve_f64("lance_aimd_initial_rate", storage_options, 2000.0)?;
291        let min_rate = resolve_f64("lance_aimd_min_rate", storage_options, 1.0)?;
292        let max_rate = resolve_f64("lance_aimd_max_rate", storage_options, 5000.0)?;
293        let decrease_factor = resolve_f64("lance_aimd_decrease_factor", storage_options, 0.5)?;
294        let additive_increment =
295            resolve_f64("lance_aimd_additive_increment", storage_options, 300.0)?;
296        let burst_capacity = resolve_u32("lance_aimd_burst_capacity", storage_options, 100)?;
297        let max_retries = resolve_usize("lance_aimd_max_retries", storage_options, 3)?;
298        let min_backoff_ms = resolve_u64("lance_aimd_min_backoff_ms", storage_options, 100)?;
299        let max_backoff_ms = resolve_u64("lance_aimd_max_backoff_ms", storage_options, 300)?;
300
301        let aimd = AimdConfig::default()
302            .with_initial_rate(initial_rate)
303            .with_min_rate(min_rate)
304            .with_max_rate(max_rate)
305            .with_decrease_factor(decrease_factor)
306            .with_additive_increment(additive_increment);
307
308        Ok(Self {
309            max_retries,
310            min_backoff_ms,
311            max_backoff_ms,
312            ..Self::default()
313                .with_aimd(aimd)
314                .with_burst_capacity(burst_capacity)
315        })
316    }
317}
318
319struct TokenBucketState {
320    tokens: f64,
321    last_refill: tokio::time::Instant,
322    rate: f64,
323}
324
325/// Per-category throttle state: an AIMD controller paired with a token bucket.
326struct OperationThrottle {
327    controller: AimdController,
328    bucket: Mutex<TokenBucketState>,
329    burst_capacity: f64,
330    max_retries: usize,
331    min_backoff_ms: u64,
332    max_backoff_ms: u64,
333}
334
335impl OperationThrottle {
336    fn new(
337        aimd_config: AimdConfig,
338        burst_capacity: f64,
339        max_retries: usize,
340        min_backoff_ms: u64,
341        max_backoff_ms: u64,
342    ) -> lance_core::Result<Self> {
343        let initial_rate = aimd_config.initial_rate;
344        let controller = AimdController::new(aimd_config)?;
345        Ok(Self {
346            controller,
347            bucket: Mutex::new(TokenBucketState {
348                tokens: burst_capacity,
349                last_refill: tokio::time::Instant::now(),
350                rate: initial_rate,
351            }),
352            burst_capacity,
353            max_retries,
354            min_backoff_ms,
355            max_backoff_ms,
356        })
357    }
358
359    /// Acquire a token from the bucket, sleeping if none are available.
360    ///
361    /// Each caller reserves a token immediately (allowing `tokens` to go
362    /// negative) so that concurrent waiters queue behind each other instead
363    /// of all waking at the same instant (thundering herd).
364    async fn acquire_token(&self) {
365        let sleep_duration = {
366            let mut bucket = self.bucket.lock().await;
367            let now = tokio::time::Instant::now();
368            let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
369            bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity);
370            bucket.last_refill = now;
371
372            // Reserve a token (may go negative to queue behind other waiters)
373            bucket.tokens -= 1.0;
374
375            if bucket.tokens >= 0.0 {
376                // Had a token available, no need to sleep
377                return;
378            }
379
380            // Sleep proportional to our position in the queue
381            std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate)
382        };
383
384        tokio::time::sleep(sleep_duration).await;
385    }
386
387    /// Update the bucket's fill rate from the controller.
388    async fn update_bucket_rate(&self, new_rate: f64) {
389        let mut bucket = self.bucket.lock().await;
390        bucket.rate = new_rate;
391    }
392
393    /// Classify a result and feed it back to the AIMD controller without
394    /// acquiring a token. Uses `try_lock` for the bucket update so that if the
395    /// bucket lock is contended the rate update is deferred to the next
396    /// `throttled()` call.
397    fn observe_outcome<T>(&self, result: &OSResult<T>) {
398        let outcome = match result {
399            Ok(_) => RequestOutcome::Success,
400            Err(err) if is_throttle_error(err) => {
401                debug!(
402                    target: TRACE_OBJECT_STORE_THROTTLE,
403                    error = %err,
404                    "Throttle error detected in stream"
405                );
406                RequestOutcome::Throttled
407            }
408            Err(_) => RequestOutcome::Success,
409        };
410        let prev_rate = self.controller.current_rate();
411        let new_rate = self.controller.record_outcome(outcome);
412        if new_rate < prev_rate
413            && let Err(err) = result.as_ref()
414        {
415            warn!(
416                target: TRACE_OBJECT_STORE_THROTTLE,
417                previous_rate = format!("{prev_rate:.1}"),
418                new_rate = format!("{new_rate:.1}"),
419                error = %err,
420                "AIMD throttle: rate reduced due to throttle errors"
421            );
422        }
423        if let Ok(mut bucket) = self.bucket.try_lock() {
424            bucket.rate = new_rate;
425        }
426    }
427
428    /// Execute an operation with throttling: acquire token, run, classify result.
429    /// On throttle errors, retries up to `max_retries` times with a random
430    /// backoff between `min_backoff_ms` and `max_backoff_ms` between attempts.
431    async fn throttled<T, F, Fut>(&self, f: F) -> OSResult<T>
432    where
433        F: Fn() -> Fut,
434        Fut: std::future::Future<Output = OSResult<T>>,
435    {
436        for attempt in 0..=self.max_retries {
437            self.acquire_token().await;
438            let result = f().await;
439            let outcome = match &result {
440                Ok(_) => RequestOutcome::Success,
441                Err(err) if is_throttle_error(err) => {
442                    debug!(
443                        target: TRACE_OBJECT_STORE_THROTTLE,
444                        error = %err,
445                        "Throttle error detected"
446                    );
447                    RequestOutcome::Throttled
448                }
449                Err(_) => RequestOutcome::Success, // Non-throttle errors don't indicate capacity problems
450            };
451            let prev_rate = self.controller.current_rate();
452            let new_rate = self.controller.record_outcome(outcome);
453            if new_rate < prev_rate
454                && let Err(err) = result.as_ref()
455            {
456                warn!(
457                    target: TRACE_OBJECT_STORE_THROTTLE,
458                    previous_rate = format!("{prev_rate:.1}"),
459                    new_rate = format!("{new_rate:.1}"),
460                    error = %err,
461                    "AIMD throttle: rate reduced due to throttle errors"
462                );
463            }
464            self.update_bucket_rate(new_rate).await;
465
466            match &result {
467                Err(err) if is_throttle_error(err) && attempt < self.max_retries => {
468                    let backoff_ms =
469                        rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms);
470                    debug!(
471                        target: TRACE_OBJECT_STORE_THROTTLE,
472                        attempt = attempt + 1,
473                        max_retries = self.max_retries,
474                        backoff_ms,
475                        error = %err,
476                        "Retrying after throttle error"
477                    );
478                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
479                    continue;
480                }
481                _ => return result,
482            }
483        }
484        unreachable!()
485    }
486}
487
488impl Debug for OperationThrottle {
489    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
490        f.debug_struct("OperationThrottle")
491            .field("controller", &self.controller)
492            .field("burst_capacity", &self.burst_capacity)
493            .finish()
494    }
495}
496
497/// A [`MultipartUpload`] wrapper that throttles and retries `put_part`,
498/// `complete`, and `abort`, feeding outcomes back to the write AIMD
499/// controller.
500struct ThrottledMultipartUpload {
501    target: Box<dyn MultipartUpload>,
502    write: Arc<OperationThrottle>,
503}
504
505impl Debug for ThrottledMultipartUpload {
506    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
507        f.debug_struct("ThrottledMultipartUpload").finish()
508    }
509}
510
511#[async_trait]
512impl MultipartUpload for ThrottledMultipartUpload {
513    fn put_part(&mut self, data: PutPayload) -> UploadPart {
514        let write = Arc::clone(&self.write);
515        // Call put_part synchronously to preserve part ordering regardless
516        // of which futures are awaited first.
517        let fut = self.target.put_part(data);
518        Box::pin(async move {
519            write.acquire_token().await;
520            let result = fut.await;
521            write.observe_outcome(&result);
522            result
523        })
524    }
525
526    async fn complete(&mut self) -> OSResult<PutResult> {
527        let target = &mut self.target;
528        for attempt in 0..=self.write.max_retries {
529            self.write.acquire_token().await;
530            let result = target.complete().await;
531            self.write.observe_outcome(&result);
532
533            match &result {
534                Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
535                    let backoff_ms = rand::rng()
536                        .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
537                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
538                    continue;
539                }
540                _ => return result,
541            }
542        }
543        unreachable!()
544    }
545
546    async fn abort(&mut self) -> OSResult<()> {
547        let target = &mut self.target;
548        for attempt in 0..=self.write.max_retries {
549            self.write.acquire_token().await;
550            let result = target.abort().await;
551            self.write.observe_outcome(&result);
552
553            match &result {
554                Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
555                    let backoff_ms = rand::rng()
556                        .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
557                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
558                    continue;
559                }
560                _ => return result,
561            }
562        }
563        unreachable!()
564    }
565}
566
567/// An ObjectStore wrapper that rate-limits operations using per-category token
568/// buckets whose fill rates are controlled by AIMD algorithms.
569///
570/// Operations are split into four independent categories:
571/// - **read**: `get`, `get_opts`, `get_range`, `get_ranges`, `head`
572/// - **write**: `put`, `put_opts`, `put_multipart`, `put_multipart_opts`, `copy`, `copy_if_not_exists`, `rename`, `rename_if_not_exists`
573/// - **delete**: `delete`
574/// - **list**: `list`, `list_with_offset`, `list_with_delimiter`
575///
576/// Streaming list operations acquire a token before starting the underlying list stream.
577/// Streaming operations also observe each yielded item and feed the result back to the
578/// AIMD controller so it can adjust the rate for other operations in the same category.
579///
580/// This is not perfect but probably as close as we can get without moving the throttle into
581/// the object_store crate itself.
582pub struct AimdThrottledStore {
583    target: Arc<dyn ObjectStore>,
584    read: Arc<OperationThrottle>,
585    write: Arc<OperationThrottle>,
586    delete: Arc<OperationThrottle>,
587    list: Arc<OperationThrottle>,
588}
589
590impl Debug for AimdThrottledStore {
591    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
592        f.debug_struct("AimdThrottledStore")
593            .field("target", &self.target)
594            .field("read", &self.read)
595            .field("write", &self.write)
596            .field("delete", &self.delete)
597            .field("list", &self.list)
598            .finish()
599    }
600}
601
602impl Display for AimdThrottledStore {
603    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
604        write!(f, "AimdThrottledStore({})", self.target)
605    }
606}
607
608impl AimdThrottledStore {
609    pub fn new(
610        target: Arc<dyn ObjectStore>,
611        config: AimdThrottleConfig,
612    ) -> lance_core::Result<Self> {
613        let burst = config.burst_capacity as f64;
614        let max_retries = config.max_retries;
615        let min_backoff_ms = config.min_backoff_ms;
616        let max_backoff_ms = config.max_backoff_ms;
617        Ok(Self {
618            target,
619            read: Arc::new(OperationThrottle::new(
620                config.read,
621                burst,
622                max_retries,
623                min_backoff_ms,
624                max_backoff_ms,
625            )?),
626            write: Arc::new(OperationThrottle::new(
627                config.write,
628                burst,
629                max_retries,
630                min_backoff_ms,
631                max_backoff_ms,
632            )?),
633            delete: Arc::new(OperationThrottle::new(
634                config.delete,
635                burst,
636                max_retries,
637                min_backoff_ms,
638                max_backoff_ms,
639            )?),
640            list: Arc::new(OperationThrottle::new(
641                config.list,
642                burst,
643                max_retries,
644                min_backoff_ms,
645                max_backoff_ms,
646            )?),
647        })
648    }
649}
650
651#[async_trait]
652#[deny(clippy::missing_trait_methods)]
653impl ObjectStore for AimdThrottledStore {
654    async fn put_opts(
655        &self,
656        location: &Path,
657        bytes: PutPayload,
658        opts: PutOptions,
659    ) -> OSResult<PutResult> {
660        self.write
661            .throttled(|| self.target.put_opts(location, bytes.clone(), opts.clone()))
662            .await
663    }
664
665    async fn put_multipart_opts(
666        &self,
667        location: &Path,
668        opts: PutMultipartOptions,
669    ) -> OSResult<Box<dyn MultipartUpload>> {
670        let target = self
671            .write
672            .throttled(|| self.target.put_multipart_opts(location, opts.clone()))
673            .await?;
674        Ok(Box::new(ThrottledMultipartUpload {
675            target,
676            write: Arc::clone(&self.write),
677        }))
678    }
679
680    async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
681        self.read
682            .throttled(|| self.target.get_opts(location, options.clone()))
683            .await
684    }
685
686    async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
687        self.read
688            .throttled(|| self.target.get_ranges(location, ranges))
689            .await
690    }
691
692    fn delete_stream(
693        &self,
694        locations: BoxStream<'static, OSResult<Path>>,
695    ) -> BoxStream<'static, OSResult<Path>> {
696        let delete = Arc::clone(&self.delete);
697        self.target
698            .delete_stream(locations)
699            .map(move |item| {
700                delete.observe_outcome(&item);
701                item
702            })
703            .boxed()
704    }
705
706    fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
707        let throttle = Arc::clone(&self.list);
708        let throttle_for_start = Arc::clone(&throttle);
709        let target = Arc::clone(&self.target);
710        let prefix = prefix.cloned();
711        futures::stream::once(async move {
712            throttle_for_start.acquire_token().await;
713            target.list(prefix.as_ref())
714        })
715        .flatten()
716        .map(move |item| {
717            throttle.observe_outcome(&item);
718            item
719        })
720        .boxed()
721    }
722
723    fn list_with_offset(
724        &self,
725        prefix: Option<&Path>,
726        offset: &Path,
727    ) -> BoxStream<'static, OSResult<ObjectMeta>> {
728        let throttle = Arc::clone(&self.list);
729        let throttle_for_start = Arc::clone(&throttle);
730        let target = Arc::clone(&self.target);
731        let prefix = prefix.cloned();
732        let offset = offset.clone();
733        futures::stream::once(async move {
734            throttle_for_start.acquire_token().await;
735            target.list_with_offset(prefix.as_ref(), &offset)
736        })
737        .flatten()
738        .map(move |item| {
739            throttle.observe_outcome(&item);
740            item
741        })
742        .boxed()
743    }
744
745    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
746        self.list
747            .throttled(|| self.target.list_with_delimiter(prefix))
748            .await
749    }
750
751    async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
752        self.write
753            .throttled(|| self.target.copy_opts(from, to, opts.clone()))
754            .await
755    }
756
757    async fn rename_opts(&self, from: &Path, to: &Path, opts: RenameOptions) -> OSResult<()> {
758        self.write
759            .throttled(|| self.target.rename_opts(from, to, opts.clone()))
760            .await
761    }
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use object_store::memory::InMemory;
768    use rstest::rstest;
769    use std::collections::VecDeque;
770    use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
771
772    const THROTTLE_ERROR_RESPONSE: &str = "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s - Server returned non-2xx status code: 503: x-ms-request-id: azure-request-id";
773
774    fn make_generic_error(msg: &str) -> object_store::Error {
775        object_store::Error::Generic {
776            store: "test",
777            source: msg.into(),
778        }
779    }
780
781    #[rstest]
782    #[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)]
783    #[case::retries_in_message(
784        "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s",
785        true
786    )]
787    #[case::not_found("Object not found", false)]
788    #[case::permission_denied("Access denied", false)]
789    #[case::timeout("Connection timed out", false)]
790    #[case::http_429_without_retries("HTTP 429 Too Many Requests", true)]
791    #[case::slowdown_without_retries("SlowDown: Please reduce your request rate", true)]
792    #[case::azure_server_busy("Code: ServerBusy", true)]
793    #[case::azure_egress_limit("Message: Egress is over the account limit", true)]
794    fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) {
795        let err = make_generic_error(msg);
796        assert_eq!(
797            is_throttle_error(&err),
798            expected,
799            "is_throttle_error for '{}' should be {}",
800            msg,
801            expected
802        );
803    }
804
805    #[test]
806    fn test_non_generic_errors_are_not_throttle() {
807        let err = object_store::Error::NotFound {
808            path: "test".to_string(),
809            source: "not found".into(),
810        };
811        assert!(!is_throttle_error(&err));
812    }
813
814    #[tokio::test]
815    async fn test_basic_put_get_through_wrapper() {
816        let store = Arc::new(InMemory::new());
817        let config = AimdThrottleConfig::default();
818        let throttled = AimdThrottledStore::new(store, config).unwrap();
819
820        let path = Path::from("test/file.txt");
821        let data = PutPayload::from_static(b"hello world");
822        throttled.put(&path, data).await.unwrap();
823
824        let result = throttled.get(&path).await.unwrap();
825        let bytes = result.bytes().await.unwrap();
826        assert_eq!(bytes.as_ref(), b"hello world");
827    }
828
829    #[tokio::test]
830    async fn test_rate_decreases_on_throttle() {
831        let store = Arc::new(InMemory::new());
832        let config = AimdThrottleConfig::default().with_aimd(
833            AimdConfig::default()
834                .with_initial_rate(100.0)
835                .with_decrease_factor(0.5)
836                .with_window_duration(std::time::Duration::from_millis(10)),
837        );
838        let throttled = AimdThrottledStore::new(store, config).unwrap();
839
840        let initial_rate = throttled.read.controller.current_rate();
841        assert_eq!(initial_rate, 100.0);
842
843        // Simulate a throttle outcome directly
844        throttled
845            .read
846            .controller
847            .record_outcome(RequestOutcome::Throttled);
848
849        // Wait for window to expire and trigger evaluation
850        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
851        throttled
852            .read
853            .controller
854            .record_outcome(RequestOutcome::Success);
855
856        let new_rate = throttled.read.controller.current_rate();
857        assert!(
858            new_rate < initial_rate,
859            "Rate should decrease after throttle: {} < {}",
860            new_rate,
861            initial_rate
862        );
863    }
864
865    #[tokio::test]
866    async fn test_rate_recovers_on_success() {
867        let store = Arc::new(InMemory::new());
868        let config = AimdThrottleConfig::default().with_aimd(
869            AimdConfig::default()
870                .with_initial_rate(100.0)
871                .with_decrease_factor(0.5)
872                .with_additive_increment(10.0)
873                .with_window_duration(std::time::Duration::from_millis(10)),
874        );
875        let throttled = AimdThrottledStore::new(store, config).unwrap();
876
877        // First decrease via throttle
878        throttled
879            .read
880            .controller
881            .record_outcome(RequestOutcome::Throttled);
882        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
883        throttled
884            .read
885            .controller
886            .record_outcome(RequestOutcome::Success);
887        let decreased_rate = throttled.read.controller.current_rate();
888        assert_eq!(decreased_rate, 50.0);
889
890        // Now recover via success
891        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
892        throttled
893            .read
894            .controller
895            .record_outcome(RequestOutcome::Success);
896        let recovered_rate = throttled.read.controller.current_rate();
897        assert_eq!(recovered_rate, 60.0);
898    }
899
900    #[tokio::test]
901    async fn test_as_dyn_object_store() {
902        let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
903        let throttled: Arc<dyn ObjectStore> =
904            Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap());
905
906        let path = Path::from("test/data.bin");
907        let data = PutPayload::from_static(b"test data");
908        throttled.put(&path, data).await.unwrap();
909
910        let result = throttled.get(&path).await.unwrap();
911        let bytes = result.bytes().await.unwrap();
912        assert_eq!(bytes.as_ref(), b"test data");
913    }
914
915    #[tokio::test]
916    async fn test_token_bucket_delays_when_exhausted() {
917        let store = Arc::new(InMemory::new());
918        // Very low rate and burst capacity to force waiting
919        let config = AimdThrottleConfig::default()
920            .with_burst_capacity(1)
921            .with_aimd(AimdConfig::default().with_initial_rate(10.0));
922        let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap());
923
924        let path = Path::from("test/file.txt");
925        let data = PutPayload::from_static(b"data");
926        throttled.put(&path, data).await.unwrap();
927
928        // After consuming the burst token, the next request should take ~100ms
929        // (1 token / 10 tokens-per-sec). We verify it takes at least 50ms.
930        let start = std::time::Instant::now();
931        let data2 = PutPayload::from_static(b"data2");
932        throttled.put(&path, data2).await.unwrap();
933        let elapsed = start.elapsed();
934
935        assert!(
936            elapsed >= std::time::Duration::from_millis(50),
937            "Expected delay for token refill, but elapsed was {:?}",
938            elapsed
939        );
940    }
941
942    #[tokio::test]
943    async fn test_list_observes_outcomes() {
944        let store = Arc::new(InMemory::new());
945        let config = AimdThrottleConfig::default();
946        let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
947
948        let path = Path::from("prefix/file.txt");
949        let data = PutPayload::from_static(b"data");
950        store.put(&path, data).await.unwrap();
951
952        let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
953        assert_eq!(items.len(), 1);
954        assert!(items[0].is_ok());
955    }
956
957    /// A mock store whose `list` stream yields a configurable sequence of
958    /// Ok / throttle-error items. Used to verify that the AIMD wrapper
959    /// observes errors surfaced inside list streams.
960    struct ThrottlingListMockStore {
961        inner: InMemory,
962        /// Number of throttle errors to inject at the start of each list call.
963        throttle_count: usize,
964    }
965
966    impl Display for ThrottlingListMockStore {
967        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
968            write!(f, "ThrottlingListMockStore")
969        }
970    }
971
972    impl Debug for ThrottlingListMockStore {
973        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
974            f.debug_struct("ThrottlingListMockStore").finish()
975        }
976    }
977
978    #[async_trait]
979    impl ObjectStore for ThrottlingListMockStore {
980        async fn put_opts(
981            &self,
982            location: &Path,
983            bytes: PutPayload,
984            opts: PutOptions,
985        ) -> OSResult<PutResult> {
986            self.inner.put_opts(location, bytes, opts).await
987        }
988        async fn put_multipart_opts(
989            &self,
990            location: &Path,
991            opts: PutMultipartOptions,
992        ) -> OSResult<Box<dyn MultipartUpload>> {
993            self.inner.put_multipart_opts(location, opts).await
994        }
995        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
996            self.inner.get_opts(location, options).await
997        }
998        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
999            self.inner.get_ranges(location, ranges).await
1000        }
1001        fn delete_stream(
1002            &self,
1003            locations: BoxStream<'static, OSResult<Path>>,
1004        ) -> BoxStream<'static, OSResult<Path>> {
1005            self.inner.delete_stream(locations)
1006        }
1007        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1008            let n = self.throttle_count;
1009            let inner_stream = self.inner.list(prefix);
1010            let errors = futures::stream::iter((0..n).map(|_| {
1011                Err(object_store::Error::Generic {
1012                    store: "ThrottlingListMock",
1013                    source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s"
1014                        .into(),
1015                })
1016            }));
1017            errors.chain(inner_stream).boxed()
1018        }
1019        fn list_with_offset(
1020            &self,
1021            prefix: Option<&Path>,
1022            offset: &Path,
1023        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1024            self.inner.list_with_offset(prefix, offset)
1025        }
1026        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1027            self.inner.list_with_delimiter(prefix).await
1028        }
1029        async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1030            self.inner.copy_opts(from, to, opts).await
1031        }
1032    }
1033
1034    #[tokio::test]
1035    async fn test_list_stream_throttle_errors_decrease_rate() {
1036        let mock = Arc::new(ThrottlingListMockStore {
1037            inner: InMemory::new(),
1038            throttle_count: 5,
1039        });
1040
1041        // Seed a file so the real items come through after the errors.
1042        mock.put(
1043            &Path::from("prefix/file.txt"),
1044            PutPayload::from_static(b"data"),
1045        )
1046        .await
1047        .unwrap();
1048
1049        let config = AimdThrottleConfig::default().with_list_aimd(
1050            AimdConfig::default()
1051                .with_initial_rate(100.0)
1052                .with_decrease_factor(0.5)
1053                .with_window_duration(std::time::Duration::from_millis(10)),
1054        );
1055        let throttled = AimdThrottledStore::new(mock as Arc<dyn ObjectStore>, config).unwrap();
1056
1057        let initial_rate = throttled.list.controller.current_rate();
1058        assert_eq!(initial_rate, 100.0);
1059
1060        let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
1061
1062        // 5 errors + 1 real item
1063        assert_eq!(items.len(), 6);
1064        assert!(items[0].is_err());
1065        assert!(items[5].is_ok());
1066
1067        // Wait for the AIMD window to expire and trigger evaluation.
1068        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1069        throttled
1070            .list
1071            .controller
1072            .record_outcome(RequestOutcome::Success);
1073
1074        let new_rate = throttled.list.controller.current_rate();
1075        assert!(
1076            new_rate < initial_rate,
1077            "List rate should decrease after stream throttle errors: {} < {}",
1078            new_rate,
1079            initial_rate
1080        );
1081    }
1082
1083    struct CountingListStartStore {
1084        inner: InMemory,
1085        list_calls: AtomicUsize,
1086        offset_calls: AtomicUsize,
1087    }
1088
1089    impl Default for CountingListStartStore {
1090        fn default() -> Self {
1091            Self {
1092                inner: InMemory::new(),
1093                list_calls: AtomicUsize::new(0),
1094                offset_calls: AtomicUsize::new(0),
1095            }
1096        }
1097    }
1098
1099    impl CountingListStartStore {
1100        fn list_calls(&self) -> usize {
1101            self.list_calls.load(Ordering::SeqCst)
1102        }
1103
1104        fn offset_calls(&self) -> usize {
1105            self.offset_calls.load(Ordering::SeqCst)
1106        }
1107    }
1108
1109    impl Display for CountingListStartStore {
1110        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1111            write!(f, "CountingListStartStore")
1112        }
1113    }
1114
1115    impl Debug for CountingListStartStore {
1116        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1117            f.debug_struct("CountingListStartStore").finish()
1118        }
1119    }
1120
1121    #[async_trait]
1122    impl ObjectStore for CountingListStartStore {
1123        async fn put_opts(
1124            &self,
1125            location: &Path,
1126            bytes: PutPayload,
1127            opts: PutOptions,
1128        ) -> OSResult<PutResult> {
1129            self.inner.put_opts(location, bytes, opts).await
1130        }
1131
1132        async fn put_multipart_opts(
1133            &self,
1134            location: &Path,
1135            opts: PutMultipartOptions,
1136        ) -> OSResult<Box<dyn MultipartUpload>> {
1137            self.inner.put_multipart_opts(location, opts).await
1138        }
1139
1140        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1141            self.inner.get_opts(location, options).await
1142        }
1143
1144        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1145            self.inner.get_ranges(location, ranges).await
1146        }
1147
1148        fn delete_stream(
1149            &self,
1150            locations: BoxStream<'static, OSResult<Path>>,
1151        ) -> BoxStream<'static, OSResult<Path>> {
1152            self.inner.delete_stream(locations)
1153        }
1154
1155        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1156            self.list_calls.fetch_add(1, Ordering::SeqCst);
1157            self.inner.list(prefix)
1158        }
1159
1160        fn list_with_offset(
1161            &self,
1162            prefix: Option<&Path>,
1163            offset: &Path,
1164        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1165            self.offset_calls.fetch_add(1, Ordering::SeqCst);
1166            self.inner.list_with_offset(prefix, offset)
1167        }
1168
1169        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1170            self.inner.list_with_delimiter(prefix).await
1171        }
1172
1173        async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1174            self.inner.copy_opts(from, to, opts).await
1175        }
1176    }
1177
1178    fn list_start_throttle_config() -> AimdThrottleConfig {
1179        // Use a low rate (10 tokens/s) so that the token-acquisition sleep is
1180        // 1/10 = 100 ms — well above the 50 ms timeout used in assertions,
1181        // avoiding flakiness from coarse OS timer resolution (e.g. Windows ~16 ms).
1182        AimdThrottleConfig::default()
1183            .with_burst_capacity(0)
1184            .with_list_aimd(AimdConfig::default().with_initial_rate(10.0))
1185    }
1186
1187    #[tokio::test(start_paused = true)]
1188    async fn test_list_acquires_token_before_starting_underlying_stream() {
1189        let store = Arc::new(CountingListStartStore::default());
1190        store
1191            .put(
1192                &Path::from("prefix/file.txt"),
1193                PutPayload::from_static(b"data"),
1194            )
1195            .await
1196            .unwrap();
1197        let throttled = AimdThrottledStore::new(
1198            store.clone() as Arc<dyn ObjectStore>,
1199            list_start_throttle_config(),
1200        )
1201        .unwrap();
1202
1203        let mut stream = throttled.list(Some(&Path::from("prefix")));
1204        assert_eq!(store.list_calls(), 0);
1205        // With rate=10 tokens/s and burst_capacity=0, the token acquisition
1206        // sleeps for 100 ms. A 50 ms timeout must expire before that.
1207        assert!(
1208            tokio::time::timeout(std::time::Duration::from_millis(50), stream.next())
1209                .await
1210                .is_err()
1211        );
1212        assert_eq!(store.list_calls(), 0);
1213
1214        let item = tokio::time::timeout(std::time::Duration::from_millis(300), stream.next())
1215            .await
1216            .unwrap()
1217            .unwrap()
1218            .unwrap();
1219        assert_eq!(item.location, Path::from("prefix/file.txt"));
1220        assert_eq!(store.list_calls(), 1);
1221    }
1222
1223    #[tokio::test(start_paused = true)]
1224    async fn test_list_with_offset_acquires_token_before_starting_underlying_stream() {
1225        let store = Arc::new(CountingListStartStore::default());
1226        store
1227            .put(&Path::from("prefix/b"), PutPayload::from_static(b"data"))
1228            .await
1229            .unwrap();
1230        let throttled = AimdThrottledStore::new(
1231            store.clone() as Arc<dyn ObjectStore>,
1232            list_start_throttle_config(),
1233        )
1234        .unwrap();
1235
1236        let mut stream =
1237            throttled.list_with_offset(Some(&Path::from("prefix")), &Path::from("prefix/a"));
1238        assert_eq!(store.offset_calls(), 0);
1239        // With rate=10 tokens/s and burst_capacity=0, the token acquisition
1240        // sleeps for 100 ms. A 50 ms timeout must expire before that.
1241        assert!(
1242            tokio::time::timeout(std::time::Duration::from_millis(50), stream.next())
1243                .await
1244                .is_err()
1245        );
1246        assert_eq!(store.offset_calls(), 0);
1247
1248        let item = tokio::time::timeout(std::time::Duration::from_millis(300), stream.next())
1249            .await
1250            .unwrap()
1251            .unwrap()
1252            .unwrap();
1253        assert_eq!(item.location, Path::from("prefix/b"));
1254        assert_eq!(store.offset_calls(), 1);
1255    }
1256
1257    #[tokio::test]
1258    async fn test_per_category_independence() {
1259        let store = Arc::new(InMemory::new());
1260        let config = AimdThrottleConfig::default().with_aimd(
1261            AimdConfig::default()
1262                .with_initial_rate(100.0)
1263                .with_decrease_factor(0.5)
1264                .with_window_duration(std::time::Duration::from_millis(10)),
1265        );
1266        let throttled = AimdThrottledStore::new(store, config).unwrap();
1267
1268        // Push the read controller into a throttled state
1269        throttled
1270            .read
1271            .controller
1272            .record_outcome(RequestOutcome::Throttled);
1273        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1274        throttled
1275            .read
1276            .controller
1277            .record_outcome(RequestOutcome::Success);
1278
1279        let read_rate = throttled.read.controller.current_rate();
1280        let write_rate = throttled.write.controller.current_rate();
1281        let delete_rate = throttled.delete.controller.current_rate();
1282        let list_rate = throttled.list.controller.current_rate();
1283
1284        assert_eq!(read_rate, 50.0, "Read rate should have decreased");
1285        assert_eq!(write_rate, 100.0, "Write rate should be unaffected");
1286        assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected");
1287        assert_eq!(list_rate, 100.0, "List rate should be unaffected");
1288    }
1289
1290    #[tokio::test]
1291    async fn test_per_category_config() {
1292        let store = Arc::new(InMemory::new());
1293        let config = AimdThrottleConfig::default()
1294            .with_read_aimd(AimdConfig::default().with_initial_rate(200.0))
1295            .with_write_aimd(AimdConfig::default().with_initial_rate(100.0))
1296            .with_delete_aimd(AimdConfig::default().with_initial_rate(50.0))
1297            .with_list_aimd(AimdConfig::default().with_initial_rate(25.0));
1298        let throttled = AimdThrottledStore::new(store, config).unwrap();
1299
1300        assert_eq!(throttled.read.controller.current_rate(), 200.0);
1301        assert_eq!(throttled.write.controller.current_rate(), 100.0);
1302        assert_eq!(throttled.delete.controller.current_rate(), 50.0);
1303        assert_eq!(throttled.list.controller.current_rate(), 25.0);
1304    }
1305
1306    /// A mock [`ObjectStore`] that measures request rate over a sliding window
1307    /// and returns 503 errors when the rate exceeds a configurable threshold.
1308    /// Write and metadata-only operations are not rate-limited.
1309    struct RateLimitingMockStore {
1310        inner: InMemory,
1311        /// Timestamps of recent successful (admitted) requests.
1312        timestamps: std::sync::Mutex<VecDeque<std::time::Instant>>,
1313        /// Maximum requests allowed within `window`.
1314        max_per_window: usize,
1315        /// Sliding window duration.
1316        window: std::time::Duration,
1317        success_count: AtomicU64,
1318        throttle_count: AtomicU64,
1319    }
1320
1321    impl RateLimitingMockStore {
1322        fn new(max_per_window: usize, window: std::time::Duration) -> Self {
1323            Self {
1324                inner: InMemory::new(),
1325                timestamps: std::sync::Mutex::new(VecDeque::new()),
1326                max_per_window,
1327                window,
1328                success_count: AtomicU64::new(0),
1329                throttle_count: AtomicU64::new(0),
1330            }
1331        }
1332
1333        /// Returns `true` if the request is admitted, `false` if throttled.
1334        fn check_rate(&self) -> bool {
1335            let mut ts = self.timestamps.lock().unwrap();
1336            let now = std::time::Instant::now();
1337            while let Some(&front) = ts.front() {
1338                if now.duration_since(front) > self.window {
1339                    ts.pop_front();
1340                } else {
1341                    break;
1342                }
1343            }
1344            if ts.len() >= self.max_per_window {
1345                self.throttle_count.fetch_add(1, Ordering::Relaxed);
1346                false
1347            } else {
1348                ts.push_back(now);
1349                self.success_count.fetch_add(1, Ordering::Relaxed);
1350                true
1351            }
1352        }
1353
1354        fn throttle_error() -> object_store::Error {
1355            object_store::Error::Generic {
1356                store: "RateLimitingMock",
1357                source: THROTTLE_ERROR_RESPONSE.into(),
1358            }
1359        }
1360    }
1361
1362    impl Display for RateLimitingMockStore {
1363        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1364            write!(f, "RateLimitingMockStore")
1365        }
1366    }
1367
1368    impl Debug for RateLimitingMockStore {
1369        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1370            f.debug_struct("RateLimitingMockStore").finish()
1371        }
1372    }
1373
1374    #[async_trait]
1375    impl ObjectStore for RateLimitingMockStore {
1376        async fn put_opts(
1377            &self,
1378            location: &Path,
1379            bytes: PutPayload,
1380            opts: PutOptions,
1381        ) -> OSResult<PutResult> {
1382            self.inner.put_opts(location, bytes, opts).await
1383        }
1384
1385        async fn put_multipart_opts(
1386            &self,
1387            location: &Path,
1388            opts: PutMultipartOptions,
1389        ) -> OSResult<Box<dyn MultipartUpload>> {
1390            self.inner.put_multipart_opts(location, opts).await
1391        }
1392
1393        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1394            if self.check_rate() {
1395                self.inner.get_opts(location, options).await
1396            } else {
1397                Err(Self::throttle_error())
1398            }
1399        }
1400
1401        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1402            if self.check_rate() {
1403                self.inner.get_ranges(location, ranges).await
1404            } else {
1405                Err(Self::throttle_error())
1406            }
1407        }
1408
1409        fn delete_stream(
1410            &self,
1411            locations: BoxStream<'static, OSResult<Path>>,
1412        ) -> BoxStream<'static, OSResult<Path>> {
1413            self.inner.delete_stream(locations)
1414        }
1415
1416        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1417            self.inner.list(prefix)
1418        }
1419
1420        fn list_with_offset(
1421            &self,
1422            prefix: Option<&Path>,
1423            offset: &Path,
1424        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1425            self.inner.list_with_offset(prefix, offset)
1426        }
1427
1428        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1429            self.inner.list_with_delimiter(prefix).await
1430        }
1431
1432        async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1433            self.inner.copy_opts(from, to, opts).await
1434        }
1435    }
1436
1437    /// Verify that multiple concurrent readers sharing an AIMD-throttled store
1438    /// converge to the backend's actual capacity.
1439    ///
1440    /// Setup:
1441    /// - Mock backend allows 30 requests per 100ms (= 300 req/s).
1442    /// - 5 reader tasks, each with their own [`AimdThrottledStore`] wrapping
1443    ///   the shared mock.
1444    /// - AIMD: 100ms window, initial rate 100 req/s, decrease 0.5, increase 2.
1445    /// - Readers issue `head()` requests as fast as the throttle allows for 2s.
1446    ///
1447    /// Expected behaviour:
1448    /// - Initial burst (100 burst tokens × 5 readers) overshoots the mock
1449    ///   capacity, causing many 503s. Each reader's AIMD halves its rate.
1450    /// - After the transient, each reader converges to ~60 req/s (300/5).
1451    /// - Over 2 seconds, total successful requests should be in the range
1452    ///   [300, 900] (theoretical max ≈ 600).
1453    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1454    async fn test_aimd_throttle_under_concurrent_load() {
1455        let mock = Arc::new(RateLimitingMockStore::new(
1456            30,
1457            std::time::Duration::from_millis(100),
1458        ));
1459
1460        // Seed a test file so head() succeeds when admitted.
1461        let path = Path::from("test/data.bin");
1462        mock.put(&path, PutPayload::from_static(b"test data"))
1463            .await
1464            .unwrap();
1465
1466        let aimd = AimdConfig::default()
1467            .with_initial_rate(100.0)
1468            .with_decrease_factor(0.5)
1469            .with_additive_increment(2.0)
1470            .with_window_duration(std::time::Duration::from_millis(100));
1471        let throttle_config = AimdThrottleConfig::default()
1472            .with_aimd(aimd)
1473            .with_burst_capacity(100);
1474
1475        let num_readers = 5;
1476        let test_duration = std::time::Duration::from_secs(2);
1477        let mut handles = Vec::new();
1478
1479        for _ in 0..num_readers {
1480            let store = Arc::new(
1481                AimdThrottledStore::new(
1482                    mock.clone() as Arc<dyn ObjectStore>,
1483                    throttle_config.clone(),
1484                )
1485                .unwrap(),
1486            );
1487            let p = path.clone();
1488            handles.push(tokio::spawn(async move {
1489                let deadline = std::time::Instant::now() + test_duration;
1490                let mut count = 0u64;
1491                while std::time::Instant::now() < deadline {
1492                    let _ = store.head(&p).await;
1493                    count += 1;
1494                }
1495                count
1496            }));
1497        }
1498
1499        let mut total_reader_requests = 0u64;
1500        for handle in handles {
1501            total_reader_requests += handle.await.unwrap();
1502        }
1503
1504        let successes = mock.success_count.load(Ordering::Relaxed);
1505        let throttled = mock.throttle_count.load(Ordering::Relaxed);
1506        let total_mock = successes + throttled;
1507
1508        // Mock-side count >= reader-side count because the AIMD layer retries
1509        // throttle errors internally, causing multiple mock calls per reader call.
1510        assert!(
1511            total_mock >= total_reader_requests,
1512            "Mock-side count ({total_mock}) should be >= reader-side count ({total_reader_requests})"
1513        );
1514
1515        // Mock capacity is 30/100ms = 300 req/s. Over 2s the theoretical max is
1516        // ~600 successful requests. With AIMD ramp-up, expect somewhat fewer.
1517        assert!(
1518            successes >= 300,
1519            "Expected >= 300 successful requests over 2s, got {successes}"
1520        );
1521        assert!(
1522            successes <= 900,
1523            "Expected <= 900 successful requests, got {successes}"
1524        );
1525
1526        // The initial burst exceeds mock capacity, so throttling must occur.
1527        assert!(throttled > 0, "Expected some throttled requests but got 0");
1528
1529        // Without AIMD, raw tokio tasks against InMemory would fire 100k+ req/s.
1530        // AIMD should keep the total well under 5000 over 2s.
1531        assert!(
1532            total_mock <= 5000,
1533            "AIMD should limit total requests, got {total_mock}"
1534        );
1535    }
1536
1537    /// A mock store that returns a configurable number of throttle errors
1538    /// before succeeding on `get` operations. Used to test the retry logic
1539    /// inside `OperationThrottle::throttled()`.
1540    struct RetryTestMockStore {
1541        inner: InMemory,
1542        /// Number of throttle errors remaining before success.
1543        errors_remaining: std::sync::Mutex<usize>,
1544        /// Total number of `get` calls observed.
1545        get_call_count: AtomicU64,
1546    }
1547
1548    impl RetryTestMockStore {
1549        fn new(errors_before_success: usize) -> Self {
1550            Self {
1551                inner: InMemory::new(),
1552                errors_remaining: std::sync::Mutex::new(errors_before_success),
1553                get_call_count: AtomicU64::new(0),
1554            }
1555        }
1556    }
1557
1558    impl Display for RetryTestMockStore {
1559        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1560            write!(f, "RetryTestMockStore")
1561        }
1562    }
1563
1564    impl Debug for RetryTestMockStore {
1565        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1566            f.debug_struct("RetryTestMockStore").finish()
1567        }
1568    }
1569
1570    #[async_trait]
1571    impl ObjectStore for RetryTestMockStore {
1572        async fn put_opts(
1573            &self,
1574            location: &Path,
1575            bytes: PutPayload,
1576            opts: PutOptions,
1577        ) -> OSResult<PutResult> {
1578            self.inner.put_opts(location, bytes, opts).await
1579        }
1580        async fn put_multipart_opts(
1581            &self,
1582            location: &Path,
1583            opts: PutMultipartOptions,
1584        ) -> OSResult<Box<dyn MultipartUpload>> {
1585            self.inner.put_multipart_opts(location, opts).await
1586        }
1587        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1588            self.get_call_count.fetch_add(1, Ordering::Relaxed);
1589            let should_error = {
1590                let mut remaining = self.errors_remaining.lock().unwrap();
1591                if *remaining > 0 {
1592                    *remaining -= 1;
1593                    true
1594                } else {
1595                    false
1596                }
1597            };
1598            if should_error {
1599                Err(object_store::Error::Generic {
1600                    store: "RetryTestMock",
1601                    source: THROTTLE_ERROR_RESPONSE.into(),
1602                })
1603            } else {
1604                self.inner.get_opts(location, options).await
1605            }
1606        }
1607        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1608            self.inner.get_ranges(location, ranges).await
1609        }
1610        fn delete_stream(
1611            &self,
1612            locations: BoxStream<'static, OSResult<Path>>,
1613        ) -> BoxStream<'static, OSResult<Path>> {
1614            self.inner.delete_stream(locations)
1615        }
1616        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1617            self.inner.list(prefix)
1618        }
1619        fn list_with_offset(
1620            &self,
1621            prefix: Option<&Path>,
1622            offset: &Path,
1623        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1624            self.inner.list_with_offset(prefix, offset)
1625        }
1626        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1627            self.inner.list_with_delimiter(prefix).await
1628        }
1629        async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1630            self.inner.copy_opts(from, to, opts).await
1631        }
1632    }
1633
1634    #[tokio::test]
1635    async fn test_throttled_retries_on_throttle_error_then_succeeds() {
1636        // Mock returns 2 throttle errors then succeeds (within MAX_RETRIES=3)
1637        let mock = Arc::new(RetryTestMockStore::new(2));
1638        let path = Path::from("test/retry.txt");
1639        mock.put(&path, PutPayload::from_static(b"retry data"))
1640            .await
1641            .unwrap();
1642
1643        let config = AimdThrottleConfig::default();
1644        let throttled =
1645            AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1646
1647        let result = throttled.get(&path).await;
1648        assert!(result.is_ok(), "Expected success after retries");
1649
1650        let bytes = result.unwrap().bytes().await.unwrap();
1651        assert_eq!(bytes.as_ref(), b"retry data");
1652
1653        // Should have called get 3 times total: 2 failures + 1 success
1654        assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 3);
1655    }
1656
1657    #[tokio::test]
1658    async fn test_throttled_fails_after_max_retries_exceeded() {
1659        // Mock returns 4 throttle errors (more than MAX_RETRIES=3),
1660        // so all 4 attempts (initial + 3 retries) will fail.
1661        let mock = Arc::new(RetryTestMockStore::new(10));
1662        let path = Path::from("test/fail.txt");
1663        mock.put(&path, PutPayload::from_static(b"fail data"))
1664            .await
1665            .unwrap();
1666
1667        let config = AimdThrottleConfig::default();
1668        let throttled =
1669            AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1670
1671        let result = throttled.get(&path).await;
1672        assert!(result.is_err(), "Expected error after max retries");
1673        let err = result.unwrap_err();
1674        assert!(is_throttle_error(&err));
1675
1676        let lance_error = lance_core::Error::from(err);
1677        let error_message = lance_error.to_string();
1678        assert!(error_message.contains("x-ms-request-id"));
1679        assert!(error_message.contains("azure-request-id"));
1680
1681        // Should have called get 4 times: initial attempt + 3 retries
1682        assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4);
1683    }
1684
1685    #[tokio::test]
1686    async fn test_throttled_multipart_reorders_parts() {
1687        let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
1688        let config = AimdThrottleConfig::default();
1689        let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
1690
1691        let path = Path::from("test/multipart_ordering.bin");
1692        let mut upload = throttled.put_multipart(&path).await.unwrap();
1693
1694        // Create futures for two parts in order: A then B.
1695        let fut_a = upload.put_part(PutPayload::from_static(b"AAAA"));
1696        let fut_b = upload.put_part(PutPayload::from_static(b"BBBB"));
1697
1698        // Await in REVERSE order. Part ordering should be determined by
1699        // creation order (put_part call order), not by await order.
1700        fut_b.await.unwrap();
1701        fut_a.await.unwrap();
1702
1703        upload.complete().await.unwrap();
1704
1705        let result = store.get(&path).await.unwrap();
1706        let bytes = result.bytes().await.unwrap();
1707
1708        assert_eq!(
1709            bytes.as_ref(),
1710            b"AAAABBBB",
1711            "Parts were reordered! Got {:?} instead of AAAABBBB.",
1712            std::str::from_utf8(&bytes).unwrap_or("<non-utf8>"),
1713        );
1714    }
1715}