1use std::collections::HashMap;
24use std::fmt::{Debug, Display, Formatter};
25use std::ops::Range;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use bytes::Bytes;
30use futures::StreamExt;
31use futures::stream::BoxStream;
32use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome};
33use lance_core::utils::tracing::TRACE_OBJECT_STORE_THROTTLE;
34#[cfg(test)]
35use object_store::ObjectStoreExt;
36use object_store::path::Path;
37use object_store::{
38 CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
39 PutMultipartOptions, PutOptions, PutPayload, PutResult, RenameOptions, Result as OSResult,
40 UploadPart,
41};
42use rand::Rng;
43use tokio::sync::Mutex;
44use tracing::{debug, warn};
45
46pub fn is_throttle_error(err: &object_store::Error) -> bool {
64 if let object_store::Error::Generic { source, .. } = err {
66 let message = source.to_string();
67 let lowercase = message.to_ascii_lowercase();
68 lowercase.contains("retries, max_retries")
69 || lowercase.contains("serverbusy")
70 || lowercase.contains("server busy")
71 || lowercase.contains("egress is over the account limit")
72 || lowercase.contains("http 429")
73 || lowercase.contains("status code: 429")
74 || lowercase.contains("429 too many requests")
75 || lowercase.contains("too many requests")
76 || lowercase.contains("slowdown")
77 || lowercase.contains("please reduce your request rate")
78 || lowercase.contains("rate limit")
79 || lowercase.contains("throttling")
80 || lowercase.contains("throttled")
81 } else {
82 false
83 }
84}
85
86#[derive(Debug, Clone)]
93pub struct AimdThrottleConfig {
94 pub read: AimdConfig,
96 pub write: AimdConfig,
98 pub delete: AimdConfig,
100 pub list: AimdConfig,
102 pub burst_capacity: u32,
104 pub max_retries: usize,
106 pub min_backoff_ms: u64,
108 pub max_backoff_ms: u64,
110}
111
112impl Default for AimdThrottleConfig {
113 fn default() -> Self {
114 let aimd = AimdConfig::default();
115 Self {
116 read: aimd.clone(),
117 write: aimd.clone(),
118 delete: aimd.clone(),
119 list: aimd,
120 burst_capacity: 100,
121 max_retries: 3,
122 min_backoff_ms: 100,
123 max_backoff_ms: 300,
124 }
125 }
126}
127
128impl AimdThrottleConfig {
129 pub fn with_aimd(self, aimd: AimdConfig) -> Self {
131 Self {
132 read: aimd.clone(),
133 write: aimd.clone(),
134 delete: aimd.clone(),
135 list: aimd,
136 ..self
137 }
138 }
139
140 pub fn with_read_aimd(self, aimd: AimdConfig) -> Self {
142 Self { read: aimd, ..self }
143 }
144
145 pub fn with_write_aimd(self, aimd: AimdConfig) -> Self {
147 Self {
148 write: aimd,
149 ..self
150 }
151 }
152
153 pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self {
155 Self {
156 delete: aimd,
157 ..self
158 }
159 }
160
161 pub fn with_list_aimd(self, aimd: AimdConfig) -> Self {
163 Self { list: aimd, ..self }
164 }
165
166 pub fn is_disabled(&self) -> bool {
168 self.max_retries == 0
169 }
170
171 pub fn with_burst_capacity(self, burst_capacity: u32) -> Self {
172 Self {
173 burst_capacity,
174 ..self
175 }
176 }
177
178 pub fn from_storage_options(
196 storage_options: Option<&HashMap<String, String>>,
197 ) -> lance_core::Result<Self> {
198 fn resolve_f64(
199 key: &str,
200 storage_options: Option<&HashMap<String, String>>,
201 default: f64,
202 ) -> lance_core::Result<f64> {
203 let env_key = key.to_ascii_uppercase();
204 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
205 val.parse::<f64>().map_err(|_| {
206 lance_core::Error::invalid_input(format!(
207 "Invalid value for storage option '{key}': '{val}'"
208 ))
209 })
210 } else if let Ok(val) = std::env::var(&env_key) {
211 val.parse::<f64>().map_err(|_| {
212 lance_core::Error::invalid_input(format!(
213 "Invalid value for env var '{env_key}': '{val}'"
214 ))
215 })
216 } else {
217 Ok(default)
218 }
219 }
220
221 fn resolve_u32(
222 key: &str,
223 storage_options: Option<&HashMap<String, String>>,
224 default: u32,
225 ) -> lance_core::Result<u32> {
226 let env_key = key.to_ascii_uppercase();
227 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
228 val.parse::<u32>().map_err(|_| {
229 lance_core::Error::invalid_input(format!(
230 "Invalid value for storage option '{key}': '{val}'"
231 ))
232 })
233 } else if let Ok(val) = std::env::var(&env_key) {
234 val.parse::<u32>().map_err(|_| {
235 lance_core::Error::invalid_input(format!(
236 "Invalid value for env var '{env_key}': '{val}'"
237 ))
238 })
239 } else {
240 Ok(default)
241 }
242 }
243
244 fn resolve_usize(
245 key: &str,
246 storage_options: Option<&HashMap<String, String>>,
247 default: usize,
248 ) -> lance_core::Result<usize> {
249 let env_key = key.to_ascii_uppercase();
250 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
251 val.parse::<usize>().map_err(|_| {
252 lance_core::Error::invalid_input(format!(
253 "Invalid value for storage option '{key}': '{val}'"
254 ))
255 })
256 } else if let Ok(val) = std::env::var(&env_key) {
257 val.parse::<usize>().map_err(|_| {
258 lance_core::Error::invalid_input(format!(
259 "Invalid value for env var '{env_key}': '{val}'"
260 ))
261 })
262 } else {
263 Ok(default)
264 }
265 }
266
267 fn resolve_u64(
268 key: &str,
269 storage_options: Option<&HashMap<String, String>>,
270 default: u64,
271 ) -> lance_core::Result<u64> {
272 let env_key = key.to_ascii_uppercase();
273 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
274 val.parse::<u64>().map_err(|_| {
275 lance_core::Error::invalid_input(format!(
276 "Invalid value for storage option '{key}': '{val}'"
277 ))
278 })
279 } else if let Ok(val) = std::env::var(&env_key) {
280 val.parse::<u64>().map_err(|_| {
281 lance_core::Error::invalid_input(format!(
282 "Invalid value for env var '{env_key}': '{val}'"
283 ))
284 })
285 } else {
286 Ok(default)
287 }
288 }
289
290 let initial_rate = resolve_f64("lance_aimd_initial_rate", storage_options, 2000.0)?;
291 let min_rate = resolve_f64("lance_aimd_min_rate", storage_options, 1.0)?;
292 let max_rate = resolve_f64("lance_aimd_max_rate", storage_options, 5000.0)?;
293 let decrease_factor = resolve_f64("lance_aimd_decrease_factor", storage_options, 0.5)?;
294 let additive_increment =
295 resolve_f64("lance_aimd_additive_increment", storage_options, 300.0)?;
296 let burst_capacity = resolve_u32("lance_aimd_burst_capacity", storage_options, 100)?;
297 let max_retries = resolve_usize("lance_aimd_max_retries", storage_options, 3)?;
298 let min_backoff_ms = resolve_u64("lance_aimd_min_backoff_ms", storage_options, 100)?;
299 let max_backoff_ms = resolve_u64("lance_aimd_max_backoff_ms", storage_options, 300)?;
300
301 let aimd = AimdConfig::default()
302 .with_initial_rate(initial_rate)
303 .with_min_rate(min_rate)
304 .with_max_rate(max_rate)
305 .with_decrease_factor(decrease_factor)
306 .with_additive_increment(additive_increment);
307
308 Ok(Self {
309 max_retries,
310 min_backoff_ms,
311 max_backoff_ms,
312 ..Self::default()
313 .with_aimd(aimd)
314 .with_burst_capacity(burst_capacity)
315 })
316 }
317}
318
319struct TokenBucketState {
320 tokens: f64,
321 last_refill: tokio::time::Instant,
322 rate: f64,
323}
324
325struct OperationThrottle {
327 controller: AimdController,
328 bucket: Mutex<TokenBucketState>,
329 burst_capacity: f64,
330 max_retries: usize,
331 min_backoff_ms: u64,
332 max_backoff_ms: u64,
333}
334
335impl OperationThrottle {
336 fn new(
337 aimd_config: AimdConfig,
338 burst_capacity: f64,
339 max_retries: usize,
340 min_backoff_ms: u64,
341 max_backoff_ms: u64,
342 ) -> lance_core::Result<Self> {
343 let initial_rate = aimd_config.initial_rate;
344 let controller = AimdController::new(aimd_config)?;
345 Ok(Self {
346 controller,
347 bucket: Mutex::new(TokenBucketState {
348 tokens: burst_capacity,
349 last_refill: tokio::time::Instant::now(),
350 rate: initial_rate,
351 }),
352 burst_capacity,
353 max_retries,
354 min_backoff_ms,
355 max_backoff_ms,
356 })
357 }
358
359 async fn acquire_token(&self) {
365 let sleep_duration = {
366 let mut bucket = self.bucket.lock().await;
367 let now = tokio::time::Instant::now();
368 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
369 bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity);
370 bucket.last_refill = now;
371
372 bucket.tokens -= 1.0;
374
375 if bucket.tokens >= 0.0 {
376 return;
378 }
379
380 std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate)
382 };
383
384 tokio::time::sleep(sleep_duration).await;
385 }
386
387 async fn update_bucket_rate(&self, new_rate: f64) {
389 let mut bucket = self.bucket.lock().await;
390 bucket.rate = new_rate;
391 }
392
393 fn observe_outcome<T>(&self, result: &OSResult<T>) {
398 let outcome = match result {
399 Ok(_) => RequestOutcome::Success,
400 Err(err) if is_throttle_error(err) => {
401 debug!(
402 target: TRACE_OBJECT_STORE_THROTTLE,
403 error = %err,
404 "Throttle error detected in stream"
405 );
406 RequestOutcome::Throttled
407 }
408 Err(_) => RequestOutcome::Success,
409 };
410 let prev_rate = self.controller.current_rate();
411 let new_rate = self.controller.record_outcome(outcome);
412 if new_rate < prev_rate
413 && let Err(err) = result.as_ref()
414 {
415 warn!(
416 target: TRACE_OBJECT_STORE_THROTTLE,
417 previous_rate = format!("{prev_rate:.1}"),
418 new_rate = format!("{new_rate:.1}"),
419 error = %err,
420 "AIMD throttle: rate reduced due to throttle errors"
421 );
422 }
423 if let Ok(mut bucket) = self.bucket.try_lock() {
424 bucket.rate = new_rate;
425 }
426 }
427
428 async fn throttled<T, F, Fut>(&self, f: F) -> OSResult<T>
432 where
433 F: Fn() -> Fut,
434 Fut: std::future::Future<Output = OSResult<T>>,
435 {
436 for attempt in 0..=self.max_retries {
437 self.acquire_token().await;
438 let result = f().await;
439 let outcome = match &result {
440 Ok(_) => RequestOutcome::Success,
441 Err(err) if is_throttle_error(err) => {
442 debug!(
443 target: TRACE_OBJECT_STORE_THROTTLE,
444 error = %err,
445 "Throttle error detected"
446 );
447 RequestOutcome::Throttled
448 }
449 Err(_) => RequestOutcome::Success, };
451 let prev_rate = self.controller.current_rate();
452 let new_rate = self.controller.record_outcome(outcome);
453 if new_rate < prev_rate
454 && let Err(err) = result.as_ref()
455 {
456 warn!(
457 target: TRACE_OBJECT_STORE_THROTTLE,
458 previous_rate = format!("{prev_rate:.1}"),
459 new_rate = format!("{new_rate:.1}"),
460 error = %err,
461 "AIMD throttle: rate reduced due to throttle errors"
462 );
463 }
464 self.update_bucket_rate(new_rate).await;
465
466 match &result {
467 Err(err) if is_throttle_error(err) && attempt < self.max_retries => {
468 let backoff_ms =
469 rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms);
470 debug!(
471 target: TRACE_OBJECT_STORE_THROTTLE,
472 attempt = attempt + 1,
473 max_retries = self.max_retries,
474 backoff_ms,
475 error = %err,
476 "Retrying after throttle error"
477 );
478 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
479 continue;
480 }
481 _ => return result,
482 }
483 }
484 unreachable!()
485 }
486}
487
488impl Debug for OperationThrottle {
489 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
490 f.debug_struct("OperationThrottle")
491 .field("controller", &self.controller)
492 .field("burst_capacity", &self.burst_capacity)
493 .finish()
494 }
495}
496
497struct ThrottledMultipartUpload {
501 target: Box<dyn MultipartUpload>,
502 write: Arc<OperationThrottle>,
503}
504
505impl Debug for ThrottledMultipartUpload {
506 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
507 f.debug_struct("ThrottledMultipartUpload").finish()
508 }
509}
510
511#[async_trait]
512impl MultipartUpload for ThrottledMultipartUpload {
513 fn put_part(&mut self, data: PutPayload) -> UploadPart {
514 let write = Arc::clone(&self.write);
515 let fut = self.target.put_part(data);
518 Box::pin(async move {
519 write.acquire_token().await;
520 let result = fut.await;
521 write.observe_outcome(&result);
522 result
523 })
524 }
525
526 async fn complete(&mut self) -> OSResult<PutResult> {
527 let target = &mut self.target;
528 for attempt in 0..=self.write.max_retries {
529 self.write.acquire_token().await;
530 let result = target.complete().await;
531 self.write.observe_outcome(&result);
532
533 match &result {
534 Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
535 let backoff_ms = rand::rng()
536 .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
537 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
538 continue;
539 }
540 _ => return result,
541 }
542 }
543 unreachable!()
544 }
545
546 async fn abort(&mut self) -> OSResult<()> {
547 let target = &mut self.target;
548 for attempt in 0..=self.write.max_retries {
549 self.write.acquire_token().await;
550 let result = target.abort().await;
551 self.write.observe_outcome(&result);
552
553 match &result {
554 Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
555 let backoff_ms = rand::rng()
556 .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
557 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
558 continue;
559 }
560 _ => return result,
561 }
562 }
563 unreachable!()
564 }
565}
566
567pub struct AimdThrottledStore {
583 target: Arc<dyn ObjectStore>,
584 read: Arc<OperationThrottle>,
585 write: Arc<OperationThrottle>,
586 delete: Arc<OperationThrottle>,
587 list: Arc<OperationThrottle>,
588}
589
590impl Debug for AimdThrottledStore {
591 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
592 f.debug_struct("AimdThrottledStore")
593 .field("target", &self.target)
594 .field("read", &self.read)
595 .field("write", &self.write)
596 .field("delete", &self.delete)
597 .field("list", &self.list)
598 .finish()
599 }
600}
601
602impl Display for AimdThrottledStore {
603 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
604 write!(f, "AimdThrottledStore({})", self.target)
605 }
606}
607
608impl AimdThrottledStore {
609 pub fn new(
610 target: Arc<dyn ObjectStore>,
611 config: AimdThrottleConfig,
612 ) -> lance_core::Result<Self> {
613 let burst = config.burst_capacity as f64;
614 let max_retries = config.max_retries;
615 let min_backoff_ms = config.min_backoff_ms;
616 let max_backoff_ms = config.max_backoff_ms;
617 Ok(Self {
618 target,
619 read: Arc::new(OperationThrottle::new(
620 config.read,
621 burst,
622 max_retries,
623 min_backoff_ms,
624 max_backoff_ms,
625 )?),
626 write: Arc::new(OperationThrottle::new(
627 config.write,
628 burst,
629 max_retries,
630 min_backoff_ms,
631 max_backoff_ms,
632 )?),
633 delete: Arc::new(OperationThrottle::new(
634 config.delete,
635 burst,
636 max_retries,
637 min_backoff_ms,
638 max_backoff_ms,
639 )?),
640 list: Arc::new(OperationThrottle::new(
641 config.list,
642 burst,
643 max_retries,
644 min_backoff_ms,
645 max_backoff_ms,
646 )?),
647 })
648 }
649}
650
651#[async_trait]
652#[deny(clippy::missing_trait_methods)]
653impl ObjectStore for AimdThrottledStore {
654 async fn put_opts(
655 &self,
656 location: &Path,
657 bytes: PutPayload,
658 opts: PutOptions,
659 ) -> OSResult<PutResult> {
660 self.write
661 .throttled(|| self.target.put_opts(location, bytes.clone(), opts.clone()))
662 .await
663 }
664
665 async fn put_multipart_opts(
666 &self,
667 location: &Path,
668 opts: PutMultipartOptions,
669 ) -> OSResult<Box<dyn MultipartUpload>> {
670 let target = self
671 .write
672 .throttled(|| self.target.put_multipart_opts(location, opts.clone()))
673 .await?;
674 Ok(Box::new(ThrottledMultipartUpload {
675 target,
676 write: Arc::clone(&self.write),
677 }))
678 }
679
680 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
681 self.read
682 .throttled(|| self.target.get_opts(location, options.clone()))
683 .await
684 }
685
686 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
687 self.read
688 .throttled(|| self.target.get_ranges(location, ranges))
689 .await
690 }
691
692 fn delete_stream(
693 &self,
694 locations: BoxStream<'static, OSResult<Path>>,
695 ) -> BoxStream<'static, OSResult<Path>> {
696 let delete = Arc::clone(&self.delete);
697 self.target
698 .delete_stream(locations)
699 .map(move |item| {
700 delete.observe_outcome(&item);
701 item
702 })
703 .boxed()
704 }
705
706 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
707 let throttle = Arc::clone(&self.list);
708 let throttle_for_start = Arc::clone(&throttle);
709 let target = Arc::clone(&self.target);
710 let prefix = prefix.cloned();
711 futures::stream::once(async move {
712 throttle_for_start.acquire_token().await;
713 target.list(prefix.as_ref())
714 })
715 .flatten()
716 .map(move |item| {
717 throttle.observe_outcome(&item);
718 item
719 })
720 .boxed()
721 }
722
723 fn list_with_offset(
724 &self,
725 prefix: Option<&Path>,
726 offset: &Path,
727 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
728 let throttle = Arc::clone(&self.list);
729 let throttle_for_start = Arc::clone(&throttle);
730 let target = Arc::clone(&self.target);
731 let prefix = prefix.cloned();
732 let offset = offset.clone();
733 futures::stream::once(async move {
734 throttle_for_start.acquire_token().await;
735 target.list_with_offset(prefix.as_ref(), &offset)
736 })
737 .flatten()
738 .map(move |item| {
739 throttle.observe_outcome(&item);
740 item
741 })
742 .boxed()
743 }
744
745 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
746 self.list
747 .throttled(|| self.target.list_with_delimiter(prefix))
748 .await
749 }
750
751 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
752 self.write
753 .throttled(|| self.target.copy_opts(from, to, opts.clone()))
754 .await
755 }
756
757 async fn rename_opts(&self, from: &Path, to: &Path, opts: RenameOptions) -> OSResult<()> {
758 self.write
759 .throttled(|| self.target.rename_opts(from, to, opts.clone()))
760 .await
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767 use object_store::memory::InMemory;
768 use rstest::rstest;
769 use std::collections::VecDeque;
770 use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
771
772 const THROTTLE_ERROR_RESPONSE: &str = "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s - Server returned non-2xx status code: 503: x-ms-request-id: azure-request-id";
773
774 fn make_generic_error(msg: &str) -> object_store::Error {
775 object_store::Error::Generic {
776 store: "test",
777 source: msg.into(),
778 }
779 }
780
781 #[rstest]
782 #[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)]
783 #[case::retries_in_message(
784 "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s",
785 true
786 )]
787 #[case::not_found("Object not found", false)]
788 #[case::permission_denied("Access denied", false)]
789 #[case::timeout("Connection timed out", false)]
790 #[case::http_429_without_retries("HTTP 429 Too Many Requests", true)]
791 #[case::slowdown_without_retries("SlowDown: Please reduce your request rate", true)]
792 #[case::azure_server_busy("Code: ServerBusy", true)]
793 #[case::azure_egress_limit("Message: Egress is over the account limit", true)]
794 fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) {
795 let err = make_generic_error(msg);
796 assert_eq!(
797 is_throttle_error(&err),
798 expected,
799 "is_throttle_error for '{}' should be {}",
800 msg,
801 expected
802 );
803 }
804
805 #[test]
806 fn test_non_generic_errors_are_not_throttle() {
807 let err = object_store::Error::NotFound {
808 path: "test".to_string(),
809 source: "not found".into(),
810 };
811 assert!(!is_throttle_error(&err));
812 }
813
814 #[tokio::test]
815 async fn test_basic_put_get_through_wrapper() {
816 let store = Arc::new(InMemory::new());
817 let config = AimdThrottleConfig::default();
818 let throttled = AimdThrottledStore::new(store, config).unwrap();
819
820 let path = Path::from("test/file.txt");
821 let data = PutPayload::from_static(b"hello world");
822 throttled.put(&path, data).await.unwrap();
823
824 let result = throttled.get(&path).await.unwrap();
825 let bytes = result.bytes().await.unwrap();
826 assert_eq!(bytes.as_ref(), b"hello world");
827 }
828
829 #[tokio::test]
830 async fn test_rate_decreases_on_throttle() {
831 let store = Arc::new(InMemory::new());
832 let config = AimdThrottleConfig::default().with_aimd(
833 AimdConfig::default()
834 .with_initial_rate(100.0)
835 .with_decrease_factor(0.5)
836 .with_window_duration(std::time::Duration::from_millis(10)),
837 );
838 let throttled = AimdThrottledStore::new(store, config).unwrap();
839
840 let initial_rate = throttled.read.controller.current_rate();
841 assert_eq!(initial_rate, 100.0);
842
843 throttled
845 .read
846 .controller
847 .record_outcome(RequestOutcome::Throttled);
848
849 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
851 throttled
852 .read
853 .controller
854 .record_outcome(RequestOutcome::Success);
855
856 let new_rate = throttled.read.controller.current_rate();
857 assert!(
858 new_rate < initial_rate,
859 "Rate should decrease after throttle: {} < {}",
860 new_rate,
861 initial_rate
862 );
863 }
864
865 #[tokio::test]
866 async fn test_rate_recovers_on_success() {
867 let store = Arc::new(InMemory::new());
868 let config = AimdThrottleConfig::default().with_aimd(
869 AimdConfig::default()
870 .with_initial_rate(100.0)
871 .with_decrease_factor(0.5)
872 .with_additive_increment(10.0)
873 .with_window_duration(std::time::Duration::from_millis(10)),
874 );
875 let throttled = AimdThrottledStore::new(store, config).unwrap();
876
877 throttled
879 .read
880 .controller
881 .record_outcome(RequestOutcome::Throttled);
882 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
883 throttled
884 .read
885 .controller
886 .record_outcome(RequestOutcome::Success);
887 let decreased_rate = throttled.read.controller.current_rate();
888 assert_eq!(decreased_rate, 50.0);
889
890 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
892 throttled
893 .read
894 .controller
895 .record_outcome(RequestOutcome::Success);
896 let recovered_rate = throttled.read.controller.current_rate();
897 assert_eq!(recovered_rate, 60.0);
898 }
899
900 #[tokio::test]
901 async fn test_as_dyn_object_store() {
902 let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
903 let throttled: Arc<dyn ObjectStore> =
904 Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap());
905
906 let path = Path::from("test/data.bin");
907 let data = PutPayload::from_static(b"test data");
908 throttled.put(&path, data).await.unwrap();
909
910 let result = throttled.get(&path).await.unwrap();
911 let bytes = result.bytes().await.unwrap();
912 assert_eq!(bytes.as_ref(), b"test data");
913 }
914
915 #[tokio::test]
916 async fn test_token_bucket_delays_when_exhausted() {
917 let store = Arc::new(InMemory::new());
918 let config = AimdThrottleConfig::default()
920 .with_burst_capacity(1)
921 .with_aimd(AimdConfig::default().with_initial_rate(10.0));
922 let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap());
923
924 let path = Path::from("test/file.txt");
925 let data = PutPayload::from_static(b"data");
926 throttled.put(&path, data).await.unwrap();
927
928 let start = std::time::Instant::now();
931 let data2 = PutPayload::from_static(b"data2");
932 throttled.put(&path, data2).await.unwrap();
933 let elapsed = start.elapsed();
934
935 assert!(
936 elapsed >= std::time::Duration::from_millis(50),
937 "Expected delay for token refill, but elapsed was {:?}",
938 elapsed
939 );
940 }
941
942 #[tokio::test]
943 async fn test_list_observes_outcomes() {
944 let store = Arc::new(InMemory::new());
945 let config = AimdThrottleConfig::default();
946 let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
947
948 let path = Path::from("prefix/file.txt");
949 let data = PutPayload::from_static(b"data");
950 store.put(&path, data).await.unwrap();
951
952 let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
953 assert_eq!(items.len(), 1);
954 assert!(items[0].is_ok());
955 }
956
957 struct ThrottlingListMockStore {
961 inner: InMemory,
962 throttle_count: usize,
964 }
965
966 impl Display for ThrottlingListMockStore {
967 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
968 write!(f, "ThrottlingListMockStore")
969 }
970 }
971
972 impl Debug for ThrottlingListMockStore {
973 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
974 f.debug_struct("ThrottlingListMockStore").finish()
975 }
976 }
977
978 #[async_trait]
979 impl ObjectStore for ThrottlingListMockStore {
980 async fn put_opts(
981 &self,
982 location: &Path,
983 bytes: PutPayload,
984 opts: PutOptions,
985 ) -> OSResult<PutResult> {
986 self.inner.put_opts(location, bytes, opts).await
987 }
988 async fn put_multipart_opts(
989 &self,
990 location: &Path,
991 opts: PutMultipartOptions,
992 ) -> OSResult<Box<dyn MultipartUpload>> {
993 self.inner.put_multipart_opts(location, opts).await
994 }
995 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
996 self.inner.get_opts(location, options).await
997 }
998 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
999 self.inner.get_ranges(location, ranges).await
1000 }
1001 fn delete_stream(
1002 &self,
1003 locations: BoxStream<'static, OSResult<Path>>,
1004 ) -> BoxStream<'static, OSResult<Path>> {
1005 self.inner.delete_stream(locations)
1006 }
1007 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1008 let n = self.throttle_count;
1009 let inner_stream = self.inner.list(prefix);
1010 let errors = futures::stream::iter((0..n).map(|_| {
1011 Err(object_store::Error::Generic {
1012 store: "ThrottlingListMock",
1013 source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s"
1014 .into(),
1015 })
1016 }));
1017 errors.chain(inner_stream).boxed()
1018 }
1019 fn list_with_offset(
1020 &self,
1021 prefix: Option<&Path>,
1022 offset: &Path,
1023 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1024 self.inner.list_with_offset(prefix, offset)
1025 }
1026 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1027 self.inner.list_with_delimiter(prefix).await
1028 }
1029 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1030 self.inner.copy_opts(from, to, opts).await
1031 }
1032 }
1033
1034 #[tokio::test]
1035 async fn test_list_stream_throttle_errors_decrease_rate() {
1036 let mock = Arc::new(ThrottlingListMockStore {
1037 inner: InMemory::new(),
1038 throttle_count: 5,
1039 });
1040
1041 mock.put(
1043 &Path::from("prefix/file.txt"),
1044 PutPayload::from_static(b"data"),
1045 )
1046 .await
1047 .unwrap();
1048
1049 let config = AimdThrottleConfig::default().with_list_aimd(
1050 AimdConfig::default()
1051 .with_initial_rate(100.0)
1052 .with_decrease_factor(0.5)
1053 .with_window_duration(std::time::Duration::from_millis(10)),
1054 );
1055 let throttled = AimdThrottledStore::new(mock as Arc<dyn ObjectStore>, config).unwrap();
1056
1057 let initial_rate = throttled.list.controller.current_rate();
1058 assert_eq!(initial_rate, 100.0);
1059
1060 let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
1061
1062 assert_eq!(items.len(), 6);
1064 assert!(items[0].is_err());
1065 assert!(items[5].is_ok());
1066
1067 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1069 throttled
1070 .list
1071 .controller
1072 .record_outcome(RequestOutcome::Success);
1073
1074 let new_rate = throttled.list.controller.current_rate();
1075 assert!(
1076 new_rate < initial_rate,
1077 "List rate should decrease after stream throttle errors: {} < {}",
1078 new_rate,
1079 initial_rate
1080 );
1081 }
1082
1083 struct CountingListStartStore {
1084 inner: InMemory,
1085 list_calls: AtomicUsize,
1086 offset_calls: AtomicUsize,
1087 }
1088
1089 impl Default for CountingListStartStore {
1090 fn default() -> Self {
1091 Self {
1092 inner: InMemory::new(),
1093 list_calls: AtomicUsize::new(0),
1094 offset_calls: AtomicUsize::new(0),
1095 }
1096 }
1097 }
1098
1099 impl CountingListStartStore {
1100 fn list_calls(&self) -> usize {
1101 self.list_calls.load(Ordering::SeqCst)
1102 }
1103
1104 fn offset_calls(&self) -> usize {
1105 self.offset_calls.load(Ordering::SeqCst)
1106 }
1107 }
1108
1109 impl Display for CountingListStartStore {
1110 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1111 write!(f, "CountingListStartStore")
1112 }
1113 }
1114
1115 impl Debug for CountingListStartStore {
1116 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1117 f.debug_struct("CountingListStartStore").finish()
1118 }
1119 }
1120
1121 #[async_trait]
1122 impl ObjectStore for CountingListStartStore {
1123 async fn put_opts(
1124 &self,
1125 location: &Path,
1126 bytes: PutPayload,
1127 opts: PutOptions,
1128 ) -> OSResult<PutResult> {
1129 self.inner.put_opts(location, bytes, opts).await
1130 }
1131
1132 async fn put_multipart_opts(
1133 &self,
1134 location: &Path,
1135 opts: PutMultipartOptions,
1136 ) -> OSResult<Box<dyn MultipartUpload>> {
1137 self.inner.put_multipart_opts(location, opts).await
1138 }
1139
1140 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1141 self.inner.get_opts(location, options).await
1142 }
1143
1144 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1145 self.inner.get_ranges(location, ranges).await
1146 }
1147
1148 fn delete_stream(
1149 &self,
1150 locations: BoxStream<'static, OSResult<Path>>,
1151 ) -> BoxStream<'static, OSResult<Path>> {
1152 self.inner.delete_stream(locations)
1153 }
1154
1155 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1156 self.list_calls.fetch_add(1, Ordering::SeqCst);
1157 self.inner.list(prefix)
1158 }
1159
1160 fn list_with_offset(
1161 &self,
1162 prefix: Option<&Path>,
1163 offset: &Path,
1164 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1165 self.offset_calls.fetch_add(1, Ordering::SeqCst);
1166 self.inner.list_with_offset(prefix, offset)
1167 }
1168
1169 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1170 self.inner.list_with_delimiter(prefix).await
1171 }
1172
1173 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1174 self.inner.copy_opts(from, to, opts).await
1175 }
1176 }
1177
1178 fn list_start_throttle_config() -> AimdThrottleConfig {
1179 AimdThrottleConfig::default()
1183 .with_burst_capacity(0)
1184 .with_list_aimd(AimdConfig::default().with_initial_rate(10.0))
1185 }
1186
1187 #[tokio::test(start_paused = true)]
1188 async fn test_list_acquires_token_before_starting_underlying_stream() {
1189 let store = Arc::new(CountingListStartStore::default());
1190 store
1191 .put(
1192 &Path::from("prefix/file.txt"),
1193 PutPayload::from_static(b"data"),
1194 )
1195 .await
1196 .unwrap();
1197 let throttled = AimdThrottledStore::new(
1198 store.clone() as Arc<dyn ObjectStore>,
1199 list_start_throttle_config(),
1200 )
1201 .unwrap();
1202
1203 let mut stream = throttled.list(Some(&Path::from("prefix")));
1204 assert_eq!(store.list_calls(), 0);
1205 assert!(
1208 tokio::time::timeout(std::time::Duration::from_millis(50), stream.next())
1209 .await
1210 .is_err()
1211 );
1212 assert_eq!(store.list_calls(), 0);
1213
1214 let item = tokio::time::timeout(std::time::Duration::from_millis(300), stream.next())
1215 .await
1216 .unwrap()
1217 .unwrap()
1218 .unwrap();
1219 assert_eq!(item.location, Path::from("prefix/file.txt"));
1220 assert_eq!(store.list_calls(), 1);
1221 }
1222
1223 #[tokio::test(start_paused = true)]
1224 async fn test_list_with_offset_acquires_token_before_starting_underlying_stream() {
1225 let store = Arc::new(CountingListStartStore::default());
1226 store
1227 .put(&Path::from("prefix/b"), PutPayload::from_static(b"data"))
1228 .await
1229 .unwrap();
1230 let throttled = AimdThrottledStore::new(
1231 store.clone() as Arc<dyn ObjectStore>,
1232 list_start_throttle_config(),
1233 )
1234 .unwrap();
1235
1236 let mut stream =
1237 throttled.list_with_offset(Some(&Path::from("prefix")), &Path::from("prefix/a"));
1238 assert_eq!(store.offset_calls(), 0);
1239 assert!(
1242 tokio::time::timeout(std::time::Duration::from_millis(50), stream.next())
1243 .await
1244 .is_err()
1245 );
1246 assert_eq!(store.offset_calls(), 0);
1247
1248 let item = tokio::time::timeout(std::time::Duration::from_millis(300), stream.next())
1249 .await
1250 .unwrap()
1251 .unwrap()
1252 .unwrap();
1253 assert_eq!(item.location, Path::from("prefix/b"));
1254 assert_eq!(store.offset_calls(), 1);
1255 }
1256
1257 #[tokio::test]
1258 async fn test_per_category_independence() {
1259 let store = Arc::new(InMemory::new());
1260 let config = AimdThrottleConfig::default().with_aimd(
1261 AimdConfig::default()
1262 .with_initial_rate(100.0)
1263 .with_decrease_factor(0.5)
1264 .with_window_duration(std::time::Duration::from_millis(10)),
1265 );
1266 let throttled = AimdThrottledStore::new(store, config).unwrap();
1267
1268 throttled
1270 .read
1271 .controller
1272 .record_outcome(RequestOutcome::Throttled);
1273 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1274 throttled
1275 .read
1276 .controller
1277 .record_outcome(RequestOutcome::Success);
1278
1279 let read_rate = throttled.read.controller.current_rate();
1280 let write_rate = throttled.write.controller.current_rate();
1281 let delete_rate = throttled.delete.controller.current_rate();
1282 let list_rate = throttled.list.controller.current_rate();
1283
1284 assert_eq!(read_rate, 50.0, "Read rate should have decreased");
1285 assert_eq!(write_rate, 100.0, "Write rate should be unaffected");
1286 assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected");
1287 assert_eq!(list_rate, 100.0, "List rate should be unaffected");
1288 }
1289
1290 #[tokio::test]
1291 async fn test_per_category_config() {
1292 let store = Arc::new(InMemory::new());
1293 let config = AimdThrottleConfig::default()
1294 .with_read_aimd(AimdConfig::default().with_initial_rate(200.0))
1295 .with_write_aimd(AimdConfig::default().with_initial_rate(100.0))
1296 .with_delete_aimd(AimdConfig::default().with_initial_rate(50.0))
1297 .with_list_aimd(AimdConfig::default().with_initial_rate(25.0));
1298 let throttled = AimdThrottledStore::new(store, config).unwrap();
1299
1300 assert_eq!(throttled.read.controller.current_rate(), 200.0);
1301 assert_eq!(throttled.write.controller.current_rate(), 100.0);
1302 assert_eq!(throttled.delete.controller.current_rate(), 50.0);
1303 assert_eq!(throttled.list.controller.current_rate(), 25.0);
1304 }
1305
1306 struct RateLimitingMockStore {
1310 inner: InMemory,
1311 timestamps: std::sync::Mutex<VecDeque<std::time::Instant>>,
1313 max_per_window: usize,
1315 window: std::time::Duration,
1317 success_count: AtomicU64,
1318 throttle_count: AtomicU64,
1319 }
1320
1321 impl RateLimitingMockStore {
1322 fn new(max_per_window: usize, window: std::time::Duration) -> Self {
1323 Self {
1324 inner: InMemory::new(),
1325 timestamps: std::sync::Mutex::new(VecDeque::new()),
1326 max_per_window,
1327 window,
1328 success_count: AtomicU64::new(0),
1329 throttle_count: AtomicU64::new(0),
1330 }
1331 }
1332
1333 fn check_rate(&self) -> bool {
1335 let mut ts = self.timestamps.lock().unwrap();
1336 let now = std::time::Instant::now();
1337 while let Some(&front) = ts.front() {
1338 if now.duration_since(front) > self.window {
1339 ts.pop_front();
1340 } else {
1341 break;
1342 }
1343 }
1344 if ts.len() >= self.max_per_window {
1345 self.throttle_count.fetch_add(1, Ordering::Relaxed);
1346 false
1347 } else {
1348 ts.push_back(now);
1349 self.success_count.fetch_add(1, Ordering::Relaxed);
1350 true
1351 }
1352 }
1353
1354 fn throttle_error() -> object_store::Error {
1355 object_store::Error::Generic {
1356 store: "RateLimitingMock",
1357 source: THROTTLE_ERROR_RESPONSE.into(),
1358 }
1359 }
1360 }
1361
1362 impl Display for RateLimitingMockStore {
1363 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1364 write!(f, "RateLimitingMockStore")
1365 }
1366 }
1367
1368 impl Debug for RateLimitingMockStore {
1369 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1370 f.debug_struct("RateLimitingMockStore").finish()
1371 }
1372 }
1373
1374 #[async_trait]
1375 impl ObjectStore for RateLimitingMockStore {
1376 async fn put_opts(
1377 &self,
1378 location: &Path,
1379 bytes: PutPayload,
1380 opts: PutOptions,
1381 ) -> OSResult<PutResult> {
1382 self.inner.put_opts(location, bytes, opts).await
1383 }
1384
1385 async fn put_multipart_opts(
1386 &self,
1387 location: &Path,
1388 opts: PutMultipartOptions,
1389 ) -> OSResult<Box<dyn MultipartUpload>> {
1390 self.inner.put_multipart_opts(location, opts).await
1391 }
1392
1393 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1394 if self.check_rate() {
1395 self.inner.get_opts(location, options).await
1396 } else {
1397 Err(Self::throttle_error())
1398 }
1399 }
1400
1401 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1402 if self.check_rate() {
1403 self.inner.get_ranges(location, ranges).await
1404 } else {
1405 Err(Self::throttle_error())
1406 }
1407 }
1408
1409 fn delete_stream(
1410 &self,
1411 locations: BoxStream<'static, OSResult<Path>>,
1412 ) -> BoxStream<'static, OSResult<Path>> {
1413 self.inner.delete_stream(locations)
1414 }
1415
1416 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1417 self.inner.list(prefix)
1418 }
1419
1420 fn list_with_offset(
1421 &self,
1422 prefix: Option<&Path>,
1423 offset: &Path,
1424 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1425 self.inner.list_with_offset(prefix, offset)
1426 }
1427
1428 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1429 self.inner.list_with_delimiter(prefix).await
1430 }
1431
1432 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1433 self.inner.copy_opts(from, to, opts).await
1434 }
1435 }
1436
1437 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1454 async fn test_aimd_throttle_under_concurrent_load() {
1455 let mock = Arc::new(RateLimitingMockStore::new(
1456 30,
1457 std::time::Duration::from_millis(100),
1458 ));
1459
1460 let path = Path::from("test/data.bin");
1462 mock.put(&path, PutPayload::from_static(b"test data"))
1463 .await
1464 .unwrap();
1465
1466 let aimd = AimdConfig::default()
1467 .with_initial_rate(100.0)
1468 .with_decrease_factor(0.5)
1469 .with_additive_increment(2.0)
1470 .with_window_duration(std::time::Duration::from_millis(100));
1471 let throttle_config = AimdThrottleConfig::default()
1472 .with_aimd(aimd)
1473 .with_burst_capacity(100);
1474
1475 let num_readers = 5;
1476 let test_duration = std::time::Duration::from_secs(2);
1477 let mut handles = Vec::new();
1478
1479 for _ in 0..num_readers {
1480 let store = Arc::new(
1481 AimdThrottledStore::new(
1482 mock.clone() as Arc<dyn ObjectStore>,
1483 throttle_config.clone(),
1484 )
1485 .unwrap(),
1486 );
1487 let p = path.clone();
1488 handles.push(tokio::spawn(async move {
1489 let deadline = std::time::Instant::now() + test_duration;
1490 let mut count = 0u64;
1491 while std::time::Instant::now() < deadline {
1492 let _ = store.head(&p).await;
1493 count += 1;
1494 }
1495 count
1496 }));
1497 }
1498
1499 let mut total_reader_requests = 0u64;
1500 for handle in handles {
1501 total_reader_requests += handle.await.unwrap();
1502 }
1503
1504 let successes = mock.success_count.load(Ordering::Relaxed);
1505 let throttled = mock.throttle_count.load(Ordering::Relaxed);
1506 let total_mock = successes + throttled;
1507
1508 assert!(
1511 total_mock >= total_reader_requests,
1512 "Mock-side count ({total_mock}) should be >= reader-side count ({total_reader_requests})"
1513 );
1514
1515 assert!(
1518 successes >= 300,
1519 "Expected >= 300 successful requests over 2s, got {successes}"
1520 );
1521 assert!(
1522 successes <= 900,
1523 "Expected <= 900 successful requests, got {successes}"
1524 );
1525
1526 assert!(throttled > 0, "Expected some throttled requests but got 0");
1528
1529 assert!(
1532 total_mock <= 5000,
1533 "AIMD should limit total requests, got {total_mock}"
1534 );
1535 }
1536
1537 struct RetryTestMockStore {
1541 inner: InMemory,
1542 errors_remaining: std::sync::Mutex<usize>,
1544 get_call_count: AtomicU64,
1546 }
1547
1548 impl RetryTestMockStore {
1549 fn new(errors_before_success: usize) -> Self {
1550 Self {
1551 inner: InMemory::new(),
1552 errors_remaining: std::sync::Mutex::new(errors_before_success),
1553 get_call_count: AtomicU64::new(0),
1554 }
1555 }
1556 }
1557
1558 impl Display for RetryTestMockStore {
1559 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1560 write!(f, "RetryTestMockStore")
1561 }
1562 }
1563
1564 impl Debug for RetryTestMockStore {
1565 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1566 f.debug_struct("RetryTestMockStore").finish()
1567 }
1568 }
1569
1570 #[async_trait]
1571 impl ObjectStore for RetryTestMockStore {
1572 async fn put_opts(
1573 &self,
1574 location: &Path,
1575 bytes: PutPayload,
1576 opts: PutOptions,
1577 ) -> OSResult<PutResult> {
1578 self.inner.put_opts(location, bytes, opts).await
1579 }
1580 async fn put_multipart_opts(
1581 &self,
1582 location: &Path,
1583 opts: PutMultipartOptions,
1584 ) -> OSResult<Box<dyn MultipartUpload>> {
1585 self.inner.put_multipart_opts(location, opts).await
1586 }
1587 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1588 self.get_call_count.fetch_add(1, Ordering::Relaxed);
1589 let should_error = {
1590 let mut remaining = self.errors_remaining.lock().unwrap();
1591 if *remaining > 0 {
1592 *remaining -= 1;
1593 true
1594 } else {
1595 false
1596 }
1597 };
1598 if should_error {
1599 Err(object_store::Error::Generic {
1600 store: "RetryTestMock",
1601 source: THROTTLE_ERROR_RESPONSE.into(),
1602 })
1603 } else {
1604 self.inner.get_opts(location, options).await
1605 }
1606 }
1607 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1608 self.inner.get_ranges(location, ranges).await
1609 }
1610 fn delete_stream(
1611 &self,
1612 locations: BoxStream<'static, OSResult<Path>>,
1613 ) -> BoxStream<'static, OSResult<Path>> {
1614 self.inner.delete_stream(locations)
1615 }
1616 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1617 self.inner.list(prefix)
1618 }
1619 fn list_with_offset(
1620 &self,
1621 prefix: Option<&Path>,
1622 offset: &Path,
1623 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1624 self.inner.list_with_offset(prefix, offset)
1625 }
1626 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1627 self.inner.list_with_delimiter(prefix).await
1628 }
1629 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1630 self.inner.copy_opts(from, to, opts).await
1631 }
1632 }
1633
1634 #[tokio::test]
1635 async fn test_throttled_retries_on_throttle_error_then_succeeds() {
1636 let mock = Arc::new(RetryTestMockStore::new(2));
1638 let path = Path::from("test/retry.txt");
1639 mock.put(&path, PutPayload::from_static(b"retry data"))
1640 .await
1641 .unwrap();
1642
1643 let config = AimdThrottleConfig::default();
1644 let throttled =
1645 AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1646
1647 let result = throttled.get(&path).await;
1648 assert!(result.is_ok(), "Expected success after retries");
1649
1650 let bytes = result.unwrap().bytes().await.unwrap();
1651 assert_eq!(bytes.as_ref(), b"retry data");
1652
1653 assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 3);
1655 }
1656
1657 #[tokio::test]
1658 async fn test_throttled_fails_after_max_retries_exceeded() {
1659 let mock = Arc::new(RetryTestMockStore::new(10));
1662 let path = Path::from("test/fail.txt");
1663 mock.put(&path, PutPayload::from_static(b"fail data"))
1664 .await
1665 .unwrap();
1666
1667 let config = AimdThrottleConfig::default();
1668 let throttled =
1669 AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1670
1671 let result = throttled.get(&path).await;
1672 assert!(result.is_err(), "Expected error after max retries");
1673 let err = result.unwrap_err();
1674 assert!(is_throttle_error(&err));
1675
1676 let lance_error = lance_core::Error::from(err);
1677 let error_message = lance_error.to_string();
1678 assert!(error_message.contains("x-ms-request-id"));
1679 assert!(error_message.contains("azure-request-id"));
1680
1681 assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4);
1683 }
1684
1685 #[tokio::test]
1686 async fn test_throttled_multipart_reorders_parts() {
1687 let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
1688 let config = AimdThrottleConfig::default();
1689 let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
1690
1691 let path = Path::from("test/multipart_ordering.bin");
1692 let mut upload = throttled.put_multipart(&path).await.unwrap();
1693
1694 let fut_a = upload.put_part(PutPayload::from_static(b"AAAA"));
1696 let fut_b = upload.put_part(PutPayload::from_static(b"BBBB"));
1697
1698 fut_b.await.unwrap();
1701 fut_a.await.unwrap();
1702
1703 upload.complete().await.unwrap();
1704
1705 let result = store.get(&path).await.unwrap();
1706 let bytes = result.bytes().await.unwrap();
1707
1708 assert_eq!(
1709 bytes.as_ref(),
1710 b"AAAABBBB",
1711 "Parts were reordered! Got {:?} instead of AAAABBBB.",
1712 std::str::from_utf8(&bytes).unwrap_or("<non-utf8>"),
1713 );
1714 }
1715}