1use std::collections::HashMap;
24use std::fmt::{Debug, Display, Formatter};
25use std::ops::Range;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use bytes::Bytes;
30use futures::StreamExt;
31use futures::stream::BoxStream;
32use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome};
33use lance_core::utils::tracing::TRACE_OBJECT_STORE_THROTTLE;
34#[cfg(test)]
35use object_store::ObjectStoreExt;
36use object_store::path::Path;
37use object_store::{
38 CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
39 PutMultipartOptions, PutOptions, PutPayload, PutResult, RenameOptions, Result as OSResult,
40 UploadPart,
41};
42use rand::Rng;
43use tokio::sync::Mutex;
44use tracing::{debug, warn};
45
46pub fn is_throttle_error(err: &object_store::Error) -> bool {
64 if let object_store::Error::Generic { source, .. } = err {
66 source.to_string().contains("retries, max_retries")
67 } else {
68 false
69 }
70}
71
72#[derive(Debug, Clone)]
79pub struct AimdThrottleConfig {
80 pub read: AimdConfig,
82 pub write: AimdConfig,
84 pub delete: AimdConfig,
86 pub list: AimdConfig,
88 pub burst_capacity: u32,
90 pub max_retries: usize,
92 pub min_backoff_ms: u64,
94 pub max_backoff_ms: u64,
96}
97
98impl Default for AimdThrottleConfig {
99 fn default() -> Self {
100 let aimd = AimdConfig::default();
101 Self {
102 read: aimd.clone(),
103 write: aimd.clone(),
104 delete: aimd.clone(),
105 list: aimd,
106 burst_capacity: 100,
107 max_retries: 3,
108 min_backoff_ms: 100,
109 max_backoff_ms: 300,
110 }
111 }
112}
113
114impl AimdThrottleConfig {
115 pub fn with_aimd(self, aimd: AimdConfig) -> Self {
117 Self {
118 read: aimd.clone(),
119 write: aimd.clone(),
120 delete: aimd.clone(),
121 list: aimd,
122 ..self
123 }
124 }
125
126 pub fn with_read_aimd(self, aimd: AimdConfig) -> Self {
128 Self { read: aimd, ..self }
129 }
130
131 pub fn with_write_aimd(self, aimd: AimdConfig) -> Self {
133 Self {
134 write: aimd,
135 ..self
136 }
137 }
138
139 pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self {
141 Self {
142 delete: aimd,
143 ..self
144 }
145 }
146
147 pub fn with_list_aimd(self, aimd: AimdConfig) -> Self {
149 Self { list: aimd, ..self }
150 }
151
152 pub fn is_disabled(&self) -> bool {
154 self.max_retries == 0
155 }
156
157 pub fn with_burst_capacity(self, burst_capacity: u32) -> Self {
158 Self {
159 burst_capacity,
160 ..self
161 }
162 }
163
164 pub fn from_storage_options(
182 storage_options: Option<&HashMap<String, String>>,
183 ) -> lance_core::Result<Self> {
184 fn resolve_f64(
185 key: &str,
186 storage_options: Option<&HashMap<String, String>>,
187 default: f64,
188 ) -> lance_core::Result<f64> {
189 let env_key = key.to_ascii_uppercase();
190 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
191 val.parse::<f64>().map_err(|_| {
192 lance_core::Error::invalid_input(format!(
193 "Invalid value for storage option '{key}': '{val}'"
194 ))
195 })
196 } else if let Ok(val) = std::env::var(&env_key) {
197 val.parse::<f64>().map_err(|_| {
198 lance_core::Error::invalid_input(format!(
199 "Invalid value for env var '{env_key}': '{val}'"
200 ))
201 })
202 } else {
203 Ok(default)
204 }
205 }
206
207 fn resolve_u32(
208 key: &str,
209 storage_options: Option<&HashMap<String, String>>,
210 default: u32,
211 ) -> lance_core::Result<u32> {
212 let env_key = key.to_ascii_uppercase();
213 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
214 val.parse::<u32>().map_err(|_| {
215 lance_core::Error::invalid_input(format!(
216 "Invalid value for storage option '{key}': '{val}'"
217 ))
218 })
219 } else if let Ok(val) = std::env::var(&env_key) {
220 val.parse::<u32>().map_err(|_| {
221 lance_core::Error::invalid_input(format!(
222 "Invalid value for env var '{env_key}': '{val}'"
223 ))
224 })
225 } else {
226 Ok(default)
227 }
228 }
229
230 fn resolve_usize(
231 key: &str,
232 storage_options: Option<&HashMap<String, String>>,
233 default: usize,
234 ) -> lance_core::Result<usize> {
235 let env_key = key.to_ascii_uppercase();
236 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
237 val.parse::<usize>().map_err(|_| {
238 lance_core::Error::invalid_input(format!(
239 "Invalid value for storage option '{key}': '{val}'"
240 ))
241 })
242 } else if let Ok(val) = std::env::var(&env_key) {
243 val.parse::<usize>().map_err(|_| {
244 lance_core::Error::invalid_input(format!(
245 "Invalid value for env var '{env_key}': '{val}'"
246 ))
247 })
248 } else {
249 Ok(default)
250 }
251 }
252
253 fn resolve_u64(
254 key: &str,
255 storage_options: Option<&HashMap<String, String>>,
256 default: u64,
257 ) -> lance_core::Result<u64> {
258 let env_key = key.to_ascii_uppercase();
259 if let Some(val) = storage_options.and_then(|opts| opts.get(key)) {
260 val.parse::<u64>().map_err(|_| {
261 lance_core::Error::invalid_input(format!(
262 "Invalid value for storage option '{key}': '{val}'"
263 ))
264 })
265 } else if let Ok(val) = std::env::var(&env_key) {
266 val.parse::<u64>().map_err(|_| {
267 lance_core::Error::invalid_input(format!(
268 "Invalid value for env var '{env_key}': '{val}'"
269 ))
270 })
271 } else {
272 Ok(default)
273 }
274 }
275
276 let initial_rate = resolve_f64("lance_aimd_initial_rate", storage_options, 2000.0)?;
277 let min_rate = resolve_f64("lance_aimd_min_rate", storage_options, 1.0)?;
278 let max_rate = resolve_f64("lance_aimd_max_rate", storage_options, 5000.0)?;
279 let decrease_factor = resolve_f64("lance_aimd_decrease_factor", storage_options, 0.5)?;
280 let additive_increment =
281 resolve_f64("lance_aimd_additive_increment", storage_options, 300.0)?;
282 let burst_capacity = resolve_u32("lance_aimd_burst_capacity", storage_options, 100)?;
283 let max_retries = resolve_usize("lance_aimd_max_retries", storage_options, 3)?;
284 let min_backoff_ms = resolve_u64("lance_aimd_min_backoff_ms", storage_options, 100)?;
285 let max_backoff_ms = resolve_u64("lance_aimd_max_backoff_ms", storage_options, 300)?;
286
287 let aimd = AimdConfig::default()
288 .with_initial_rate(initial_rate)
289 .with_min_rate(min_rate)
290 .with_max_rate(max_rate)
291 .with_decrease_factor(decrease_factor)
292 .with_additive_increment(additive_increment);
293
294 Ok(Self {
295 max_retries,
296 min_backoff_ms,
297 max_backoff_ms,
298 ..Self::default()
299 .with_aimd(aimd)
300 .with_burst_capacity(burst_capacity)
301 })
302 }
303}
304
305struct TokenBucketState {
306 tokens: f64,
307 last_refill: std::time::Instant,
308 rate: f64,
309}
310
311struct OperationThrottle {
313 controller: AimdController,
314 bucket: Mutex<TokenBucketState>,
315 burst_capacity: f64,
316 max_retries: usize,
317 min_backoff_ms: u64,
318 max_backoff_ms: u64,
319}
320
321impl OperationThrottle {
322 fn new(
323 aimd_config: AimdConfig,
324 burst_capacity: f64,
325 max_retries: usize,
326 min_backoff_ms: u64,
327 max_backoff_ms: u64,
328 ) -> lance_core::Result<Self> {
329 let initial_rate = aimd_config.initial_rate;
330 let controller = AimdController::new(aimd_config)?;
331 Ok(Self {
332 controller,
333 bucket: Mutex::new(TokenBucketState {
334 tokens: burst_capacity,
335 last_refill: std::time::Instant::now(),
336 rate: initial_rate,
337 }),
338 burst_capacity,
339 max_retries,
340 min_backoff_ms,
341 max_backoff_ms,
342 })
343 }
344
345 async fn acquire_token(&self) {
351 let sleep_duration = {
352 let mut bucket = self.bucket.lock().await;
353 let now = std::time::Instant::now();
354 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
355 bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity);
356 bucket.last_refill = now;
357
358 bucket.tokens -= 1.0;
360
361 if bucket.tokens >= 0.0 {
362 return;
364 }
365
366 std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate)
368 };
369
370 tokio::time::sleep(sleep_duration).await;
371 }
372
373 async fn update_bucket_rate(&self, new_rate: f64) {
375 let mut bucket = self.bucket.lock().await;
376 bucket.rate = new_rate;
377 }
378
379 fn observe_outcome<T>(&self, result: &OSResult<T>) {
384 let outcome = match result {
385 Ok(_) => RequestOutcome::Success,
386 Err(err) if is_throttle_error(err) => {
387 debug!(
388 target: TRACE_OBJECT_STORE_THROTTLE,
389 error = %err,
390 "Throttle error detected in stream"
391 );
392 RequestOutcome::Throttled
393 }
394 Err(_) => RequestOutcome::Success,
395 };
396 let prev_rate = self.controller.current_rate();
397 let new_rate = self.controller.record_outcome(outcome);
398 if new_rate < prev_rate
399 && let Err(err) = result.as_ref()
400 {
401 warn!(
402 target: TRACE_OBJECT_STORE_THROTTLE,
403 previous_rate = format!("{prev_rate:.1}"),
404 new_rate = format!("{new_rate:.1}"),
405 error = %err,
406 "AIMD throttle: rate reduced due to throttle errors"
407 );
408 }
409 if let Ok(mut bucket) = self.bucket.try_lock() {
410 bucket.rate = new_rate;
411 }
412 }
413
414 async fn throttled<T, F, Fut>(&self, f: F) -> OSResult<T>
418 where
419 F: Fn() -> Fut,
420 Fut: std::future::Future<Output = OSResult<T>>,
421 {
422 for attempt in 0..=self.max_retries {
423 self.acquire_token().await;
424 let result = f().await;
425 let outcome = match &result {
426 Ok(_) => RequestOutcome::Success,
427 Err(err) if is_throttle_error(err) => {
428 debug!(
429 target: TRACE_OBJECT_STORE_THROTTLE,
430 error = %err,
431 "Throttle error detected"
432 );
433 RequestOutcome::Throttled
434 }
435 Err(_) => RequestOutcome::Success, };
437 let prev_rate = self.controller.current_rate();
438 let new_rate = self.controller.record_outcome(outcome);
439 if new_rate < prev_rate
440 && let Err(err) = result.as_ref()
441 {
442 warn!(
443 target: TRACE_OBJECT_STORE_THROTTLE,
444 previous_rate = format!("{prev_rate:.1}"),
445 new_rate = format!("{new_rate:.1}"),
446 error = %err,
447 "AIMD throttle: rate reduced due to throttle errors"
448 );
449 }
450 self.update_bucket_rate(new_rate).await;
451
452 match &result {
453 Err(err) if is_throttle_error(err) && attempt < self.max_retries => {
454 let backoff_ms =
455 rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms);
456 debug!(
457 target: TRACE_OBJECT_STORE_THROTTLE,
458 attempt = attempt + 1,
459 max_retries = self.max_retries,
460 backoff_ms,
461 error = %err,
462 "Retrying after throttle error"
463 );
464 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
465 continue;
466 }
467 _ => return result,
468 }
469 }
470 unreachable!()
471 }
472}
473
474impl Debug for OperationThrottle {
475 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
476 f.debug_struct("OperationThrottle")
477 .field("controller", &self.controller)
478 .field("burst_capacity", &self.burst_capacity)
479 .finish()
480 }
481}
482
483struct ThrottledMultipartUpload {
487 target: Box<dyn MultipartUpload>,
488 write: Arc<OperationThrottle>,
489}
490
491impl Debug for ThrottledMultipartUpload {
492 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
493 f.debug_struct("ThrottledMultipartUpload").finish()
494 }
495}
496
497#[async_trait]
498impl MultipartUpload for ThrottledMultipartUpload {
499 fn put_part(&mut self, data: PutPayload) -> UploadPart {
500 let write = Arc::clone(&self.write);
501 let fut = self.target.put_part(data);
504 Box::pin(async move {
505 write.acquire_token().await;
506 let result = fut.await;
507 write.observe_outcome(&result);
508 result
509 })
510 }
511
512 async fn complete(&mut self) -> OSResult<PutResult> {
513 let target = &mut self.target;
514 for attempt in 0..=self.write.max_retries {
515 self.write.acquire_token().await;
516 let result = target.complete().await;
517 self.write.observe_outcome(&result);
518
519 match &result {
520 Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
521 let backoff_ms = rand::rng()
522 .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
523 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
524 continue;
525 }
526 _ => return result,
527 }
528 }
529 unreachable!()
530 }
531
532 async fn abort(&mut self) -> OSResult<()> {
533 let target = &mut self.target;
534 for attempt in 0..=self.write.max_retries {
535 self.write.acquire_token().await;
536 let result = target.abort().await;
537 self.write.observe_outcome(&result);
538
539 match &result {
540 Err(err) if is_throttle_error(err) && attempt < self.write.max_retries => {
541 let backoff_ms = rand::rng()
542 .random_range(self.write.min_backoff_ms..=self.write.max_backoff_ms);
543 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
544 continue;
545 }
546 _ => return result,
547 }
548 }
549 unreachable!()
550 }
551}
552
553pub struct AimdThrottledStore {
569 target: Arc<dyn ObjectStore>,
570 read: Arc<OperationThrottle>,
571 write: Arc<OperationThrottle>,
572 delete: Arc<OperationThrottle>,
573 list: Arc<OperationThrottle>,
574}
575
576impl Debug for AimdThrottledStore {
577 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
578 f.debug_struct("AimdThrottledStore")
579 .field("target", &self.target)
580 .field("read", &self.read)
581 .field("write", &self.write)
582 .field("delete", &self.delete)
583 .field("list", &self.list)
584 .finish()
585 }
586}
587
588impl Display for AimdThrottledStore {
589 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
590 write!(f, "AimdThrottledStore({})", self.target)
591 }
592}
593
594impl AimdThrottledStore {
595 pub fn new(
596 target: Arc<dyn ObjectStore>,
597 config: AimdThrottleConfig,
598 ) -> lance_core::Result<Self> {
599 let burst = config.burst_capacity as f64;
600 let max_retries = config.max_retries;
601 let min_backoff_ms = config.min_backoff_ms;
602 let max_backoff_ms = config.max_backoff_ms;
603 Ok(Self {
604 target,
605 read: Arc::new(OperationThrottle::new(
606 config.read,
607 burst,
608 max_retries,
609 min_backoff_ms,
610 max_backoff_ms,
611 )?),
612 write: Arc::new(OperationThrottle::new(
613 config.write,
614 burst,
615 max_retries,
616 min_backoff_ms,
617 max_backoff_ms,
618 )?),
619 delete: Arc::new(OperationThrottle::new(
620 config.delete,
621 burst,
622 max_retries,
623 min_backoff_ms,
624 max_backoff_ms,
625 )?),
626 list: Arc::new(OperationThrottle::new(
627 config.list,
628 burst,
629 max_retries,
630 min_backoff_ms,
631 max_backoff_ms,
632 )?),
633 })
634 }
635}
636
637#[async_trait]
638#[deny(clippy::missing_trait_methods)]
639impl ObjectStore for AimdThrottledStore {
640 async fn put_opts(
641 &self,
642 location: &Path,
643 bytes: PutPayload,
644 opts: PutOptions,
645 ) -> OSResult<PutResult> {
646 self.write
647 .throttled(|| self.target.put_opts(location, bytes.clone(), opts.clone()))
648 .await
649 }
650
651 async fn put_multipart_opts(
652 &self,
653 location: &Path,
654 opts: PutMultipartOptions,
655 ) -> OSResult<Box<dyn MultipartUpload>> {
656 let target = self
657 .write
658 .throttled(|| self.target.put_multipart_opts(location, opts.clone()))
659 .await?;
660 Ok(Box::new(ThrottledMultipartUpload {
661 target,
662 write: Arc::clone(&self.write),
663 }))
664 }
665
666 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
667 self.read
668 .throttled(|| self.target.get_opts(location, options.clone()))
669 .await
670 }
671
672 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
673 self.read
674 .throttled(|| self.target.get_ranges(location, ranges))
675 .await
676 }
677
678 fn delete_stream(
679 &self,
680 locations: BoxStream<'static, OSResult<Path>>,
681 ) -> BoxStream<'static, OSResult<Path>> {
682 let delete = Arc::clone(&self.delete);
683 self.target
684 .delete_stream(locations)
685 .map(move |item| {
686 delete.observe_outcome(&item);
687 item
688 })
689 .boxed()
690 }
691
692 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
693 let throttle = Arc::clone(&self.list);
694 self.target
695 .list(prefix)
696 .map(move |item| {
697 throttle.observe_outcome(&item);
698 item
699 })
700 .boxed()
701 }
702
703 fn list_with_offset(
704 &self,
705 prefix: Option<&Path>,
706 offset: &Path,
707 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
708 let throttle = Arc::clone(&self.list);
709 self.target
710 .list_with_offset(prefix, offset)
711 .map(move |item| {
712 throttle.observe_outcome(&item);
713 item
714 })
715 .boxed()
716 }
717
718 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
719 self.list
720 .throttled(|| self.target.list_with_delimiter(prefix))
721 .await
722 }
723
724 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
725 self.write
726 .throttled(|| self.target.copy_opts(from, to, opts.clone()))
727 .await
728 }
729
730 async fn rename_opts(&self, from: &Path, to: &Path, opts: RenameOptions) -> OSResult<()> {
731 self.write
732 .throttled(|| self.target.rename_opts(from, to, opts.clone()))
733 .await
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740 use object_store::memory::InMemory;
741 use rstest::rstest;
742 use std::collections::VecDeque;
743 use std::sync::atomic::{AtomicU64, Ordering};
744
745 const THROTTLE_ERROR_RESPONSE: &str = "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s - Server returned non-2xx status code: 503: x-ms-request-id: azure-request-id";
746
747 fn make_generic_error(msg: &str) -> object_store::Error {
748 object_store::Error::Generic {
749 store: "test",
750 source: msg.into(),
751 }
752 }
753
754 #[rstest]
755 #[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)]
756 #[case::retries_in_message(
757 "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s",
758 true
759 )]
760 #[case::not_found("Object not found", false)]
761 #[case::permission_denied("Access denied", false)]
762 #[case::timeout("Connection timed out", false)]
763 #[case::http_429_without_retries("HTTP 429 Too Many Requests", false)]
764 #[case::slowdown_without_retries("SlowDown: Please reduce your request rate", false)]
765 fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) {
766 let err = make_generic_error(msg);
767 assert_eq!(
768 is_throttle_error(&err),
769 expected,
770 "is_throttle_error for '{}' should be {}",
771 msg,
772 expected
773 );
774 }
775
776 #[test]
777 fn test_non_generic_errors_are_not_throttle() {
778 let err = object_store::Error::NotFound {
779 path: "test".to_string(),
780 source: "not found".into(),
781 };
782 assert!(!is_throttle_error(&err));
783 }
784
785 #[tokio::test]
786 async fn test_basic_put_get_through_wrapper() {
787 let store = Arc::new(InMemory::new());
788 let config = AimdThrottleConfig::default();
789 let throttled = AimdThrottledStore::new(store, config).unwrap();
790
791 let path = Path::from("test/file.txt");
792 let data = PutPayload::from_static(b"hello world");
793 throttled.put(&path, data).await.unwrap();
794
795 let result = throttled.get(&path).await.unwrap();
796 let bytes = result.bytes().await.unwrap();
797 assert_eq!(bytes.as_ref(), b"hello world");
798 }
799
800 #[tokio::test]
801 async fn test_rate_decreases_on_throttle() {
802 let store = Arc::new(InMemory::new());
803 let config = AimdThrottleConfig::default().with_aimd(
804 AimdConfig::default()
805 .with_initial_rate(100.0)
806 .with_decrease_factor(0.5)
807 .with_window_duration(std::time::Duration::from_millis(10)),
808 );
809 let throttled = AimdThrottledStore::new(store, config).unwrap();
810
811 let initial_rate = throttled.read.controller.current_rate();
812 assert_eq!(initial_rate, 100.0);
813
814 throttled
816 .read
817 .controller
818 .record_outcome(RequestOutcome::Throttled);
819
820 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
822 throttled
823 .read
824 .controller
825 .record_outcome(RequestOutcome::Success);
826
827 let new_rate = throttled.read.controller.current_rate();
828 assert!(
829 new_rate < initial_rate,
830 "Rate should decrease after throttle: {} < {}",
831 new_rate,
832 initial_rate
833 );
834 }
835
836 #[tokio::test]
837 async fn test_rate_recovers_on_success() {
838 let store = Arc::new(InMemory::new());
839 let config = AimdThrottleConfig::default().with_aimd(
840 AimdConfig::default()
841 .with_initial_rate(100.0)
842 .with_decrease_factor(0.5)
843 .with_additive_increment(10.0)
844 .with_window_duration(std::time::Duration::from_millis(10)),
845 );
846 let throttled = AimdThrottledStore::new(store, config).unwrap();
847
848 throttled
850 .read
851 .controller
852 .record_outcome(RequestOutcome::Throttled);
853 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
854 throttled
855 .read
856 .controller
857 .record_outcome(RequestOutcome::Success);
858 let decreased_rate = throttled.read.controller.current_rate();
859 assert_eq!(decreased_rate, 50.0);
860
861 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
863 throttled
864 .read
865 .controller
866 .record_outcome(RequestOutcome::Success);
867 let recovered_rate = throttled.read.controller.current_rate();
868 assert_eq!(recovered_rate, 60.0);
869 }
870
871 #[tokio::test]
872 async fn test_as_dyn_object_store() {
873 let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
874 let throttled: Arc<dyn ObjectStore> =
875 Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap());
876
877 let path = Path::from("test/data.bin");
878 let data = PutPayload::from_static(b"test data");
879 throttled.put(&path, data).await.unwrap();
880
881 let result = throttled.get(&path).await.unwrap();
882 let bytes = result.bytes().await.unwrap();
883 assert_eq!(bytes.as_ref(), b"test data");
884 }
885
886 #[tokio::test]
887 async fn test_token_bucket_delays_when_exhausted() {
888 let store = Arc::new(InMemory::new());
889 let config = AimdThrottleConfig::default()
891 .with_burst_capacity(1)
892 .with_aimd(AimdConfig::default().with_initial_rate(10.0));
893 let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap());
894
895 let path = Path::from("test/file.txt");
896 let data = PutPayload::from_static(b"data");
897 throttled.put(&path, data).await.unwrap();
898
899 let start = std::time::Instant::now();
902 let data2 = PutPayload::from_static(b"data2");
903 throttled.put(&path, data2).await.unwrap();
904 let elapsed = start.elapsed();
905
906 assert!(
907 elapsed >= std::time::Duration::from_millis(50),
908 "Expected delay for token refill, but elapsed was {:?}",
909 elapsed
910 );
911 }
912
913 #[tokio::test]
914 async fn test_list_observes_outcomes() {
915 let store = Arc::new(InMemory::new());
916 let config = AimdThrottleConfig::default();
917 let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
918
919 let path = Path::from("prefix/file.txt");
920 let data = PutPayload::from_static(b"data");
921 store.put(&path, data).await.unwrap();
922
923 let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
924 assert_eq!(items.len(), 1);
925 assert!(items[0].is_ok());
926 }
927
928 struct ThrottlingListMockStore {
932 inner: InMemory,
933 throttle_count: usize,
935 }
936
937 impl Display for ThrottlingListMockStore {
938 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
939 write!(f, "ThrottlingListMockStore")
940 }
941 }
942
943 impl Debug for ThrottlingListMockStore {
944 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
945 f.debug_struct("ThrottlingListMockStore").finish()
946 }
947 }
948
949 #[async_trait]
950 impl ObjectStore for ThrottlingListMockStore {
951 async fn put_opts(
952 &self,
953 location: &Path,
954 bytes: PutPayload,
955 opts: PutOptions,
956 ) -> OSResult<PutResult> {
957 self.inner.put_opts(location, bytes, opts).await
958 }
959 async fn put_multipart_opts(
960 &self,
961 location: &Path,
962 opts: PutMultipartOptions,
963 ) -> OSResult<Box<dyn MultipartUpload>> {
964 self.inner.put_multipart_opts(location, opts).await
965 }
966 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
967 self.inner.get_opts(location, options).await
968 }
969 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
970 self.inner.get_ranges(location, ranges).await
971 }
972 fn delete_stream(
973 &self,
974 locations: BoxStream<'static, OSResult<Path>>,
975 ) -> BoxStream<'static, OSResult<Path>> {
976 self.inner.delete_stream(locations)
977 }
978 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
979 let n = self.throttle_count;
980 let inner_stream = self.inner.list(prefix);
981 let errors = futures::stream::iter((0..n).map(|_| {
982 Err(object_store::Error::Generic {
983 store: "ThrottlingListMock",
984 source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s"
985 .into(),
986 })
987 }));
988 errors.chain(inner_stream).boxed()
989 }
990 fn list_with_offset(
991 &self,
992 prefix: Option<&Path>,
993 offset: &Path,
994 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
995 self.inner.list_with_offset(prefix, offset)
996 }
997 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
998 self.inner.list_with_delimiter(prefix).await
999 }
1000 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1001 self.inner.copy_opts(from, to, opts).await
1002 }
1003 }
1004
1005 #[tokio::test]
1006 async fn test_list_stream_throttle_errors_decrease_rate() {
1007 let mock = Arc::new(ThrottlingListMockStore {
1008 inner: InMemory::new(),
1009 throttle_count: 5,
1010 });
1011
1012 mock.put(
1014 &Path::from("prefix/file.txt"),
1015 PutPayload::from_static(b"data"),
1016 )
1017 .await
1018 .unwrap();
1019
1020 let config = AimdThrottleConfig::default().with_list_aimd(
1021 AimdConfig::default()
1022 .with_initial_rate(100.0)
1023 .with_decrease_factor(0.5)
1024 .with_window_duration(std::time::Duration::from_millis(10)),
1025 );
1026 let throttled = AimdThrottledStore::new(mock as Arc<dyn ObjectStore>, config).unwrap();
1027
1028 let initial_rate = throttled.list.controller.current_rate();
1029 assert_eq!(initial_rate, 100.0);
1030
1031 let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await;
1032
1033 assert_eq!(items.len(), 6);
1035 assert!(items[0].is_err());
1036 assert!(items[5].is_ok());
1037
1038 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1040 throttled
1041 .list
1042 .controller
1043 .record_outcome(RequestOutcome::Success);
1044
1045 let new_rate = throttled.list.controller.current_rate();
1046 assert!(
1047 new_rate < initial_rate,
1048 "List rate should decrease after stream throttle errors: {} < {}",
1049 new_rate,
1050 initial_rate
1051 );
1052 }
1053
1054 #[tokio::test]
1055 async fn test_per_category_independence() {
1056 let store = Arc::new(InMemory::new());
1057 let config = AimdThrottleConfig::default().with_aimd(
1058 AimdConfig::default()
1059 .with_initial_rate(100.0)
1060 .with_decrease_factor(0.5)
1061 .with_window_duration(std::time::Duration::from_millis(10)),
1062 );
1063 let throttled = AimdThrottledStore::new(store, config).unwrap();
1064
1065 throttled
1067 .read
1068 .controller
1069 .record_outcome(RequestOutcome::Throttled);
1070 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1071 throttled
1072 .read
1073 .controller
1074 .record_outcome(RequestOutcome::Success);
1075
1076 let read_rate = throttled.read.controller.current_rate();
1077 let write_rate = throttled.write.controller.current_rate();
1078 let delete_rate = throttled.delete.controller.current_rate();
1079 let list_rate = throttled.list.controller.current_rate();
1080
1081 assert_eq!(read_rate, 50.0, "Read rate should have decreased");
1082 assert_eq!(write_rate, 100.0, "Write rate should be unaffected");
1083 assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected");
1084 assert_eq!(list_rate, 100.0, "List rate should be unaffected");
1085 }
1086
1087 #[tokio::test]
1088 async fn test_per_category_config() {
1089 let store = Arc::new(InMemory::new());
1090 let config = AimdThrottleConfig::default()
1091 .with_read_aimd(AimdConfig::default().with_initial_rate(200.0))
1092 .with_write_aimd(AimdConfig::default().with_initial_rate(100.0))
1093 .with_delete_aimd(AimdConfig::default().with_initial_rate(50.0))
1094 .with_list_aimd(AimdConfig::default().with_initial_rate(25.0));
1095 let throttled = AimdThrottledStore::new(store, config).unwrap();
1096
1097 assert_eq!(throttled.read.controller.current_rate(), 200.0);
1098 assert_eq!(throttled.write.controller.current_rate(), 100.0);
1099 assert_eq!(throttled.delete.controller.current_rate(), 50.0);
1100 assert_eq!(throttled.list.controller.current_rate(), 25.0);
1101 }
1102
1103 struct RateLimitingMockStore {
1107 inner: InMemory,
1108 timestamps: std::sync::Mutex<VecDeque<std::time::Instant>>,
1110 max_per_window: usize,
1112 window: std::time::Duration,
1114 success_count: AtomicU64,
1115 throttle_count: AtomicU64,
1116 }
1117
1118 impl RateLimitingMockStore {
1119 fn new(max_per_window: usize, window: std::time::Duration) -> Self {
1120 Self {
1121 inner: InMemory::new(),
1122 timestamps: std::sync::Mutex::new(VecDeque::new()),
1123 max_per_window,
1124 window,
1125 success_count: AtomicU64::new(0),
1126 throttle_count: AtomicU64::new(0),
1127 }
1128 }
1129
1130 fn check_rate(&self) -> bool {
1132 let mut ts = self.timestamps.lock().unwrap();
1133 let now = std::time::Instant::now();
1134 while let Some(&front) = ts.front() {
1135 if now.duration_since(front) > self.window {
1136 ts.pop_front();
1137 } else {
1138 break;
1139 }
1140 }
1141 if ts.len() >= self.max_per_window {
1142 self.throttle_count.fetch_add(1, Ordering::Relaxed);
1143 false
1144 } else {
1145 ts.push_back(now);
1146 self.success_count.fetch_add(1, Ordering::Relaxed);
1147 true
1148 }
1149 }
1150
1151 fn throttle_error() -> object_store::Error {
1152 object_store::Error::Generic {
1153 store: "RateLimitingMock",
1154 source: THROTTLE_ERROR_RESPONSE.into(),
1155 }
1156 }
1157 }
1158
1159 impl Display for RateLimitingMockStore {
1160 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1161 write!(f, "RateLimitingMockStore")
1162 }
1163 }
1164
1165 impl Debug for RateLimitingMockStore {
1166 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1167 f.debug_struct("RateLimitingMockStore").finish()
1168 }
1169 }
1170
1171 #[async_trait]
1172 impl ObjectStore for RateLimitingMockStore {
1173 async fn put_opts(
1174 &self,
1175 location: &Path,
1176 bytes: PutPayload,
1177 opts: PutOptions,
1178 ) -> OSResult<PutResult> {
1179 self.inner.put_opts(location, bytes, opts).await
1180 }
1181
1182 async fn put_multipart_opts(
1183 &self,
1184 location: &Path,
1185 opts: PutMultipartOptions,
1186 ) -> OSResult<Box<dyn MultipartUpload>> {
1187 self.inner.put_multipart_opts(location, opts).await
1188 }
1189
1190 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1191 if self.check_rate() {
1192 self.inner.get_opts(location, options).await
1193 } else {
1194 Err(Self::throttle_error())
1195 }
1196 }
1197
1198 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1199 if self.check_rate() {
1200 self.inner.get_ranges(location, ranges).await
1201 } else {
1202 Err(Self::throttle_error())
1203 }
1204 }
1205
1206 fn delete_stream(
1207 &self,
1208 locations: BoxStream<'static, OSResult<Path>>,
1209 ) -> BoxStream<'static, OSResult<Path>> {
1210 self.inner.delete_stream(locations)
1211 }
1212
1213 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1214 self.inner.list(prefix)
1215 }
1216
1217 fn list_with_offset(
1218 &self,
1219 prefix: Option<&Path>,
1220 offset: &Path,
1221 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1222 self.inner.list_with_offset(prefix, offset)
1223 }
1224
1225 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1226 self.inner.list_with_delimiter(prefix).await
1227 }
1228
1229 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1230 self.inner.copy_opts(from, to, opts).await
1231 }
1232 }
1233
1234 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1251 async fn test_aimd_throttle_under_concurrent_load() {
1252 let mock = Arc::new(RateLimitingMockStore::new(
1253 30,
1254 std::time::Duration::from_millis(100),
1255 ));
1256
1257 let path = Path::from("test/data.bin");
1259 mock.put(&path, PutPayload::from_static(b"test data"))
1260 .await
1261 .unwrap();
1262
1263 let aimd = AimdConfig::default()
1264 .with_initial_rate(100.0)
1265 .with_decrease_factor(0.5)
1266 .with_additive_increment(2.0)
1267 .with_window_duration(std::time::Duration::from_millis(100));
1268 let throttle_config = AimdThrottleConfig::default()
1269 .with_aimd(aimd)
1270 .with_burst_capacity(100);
1271
1272 let num_readers = 5;
1273 let test_duration = std::time::Duration::from_secs(2);
1274 let mut handles = Vec::new();
1275
1276 for _ in 0..num_readers {
1277 let store = Arc::new(
1278 AimdThrottledStore::new(
1279 mock.clone() as Arc<dyn ObjectStore>,
1280 throttle_config.clone(),
1281 )
1282 .unwrap(),
1283 );
1284 let p = path.clone();
1285 handles.push(tokio::spawn(async move {
1286 let deadline = std::time::Instant::now() + test_duration;
1287 let mut count = 0u64;
1288 while std::time::Instant::now() < deadline {
1289 let _ = store.head(&p).await;
1290 count += 1;
1291 }
1292 count
1293 }));
1294 }
1295
1296 let mut total_reader_requests = 0u64;
1297 for handle in handles {
1298 total_reader_requests += handle.await.unwrap();
1299 }
1300
1301 let successes = mock.success_count.load(Ordering::Relaxed);
1302 let throttled = mock.throttle_count.load(Ordering::Relaxed);
1303 let total_mock = successes + throttled;
1304
1305 assert!(
1308 total_mock >= total_reader_requests,
1309 "Mock-side count ({total_mock}) should be >= reader-side count ({total_reader_requests})"
1310 );
1311
1312 assert!(
1315 successes >= 300,
1316 "Expected >= 300 successful requests over 2s, got {successes}"
1317 );
1318 assert!(
1319 successes <= 900,
1320 "Expected <= 900 successful requests, got {successes}"
1321 );
1322
1323 assert!(throttled > 0, "Expected some throttled requests but got 0");
1325
1326 assert!(
1329 total_mock <= 5000,
1330 "AIMD should limit total requests, got {total_mock}"
1331 );
1332 }
1333
1334 struct RetryTestMockStore {
1338 inner: InMemory,
1339 errors_remaining: std::sync::Mutex<usize>,
1341 get_call_count: AtomicU64,
1343 }
1344
1345 impl RetryTestMockStore {
1346 fn new(errors_before_success: usize) -> Self {
1347 Self {
1348 inner: InMemory::new(),
1349 errors_remaining: std::sync::Mutex::new(errors_before_success),
1350 get_call_count: AtomicU64::new(0),
1351 }
1352 }
1353 }
1354
1355 impl Display for RetryTestMockStore {
1356 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1357 write!(f, "RetryTestMockStore")
1358 }
1359 }
1360
1361 impl Debug for RetryTestMockStore {
1362 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1363 f.debug_struct("RetryTestMockStore").finish()
1364 }
1365 }
1366
1367 #[async_trait]
1368 impl ObjectStore for RetryTestMockStore {
1369 async fn put_opts(
1370 &self,
1371 location: &Path,
1372 bytes: PutPayload,
1373 opts: PutOptions,
1374 ) -> OSResult<PutResult> {
1375 self.inner.put_opts(location, bytes, opts).await
1376 }
1377 async fn put_multipart_opts(
1378 &self,
1379 location: &Path,
1380 opts: PutMultipartOptions,
1381 ) -> OSResult<Box<dyn MultipartUpload>> {
1382 self.inner.put_multipart_opts(location, opts).await
1383 }
1384 async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
1385 self.get_call_count.fetch_add(1, Ordering::Relaxed);
1386 let should_error = {
1387 let mut remaining = self.errors_remaining.lock().unwrap();
1388 if *remaining > 0 {
1389 *remaining -= 1;
1390 true
1391 } else {
1392 false
1393 }
1394 };
1395 if should_error {
1396 Err(object_store::Error::Generic {
1397 store: "RetryTestMock",
1398 source: THROTTLE_ERROR_RESPONSE.into(),
1399 })
1400 } else {
1401 self.inner.get_opts(location, options).await
1402 }
1403 }
1404 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> OSResult<Vec<Bytes>> {
1405 self.inner.get_ranges(location, ranges).await
1406 }
1407 fn delete_stream(
1408 &self,
1409 locations: BoxStream<'static, OSResult<Path>>,
1410 ) -> BoxStream<'static, OSResult<Path>> {
1411 self.inner.delete_stream(locations)
1412 }
1413 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult<ObjectMeta>> {
1414 self.inner.list(prefix)
1415 }
1416 fn list_with_offset(
1417 &self,
1418 prefix: Option<&Path>,
1419 offset: &Path,
1420 ) -> BoxStream<'static, OSResult<ObjectMeta>> {
1421 self.inner.list_with_offset(prefix, offset)
1422 }
1423 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
1424 self.inner.list_with_delimiter(prefix).await
1425 }
1426 async fn copy_opts(&self, from: &Path, to: &Path, opts: CopyOptions) -> OSResult<()> {
1427 self.inner.copy_opts(from, to, opts).await
1428 }
1429 }
1430
1431 #[tokio::test]
1432 async fn test_throttled_retries_on_throttle_error_then_succeeds() {
1433 let mock = Arc::new(RetryTestMockStore::new(2));
1435 let path = Path::from("test/retry.txt");
1436 mock.put(&path, PutPayload::from_static(b"retry data"))
1437 .await
1438 .unwrap();
1439
1440 let config = AimdThrottleConfig::default();
1441 let throttled =
1442 AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1443
1444 let result = throttled.get(&path).await;
1445 assert!(result.is_ok(), "Expected success after retries");
1446
1447 let bytes = result.unwrap().bytes().await.unwrap();
1448 assert_eq!(bytes.as_ref(), b"retry data");
1449
1450 assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 3);
1452 }
1453
1454 #[tokio::test]
1455 async fn test_throttled_fails_after_max_retries_exceeded() {
1456 let mock = Arc::new(RetryTestMockStore::new(10));
1459 let path = Path::from("test/fail.txt");
1460 mock.put(&path, PutPayload::from_static(b"fail data"))
1461 .await
1462 .unwrap();
1463
1464 let config = AimdThrottleConfig::default();
1465 let throttled =
1466 AimdThrottledStore::new(mock.clone() as Arc<dyn ObjectStore>, config).unwrap();
1467
1468 let result = throttled.get(&path).await;
1469 assert!(result.is_err(), "Expected error after max retries");
1470 let err = result.unwrap_err();
1471 assert!(is_throttle_error(&err));
1472
1473 let lance_error = lance_core::Error::from(err);
1474 let error_message = lance_error.to_string();
1475 assert!(error_message.contains("x-ms-request-id"));
1476 assert!(error_message.contains("azure-request-id"));
1477
1478 assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4);
1480 }
1481
1482 #[tokio::test]
1483 async fn test_throttled_multipart_reorders_parts() {
1484 let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
1485 let config = AimdThrottleConfig::default();
1486 let throttled = AimdThrottledStore::new(store.clone(), config).unwrap();
1487
1488 let path = Path::from("test/multipart_ordering.bin");
1489 let mut upload = throttled.put_multipart(&path).await.unwrap();
1490
1491 let fut_a = upload.put_part(PutPayload::from_static(b"AAAA"));
1493 let fut_b = upload.put_part(PutPayload::from_static(b"BBBB"));
1494
1495 fut_b.await.unwrap();
1498 fut_a.await.unwrap();
1499
1500 upload.complete().await.unwrap();
1501
1502 let result = store.get(&path).await.unwrap();
1503 let bytes = result.bytes().await.unwrap();
1504
1505 assert_eq!(
1506 bytes.as_ref(),
1507 b"AAAABBBB",
1508 "Parts were reordered! Got {:?} instead of AAAABBBB.",
1509 std::str::from_utf8(&bytes).unwrap_or("<non-utf8>"),
1510 );
1511 }
1512}