1use std::collections::HashMap;
24use std::fmt::{Debug, Display, Formatter};
25use std::ops::Range;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use bytes::Bytes;
30use futures::StreamExt;
31use futures::stream::BoxStream;
32use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome};
33use object_store::path::Path;
34use object_store::{
35 GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
36 PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
37};
38use rand::Rng;
39use tokio::sync::Mutex;
40use tracing::{debug, warn};
41
42pub fn is_throttle_error(err: &object_store::Error) -> bool {
60 if let object_store::Error::Generic { source, .. } = err {
62 source.to_string().contains("retries, max_retries")
63 } else {
64 false
65 }
66}
67
68#[derive(Debug, Clone)]
75pub struct AimdThrottleConfig {
76 pub read: AimdConfig,
78 pub write: AimdConfig,
80 pub delete: AimdConfig,
82 pub list: AimdConfig,
84 pub burst_capacity: u32,
86 pub max_retries: usize,
88 pub min_backoff_ms: u64,
90 pub max_backoff_ms: u64,
92}
93
94impl Default for AimdThrottleConfig {
95 fn default() -> Self {
96 let aimd = AimdConfig::default();
97 Self {
98 read: aimd.clone(),
99 write: aimd.clone(),
100 delete: aimd.clone(),
101 list: aimd,
102 burst_capacity: 100,
103 max_retries: 3,
104 min_backoff_ms: 100,
105 max_backoff_ms: 300,
106 }
107 }
108}
109
110impl AimdThrottleConfig {
111 pub fn with_aimd(self, aimd: AimdConfig) -> Self {
113 Self {
114 read: aimd.clone(),
115 write: aimd.clone(),
116 delete: aimd.clone(),
117 list: aimd,
118 ..self
119 }
120 }
121
122 pub fn with_read_aimd(self, aimd: AimdConfig) -> Self {
124 Self { read: aimd, ..self }
125 }
126
127 pub fn with_write_aimd(self, aimd: AimdConfig) -> Self {
129 Self {
130 write: aimd,
131 ..self
132 }
133 }
134
135 pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self {
137 Self {
138 delete: aimd,
139 ..self
140 }
141 }
142
143 pub fn with_list_aimd(self, aimd: AimdConfig) -> Self {
145 Self { list: aimd, ..self }
146 }
147
148 pub fn is_disabled(&self) -> bool {
150 self.max_retries == 0
151 }
152
153 pub fn with_burst_capacity(self, burst_capacity: u32) -> Self {
154 Self {
155 burst_capacity,
156 ..self
157 }
158 }
159
160 pub fn from_storage_options(
178 storage_options: Option<&HashMap<String, String>>,
179 ) -> lance_core::Result<Self> {
180 fn resolve_f64(
181 key: &str,
182 storage_options: Option<&HashMap<String, String>>,
183 default: f64,
184 ) -> lance_core::Result<f64> {
185 let env_key = key.to_ascii_uppercase();
186 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
187 val.parse::<f64>().map_err(|_| {
188 lance_core::Error::invalid_input(format!(
189 "Invalid value for storage option '{key}': '{val}'"
190 ))
191 })
192 } else if let Ok(val) = std::env::var(&env_key) {
193 val.parse::<f64>().map_err(|_| {
194 lance_core::Error::invalid_input(format!(
195 "Invalid value for env var '{env_key}': '{val}'"
196 ))
197 })
198 } else {
199 Ok(default)
200 }
201 }
202
203 fn resolve_u32(
204 key: &str,
205 storage_options: Option<&HashMap<String, String>>,
206 default: u32,
207 ) -> lance_core::Result<u32> {
208 let env_key = key.to_ascii_uppercase();
209 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
210 val.parse::<u32>().map_err(|_| {
211 lance_core::Error::invalid_input(format!(
212 "Invalid value for storage option '{key}': '{val}'"
213 ))
214 })
215 } else if let Ok(val) = std::env::var(&env_key) {
216 val.parse::<u32>().map_err(|_| {
217 lance_core::Error::invalid_input(format!(
218 "Invalid value for env var '{env_key}': '{val}'"
219 ))
220 })
221 } else {
222 Ok(default)
223 }
224 }
225
226 fn resolve_usize(
227 key: &str,
228 storage_options: Option<&HashMap<String, String>>,
229 default: usize,
230 ) -> lance_core::Result<usize> {
231 let env_key = key.to_ascii_uppercase();
232 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
233 val.parse::<usize>().map_err(|_| {
234 lance_core::Error::invalid_input(format!(
235 "Invalid value for storage option '{key}': '{val}'"
236 ))
237 })
238 } else if let Ok(val) = std::env::var(&env_key) {
239 val.parse::<usize>().map_err(|_| {
240 lance_core::Error::invalid_input(format!(
241 "Invalid value for env var '{env_key}': '{val}'"
242 ))
243 })
244 } else {
245 Ok(default)
246 }
247 }
248
249 fn resolve_u64(
250 key: &str,
251 storage_options: Option<&HashMap<String, String>>,
252 default: u64,
253 ) -> lance_core::Result<u64> {
254 let env_key = key.to_ascii_uppercase();
255 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
256 val.parse::<u64>().map_err(|_| {
257 lance_core::Error::invalid_input(format!(
258 "Invalid value for storage option '{key}': '{val}'"
259 ))
260 })
261 } else if let Ok(val) = std::env::var(&env_key) {
262 val.parse::<u64>().map_err(|_| {
263 lance_core::Error::invalid_input(format!(
264 "Invalid value for env var '{env_key}': '{val}'"
265 ))
266 })
267 } else {
268 Ok(default)
269 }
270 }
271
272 let initial_rate = resolve_f64("lance_aimd_initial_rate", storage_options, 2000.0)?;
273 let min_rate = resolve_f64("lance_aimd_min_rate", storage_options, 1.0)?;
274 let max_rate = resolve_f64("lance_aimd_max_rate", storage_options, 5000.0)?;
275 let decrease_factor = resolve_f64("lance_aimd_decrease_factor", storage_options, 0.5)?;
276 let additive_increment =
277 resolve_f64("lance_aimd_additive_increment", storage_options, 300.0)?;
278 let burst_capacity = resolve_u32("lance_aimd_burst_capacity", storage_options, 100)?;
279 let max_retries = resolve_usize("lance_aimd_max_retries", storage_options, 3)?;
280 let min_backoff_ms = resolve_u64("lance_aimd_min_backoff_ms", storage_options, 100)?;
281 let max_backoff_ms = resolve_u64("lance_aimd_max_backoff_ms", storage_options, 300)?;
282
283 let aimd = AimdConfig::default()
284 .with_initial_rate(initial_rate)
285 .with_min_rate(min_rate)
286 .with_max_rate(max_rate)
287 .with_decrease_factor(decrease_factor)
288 .with_additive_increment(additive_increment);
289
290 Ok(Self {
291 max_retries,
292 min_backoff_ms,
293 max_backoff_ms,
294 ..Self::default()
295 .with_aimd(aimd)
296 .with_burst_capacity(burst_capacity)
297 })
298 }
299}
300
301struct TokenBucketState {
302 tokens: f64,
303 last_refill: std::time::Instant,
304 rate: f64,
305}
306
307struct OperationThrottle {
309 controller: AimdController,
310 bucket: Mutex<TokenBucketState>,
311 burst_capacity: f64,
312 max_retries: usize,
313 min_backoff_ms: u64,
314 max_backoff_ms: u64,
315}
316
317impl OperationThrottle {
318 fn new(
319 aimd_config: AimdConfig,
320 burst_capacity: f64,
321 max_retries: usize,
322 min_backoff_ms: u64,
323 max_backoff_ms: u64,
324 ) -> lance_core::Result<Self> {
325 let initial_rate = aimd_config.initial_rate;
326 let controller = AimdController::new(aimd_config)?;
327 Ok(Self {
328 controller,
329 bucket: Mutex::new(TokenBucketState {
330 tokens: burst_capacity,
331 last_refill: std::time::Instant::now(),
332 rate: initial_rate,
333 }),
334 burst_capacity,
335 max_retries,
336 min_backoff_ms,
337 max_backoff_ms,
338 })
339 }
340
341 async fn acquire_token(&self) {
347 let sleep_duration = {
348 let mut bucket = self.bucket.lock().await;
349 let now = std::time::Instant::now();
350 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
351 bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity);
352 bucket.last_refill = now;
353
354 bucket.tokens -= 1.0;
356
357 if bucket.tokens >= 0.0 {
358 return;
360 }
361
362 std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate)
364 };
365
366 tokio::time::sleep(sleep_duration).await;
367 }
368
369 async fn update_bucket_rate(&self, new_rate: f64) {
371 let mut bucket = self.bucket.lock().await;
372 bucket.rate = new_rate;
373 }
374
375 fn observe_outcome<T>(&self, result: &OSResult<T>) {
380 let outcome = match result {
381 Ok(_) => RequestOutcome::Success,
382 Err(err) if is_throttle_error(err) => {
383 debug!("Throttle error detected in stream");
384 RequestOutcome::Throttled
385 }
386 Err(_) => RequestOutcome::Success,
387 };
388 let prev_rate = self.controller.current_rate();
389 let new_rate = self.controller.record_outcome(outcome);
390 if new_rate < prev_rate {
391 warn!(
392 previous_rate = format!("{prev_rate:.1}"),
393 new_rate = format!("{new_rate:.1}"),
394 "AIMD throttle: rate reduced due to throttle errors"
395 );
396 }
397 if let Ok(mut bucket) = self.bucket.try_lock() {
398 bucket.rate = new_rate;
399 }
400 }
401
402 async fn throttled<T, F, Fut>(&self, f: F) -> OSResult<T>
406 where
407 F: Fn() -> Fut,
408 Fut: std::future::Future<Output = OSResult<T>>,
409 {
410 for attempt in 0..=self.max_retries {
411 self.acquire_token().await;
412 let result = f().await;
413 let outcome = match &result {
414 Ok(_) => RequestOutcome::Success,
415 Err(err) if is_throttle_error(err) => {
416 debug!("Throttle error detected");
417 RequestOutcome::Throttled
418 }
419 Err(_) => RequestOutcome::Success, };
421 let prev_rate = self.controller.current_rate();
422 let new_rate = self.controller.record_outcome(outcome);
423 if new_rate < prev_rate {
424 warn!(
425 previous_rate = format!("{prev_rate:.1}"),
426 new_rate = format!("{new_rate:.1}"),
427 "AIMD throttle: rate reduced due to throttle errors"
428 );
429 }
430 self.update_bucket_rate(new_rate).await;
431
432 match &result {
433 Err(err) if is_throttle_error(err) && attempt < self.max_retries => {
434 let backoff_ms =
435 rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms);
436 debug!(
437 attempt = attempt + 1,
438 max_retries = self.max_retries,
439 backoff_ms,
440 "Retrying after throttle error"
441 );
442 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
443 continue;
444 }
445 _ => return result,
446 }
447 }
448 unreachable!()
449 }
450}
451
452impl Debug for OperationThrottle {
453 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
454 f.debug_struct("OperationThrottle")
455 .field("controller", &self.controller)
456 .field("burst_capacity", &self.burst_capacity)
457 .finish()
458 }
459}
460
461struct ThrottledMultipartUpload {
465 target: Box<dyn MultipartUpload>,
466 write: Arc<OperationThrottle>,
467}
468
469impl Debug for ThrottledMultipartUpload {
470 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
471 f.debug_struct("ThrottledMultipartUpload").finish()
472 }
473}
474
475#[async_trait]
476impl MultipartUpload for ThrottledMultipartUpload {
477 fn put_part(&mut self, data: PutPayload) -> UploadPart {
478 let write = Arc::clone(&self.write);
479 let fut = self.target.put_part(data);
482 Box::pin(async move {
483 write.acquire_token().await;
484 let result = fut.await;
485 write.observe_outcome(&result);
486 result
487 })
488 }
489
490 async fn complete(&mut self) -> OSResult<PutResult> {
491 let target = &mut self.target;
492 for attempt in 0..=self.write.max_retries {
493 self.write.acquire_token().await;
494 let result = target.complete().await;
495 self.write.observe_outcome(&result);
496
497 match &result {
498 Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
499 let backoff_ms = rand::rng()
500 .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
501 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
502 continue;
503 }
504 _ => return result,
505 }
506 }
507 unreachable!()
508 }
509
510 async fn abort(&mut self) -> OSResult<()> {
511 let target = &mut self.target;
512 for attempt in 0..=self.write.max_retries {
513 self.write.acquire_token().await;
514 let result = target.abort().await;
515 self.write.observe_outcome(&result);
516
517 match &result {
518 Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
519 let backoff_ms = rand::rng()
520 .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
521 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
522 continue;
523 }
524 _ => return result,
525 }
526 }
527 unreachable!()
528 }
529}
530
531pub struct AimdThrottledStore {
547 target: Arc<dyn ObjectStore>,
548 read: Arc<OperationThrottle>,
549 write: Arc<OperationThrottle>,
550 delete: Arc<OperationThrottle>,
551 list: Arc<OperationThrottle>,
552}
553
554impl Debug for AimdThrottledStore {
555 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
556 f.debug_struct("AimdThrottledStore")
557 .field("target", &self.target)
558 .field("read", &self.read)
559 .field("write", &self.write)
560 .field("delete", &self.delete)
561 .field("list", &self.list)
562 .finish()
563 }
564}
565
566impl Display for AimdThrottledStore {
567 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
568 write!(f, "AimdThrottledStore({})", self.target)
569 }
570}
571
572impl AimdThrottledStore {
573 pub fn new(
574 target: Arc<dyn ObjectStore>,
575 config: AimdThrottleConfig,
576 ) -> lance_core::Result<Self> {
577 let burst = config.burst_capacity as f64;
578 let max_retries = config.max_retries;
579 let min_backoff_ms = config.min_backoff_ms;
580 let max_backoff_ms = config.max_backoff_ms;
581 Ok(Self {
582 target,
583 read: Arc::new(OperationThrottle::new(
584 config.read,
585 burst,
586 max_retries,
587 min_backoff_ms,
588 max_backoff_ms,
589 )?),
590 write: Arc::new(OperationThrottle::new(
591 config.write,
592 burst,
593 max_retries,
594 min_backoff_ms,
595 max_backoff_ms,
596 )?),
597 delete: Arc::new(OperationThrottle::new(
598 config.delete,
599 burst,
600 max_retries,
601 min_backoff_ms,
602 max_backoff_ms,
603 )?),
604 list: Arc::new(OperationThrottle::new(
605 config.list,
606 burst,
607 max_retries,
608 min_backoff_ms,
609 max_backoff_ms,
610 )?),
611 })
612 }
613}
614
615#[async_trait]
616#[deny(clippy::missing_trait_methods)]
617impl ObjectStore for AimdThrottledStore {
618 async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
619 self.write
620 .throttled(|| self.target.put(location, bytes.clone()))
621 .await
622 }
623
624 async fn put_opts(
625 &self,
626 location: &Path,
627 bytes: PutPayload,
628 opts: PutOptions,
629 ) -> OSResult<PutResult> {
630 self.write
631 .throttled(|| self.target.put_opts(location, bytes.clone(), opts.clone()))
632 .await
633 }
634
635 async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
636 let target = self
637 .write
638 .throttled(|| self.target.put_multipart(location))
639 .await?;
640 Ok(Box::new(ThrottledMultipartUpload {
641 target,
642 write: Arc::clone(&self.write),
643 }))
644 }
645
646 async fn put_multipart_opts(
647 &self,
648 location: &Path,
649 opts: PutMultipartOptions,
650 ) -> OSResult<Box<dyn MultipartUpload>> {
651 let target = self
652 .write
653 .throttled(|| self.target.put_multipart_opts(location, opts.clone()))
654 .await?;
655 Ok(Box::new(ThrottledMultipartUpload {
656 target,
657 write: Arc::clone(&self.write),
658 }))
659 }
660
661 async fn get(&self, location: &Path) -> OSResult<GetResult> {
662 self.read.throttled(|| self.target.get(location)).await
663 }
664
665 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
666 self.read
667 .throttled(|| self.target.get_opts(location, options.clone()))
668 .await
669 }
670
671 async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
672 self.read
673 .throttled(|| self.target.get_range(location, range.clone()))
674 .await
675 }
676
677 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
678 self.read
679 .throttled(|| self.target.get_ranges(location, ranges))
680 .await
681 }
682
683 async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
684 self.read.throttled(|| self.target.head(location)).await
685 }
686
687 async fn delete(&self, location: &Path) -> OSResult<()> {
688 self.delete.throttled(|| self.target.delete(location)).await
689 }
690
691 fn delete_stream<'a>(
692 &'a self,
693 locations: BoxStream<'a, OSResult<Path>>,
694 ) -> BoxStream<'a, OSResult<Path>> {
695 self.target
696 .delete_stream(locations)
697 .map(|item| {
698 self.delete.observe_outcome(&item);
699 item
700 })
701 .boxed()
702 }
703
704 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
705 let throttle = Arc::clone(&self.list);
706 self.target
707 .list(prefix)
708 .map(move |item| {
709 throttle.observe_outcome(&item);
710 item
711 })
712 .boxed()
713 }
714
715 fn list_with_offset(
716 &self,
717 prefix: Option<&Path>,
718 offset: &Path,
719 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
720 let throttle = Arc::clone(&self.list);
721 self.target
722 .list_with_offset(prefix, offset)
723 .map(move |item| {
724 throttle.observe_outcome(&item);
725 item
726 })
727 .boxed()
728 }
729
730 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
731 self.list
732 .throttled(|| self.target.list_with_delimiter(prefix))
733 .await
734 }
735
736 async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
737 self.write.throttled(|| self.target.copy(from, to)).await
738 }
739
740 async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
741 self.write.throttled(|| self.target.rename(from, to)).await
742 }
743
744 async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
745 self.write
746 .throttled(|| self.target.rename_if_not_exists(from, to))
747 .await
748 }
749
750 async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
751 self.write
752 .throttled(|| self.target.copy_if_not_exists(from, to))
753 .await
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760 use object_store::memory::InMemory;
761 use rstest::rstest;
762 use std::collections::VecDeque;
763 use std::sync::atomic::{AtomicU64, Ordering};
764
765 fn make_generic_error(msg: &str) -> object_store::Error {
766 object_store::Error::Generic {
767 store: "test",
768 source: msg.into(),
769 }
770 }
771
772 #[rstest]
773 #[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)]
774 #[case::retries_in_message(
775 "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s",
776 true
777 )]
778 #[case::not_found("Object not found", false)]
779 #[case::permission_denied("Access denied", false)]
780 #[case::timeout("Connection timed out", false)]
781 #[case::http_429_without_retries("HTTP 429 Too Many Requests", false)]
782 #[case::slowdown_without_retries("SlowDown: Please reduce your request rate", false)]
783 fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) {
784 let err = make_generic_error(msg);
785 assert_eq!(
786 is_throttle_error(&err),
787 expected,
788 "is_throttle_error for '{}' should be {}",
789 msg,
790 expected
791 );
792 }
793
794 #[test]
795 fn test_non_generic_errors_are_not_throttle() {
796 let err = object_store::Error::NotFound {
797 path: "test".to_string(),
798 source: "not found".into(),
799 };
800 assert!(!is_throttle_error(&err));
801 }
802
803 #[tokio::test]
804 async fn test_basic_put_get_through_wrapper() {
805 let store = Arc::new(InMemory::new());
806 let config = AimdThrottleConfig::default();
807 let throttled = AimdThrottledStore::new(store, config).unwrap();
808
809 let path = Path::from("test/file.txt");
810 let data = PutPayload::from_static(b"hello world");
811 throttled.put(&path, data).await.unwrap();
812
813 let result = throttled.get(&path).await.unwrap();
814 let bytes = result.bytes().await.unwrap();
815 assert_eq!(bytes.as_ref(), b"hello world");
816 }
817
818 #[tokio::test]
819 async fn test_rate_decreases_on_throttle() {
820 let store = Arc::new(InMemory::new());
821 let config = AimdThrottleConfig::default().with_aimd(
822 AimdConfig::default()
823 .with_initial_rate(100.0)
824 .with_decrease_factor(0.5)
825 .with_window_duration(std::time::Duration::from_millis(10)),
826 );
827 let throttled = AimdThrottledStore::new(store, config).unwrap();
828
829 let initial_rate = throttled.read.controller.current_rate();
830 assert_eq!(initial_rate, 100.0);
831
832 throttled
834 .read
835 .controller
836 .record_outcome(RequestOutcome::Throttled);
837
838 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
840 throttled
841 .read
842 .controller
843 .record_outcome(RequestOutcome::Success);
844
845 let new_rate = throttled.read.controller.current_rate();
846 assert!(
847 new_rate < initial_rate,
848 "Rate should decrease after throttle: {} < {}",
849 new_rate,
850 initial_rate
851 );
852 }
853
854 #[tokio::test]
855 async fn test_rate_recovers_on_success() {
856 let store = Arc::new(InMemory::new());
857 let config = AimdThrottleConfig::default().with_aimd(
858 AimdConfig::default()
859 .with_initial_rate(100.0)
860 .with_decrease_factor(0.5)
861 .with_additive_increment(10.0)
862 .with_window_duration(std::time::Duration::from_millis(10)),
863 );
864 let throttled = AimdThrottledStore::new(store, config).unwrap();
865
866 throttled
868 .read
869 .controller
870 .record_outcome(RequestOutcome::Throttled);
871 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
872 throttled
873 .read
874 .controller
875 .record_outcome(RequestOutcome::Success);
876 let decreased_rate = throttled.read.controller.current_rate();
877 assert_eq!(decreased_rate, 50.0);
878
879 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
881 throttled
882 .read
883 .controller
884 .record_outcome(RequestOutcome::Success);
885 let recovered_rate = throttled.read.controller.current_rate();
886 assert_eq!(recovered_rate, 60.0);
887 }
888
889 #[tokio::test]
890 async fn test_as_dyn_object_store() {
891 let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
892 let throttled: Arc<dyn ObjectStore> =
893 Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap());
894
895 let path = Path::from("test/data.bin");
896 let data = PutPayload::from_static(b"test data");
897 throttled.put(&path, data).await.unwrap();
898
899 let result = throttled.get(&path).await.unwrap();
900 let bytes = result.bytes().await.unwrap();
901 assert_eq!(bytes.as_ref(), b"test data");
902 }
903
904 #[tokio::test]
905 async fn test_token_bucket_delays_when_exhausted() {
906 let store = Arc::new(InMemory::new());
907 let config = AimdThrottleConfig::default()
909 .with_burst_capacity(1)
910 .with_aimd(AimdConfig::default().with_initial_rate(10.0));
911 let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap());
912
913 let path = Path::from("test/file.txt");
914 let data = PutPayload::from_static(b"data");
915 throttled.put(&path, data).await.unwrap();
916
917 let start = std::time::Instant::now();
920 let data2 = PutPayload::from_static(b"data2");
921 throttled.put(&path, data2).await.unwrap();
922 let elapsed = start.elapsed();
923
924 assert!(
925 elapsed >= std::time::Duration::from_millis(50),
926 "Expected delay for token refill, but elapsed was {:?}",
927 elapsed
928 );
929 }
930
931 #[tokio::test]
932 async fn test_list_observes_outcomes() {
933 let store = Arc::new(InMemory::new());
934 let config = AimdThrottleConfig::default();
935 let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
936
937 let path = Path::from("prefix/file.txt");
938 let data = PutPayload::from_static(b"data");
939 store.put(&path, data).await.unwrap();
940
941 let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
942 assert_eq!(items.len(), 1);
943 assert!(items[0].is_ok());
944 }
945
946 struct ThrottlingListMockStore {
950 inner: InMemory,
951 throttle_count: usize,
953 }
954
955 impl Display for ThrottlingListMockStore {
956 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
957 write!(f, "ThrottlingListMockStore")
958 }
959 }
960
961 impl Debug for ThrottlingListMockStore {
962 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
963 f.debug_struct("ThrottlingListMockStore").finish()
964 }
965 }
966
967 #[async_trait]
968 impl ObjectStore for ThrottlingListMockStore {
969 async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
970 self.inner.put(location, bytes).await
971 }
972 async fn put_opts(
973 &self,
974 location: &Path,
975 bytes: PutPayload,
976 opts: PutOptions,
977 ) -> OSResult<PutResult> {
978 self.inner.put_opts(location, bytes, opts).await
979 }
980 async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
981 self.inner.put_multipart(location).await
982 }
983 async fn put_multipart_opts(
984 &self,
985 location: &Path,
986 opts: PutMultipartOptions,
987 ) -> OSResult<Box<dyn MultipartUpload>> {
988 self.inner.put_multipart_opts(location, opts).await
989 }
990 async fn get(&self, location: &Path) -> OSResult<GetResult> {
991 self.inner.get(location).await
992 }
993 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
994 self.inner.get_opts(location, options).await
995 }
996 async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
997 self.inner.get_range(location, range).await
998 }
999 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1000 self.inner.get_ranges(location, ranges).await
1001 }
1002 async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
1003 self.inner.head(location).await
1004 }
1005 async fn delete(&self, location: &Path) -> OSResult<()> {
1006 self.inner.delete(location).await
1007 }
1008 fn delete_stream<'a>(
1009 &'a self,
1010 locations: BoxStream<'a, OSResult<Path>>,
1011 ) -> BoxStream<'a, OSResult<Path>> {
1012 self.inner.delete_stream(locations)
1013 }
1014 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1015 let n = self.throttle_count;
1016 let inner_stream = self.inner.list(prefix);
1017 let errors = futures::stream::iter((0..n).map(|_| {
1018 Err(object_store::Error::Generic {
1019 store: "ThrottlingListMock",
1020 source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s"
1021 .into(),
1022 })
1023 }));
1024 errors.chain(inner_stream).boxed()
1025 }
1026 fn list_with_offset(
1027 &self,
1028 prefix: Option<&Path>,
1029 offset: &Path,
1030 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1031 self.inner.list_with_offset(prefix, offset)
1032 }
1033 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1034 self.inner.list_with_delimiter(prefix).await
1035 }
1036 async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
1037 self.inner.copy(from, to).await
1038 }
1039 async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
1040 self.inner.rename(from, to).await
1041 }
1042 async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1043 self.inner.rename_if_not_exists(from, to).await
1044 }
1045 async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1046 self.inner.copy_if_not_exists(from, to).await
1047 }
1048 }
1049
1050 #[tokio::test]
1051 async fn test_list_stream_throttle_errors_decrease_rate() {
1052 let mock = Arc::new(ThrottlingListMockStore {
1053 inner: InMemory::new(),
1054 throttle_count: 5,
1055 });
1056
1057 mock.put(
1059 &Path::from("prefix/file.txt"),
1060 PutPayload::from_static(b"data"),
1061 )
1062 .await
1063 .unwrap();
1064
1065 let config = AimdThrottleConfig::default().with_list_aimd(
1066 AimdConfig::default()
1067 .with_initial_rate(100.0)
1068 .with_decrease_factor(0.5)
1069 .with_window_duration(std::time::Duration::from_millis(10)),
1070 );
1071 let throttled = AimdThrottledStore::new(mock as Arc<dyn ObjectStore>, config).unwrap();
1072
1073 let initial_rate = throttled.list.controller.current_rate();
1074 assert_eq!(initial_rate, 100.0);
1075
1076 let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
1077
1078 assert_eq!(items.len(), 6);
1080 assert!(items[0].is_err());
1081 assert!(items[5].is_ok());
1082
1083 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1085 throttled
1086 .list
1087 .controller
1088 .record_outcome(RequestOutcome::Success);
1089
1090 let new_rate = throttled.list.controller.current_rate();
1091 assert!(
1092 new_rate < initial_rate,
1093 "List rate should decrease after stream throttle errors: {} < {}",
1094 new_rate,
1095 initial_rate
1096 );
1097 }
1098
1099 #[tokio::test]
1100 async fn test_per_category_independence() {
1101 let store = Arc::new(InMemory::new());
1102 let config = AimdThrottleConfig::default().with_aimd(
1103 AimdConfig::default()
1104 .with_initial_rate(100.0)
1105 .with_decrease_factor(0.5)
1106 .with_window_duration(std::time::Duration::from_millis(10)),
1107 );
1108 let throttled = AimdThrottledStore::new(store, config).unwrap();
1109
1110 throttled
1112 .read
1113 .controller
1114 .record_outcome(RequestOutcome::Throttled);
1115 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1116 throttled
1117 .read
1118 .controller
1119 .record_outcome(RequestOutcome::Success);
1120
1121 let read_rate = throttled.read.controller.current_rate();
1122 let write_rate = throttled.write.controller.current_rate();
1123 let delete_rate = throttled.delete.controller.current_rate();
1124 let list_rate = throttled.list.controller.current_rate();
1125
1126 assert_eq!(read_rate, 50.0, "Read rate should have decreased");
1127 assert_eq!(write_rate, 100.0, "Write rate should be unaffected");
1128 assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected");
1129 assert_eq!(list_rate, 100.0, "List rate should be unaffected");
1130 }
1131
1132 #[tokio::test]
1133 async fn test_per_category_config() {
1134 let store = Arc::new(InMemory::new());
1135 let config = AimdThrottleConfig::default()
1136 .with_read_aimd(AimdConfig::default().with_initial_rate(200.0))
1137 .with_write_aimd(AimdConfig::default().with_initial_rate(100.0))
1138 .with_delete_aimd(AimdConfig::default().with_initial_rate(50.0))
1139 .with_list_aimd(AimdConfig::default().with_initial_rate(25.0));
1140 let throttled = AimdThrottledStore::new(store, config).unwrap();
1141
1142 assert_eq!(throttled.read.controller.current_rate(), 200.0);
1143 assert_eq!(throttled.write.controller.current_rate(), 100.0);
1144 assert_eq!(throttled.delete.controller.current_rate(), 50.0);
1145 assert_eq!(throttled.list.controller.current_rate(), 25.0);
1146 }
1147
1148 struct RateLimitingMockStore {
1152 inner: InMemory,
1153 timestamps: std::sync::Mutex<VecDeque<std::time::Instant>>,
1155 max_per_window: usize,
1157 window: std::time::Duration,
1159 success_count: AtomicU64,
1160 throttle_count: AtomicU64,
1161 }
1162
1163 impl RateLimitingMockStore {
1164 fn new(max_per_window: usize, window: std::time::Duration) -> Self {
1165 Self {
1166 inner: InMemory::new(),
1167 timestamps: std::sync::Mutex::new(VecDeque::new()),
1168 max_per_window,
1169 window,
1170 success_count: AtomicU64::new(0),
1171 throttle_count: AtomicU64::new(0),
1172 }
1173 }
1174
1175 fn check_rate(&self) -> bool {
1177 let mut ts = self.timestamps.lock().unwrap();
1178 let now = std::time::Instant::now();
1179 while let Some(&front) = ts.front() {
1180 if now.duration_since(front) > self.window {
1181 ts.pop_front();
1182 } else {
1183 break;
1184 }
1185 }
1186 if ts.len() >= self.max_per_window {
1187 self.throttle_count.fetch_add(1, Ordering::Relaxed);
1188 false
1189 } else {
1190 ts.push_back(now);
1191 self.success_count.fetch_add(1, Ordering::Relaxed);
1192 true
1193 }
1194 }
1195
1196 fn throttle_error() -> object_store::Error {
1197 object_store::Error::Generic {
1198 store: "RateLimitingMock",
1199 source: "request failed, after 10 retries, max_retries: 10, retry_timeout: 180s"
1200 .into(),
1201 }
1202 }
1203 }
1204
1205 impl Display for RateLimitingMockStore {
1206 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1207 write!(f, "RateLimitingMockStore")
1208 }
1209 }
1210
1211 impl Debug for RateLimitingMockStore {
1212 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1213 f.debug_struct("RateLimitingMockStore").finish()
1214 }
1215 }
1216
1217 #[async_trait]
1218 impl ObjectStore for RateLimitingMockStore {
1219 async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
1220 self.inner.put(location, bytes).await
1221 }
1222
1223 async fn put_opts(
1224 &self,
1225 location: &Path,
1226 bytes: PutPayload,
1227 opts: PutOptions,
1228 ) -> OSResult<PutResult> {
1229 self.inner.put_opts(location, bytes, opts).await
1230 }
1231
1232 async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
1233 self.inner.put_multipart(location).await
1234 }
1235
1236 async fn put_multipart_opts(
1237 &self,
1238 location: &Path,
1239 opts: PutMultipartOptions,
1240 ) -> OSResult<Box<dyn MultipartUpload>> {
1241 self.inner.put_multipart_opts(location, opts).await
1242 }
1243
1244 async fn get(&self, location: &Path) -> OSResult<GetResult> {
1245 if self.check_rate() {
1246 self.inner.get(location).await
1247 } else {
1248 Err(Self::throttle_error())
1249 }
1250 }
1251
1252 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1253 if self.check_rate() {
1254 self.inner.get_opts(location, options).await
1255 } else {
1256 Err(Self::throttle_error())
1257 }
1258 }
1259
1260 async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
1261 if self.check_rate() {
1262 self.inner.get_range(location, range).await
1263 } else {
1264 Err(Self::throttle_error())
1265 }
1266 }
1267
1268 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1269 if self.check_rate() {
1270 self.inner.get_ranges(location, ranges).await
1271 } else {
1272 Err(Self::throttle_error())
1273 }
1274 }
1275
1276 async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
1277 if self.check_rate() {
1278 self.inner.head(location).await
1279 } else {
1280 Err(Self::throttle_error())
1281 }
1282 }
1283
1284 async fn delete(&self, location: &Path) -> OSResult<()> {
1285 self.inner.delete(location).await
1286 }
1287
1288 fn delete_stream<'a>(
1289 &'a self,
1290 locations: BoxStream<'a, OSResult<Path>>,
1291 ) -> BoxStream<'a, OSResult<Path>> {
1292 self.inner.delete_stream(locations)
1293 }
1294
1295 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1296 self.inner.list(prefix)
1297 }
1298
1299 fn list_with_offset(
1300 &self,
1301 prefix: Option<&Path>,
1302 offset: &Path,
1303 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1304 self.inner.list_with_offset(prefix, offset)
1305 }
1306
1307 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1308 self.inner.list_with_delimiter(prefix).await
1309 }
1310
1311 async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
1312 self.inner.copy(from, to).await
1313 }
1314
1315 async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
1316 self.inner.rename(from, to).await
1317 }
1318
1319 async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1320 self.inner.rename_if_not_exists(from, to).await
1321 }
1322
1323 async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1324 self.inner.copy_if_not_exists(from, to).await
1325 }
1326 }
1327
1328 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1345 async fn test_aimd_throttle_under_concurrent_load() {
1346 let mock = Arc::new(RateLimitingMockStore::new(
1347 30,
1348 std::time::Duration::from_millis(100),
1349 ));
1350
1351 let path = Path::from("test/data.bin");
1353 mock.put(&path, PutPayload::from_static(b"test data"))
1354 .await
1355 .unwrap();
1356
1357 let aimd = AimdConfig::default()
1358 .with_initial_rate(100.0)
1359 .with_decrease_factor(0.5)
1360 .with_additive_increment(2.0)
1361 .with_window_duration(std::time::Duration::from_millis(100));
1362 let throttle_config = AimdThrottleConfig::default()
1363 .with_aimd(aimd)
1364 .with_burst_capacity(100);
1365
1366 let num_readers = 5;
1367 let test_duration = std::time::Duration::from_secs(2);
1368 let mut handles = Vec::new();
1369
1370 for _ in 0..num_readers {
1371 let store = Arc::new(
1372 AimdThrottledStore::new(
1373 mock.clone() as Arc<dyn ObjectStore>,
1374 throttle_config.clone(),
1375 )
1376 .unwrap(),
1377 );
1378 let p = path.clone();
1379 handles.push(tokio::spawn(async move {
1380 let deadline = std::time::Instant::now() + test_duration;
1381 let mut count = 0u64;
1382 while std::time::Instant::now() < deadline {
1383 let _ = store.head(&p).await;
1384 count += 1;
1385 }
1386 count
1387 }));
1388 }
1389
1390 let mut total_reader_requests = 0u64;
1391 for handle in handles {
1392 total_reader_requests += handle.await.unwrap();
1393 }
1394
1395 let successes = mock.success_count.load(Ordering::Relaxed);
1396 let throttled = mock.throttle_count.load(Ordering::Relaxed);
1397 let total_mock = successes + throttled;
1398
1399 assert!(
1402 total_mock >= total_reader_requests,
1403 "Mock-side count ({total_mock}) should be >= reader-side count ({total_reader_requests})"
1404 );
1405
1406 assert!(
1409 successes >= 300,
1410 "Expected >= 300 successful requests over 2s, got {successes}"
1411 );
1412 assert!(
1413 successes <= 900,
1414 "Expected <= 900 successful requests, got {successes}"
1415 );
1416
1417 assert!(throttled > 0, "Expected some throttled requests but got 0");
1419
1420 assert!(
1423 total_mock <= 5000,
1424 "AIMD should limit total requests, got {total_mock}"
1425 );
1426 }
1427
1428 struct RetryTestMockStore {
1432 inner: InMemory,
1433 errors_remaining: std::sync::Mutex<usize>,
1435 get_call_count: AtomicU64,
1437 }
1438
1439 impl RetryTestMockStore {
1440 fn new(errors_before_success: usize) -> Self {
1441 Self {
1442 inner: InMemory::new(),
1443 errors_remaining: std::sync::Mutex::new(errors_before_success),
1444 get_call_count: AtomicU64::new(0),
1445 }
1446 }
1447 }
1448
1449 impl Display for RetryTestMockStore {
1450 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1451 write!(f, "RetryTestMockStore")
1452 }
1453 }
1454
1455 impl Debug for RetryTestMockStore {
1456 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1457 f.debug_struct("RetryTestMockStore").finish()
1458 }
1459 }
1460
1461 #[async_trait]
1462 impl ObjectStore for RetryTestMockStore {
1463 async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
1464 self.inner.put(location, bytes).await
1465 }
1466 async fn put_opts(
1467 &self,
1468 location: &Path,
1469 bytes: PutPayload,
1470 opts: PutOptions,
1471 ) -> OSResult<PutResult> {
1472 self.inner.put_opts(location, bytes, opts).await
1473 }
1474 async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
1475 self.inner.put_multipart(location).await
1476 }
1477 async fn put_multipart_opts(
1478 &self,
1479 location: &Path,
1480 opts: PutMultipartOptions,
1481 ) -> OSResult<Box<dyn MultipartUpload>> {
1482 self.inner.put_multipart_opts(location, opts).await
1483 }
1484 async fn get(&self, location: &Path) -> OSResult<GetResult> {
1485 self.get_call_count.fetch_add(1, Ordering::Relaxed);
1486 let should_error = {
1487 let mut remaining = self.errors_remaining.lock().unwrap();
1488 if *remaining > 0 {
1489 *remaining -= 1;
1490 true
1491 } else {
1492 false
1493 }
1494 };
1495 if should_error {
1496 Err(object_store::Error::Generic {
1497 store: "RetryTestMock",
1498 source: "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s"
1499 .into(),
1500 })
1501 } else {
1502 self.inner.get(location).await
1503 }
1504 }
1505 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1506 self.inner.get_opts(location, options).await
1507 }
1508 async fn get_range(&self, location: &Path, range: Range<u64>) -> OSResult<Bytes> {
1509 self.inner.get_range(location, range).await
1510 }
1511 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1512 self.inner.get_ranges(location, ranges).await
1513 }
1514 async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
1515 self.inner.head(location).await
1516 }
1517 async fn delete(&self, location: &Path) -> OSResult<()> {
1518 self.inner.delete(location).await
1519 }
1520 fn delete_stream<'a>(
1521 &'a self,
1522 locations: BoxStream<'a, OSResult<Path>>,
1523 ) -> BoxStream<'a, OSResult<Path>> {
1524 self.inner.delete_stream(locations)
1525 }
1526 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1527 self.inner.list(prefix)
1528 }
1529 fn list_with_offset(
1530 &self,
1531 prefix: Option<&Path>,
1532 offset: &Path,
1533 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1534 self.inner.list_with_offset(prefix, offset)
1535 }
1536 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1537 self.inner.list_with_delimiter(prefix).await
1538 }
1539 async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
1540 self.inner.copy(from, to).await
1541 }
1542 async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
1543 self.inner.rename(from, to).await
1544 }
1545 async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1546 self.inner.rename_if_not_exists(from, to).await
1547 }
1548 async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
1549 self.inner.copy_if_not_exists(from, to).await
1550 }
1551 }
1552
1553 #[tokio::test]
1554 async fn test_throttled_retries_on_throttle_error_then_succeeds() {
1555 let mock = Arc::new(RetryTestMockStore::new(2));
1557 let path = Path::from("test/retry.txt");
1558 mock.put(&path, PutPayload::from_static(b"retry data"))
1559 .await
1560 .unwrap();
1561
1562 let config = AimdThrottleConfig::default();
1563 let throttled =
1564 AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1565
1566 let result = throttled.get(&path).await;
1567 assert!(result.is_ok(), "Expected success after retries");
1568
1569 let bytes = result.unwrap().bytes().await.unwrap();
1570 assert_eq!(bytes.as_ref(), b"retry data");
1571
1572 assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 3);
1574 }
1575
1576 #[tokio::test]
1577 async fn test_throttled_fails_after_max_retries_exceeded() {
1578 let mock = Arc::new(RetryTestMockStore::new(10));
1581 let path = Path::from("test/fail.txt");
1582 mock.put(&path, PutPayload::from_static(b"fail data"))
1583 .await
1584 .unwrap();
1585
1586 let config = AimdThrottleConfig::default();
1587 let throttled =
1588 AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1589
1590 let result = throttled.get(&path).await;
1591 assert!(result.is_err(), "Expected error after max retries");
1592 assert!(is_throttle_error(&result.unwrap_err()));
1593
1594 assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4);
1596 }
1597
1598 #[tokio::test]
1599 async fn test_throttled_multipart_reorders_parts() {
1600 let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
1601 let config = AimdThrottleConfig::default();
1602 let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
1603
1604 let path = Path::from("test/multipart_ordering.bin");
1605 let mut upload = throttled.put_multipart(&path).await.unwrap();
1606
1607 let fut_a = upload.put_part(PutPayload::from_static(b"AAAA"));
1609 let fut_b = upload.put_part(PutPayload::from_static(b"BBBB"));
1610
1611 fut_b.await.unwrap();
1614 fut_a.await.unwrap();
1615
1616 upload.complete().await.unwrap();
1617
1618 let result = store.get(&path).await.unwrap();
1619 let bytes = result.bytes().await.unwrap();
1620
1621 assert_eq!(
1622 bytes.as_ref(),
1623 b"AAAABBBB",
1624 "Parts were reordered! Got {:?} instead of AAAABBBB.",
1625 std::str::from_utf8(&bytes).unwrap_or("<non-utf8>"),
1626 );
1627 }
1628}