1use std::collections::HashMap;
13use std::fmt;
14use std::sync::Arc;
15use std::time::Duration;
16
17#[cfg(test)]
18use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
19
20#[cfg(not(test))]
21use std::time::{SystemTime, UNIX_EPOCH};
22
23use async_trait::async_trait;
24use lance_namespace::LanceNamespace;
25use lance_namespace::models::DescribeTableRequest;
26use tokio::sync::RwLock;
27
28use crate::{Error, Result};
29
30pub const EXPIRES_AT_MILLIS_KEY: &str = "expires_at_millis";
32
33pub const REFRESH_OFFSET_MILLIS_KEY: &str = "refresh_offset_millis";
35
36const DEFAULT_REFRESH_OFFSET_MILLIS: u64 = 60_000;
38
39#[async_trait]
72pub trait StorageOptionsProvider: Send + Sync + fmt::Debug {
73 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>>;
81
82 fn provider_id(&self) -> String;
93}
94
95pub struct LanceNamespaceStorageOptionsProvider {
97 namespace_client: Arc<dyn LanceNamespace>,
98 table_id: Vec<String>,
99}
100
101impl fmt::Debug for LanceNamespaceStorageOptionsProvider {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 write!(f, "{}", self.provider_id())
104 }
105}
106
107impl fmt::Display for LanceNamespaceStorageOptionsProvider {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 write!(f, "{}", self.provider_id())
110 }
111}
112
113impl LanceNamespaceStorageOptionsProvider {
114 pub fn new(namespace_client: Arc<dyn LanceNamespace>, table_id: Vec<String>) -> Self {
120 Self {
121 namespace_client,
122 table_id,
123 }
124 }
125}
126
127#[async_trait]
128impl StorageOptionsProvider for LanceNamespaceStorageOptionsProvider {
129 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
130 let request = DescribeTableRequest {
131 id: Some(self.table_id.clone()),
132 ..Default::default()
133 };
134
135 let response = self
136 .namespace_client
137 .describe_table(request)
138 .await
139 .map_err(|e| {
140 Error::io_source(Box::new(std::io::Error::other(format!(
141 "Failed to fetch storage options: {}",
142 e
143 ))))
144 })?;
145
146 Ok(response.storage_options)
147 }
148
149 fn provider_id(&self) -> String {
150 format!(
151 "LanceNamespaceStorageOptionsProvider {{ namespace_client: {}, table_id: {:?} }}",
152 self.namespace_client.namespace_id(),
153 self.table_id
154 )
155 }
156}
157
158pub struct StorageOptionsAccessor {
176 initial_options: Option<HashMap<String, String>>,
178
179 provider: Option<Arc<dyn StorageOptionsProvider>>,
181
182 cache: Arc<RwLock<Option<CachedStorageOptions>>>,
184
185 refresh_offset: Duration,
187}
188
189impl fmt::Debug for StorageOptionsAccessor {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_struct("StorageOptionsAccessor")
192 .field("has_initial_options", &self.initial_options.is_some())
193 .field("has_provider", &self.provider.is_some())
194 .field("refresh_offset", &self.refresh_offset)
195 .finish()
196 }
197}
198
199#[derive(Debug, Clone)]
200struct CachedStorageOptions {
201 options: HashMap<String, String>,
202 expires_at_millis: Option<u64>,
203}
204
205impl StorageOptionsAccessor {
206 fn extract_refresh_offset(options: &HashMap<String, String>) -> Duration {
208 options
209 .get(REFRESH_OFFSET_MILLIS_KEY)
210 .and_then(|s| s.parse::<u64>().ok())
211 .map(Duration::from_millis)
212 .unwrap_or(Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS))
213 }
214
215 pub fn with_static_options(options: HashMap<String, String>) -> Self {
220 let expires_at_millis = options
221 .get(EXPIRES_AT_MILLIS_KEY)
222 .and_then(|s| s.parse::<u64>().ok());
223 let refresh_offset = Self::extract_refresh_offset(&options);
224
225 Self {
226 initial_options: Some(options.clone()),
227 provider: None,
228 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
229 options,
230 expires_at_millis,
231 }))),
232 refresh_offset,
233 }
234 }
235
236 pub fn with_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
245 Self {
246 initial_options: None,
247 provider: Some(provider),
248 cache: Arc::new(RwLock::new(None)),
249 refresh_offset: Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS),
250 }
251 }
252
253 pub fn with_initial_and_provider(
263 initial_options: HashMap<String, String>,
264 provider: Arc<dyn StorageOptionsProvider>,
265 ) -> Self {
266 let expires_at_millis = initial_options
267 .get(EXPIRES_AT_MILLIS_KEY)
268 .and_then(|s| s.parse::<u64>().ok());
269 let refresh_offset = Self::extract_refresh_offset(&initial_options);
270
271 Self {
272 initial_options: Some(initial_options.clone()),
273 provider: Some(provider),
274 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
275 options: initial_options,
276 expires_at_millis,
277 }))),
278 refresh_offset,
279 }
280 }
281
282 pub async fn get_storage_options(&self) -> Result<super::StorageOptions> {
294 loop {
295 match self.do_get_storage_options().await? {
296 Some(options) => return Ok(options),
297 None => {
298 tokio::time::sleep(Duration::from_millis(10)).await;
300 continue;
301 }
302 }
303 }
304 }
305
306 #[cfg(any(feature = "aws", feature = "azure", feature = "gcp"))]
311 pub(crate) async fn refresh_storage_options(&self) -> Result<super::StorageOptions> {
312 let Some(provider) = &self.provider else {
313 return self.get_storage_options().await;
314 };
315
316 log::debug!(
317 "Refreshing storage options from provider: {}",
318 provider.provider_id()
319 );
320
321 let storage_options_map = provider.fetch_storage_options().await.map_err(|e| {
322 Error::io_source(Box::new(std::io::Error::other(format!(
323 "Failed to fetch storage options: {}",
324 e
325 ))))
326 })?;
327
328 let Some(options) = storage_options_map else {
329 if let Some(initial) = &self.initial_options {
330 return Ok(super::StorageOptions(initial.clone()));
331 }
332 log::debug!(
333 "Provider {} returned no storage options, using default credentials",
334 provider.provider_id()
335 );
336 return Ok(super::StorageOptions(HashMap::new()));
337 };
338
339 let expires_at_millis = options
340 .get(EXPIRES_AT_MILLIS_KEY)
341 .and_then(|s| s.parse::<u64>().ok());
342
343 let mut cache = self.cache.write().await;
344 *cache = Some(CachedStorageOptions {
345 options: options.clone(),
346 expires_at_millis,
347 });
348
349 Ok(super::StorageOptions(options))
350 }
351
352 async fn do_get_storage_options(&self) -> Result<Option<super::StorageOptions>> {
353 {
355 let cached = self.cache.read().await;
356 if !self.needs_refresh(&cached)
357 && let Some(cached_opts) = &*cached
358 {
359 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
360 }
361 }
362
363 let Some(provider) = &self.provider else {
365 return if let Some(initial) = &self.initial_options {
366 Ok(Some(super::StorageOptions(initial.clone())))
367 } else {
368 Ok(Some(super::StorageOptions(HashMap::new())))
370 };
371 };
372
373 let Ok(mut cache) = self.cache.try_write() else {
375 return Ok(None);
376 };
377
378 if !self.needs_refresh(&cache)
381 && let Some(cached_opts) = &*cache
382 {
383 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
384 }
385 log::debug!(
386 "Refreshing storage options from provider: {}",
387 provider.provider_id()
388 );
389
390 let storage_options_map = provider.fetch_storage_options().await.map_err(|e| {
391 Error::io_source(Box::new(std::io::Error::other(format!(
392 "Failed to fetch storage options: {}",
393 e
394 ))))
395 })?;
396
397 let Some(options) = storage_options_map else {
398 if let Some(initial) = &self.initial_options {
400 return Ok(Some(super::StorageOptions(initial.clone())));
401 }
402 log::debug!(
406 "Provider {} returned no storage options, using default credentials",
407 provider.provider_id()
408 );
409 return Ok(Some(super::StorageOptions(HashMap::new())));
410 };
411
412 let expires_at_millis = options
413 .get(EXPIRES_AT_MILLIS_KEY)
414 .and_then(|s| s.parse::<u64>().ok());
415
416 if let Some(expires_at) = expires_at_millis {
417 let now_ms = SystemTime::now()
418 .duration_since(UNIX_EPOCH)
419 .unwrap_or(Duration::from_secs(0))
420 .as_millis() as u64;
421 let expires_in_secs = (expires_at.saturating_sub(now_ms)) / 1000;
422 log::debug!(
423 "Successfully refreshed storage options from provider: {}, options expire in {} seconds",
424 provider.provider_id(),
425 expires_in_secs
426 );
427 } else {
428 log::debug!(
429 "Successfully refreshed storage options from provider: {} (no expiration)",
430 provider.provider_id()
431 );
432 }
433
434 *cache = Some(CachedStorageOptions {
435 options: options.clone(),
436 expires_at_millis,
437 });
438
439 Ok(Some(super::StorageOptions(options)))
440 }
441
442 fn needs_refresh(&self, cached: &Option<CachedStorageOptions>) -> bool {
443 match cached {
444 None => true,
445 Some(cached_opts) => {
446 if let Some(expires_at_millis) = cached_opts.expires_at_millis {
447 let now_ms = SystemTime::now()
448 .duration_since(UNIX_EPOCH)
449 .unwrap_or(Duration::from_secs(0))
450 .as_millis() as u64;
451
452 let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
454 now_ms + refresh_offset_millis >= expires_at_millis
455 } else {
456 false
458 }
459 }
460 }
461 }
462
463 pub fn initial_storage_options(&self) -> Option<&HashMap<String, String>> {
468 self.initial_options.as_ref()
469 }
470
471 pub fn accessor_id(&self) -> String {
476 if let Some(provider) = &self.provider {
477 provider.provider_id()
478 } else if let Some(initial) = &self.initial_options {
479 use std::collections::hash_map::DefaultHasher;
481 use std::hash::{Hash, Hasher};
482
483 let mut hasher = DefaultHasher::new();
484 let mut keys: Vec<_> = initial.keys().collect();
485 keys.sort();
486 for key in keys {
487 key.hash(&mut hasher);
488 initial.get(key).hash(&mut hasher);
489 }
490 format!("static_options_{:x}", hasher.finish())
491 } else {
492 "empty_accessor".to_string()
493 }
494 }
495
496 pub fn has_provider(&self) -> bool {
498 self.provider.is_some()
499 }
500
501 pub fn refresh_offset(&self) -> Duration {
503 self.refresh_offset
504 }
505
506 pub fn provider(&self) -> Option<&Arc<dyn StorageOptionsProvider>> {
508 self.provider.as_ref()
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use mock_instant::thread_local::MockClock;
516
517 #[derive(Debug)]
518 struct MockStorageOptionsProvider {
519 call_count: Arc<RwLock<usize>>,
520 expires_in_millis: Option<u64>,
521 }
522
523 impl MockStorageOptionsProvider {
524 fn new(expires_in_millis: Option<u64>) -> Self {
525 Self {
526 call_count: Arc::new(RwLock::new(0)),
527 expires_in_millis,
528 }
529 }
530
531 async fn get_call_count(&self) -> usize {
532 *self.call_count.read().await
533 }
534 }
535
536 #[async_trait]
537 impl StorageOptionsProvider for MockStorageOptionsProvider {
538 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
539 let count = {
540 let mut c = self.call_count.write().await;
541 *c += 1;
542 *c
543 };
544
545 let mut options = HashMap::from([
546 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
547 (
548 "aws_secret_access_key".to_string(),
549 format!("SECRET_{}", count),
550 ),
551 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
552 ]);
553
554 if let Some(expires_in) = self.expires_in_millis {
555 let now_ms = SystemTime::now()
556 .duration_since(UNIX_EPOCH)
557 .unwrap()
558 .as_millis() as u64;
559 let expires_at = now_ms + expires_in;
560 options.insert(EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string());
561 }
562
563 Ok(Some(options))
564 }
565
566 fn provider_id(&self) -> String {
567 let ptr = Arc::as_ptr(&self.call_count) as usize;
568 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
569 }
570 }
571
572 #[tokio::test]
573 async fn test_static_options_only() {
574 let options = HashMap::from([
575 ("key1".to_string(), "value1".to_string()),
576 ("key2".to_string(), "value2".to_string()),
577 ]);
578 let accessor = StorageOptionsAccessor::with_static_options(options.clone());
579
580 let result = accessor.get_storage_options().await.unwrap();
581 assert_eq!(result.0, options);
582 assert!(!accessor.has_provider());
583 assert_eq!(accessor.initial_storage_options(), Some(&options));
584 }
585
586 #[tokio::test]
587 async fn test_provider_only() {
588 MockClock::set_system_time(Duration::from_secs(100_000));
589
590 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
591 let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
592
593 let result = accessor.get_storage_options().await.unwrap();
594 assert!(result.0.contains_key("aws_access_key_id"));
595 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
596 assert!(accessor.has_provider());
597 assert_eq!(accessor.initial_storage_options(), None);
598 assert_eq!(mock_provider.get_call_count().await, 1);
599 }
600
601 #[tokio::test]
602 async fn test_initial_and_provider_uses_initial_first() {
603 MockClock::set_system_time(Duration::from_secs(100_000));
604
605 let now_ms = MockClock::system_time().as_millis() as u64;
606 let expires_at = now_ms + 600_000; let initial = HashMap::from([
609 ("aws_access_key_id".to_string(), "INITIAL_KEY".to_string()),
610 (
611 "aws_secret_access_key".to_string(),
612 "INITIAL_SECRET".to_string(),
613 ),
614 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
615 ]);
616 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
617
618 let accessor = StorageOptionsAccessor::with_initial_and_provider(
619 initial.clone(),
620 mock_provider.clone(),
621 );
622
623 let result = accessor.get_storage_options().await.unwrap();
625 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "INITIAL_KEY");
626 assert_eq!(mock_provider.get_call_count().await, 0); }
628
629 #[tokio::test]
630 async fn test_caching_and_refresh() {
631 MockClock::set_system_time(Duration::from_secs(100_000));
632
633 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000))); let now_ms = MockClock::system_time().as_millis() as u64;
636 let expires_at = now_ms + 600_000; let initial = HashMap::from([
638 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
639 (REFRESH_OFFSET_MILLIS_KEY.to_string(), "300000".to_string()), ]);
641 let accessor =
642 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
643
644 let result = accessor.get_storage_options().await.unwrap();
646 assert!(result.0.contains_key(EXPIRES_AT_MILLIS_KEY));
647 assert_eq!(mock_provider.get_call_count().await, 0);
648
649 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
651 let result = accessor.get_storage_options().await.unwrap();
652 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
653 assert_eq!(mock_provider.get_call_count().await, 1);
654 }
655
656 #[tokio::test]
657 async fn test_expired_initial_triggers_refresh() {
658 MockClock::set_system_time(Duration::from_secs(100_000));
659
660 let now_ms = MockClock::system_time().as_millis() as u64;
661 let expired_time = now_ms - 1_000; let initial = HashMap::from([
664 ("aws_access_key_id".to_string(), "EXPIRED_KEY".to_string()),
665 (EXPIRES_AT_MILLIS_KEY.to_string(), expired_time.to_string()),
666 ]);
667 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
668
669 let accessor =
670 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
671
672 let result = accessor.get_storage_options().await.unwrap();
674 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
675 assert_eq!(mock_provider.get_call_count().await, 1);
676 }
677
678 #[tokio::test]
679 async fn test_accessor_id_with_provider() {
680 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None));
681 let accessor = StorageOptionsAccessor::with_provider(mock_provider);
682
683 let id = accessor.accessor_id();
684 assert!(id.starts_with("MockStorageOptionsProvider"));
685 }
686
687 #[tokio::test]
688 async fn test_accessor_id_static() {
689 let options = HashMap::from([("key".to_string(), "value".to_string())]);
690 let accessor = StorageOptionsAccessor::with_static_options(options);
691
692 let id = accessor.accessor_id();
693 assert!(id.starts_with("static_options_"));
694 }
695
696 #[tokio::test]
697 async fn test_concurrent_access() {
698 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
700
701 let accessor = Arc::new(StorageOptionsAccessor::with_provider(mock_provider.clone()));
702
703 let mut handles = vec![];
705 for i in 0..10 {
706 let acc = accessor.clone();
707 let handle = tokio::spawn(async move {
708 let result = acc.get_storage_options().await.unwrap();
709 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
710 i
711 });
712 handles.push(handle);
713 }
714
715 let results: Vec<_> = futures::future::join_all(handles)
717 .await
718 .into_iter()
719 .map(|r| r.unwrap())
720 .collect();
721
722 assert_eq!(results.len(), 10);
724
725 let call_count = mock_provider.get_call_count().await;
727 assert_eq!(
728 call_count, 1,
729 "Provider should be called exactly once despite concurrent access"
730 );
731 }
732
733 #[tokio::test]
734 async fn test_no_expiration_never_refreshes() {
735 MockClock::set_system_time(Duration::from_secs(100_000));
736
737 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None)); let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
739
740 accessor.get_storage_options().await.unwrap();
742 assert_eq!(mock_provider.get_call_count().await, 1);
743
744 MockClock::set_system_time(Duration::from_secs(200_000));
746
747 accessor.get_storage_options().await.unwrap();
749 assert_eq!(mock_provider.get_call_count().await, 1);
750 }
751}