Skip to main content

lance_io/object_store/
throttle.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! AIMD-controlled token bucket rate limiter for ObjectStore operations.
5//!
6//! Wraps any [`object_store::ObjectStore`] with per-category token buckets
7//! whose fill rates are dynamically adjusted by AIMD controllers. When cloud
8//! stores return HTTP 429/503, the fill rate decreases multiplicatively. During
9//! sustained success windows, it increases additively.
10//!
11//! Operations are split into four independent categories — **read**, **write**,
12//! **delete**, **list** — each with its own AIMD controller and token bucket.
13//! This prevents a burst of reads from starving writes, and vice versa.
14//!
15//! # Example
16//!
17//! ```ignore
18//! use lance_io::object_store::throttle::{AimdThrottleConfig, AimdThrottledStore};
19//!
20//! let throttled = AimdThrottledStore::new(target, AimdThrottleConfig::default()).unwrap();
21//! ```
22
23use std::collections::HashMap;
24use std::fmt::{Debug, Display, Formatter};
25use std::ops::Range;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use bytes::Bytes;
30use futures::StreamExt;
31use futures::stream::BoxStream;
32use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome};
33use lance_core::utils::tracing::TRACE_OBJECT_STORE_THROTTLE;
34#[cfg(test)]
35use object_store::ObjectStoreExt;
36use object_store::path::Path;
37use object_store::{
38    CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
39    PutMultipartOptions, PutOptions, PutPayload, PutResult, RenameOptions, Result as OSResult,
40    UploadPart,
41};
42use rand::Rng;
43use tokio::sync::Mutex;
44use tracing::{debug, warn};
45
46/// Check whether an `object_store::Error` represents a throttle response
47/// (HTTP 429 / 503) from a cloud object store.
48///
49/// Regrettably, this information is not fully exposed by the `object_store` crate.
50/// There is no generic mechanism for a custom object store to return a throttle error.
51///
52/// However, the builtin object stores all use RetryError when retries are configured and
53/// throttle errors are returned.  Sadly, RetryError is not a public type, so we have to
54/// infer it from the error message.  This is potentially dangerous because these errors
55/// often include the URI itself and that URI could have any characters in it (e.g. if we
56/// look for 429 then we might match a 429 in a UUID).These error messages currently look like:
57///
58/// ", after ... retries, max_retries: ..., retry_timeout: ..."
59///
60/// So, as a crude heuristic, which should work for the builtin object stores, but won't
61/// work for custom object stores, we simply look for the string "retries, max_retries"
62/// in the error message.
63pub fn is_throttle_error(err: &object_store::Error) -> bool {
64    // Only Generic errors can carry throttle responses
65    if let object_store::Error::Generic { source, .. } = err {
66        source.to_string().contains("retries, max_retries")
67    } else {
68        false
69    }
70}
71
72/// Configuration for the AIMD-throttled ObjectStore wrapper.
73///
74/// Each operation category (read, write, delete, list) has its own AIMD config.
75/// Use [`with_aimd`](AimdThrottleConfig::with_aimd) to set all categories at
76/// once, or per-category methods like [`with_read_aimd`](AimdThrottleConfig::with_read_aimd)
77/// for fine-grained control.
78#[derive(Debug, Clone)]
79pub struct AimdThrottleConfig {
80    /// AIMD configuration for read operations (get, get_opts, get_range, get_ranges, head).
81    pub read: AimdConfig,
82    /// AIMD configuration for write operations (put, put_opts, put_multipart, copy, rename, etc.).
83    pub write: AimdConfig,
84    /// AIMD configuration for delete operations.
85    pub delete: AimdConfig,
86    /// AIMD configuration for list operations (list_with_delimiter).
87    pub list: AimdConfig,
88    /// Maximum tokens that can accumulate for bursts (shared across all categories).
89    pub burst_capacity: u32,
90    /// Maximum number of retries for throttle errors within the AIMD layer.
91    pub max_retries: usize,
92    /// Minimum backoff in milliseconds between retry attempts.
93    pub min_backoff_ms: u64,
94    /// Maximum backoff in milliseconds between retry attempts.
95    pub max_backoff_ms: u64,
96}
97
98impl Default for AimdThrottleConfig {
99    fn default() -> Self {
100        let aimd = AimdConfig::default();
101        Self {
102            read: aimd.clone(),
103            write: aimd.clone(),
104            delete: aimd.clone(),
105            list: aimd,
106            burst_capacity: 100,
107            max_retries: 3,
108            min_backoff_ms: 100,
109            max_backoff_ms: 300,
110        }
111    }
112}
113
114impl AimdThrottleConfig {
115    /// Set the AIMD configuration for all four operation categories at once.
116    pub fn with_aimd(self, aimd: AimdConfig) -> Self {
117        Self {
118            read: aimd.clone(),
119            write: aimd.clone(),
120            delete: aimd.clone(),
121            list: aimd,
122            ..self
123        }
124    }
125
126    /// Set the AIMD configuration for read operations.
127    pub fn with_read_aimd(self, aimd: AimdConfig) -> Self {
128        Self { read: aimd, ..self }
129    }
130
131    /// Set the AIMD configuration for write operations.
132    pub fn with_write_aimd(self, aimd: AimdConfig) -> Self {
133        Self {
134            write: aimd,
135            ..self
136        }
137    }
138
139    /// Set the AIMD configuration for delete operations.
140    pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self {
141        Self {
142            delete: aimd,
143            ..self
144        }
145    }
146
147    /// Set the AIMD configuration for list operations.
148    pub fn with_list_aimd(self, aimd: AimdConfig) -> Self {
149        Self { list: aimd, ..self }
150    }
151
152    /// Returns `true` when the AIMD throttle layer should be bypassed entirely.
153    pub fn is_disabled(&self) -> bool {
154        self.max_retries == 0
155    }
156
157    pub fn with_burst_capacity(self, burst_capacity: u32) -> Self {
158        Self {
159            burst_capacity,
160            ..self
161        }
162    }
163
164    /// Build an `AimdThrottleConfig` from storage options and environment variables.
165    ///
166    /// Storage options take precedence over environment variables, which take
167    /// precedence over defaults. A single AIMD config is applied to all four
168    /// operation categories (read/write/delete/list).
169    ///
170    /// | Setting              | Storage Option Key               | Env Var                          | Default |
171    /// |----------------------|----------------------------------|----------------------------------|---------|
172    /// | Initial rate         | `lance_aimd_initial_rate`        | `LANCE_AIMD_INITIAL_RATE`        | 2000    |
173    /// | Min rate             | `lance_aimd_min_rate`            | `LANCE_AIMD_MIN_RATE`            | 1       |
174    /// | Max rate             | `lance_aimd_max_rate`            | `LANCE_AIMD_MAX_RATE`            | 5000    |
175    /// | Decrease factor      | `lance_aimd_decrease_factor`     | `LANCE_AIMD_DECREASE_FACTOR`     | 0.5     |
176    /// | Additive increment   | `lance_aimd_additive_increment`  | `LANCE_AIMD_ADDITIVE_INCREMENT`  | 300     |
177    /// | Burst capacity       | `lance_aimd_burst_capacity`      | `LANCE_AIMD_BURST_CAPACITY`      | 100     |
178    /// | Max retries          | `lance_aimd_max_retries`         | `LANCE_AIMD_MAX_RETRIES`         | 3       |
179    /// | Min backoff ms       | `lance_aimd_min_backoff_ms`      | `LANCE_AIMD_MIN_BACKOFF_MS`      | 100     |
180    /// | Max backoff ms       | `lance_aimd_max_backoff_ms`      | `LANCE_AIMD_MAX_BACKOFF_MS`      | 300     |
181    pub fn from_storage_options(
182        storage_options: Option<&HashMap<String, String>>,
183    ) -> lance_core::Result<Self> {
184        fn resolve_f64(
185            key: &str,
186            storage_options: Option<&HashMap<String, String>>,
187            default: f64,
188        ) -> lance_core::Result<f64> {
189            let env_key = key.to_ascii_uppercase();
190            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
191                val.parse::<f64>().map_err(|_| {
192                    lance_core::Error::invalid_input(format!(
193                        "Invalid value for storage option '{key}': '{val}'"
194                    ))
195                })
196            } else if let Ok(val) = std::env::var(&env_key) {
197                val.parse::<f64>().map_err(|_| {
198                    lance_core::Error::invalid_input(format!(
199                        "Invalid value for env var '{env_key}': '{val}'"
200                    ))
201                })
202            } else {
203                Ok(default)
204            }
205        }
206
207        fn resolve_u32(
208            key: &str,
209            storage_options: Option<&HashMap<String, String>>,
210            default: u32,
211        ) -> lance_core::Result<u32> {
212            let env_key = key.to_ascii_uppercase();
213            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
214                val.parse::<u32>().map_err(|_| {
215                    lance_core::Error::invalid_input(format!(
216                        "Invalid value for storage option '{key}': '{val}'"
217                    ))
218                })
219            } else if let Ok(val) = std::env::var(&env_key) {
220                val.parse::<u32>().map_err(|_| {
221                    lance_core::Error::invalid_input(format!(
222                        "Invalid value for env var '{env_key}': '{val}'"
223                    ))
224                })
225            } else {
226                Ok(default)
227            }
228        }
229
230        fn resolve_usize(
231            key: &str,
232            storage_options: Option<&HashMap<String, String>>,
233            default: usize,
234        ) -> lance_core::Result<usize> {
235            let env_key = key.to_ascii_uppercase();
236            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
237                val.parse::<usize>().map_err(|_| {
238                    lance_core::Error::invalid_input(format!(
239                        "Invalid value for storage option '{key}': '{val}'"
240                    ))
241                })
242            } else if let Ok(val) = std::env::var(&env_key) {
243                val.parse::<usize>().map_err(|_| {
244                    lance_core::Error::invalid_input(format!(
245                        "Invalid value for env var '{env_key}': '{val}'"
246                    ))
247                })
248            } else {
249                Ok(default)
250            }
251        }
252
253        fn resolve_u64(
254            key: &str,
255            storage_options: Option<&HashMap<String, String>>,
256            default: u64,
257        ) -> lance_core::Result<u64> {
258            let env_key = key.to_ascii_uppercase();
259            if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
260                val.parse::<u64>().map_err(|_| {
261                    lance_core::Error::invalid_input(format!(
262                        "Invalid value for storage option '{key}': '{val}'"
263                    ))
264                })
265            } else if let Ok(val) = std::env::var(&env_key) {
266                val.parse::<u64>().map_err(|_| {
267                    lance_core::Error::invalid_input(format!(
268                        "Invalid value for env var '{env_key}': '{val}'"
269                    ))
270                })
271            } else {
272                Ok(default)
273            }
274        }
275
276        let initial_rate = resolve_f64("lance_aimd_initial_rate", storage_options, 2000.0)?;
277        let min_rate = resolve_f64("lance_aimd_min_rate", storage_options, 1.0)?;
278        let max_rate = resolve_f64("lance_aimd_max_rate", storage_options, 5000.0)?;
279        let decrease_factor = resolve_f64("lance_aimd_decrease_factor", storage_options, 0.5)?;
280        let additive_increment =
281            resolve_f64("lance_aimd_additive_increment", storage_options, 300.0)?;
282        let burst_capacity = resolve_u32("lance_aimd_burst_capacity", storage_options, 100)?;
283        let max_retries = resolve_usize("lance_aimd_max_retries", storage_options, 3)?;
284        let min_backoff_ms = resolve_u64("lance_aimd_min_backoff_ms", storage_options, 100)?;
285        let max_backoff_ms = resolve_u64("lance_aimd_max_backoff_ms", storage_options, 300)?;
286
287        let aimd = AimdConfig::default()
288            .with_initial_rate(initial_rate)
289            .with_min_rate(min_rate)
290            .with_max_rate(max_rate)
291            .with_decrease_factor(decrease_factor)
292            .with_additive_increment(additive_increment);
293
294        Ok(Self {
295            max_retries,
296            min_backoff_ms,
297            max_backoff_ms,
298            ..Self::default()
299                .with_aimd(aimd)
300                .with_burst_capacity(burst_capacity)
301        })
302    }
303}
304
305struct TokenBucketState {
306    tokens: f64,
307    last_refill: std::time::Instant,
308    rate: f64,
309}
310
311/// Per-category throttle state: an AIMD controller paired with a token bucket.
312struct OperationThrottle {
313    controller: AimdController,
314    bucket: Mutex<TokenBucketState>,
315    burst_capacity: f64,
316    max_retries: usize,
317    min_backoff_ms: u64,
318    max_backoff_ms: u64,
319}
320
321impl OperationThrottle {
322    fn new(
323        aimd_config: AimdConfig,
324        burst_capacity: f64,
325        max_retries: usize,
326        min_backoff_ms: u64,
327        max_backoff_ms: u64,
328    ) -> lance_core::Result<Self> {
329        let initial_rate = aimd_config.initial_rate;
330        let controller = AimdController::new(aimd_config)?;
331        Ok(Self {
332            controller,
333            bucket: Mutex::new(TokenBucketState {
334                tokens: burst_capacity,
335                last_refill: std::time::Instant::now(),
336                rate: initial_rate,
337            }),
338            burst_capacity,
339            max_retries,
340            min_backoff_ms,
341            max_backoff_ms,
342        })
343    }
344
345    /// Acquire a token from the bucket, sleeping if none are available.
346    ///
347    /// Each caller reserves a token immediately (allowing `tokens` to go
348    /// negative) so that concurrent waiters queue behind each other instead
349    /// of all waking at the same instant (thundering herd).
350    async fn acquire_token(&self) {
351        let sleep_duration = {
352            let mut bucket = self.bucket.lock().await;
353            let now = std::time::Instant::now();
354            let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
355            bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity);
356            bucket.last_refill = now;
357
358            // Reserve a token (may go negative to queue behind other waiters)
359            bucket.tokens -= 1.0;
360
361            if bucket.tokens >= 0.0 {
362                // Had a token available, no need to sleep
363                return;
364            }
365
366            // Sleep proportional to our position in the queue
367            std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate)
368        };
369
370        tokio::time::sleep(sleep_duration).await;
371    }
372
373    /// Update the bucket's fill rate from the controller.
374    async fn update_bucket_rate(&self, new_rate: f64) {
375        let mut bucket = self.bucket.lock().await;
376        bucket.rate = new_rate;
377    }
378
379    /// Classify a result and feed it back to the AIMD controller without
380    /// acquiring a token. Uses `try_lock` for the bucket update so that if the
381    /// bucket lock is contended the rate update is deferred to the next
382    /// `throttled()` call.
383    fn observe_outcome<T>(&self, result: &OSResult<T>) {
384        let outcome = match result {
385            Ok(_) => RequestOutcome::Success,
386            Err(err) if is_throttle_error(err) => {
387                debug!(
388                    target: TRACE_OBJECT_STORE_THROTTLE,
389                    error = %err,
390                    "Throttle error detected in stream"
391                );
392                RequestOutcome::Throttled
393            }
394            Err(_) => RequestOutcome::Success,
395        };
396        let prev_rate = self.controller.current_rate();
397        let new_rate = self.controller.record_outcome(outcome);
398        if new_rate < prev_rate
399            && let Err(err) = result.as_ref()
400        {
401            warn!(
402                target: TRACE_OBJECT_STORE_THROTTLE,
403                previous_rate = format!("{prev_rate:.1}"),
404                new_rate = format!("{new_rate:.1}"),
405                error = %err,
406                "AIMD throttle: rate reduced due to throttle errors"
407            );
408        }
409        if let Ok(mut bucket) = self.bucket.try_lock() {
410            bucket.rate = new_rate;
411        }
412    }
413
414    /// Execute an operation with throttling: acquire token, run, classify result.
415    /// On throttle errors, retries up to `max_retries` times with a random
416    /// backoff between `min_backoff_ms` and `max_backoff_ms` between attempts.
417    async fn throttled<T, F, Fut>(&self, f: F) -> OSResult<T>
418    where
419        F: Fn() -> Fut,
420        Fut: std::future::Future<Output = OSResult<T>>,
421    {
422        for attempt in 0..=self.max_retries {
423            self.acquire_token().await;
424            let result = f().await;
425            let outcome = match &result {
426                Ok(_) => RequestOutcome::Success,
427                Err(err) if is_throttle_error(err) => {
428                    debug!(
429                        target: TRACE_OBJECT_STORE_THROTTLE,
430                        error = %err,
431                        "Throttle error detected"
432                    );
433                    RequestOutcome::Throttled
434                }
435                Err(_) => RequestOutcome::Success, // Non-throttle errors don't indicate capacity problems
436            };
437            let prev_rate = self.controller.current_rate();
438            let new_rate = self.controller.record_outcome(outcome);
439            if new_rate < prev_rate
440                && let Err(err) = result.as_ref()
441            {
442                warn!(
443                    target: TRACE_OBJECT_STORE_THROTTLE,
444                    previous_rate = format!("{prev_rate:.1}"),
445                    new_rate = format!("{new_rate:.1}"),
446                    error = %err,
447                    "AIMD throttle: rate reduced due to throttle errors"
448                );
449            }
450            self.update_bucket_rate(new_rate).await;
451
452            match &result {
453                Err(err) if is_throttle_error(err) && attempt < self.max_retries => {
454                    let backoff_ms =
455                        rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms);
456                    debug!(
457                        target: TRACE_OBJECT_STORE_THROTTLE,
458                        attempt = attempt + 1,
459                        max_retries = self.max_retries,
460                        backoff_ms,
461                        error = %err,
462                        "Retrying after throttle error"
463                    );
464                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
465                    continue;
466                }
467                _ => return result,
468            }
469        }
470        unreachable!()
471    }
472}
473
474impl Debug for OperationThrottle {
475    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
476        f.debug_struct("OperationThrottle")
477            .field("controller", &self.controller)
478            .field("burst_capacity", &self.burst_capacity)
479            .finish()
480    }
481}
482
483/// A [`MultipartUpload`] wrapper that throttles and retries `put_part`,
484/// `complete`, and `abort`, feeding outcomes back to the write AIMD
485/// controller.
486struct ThrottledMultipartUpload {
487    target: Box<dyn MultipartUpload>,
488    write: Arc<OperationThrottle>,
489}
490
491impl Debug for ThrottledMultipartUpload {
492    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
493        f.debug_struct("ThrottledMultipartUpload").finish()
494    }
495}
496
497#[async_trait]
498impl MultipartUpload for ThrottledMultipartUpload {
499    fn put_part(&mut self, data: PutPayload) -> UploadPart {
500        let write = Arc::clone(&self.write);
501        // Call put_part synchronously to preserve part ordering regardless
502        // of which futures are awaited first.
503        let fut = self.target.put_part(data);
504        Box::pin(async move {
505            write.acquire_token().await;
506            let result = fut.await;
507            write.observe_outcome(&result);
508            result
509        })
510    }
511
512    async fn complete(&mut self) -> OSResult<PutResult> {
513        let target = &mut self.target;
514        for attempt in 0..=self.write.max_retries {
515            self.write.acquire_token().await;
516            let result = target.complete().await;
517            self.write.observe_outcome(&result);
518
519            match &result {
520                Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
521                    let backoff_ms = rand::rng()
522                        .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
523                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
524                    continue;
525                }
526                _ => return result,
527            }
528        }
529        unreachable!()
530    }
531
532    async fn abort(&mut self) -> OSResult<()> {
533        let target = &mut self.target;
534        for attempt in 0..=self.write.max_retries {
535            self.write.acquire_token().await;
536            let result = target.abort().await;
537            self.write.observe_outcome(&result);
538
539            match &result {
540                Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
541                    let backoff_ms = rand::rng()
542                        .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
543                    tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
544                    continue;
545                }
546                _ => return result,
547            }
548        }
549        unreachable!()
550    }
551}
552
553/// An ObjectStore wrapper that rate-limits operations using per-category token
554/// buckets whose fill rates are controlled by AIMD algorithms.
555///
556/// Operations are split into four independent categories:
557/// - **read**: `get`, `get_opts`, `get_range`, `get_ranges`, `head`
558/// - **write**: `put`, `put_opts`, `put_multipart`, `put_multipart_opts`, `copy`, `copy_if_not_exists`, `rename`, `rename_if_not_exists`
559/// - **delete**: `delete`
560/// - **list**: `list_with_delimiter`
561///
562/// Streaming operations (`list`, `list_with_offset`, `delete_stream`) do not acquire tokens,
563/// but observe each yielded item and feed the result back to the AIMD controller so it can
564/// adjust the rate for other operations in the same category.
565///
566/// This is not perfect but probably as close as we can get without moving the throttle into
567/// the object_store crate itself.
568pub struct AimdThrottledStore {
569    target: Arc<dyn ObjectStore>,
570    read: Arc<OperationThrottle>,
571    write: Arc<OperationThrottle>,
572    delete: Arc<OperationThrottle>,
573    list: Arc<OperationThrottle>,
574}
575
576impl Debug for AimdThrottledStore {
577    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
578        f.debug_struct("AimdThrottledStore")
579            .field("target", &self.target)
580            .field("read", &self.read)
581            .field("write", &self.write)
582            .field("delete", &self.delete)
583            .field("list", &self.list)
584            .finish()
585    }
586}
587
588impl Display for AimdThrottledStore {
589    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
590        write!(f, "AimdThrottledStore({})", self.target)
591    }
592}
593
594impl AimdThrottledStore {
595    pub fn new(
596        target: Arc<dyn ObjectStore>,
597        config: AimdThrottleConfig,
598    ) -> lance_core::Result<Self> {
599        let burst = config.burst_capacity as f64;
600        let max_retries = config.max_retries;
601        let min_backoff_ms = config.min_backoff_ms;
602        let max_backoff_ms = config.max_backoff_ms;
603        Ok(Self {
604            target,
605            read: Arc::new(OperationThrottle::new(
606                config.read,
607                burst,
608                max_retries,
609                min_backoff_ms,
610                max_backoff_ms,
611            )?),
612            write: Arc::new(OperationThrottle::new(
613                config.write,
614                burst,
615                max_retries,
616                min_backoff_ms,
617                max_backoff_ms,
618            )?),
619            delete: Arc::new(OperationThrottle::new(
620                config.delete,
621                burst,
622                max_retries,
623                min_backoff_ms,
624                max_backoff_ms,
625            )?),
626            list: Arc::new(OperationThrottle::new(
627                config.list,
628                burst,
629                max_retries,
630                min_backoff_ms,
631                max_backoff_ms,
632            )?),
633        })
634    }
635}
636
637#[async_trait]
638#[deny(clippy::missing_trait_methods)]
639impl ObjectStore for AimdThrottledStore {
640    async fn put_opts(
641        &self,
642        location: &Path,
643        bytes: PutPayload,
644        opts: PutOptions,
645    ) -> OSResult<PutResult> {
646        self.write
647            .throttled(|| self.target.put_opts(location, bytes.clone(), opts.clone()))
648            .await
649    }
650
651    async fn put_multipart_opts(
652        &self,
653        location: &Path,
654        opts: PutMultipartOptions,
655    ) -> OSResult<Box<dyn MultipartUpload>> {
656        let target = self
657            .write
658            .throttled(|| self.target.put_multipart_opts(location, opts.clone()))
659            .await?;
660        Ok(Box::new(ThrottledMultipartUpload {
661            target,
662            write: Arc::clone(&self.write),
663        }))
664    }
665
666    async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
667        self.read
668            .throttled(|| self.target.get_opts(location, options.clone()))
669            .await
670    }
671
672    async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
673        self.read
674            .throttled(|| self.target.get_ranges(location, ranges))
675            .await
676    }
677
678    fn delete_stream(
679        &self,
680        locations: BoxStream<'static, OSResult<Path>>,
681    ) -> BoxStream<'static, OSResult<Path>> {
682        let delete = Arc::clone(&self.delete);
683        self.target
684            .delete_stream(locations)
685            .map(move |item| {
686                delete.observe_outcome(&item);
687                item
688            })
689            .boxed()
690    }
691
692    fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
693        let throttle = Arc::clone(&self.list);
694        self.target
695            .list(prefix)
696            .map(move |item| {
697                throttle.observe_outcome(&item);
698                item
699            })
700            .boxed()
701    }
702
703    fn list_with_offset(
704        &self,
705        prefix: Option<&Path>,
706        offset: &Path,
707    ) -> BoxStream<'static, OSResult<ObjectMeta>> {
708        let throttle = Arc::clone(&self.list);
709        self.target
710            .list_with_offset(prefix, offset)
711            .map(move |item| {
712                throttle.observe_outcome(&item);
713                item
714            })
715            .boxed()
716    }
717
718    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
719        self.list
720            .throttled(|| self.target.list_with_delimiter(prefix))
721            .await
722    }
723
724    async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
725        self.write
726            .throttled(|| self.target.copy_opts(from, to, opts.clone()))
727            .await
728    }
729
730    async fn rename_opts(&self, from: &Path, to: &Path, opts: RenameOptions) -> OSResult<()> {
731        self.write
732            .throttled(|| self.target.rename_opts(from, to, opts.clone()))
733            .await
734    }
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740    use object_store::memory::InMemory;
741    use rstest::rstest;
742    use std::collections::VecDeque;
743    use std::sync::atomic::{AtomicU64, Ordering};
744
745    const THROTTLE_ERROR_RESPONSE: &str = "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s - Server returned non-2xx status code: 503: x-ms-request-id: azure-request-id";
746
747    fn make_generic_error(msg: &str) -> object_store::Error {
748        object_store::Error::Generic {
749            store: "test",
750            source: msg.into(),
751        }
752    }
753
754    #[rstest]
755    #[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)]
756    #[case::retries_in_message(
757        "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s",
758        true
759    )]
760    #[case::not_found("Object not found", false)]
761    #[case::permission_denied("Access denied", false)]
762    #[case::timeout("Connection timed out", false)]
763    #[case::http_429_without_retries("HTTP 429 Too Many Requests", false)]
764    #[case::slowdown_without_retries("SlowDown: Please reduce your request rate", false)]
765    fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) {
766        let err = make_generic_error(msg);
767        assert_eq!(
768            is_throttle_error(&err),
769            expected,
770            "is_throttle_error for '{}' should be {}",
771            msg,
772            expected
773        );
774    }
775
776    #[test]
777    fn test_non_generic_errors_are_not_throttle() {
778        let err = object_store::Error::NotFound {
779            path: "test".to_string(),
780            source: "not found".into(),
781        };
782        assert!(!is_throttle_error(&err));
783    }
784
785    #[tokio::test]
786    async fn test_basic_put_get_through_wrapper() {
787        let store = Arc::new(InMemory::new());
788        let config = AimdThrottleConfig::default();
789        let throttled = AimdThrottledStore::new(store, config).unwrap();
790
791        let path = Path::from("test/file.txt");
792        let data = PutPayload::from_static(b"hello world");
793        throttled.put(&path, data).await.unwrap();
794
795        let result = throttled.get(&path).await.unwrap();
796        let bytes = result.bytes().await.unwrap();
797        assert_eq!(bytes.as_ref(), b"hello world");
798    }
799
800    #[tokio::test]
801    async fn test_rate_decreases_on_throttle() {
802        let store = Arc::new(InMemory::new());
803        let config = AimdThrottleConfig::default().with_aimd(
804            AimdConfig::default()
805                .with_initial_rate(100.0)
806                .with_decrease_factor(0.5)
807                .with_window_duration(std::time::Duration::from_millis(10)),
808        );
809        let throttled = AimdThrottledStore::new(store, config).unwrap();
810
811        let initial_rate = throttled.read.controller.current_rate();
812        assert_eq!(initial_rate, 100.0);
813
814        // Simulate a throttle outcome directly
815        throttled
816            .read
817            .controller
818            .record_outcome(RequestOutcome::Throttled);
819
820        // Wait for window to expire and trigger evaluation
821        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
822        throttled
823            .read
824            .controller
825            .record_outcome(RequestOutcome::Success);
826
827        let new_rate = throttled.read.controller.current_rate();
828        assert!(
829            new_rate < initial_rate,
830            "Rate should decrease after throttle: {} < {}",
831            new_rate,
832            initial_rate
833        );
834    }
835
836    #[tokio::test]
837    async fn test_rate_recovers_on_success() {
838        let store = Arc::new(InMemory::new());
839        let config = AimdThrottleConfig::default().with_aimd(
840            AimdConfig::default()
841                .with_initial_rate(100.0)
842                .with_decrease_factor(0.5)
843                .with_additive_increment(10.0)
844                .with_window_duration(std::time::Duration::from_millis(10)),
845        );
846        let throttled = AimdThrottledStore::new(store, config).unwrap();
847
848        // First decrease via throttle
849        throttled
850            .read
851            .controller
852            .record_outcome(RequestOutcome::Throttled);
853        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
854        throttled
855            .read
856            .controller
857            .record_outcome(RequestOutcome::Success);
858        let decreased_rate = throttled.read.controller.current_rate();
859        assert_eq!(decreased_rate, 50.0);
860
861        // Now recover via success
862        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
863        throttled
864            .read
865            .controller
866            .record_outcome(RequestOutcome::Success);
867        let recovered_rate = throttled.read.controller.current_rate();
868        assert_eq!(recovered_rate, 60.0);
869    }
870
871    #[tokio::test]
872    async fn test_as_dyn_object_store() {
873        let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
874        let throttled: Arc<dyn ObjectStore> =
875            Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap());
876
877        let path = Path::from("test/data.bin");
878        let data = PutPayload::from_static(b"test data");
879        throttled.put(&path, data).await.unwrap();
880
881        let result = throttled.get(&path).await.unwrap();
882        let bytes = result.bytes().await.unwrap();
883        assert_eq!(bytes.as_ref(), b"test data");
884    }
885
886    #[tokio::test]
887    async fn test_token_bucket_delays_when_exhausted() {
888        let store = Arc::new(InMemory::new());
889        // Very low rate and burst capacity to force waiting
890        let config = AimdThrottleConfig::default()
891            .with_burst_capacity(1)
892            .with_aimd(AimdConfig::default().with_initial_rate(10.0));
893        let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap());
894
895        let path = Path::from("test/file.txt");
896        let data = PutPayload::from_static(b"data");
897        throttled.put(&path, data).await.unwrap();
898
899        // After consuming the burst token, the next request should take ~100ms
900        // (1 token / 10 tokens-per-sec). We verify it takes at least 50ms.
901        let start = std::time::Instant::now();
902        let data2 = PutPayload::from_static(b"data2");
903        throttled.put(&path, data2).await.unwrap();
904        let elapsed = start.elapsed();
905
906        assert!(
907            elapsed >= std::time::Duration::from_millis(50),
908            "Expected delay for token refill, but elapsed was {:?}",
909            elapsed
910        );
911    }
912
913    #[tokio::test]
914    async fn test_list_observes_outcomes() {
915        let store = Arc::new(InMemory::new());
916        let config = AimdThrottleConfig::default();
917        let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
918
919        let path = Path::from("prefix/file.txt");
920        let data = PutPayload::from_static(b"data");
921        store.put(&path, data).await.unwrap();
922
923        let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
924        assert_eq!(items.len(), 1);
925        assert!(items[0].is_ok());
926    }
927
928    /// A mock store whose `list` stream yields a configurable sequence of
929    /// Ok / throttle-error items. Used to verify that the AIMD wrapper
930    /// observes errors surfaced inside list streams.
931    struct ThrottlingListMockStore {
932        inner: InMemory,
933        /// Number of throttle errors to inject at the start of each list call.
934        throttle_count: usize,
935    }
936
937    impl Display for ThrottlingListMockStore {
938        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
939            write!(f, "ThrottlingListMockStore")
940        }
941    }
942
943    impl Debug for ThrottlingListMockStore {
944        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
945            f.debug_struct("ThrottlingListMockStore").finish()
946        }
947    }
948
949    #[async_trait]
950    impl ObjectStore for ThrottlingListMockStore {
951        async fn put_opts(
952            &self,
953            location: &Path,
954            bytes: PutPayload,
955            opts: PutOptions,
956        ) -> OSResult<PutResult> {
957            self.inner.put_opts(location, bytes, opts).await
958        }
959        async fn put_multipart_opts(
960            &self,
961            location: &Path,
962            opts: PutMultipartOptions,
963        ) -> OSResult<Box<dyn MultipartUpload>> {
964            self.inner.put_multipart_opts(location, opts).await
965        }
966        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
967            self.inner.get_opts(location, options).await
968        }
969        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
970            self.inner.get_ranges(location, ranges).await
971        }
972        fn delete_stream(
973            &self,
974            locations: BoxStream<'static, OSResult<Path>>,
975        ) -> BoxStream<'static, OSResult<Path>> {
976            self.inner.delete_stream(locations)
977        }
978        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
979            let n = self.throttle_count;
980            let inner_stream = self.inner.list(prefix);
981            let errors = futures::stream::iter((0..n).map(|_| {
982                Err(object_store::Error::Generic {
983                    store: "ThrottlingListMock",
984                    source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s"
985                        .into(),
986                })
987            }));
988            errors.chain(inner_stream).boxed()
989        }
990        fn list_with_offset(
991            &self,
992            prefix: Option<&Path>,
993            offset: &Path,
994        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
995            self.inner.list_with_offset(prefix, offset)
996        }
997        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
998            self.inner.list_with_delimiter(prefix).await
999        }
1000        async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1001            self.inner.copy_opts(from, to, opts).await
1002        }
1003    }
1004
1005    #[tokio::test]
1006    async fn test_list_stream_throttle_errors_decrease_rate() {
1007        let mock = Arc::new(ThrottlingListMockStore {
1008            inner: InMemory::new(),
1009            throttle_count: 5,
1010        });
1011
1012        // Seed a file so the real items come through after the errors.
1013        mock.put(
1014            &Path::from("prefix/file.txt"),
1015            PutPayload::from_static(b"data"),
1016        )
1017        .await
1018        .unwrap();
1019
1020        let config = AimdThrottleConfig::default().with_list_aimd(
1021            AimdConfig::default()
1022                .with_initial_rate(100.0)
1023                .with_decrease_factor(0.5)
1024                .with_window_duration(std::time::Duration::from_millis(10)),
1025        );
1026        let throttled = AimdThrottledStore::new(mock as Arc<dyn ObjectStore>, config).unwrap();
1027
1028        let initial_rate = throttled.list.controller.current_rate();
1029        assert_eq!(initial_rate, 100.0);
1030
1031        let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
1032
1033        // 5 errors + 1 real item
1034        assert_eq!(items.len(), 6);
1035        assert!(items[0].is_err());
1036        assert!(items[5].is_ok());
1037
1038        // Wait for the AIMD window to expire and trigger evaluation.
1039        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1040        throttled
1041            .list
1042            .controller
1043            .record_outcome(RequestOutcome::Success);
1044
1045        let new_rate = throttled.list.controller.current_rate();
1046        assert!(
1047            new_rate < initial_rate,
1048            "List rate should decrease after stream throttle errors: {} < {}",
1049            new_rate,
1050            initial_rate
1051        );
1052    }
1053
1054    #[tokio::test]
1055    async fn test_per_category_independence() {
1056        let store = Arc::new(InMemory::new());
1057        let config = AimdThrottleConfig::default().with_aimd(
1058            AimdConfig::default()
1059                .with_initial_rate(100.0)
1060                .with_decrease_factor(0.5)
1061                .with_window_duration(std::time::Duration::from_millis(10)),
1062        );
1063        let throttled = AimdThrottledStore::new(store, config).unwrap();
1064
1065        // Push the read controller into a throttled state
1066        throttled
1067            .read
1068            .controller
1069            .record_outcome(RequestOutcome::Throttled);
1070        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1071        throttled
1072            .read
1073            .controller
1074            .record_outcome(RequestOutcome::Success);
1075
1076        let read_rate = throttled.read.controller.current_rate();
1077        let write_rate = throttled.write.controller.current_rate();
1078        let delete_rate = throttled.delete.controller.current_rate();
1079        let list_rate = throttled.list.controller.current_rate();
1080
1081        assert_eq!(read_rate, 50.0, "Read rate should have decreased");
1082        assert_eq!(write_rate, 100.0, "Write rate should be unaffected");
1083        assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected");
1084        assert_eq!(list_rate, 100.0, "List rate should be unaffected");
1085    }
1086
1087    #[tokio::test]
1088    async fn test_per_category_config() {
1089        let store = Arc::new(InMemory::new());
1090        let config = AimdThrottleConfig::default()
1091            .with_read_aimd(AimdConfig::default().with_initial_rate(200.0))
1092            .with_write_aimd(AimdConfig::default().with_initial_rate(100.0))
1093            .with_delete_aimd(AimdConfig::default().with_initial_rate(50.0))
1094            .with_list_aimd(AimdConfig::default().with_initial_rate(25.0));
1095        let throttled = AimdThrottledStore::new(store, config).unwrap();
1096
1097        assert_eq!(throttled.read.controller.current_rate(), 200.0);
1098        assert_eq!(throttled.write.controller.current_rate(), 100.0);
1099        assert_eq!(throttled.delete.controller.current_rate(), 50.0);
1100        assert_eq!(throttled.list.controller.current_rate(), 25.0);
1101    }
1102
1103    /// A mock [`ObjectStore`] that measures request rate over a sliding window
1104    /// and returns 503 errors when the rate exceeds a configurable threshold.
1105    /// Write and metadata-only operations are not rate-limited.
1106    struct RateLimitingMockStore {
1107        inner: InMemory,
1108        /// Timestamps of recent successful (admitted) requests.
1109        timestamps: std::sync::Mutex<VecDeque<std::time::Instant>>,
1110        /// Maximum requests allowed within `window`.
1111        max_per_window: usize,
1112        /// Sliding window duration.
1113        window: std::time::Duration,
1114        success_count: AtomicU64,
1115        throttle_count: AtomicU64,
1116    }
1117
1118    impl RateLimitingMockStore {
1119        fn new(max_per_window: usize, window: std::time::Duration) -> Self {
1120            Self {
1121                inner: InMemory::new(),
1122                timestamps: std::sync::Mutex::new(VecDeque::new()),
1123                max_per_window,
1124                window,
1125                success_count: AtomicU64::new(0),
1126                throttle_count: AtomicU64::new(0),
1127            }
1128        }
1129
1130        /// Returns `true` if the request is admitted, `false` if throttled.
1131        fn check_rate(&self) -> bool {
1132            let mut ts = self.timestamps.lock().unwrap();
1133            let now = std::time::Instant::now();
1134            while let Some(&front) = ts.front() {
1135                if now.duration_since(front) > self.window {
1136                    ts.pop_front();
1137                } else {
1138                    break;
1139                }
1140            }
1141            if ts.len() >= self.max_per_window {
1142                self.throttle_count.fetch_add(1, Ordering::Relaxed);
1143                false
1144            } else {
1145                ts.push_back(now);
1146                self.success_count.fetch_add(1, Ordering::Relaxed);
1147                true
1148            }
1149        }
1150
1151        fn throttle_error() -> object_store::Error {
1152            object_store::Error::Generic {
1153                store: "RateLimitingMock",
1154                source: THROTTLE_ERROR_RESPONSE.into(),
1155            }
1156        }
1157    }
1158
1159    impl Display for RateLimitingMockStore {
1160        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1161            write!(f, "RateLimitingMockStore")
1162        }
1163    }
1164
1165    impl Debug for RateLimitingMockStore {
1166        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1167            f.debug_struct("RateLimitingMockStore").finish()
1168        }
1169    }
1170
1171    #[async_trait]
1172    impl ObjectStore for RateLimitingMockStore {
1173        async fn put_opts(
1174            &self,
1175            location: &Path,
1176            bytes: PutPayload,
1177            opts: PutOptions,
1178        ) -> OSResult<PutResult> {
1179            self.inner.put_opts(location, bytes, opts).await
1180        }
1181
1182        async fn put_multipart_opts(
1183            &self,
1184            location: &Path,
1185            opts: PutMultipartOptions,
1186        ) -> OSResult<Box<dyn MultipartUpload>> {
1187            self.inner.put_multipart_opts(location, opts).await
1188        }
1189
1190        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1191            if self.check_rate() {
1192                self.inner.get_opts(location, options).await
1193            } else {
1194                Err(Self::throttle_error())
1195            }
1196        }
1197
1198        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1199            if self.check_rate() {
1200                self.inner.get_ranges(location, ranges).await
1201            } else {
1202                Err(Self::throttle_error())
1203            }
1204        }
1205
1206        fn delete_stream(
1207            &self,
1208            locations: BoxStream<'static, OSResult<Path>>,
1209        ) -> BoxStream<'static, OSResult<Path>> {
1210            self.inner.delete_stream(locations)
1211        }
1212
1213        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1214            self.inner.list(prefix)
1215        }
1216
1217        fn list_with_offset(
1218            &self,
1219            prefix: Option<&Path>,
1220            offset: &Path,
1221        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1222            self.inner.list_with_offset(prefix, offset)
1223        }
1224
1225        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1226            self.inner.list_with_delimiter(prefix).await
1227        }
1228
1229        async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1230            self.inner.copy_opts(from, to, opts).await
1231        }
1232    }
1233
1234    /// Verify that multiple concurrent readers sharing an AIMD-throttled store
1235    /// converge to the backend's actual capacity.
1236    ///
1237    /// Setup:
1238    /// - Mock backend allows 30 requests per 100ms (= 300 req/s).
1239    /// - 5 reader tasks, each with their own [`AimdThrottledStore`] wrapping
1240    ///   the shared mock.
1241    /// - AIMD: 100ms window, initial rate 100 req/s, decrease 0.5, increase 2.
1242    /// - Readers issue `head()` requests as fast as the throttle allows for 2s.
1243    ///
1244    /// Expected behaviour:
1245    /// - Initial burst (100 burst tokens × 5 readers) overshoots the mock
1246    ///   capacity, causing many 503s. Each reader's AIMD halves its rate.
1247    /// - After the transient, each reader converges to ~60 req/s (300/5).
1248    /// - Over 2 seconds, total successful requests should be in the range
1249    ///   [300, 900] (theoretical max ≈ 600).
1250    #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1251    async fn test_aimd_throttle_under_concurrent_load() {
1252        let mock = Arc::new(RateLimitingMockStore::new(
1253            30,
1254            std::time::Duration::from_millis(100),
1255        ));
1256
1257        // Seed a test file so head() succeeds when admitted.
1258        let path = Path::from("test/data.bin");
1259        mock.put(&path, PutPayload::from_static(b"test data"))
1260            .await
1261            .unwrap();
1262
1263        let aimd = AimdConfig::default()
1264            .with_initial_rate(100.0)
1265            .with_decrease_factor(0.5)
1266            .with_additive_increment(2.0)
1267            .with_window_duration(std::time::Duration::from_millis(100));
1268        let throttle_config = AimdThrottleConfig::default()
1269            .with_aimd(aimd)
1270            .with_burst_capacity(100);
1271
1272        let num_readers = 5;
1273        let test_duration = std::time::Duration::from_secs(2);
1274        let mut handles = Vec::new();
1275
1276        for _ in 0..num_readers {
1277            let store = Arc::new(
1278                AimdThrottledStore::new(
1279                    mock.clone() as Arc<dyn ObjectStore>,
1280                    throttle_config.clone(),
1281                )
1282                .unwrap(),
1283            );
1284            let p = path.clone();
1285            handles.push(tokio::spawn(async move {
1286                let deadline = std::time::Instant::now() + test_duration;
1287                let mut count = 0u64;
1288                while std::time::Instant::now() < deadline {
1289                    let _ = store.head(&p).await;
1290                    count += 1;
1291                }
1292                count
1293            }));
1294        }
1295
1296        let mut total_reader_requests = 0u64;
1297        for handle in handles {
1298            total_reader_requests += handle.await.unwrap();
1299        }
1300
1301        let successes = mock.success_count.load(Ordering::Relaxed);
1302        let throttled = mock.throttle_count.load(Ordering::Relaxed);
1303        let total_mock = successes + throttled;
1304
1305        // Mock-side count >= reader-side count because the AIMD layer retries
1306        // throttle errors internally, causing multiple mock calls per reader call.
1307        assert!(
1308            total_mock >= total_reader_requests,
1309            "Mock-side count ({total_mock}) should be >= reader-side count ({total_reader_requests})"
1310        );
1311
1312        // Mock capacity is 30/100ms = 300 req/s. Over 2s the theoretical max is
1313        // ~600 successful requests. With AIMD ramp-up, expect somewhat fewer.
1314        assert!(
1315            successes >= 300,
1316            "Expected >= 300 successful requests over 2s, got {successes}"
1317        );
1318        assert!(
1319            successes <= 900,
1320            "Expected <= 900 successful requests, got {successes}"
1321        );
1322
1323        // The initial burst exceeds mock capacity, so throttling must occur.
1324        assert!(throttled > 0, "Expected some throttled requests but got 0");
1325
1326        // Without AIMD, raw tokio tasks against InMemory would fire 100k+ req/s.
1327        // AIMD should keep the total well under 5000 over 2s.
1328        assert!(
1329            total_mock <= 5000,
1330            "AIMD should limit total requests, got {total_mock}"
1331        );
1332    }
1333
1334    /// A mock store that returns a configurable number of throttle errors
1335    /// before succeeding on `get` operations. Used to test the retry logic
1336    /// inside `OperationThrottle::throttled()`.
1337    struct RetryTestMockStore {
1338        inner: InMemory,
1339        /// Number of throttle errors remaining before success.
1340        errors_remaining: std::sync::Mutex<usize>,
1341        /// Total number of `get` calls observed.
1342        get_call_count: AtomicU64,
1343    }
1344
1345    impl RetryTestMockStore {
1346        fn new(errors_before_success: usize) -> Self {
1347            Self {
1348                inner: InMemory::new(),
1349                errors_remaining: std::sync::Mutex::new(errors_before_success),
1350                get_call_count: AtomicU64::new(0),
1351            }
1352        }
1353    }
1354
1355    impl Display for RetryTestMockStore {
1356        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1357            write!(f, "RetryTestMockStore")
1358        }
1359    }
1360
1361    impl Debug for RetryTestMockStore {
1362        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1363            f.debug_struct("RetryTestMockStore").finish()
1364        }
1365    }
1366
1367    #[async_trait]
1368    impl ObjectStore for RetryTestMockStore {
1369        async fn put_opts(
1370            &self,
1371            location: &Path,
1372            bytes: PutPayload,
1373            opts: PutOptions,
1374        ) -> OSResult<PutResult> {
1375            self.inner.put_opts(location, bytes, opts).await
1376        }
1377        async fn put_multipart_opts(
1378            &self,
1379            location: &Path,
1380            opts: PutMultipartOptions,
1381        ) -> OSResult<Box<dyn MultipartUpload>> {
1382            self.inner.put_multipart_opts(location, opts).await
1383        }
1384        async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1385            self.get_call_count.fetch_add(1, Ordering::Relaxed);
1386            let should_error = {
1387                let mut remaining = self.errors_remaining.lock().unwrap();
1388                if *remaining > 0 {
1389                    *remaining -= 1;
1390                    true
1391                } else {
1392                    false
1393                }
1394            };
1395            if should_error {
1396                Err(object_store::Error::Generic {
1397                    store: "RetryTestMock",
1398                    source: THROTTLE_ERROR_RESPONSE.into(),
1399                })
1400            } else {
1401                self.inner.get_opts(location, options).await
1402            }
1403        }
1404        async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1405            self.inner.get_ranges(location, ranges).await
1406        }
1407        fn delete_stream(
1408            &self,
1409            locations: BoxStream<'static, OSResult<Path>>,
1410        ) -> BoxStream<'static, OSResult<Path>> {
1411            self.inner.delete_stream(locations)
1412        }
1413        fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1414            self.inner.list(prefix)
1415        }
1416        fn list_with_offset(
1417            &self,
1418            prefix: Option<&Path>,
1419            offset: &Path,
1420        ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1421            self.inner.list_with_offset(prefix, offset)
1422        }
1423        async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1424            self.inner.list_with_delimiter(prefix).await
1425        }
1426        async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1427            self.inner.copy_opts(from, to, opts).await
1428        }
1429    }
1430
1431    #[tokio::test]
1432    async fn test_throttled_retries_on_throttle_error_then_succeeds() {
1433        // Mock returns 2 throttle errors then succeeds (within MAX_RETRIES=3)
1434        let mock = Arc::new(RetryTestMockStore::new(2));
1435        let path = Path::from("test/retry.txt");
1436        mock.put(&path, PutPayload::from_static(b"retry data"))
1437            .await
1438            .unwrap();
1439
1440        let config = AimdThrottleConfig::default();
1441        let throttled =
1442            AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1443
1444        let result = throttled.get(&path).await;
1445        assert!(result.is_ok(), "Expected success after retries");
1446
1447        let bytes = result.unwrap().bytes().await.unwrap();
1448        assert_eq!(bytes.as_ref(), b"retry data");
1449
1450        // Should have called get 3 times total: 2 failures + 1 success
1451        assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 3);
1452    }
1453
1454    #[tokio::test]
1455    async fn test_throttled_fails_after_max_retries_exceeded() {
1456        // Mock returns 4 throttle errors (more than MAX_RETRIES=3),
1457        // so all 4 attempts (initial + 3 retries) will fail.
1458        let mock = Arc::new(RetryTestMockStore::new(10));
1459        let path = Path::from("test/fail.txt");
1460        mock.put(&path, PutPayload::from_static(b"fail data"))
1461            .await
1462            .unwrap();
1463
1464        let config = AimdThrottleConfig::default();
1465        let throttled =
1466            AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1467
1468        let result = throttled.get(&path).await;
1469        assert!(result.is_err(), "Expected error after max retries");
1470        let err = result.unwrap_err();
1471        assert!(is_throttle_error(&err));
1472
1473        let lance_error = lance_core::Error::from(err);
1474        let error_message = lance_error.to_string();
1475        assert!(error_message.contains("x-ms-request-id"));
1476        assert!(error_message.contains("azure-request-id"));
1477
1478        // Should have called get 4 times: initial attempt + 3 retries
1479        assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4);
1480    }
1481
1482    #[tokio::test]
1483    async fn test_throttled_multipart_reorders_parts() {
1484        let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
1485        let config = AimdThrottleConfig::default();
1486        let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
1487
1488        let path = Path::from("test/multipart_ordering.bin");
1489        let mut upload = throttled.put_multipart(&path).await.unwrap();
1490
1491        // Create futures for two parts in order: A then B.
1492        let fut_a = upload.put_part(PutPayload::from_static(b"AAAA"));
1493        let fut_b = upload.put_part(PutPayload::from_static(b"BBBB"));
1494
1495        // Await in REVERSE order. Part ordering should be determined by
1496        // creation order (put_part call order), not by await order.
1497        fut_b.await.unwrap();
1498        fut_a.await.unwrap();
1499
1500        upload.complete().await.unwrap();
1501
1502        let result = store.get(&path).await.unwrap();
1503        let bytes = result.bytes().await.unwrap();
1504
1505        assert_eq!(
1506            bytes.as_ref(),
1507            b"AAAABBBB",
1508            "Parts were reordered! Got {:?} instead of AAAABBBB.",
1509            std::str::from_utf8(&bytes).unwrap_or("<non-utf8>"),
1510        );
1511    }
1512}