1use std::collections::HashMap;
13use std::fmt;
14use std::sync::Arc;
15use std::time::Duration;
16
17#[cfg(test)]
18use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
19
20#[cfg(not(test))]
21use std::time::{SystemTime, UNIX_EPOCH};
22
23use async_trait::async_trait;
24use lance_namespace::LanceNamespace;
25use lance_namespace::models::DescribeTableRequest;
26use tokio::sync::RwLock;
27
28use crate::{Error, Result};
29
30pub const EXPIRES_AT_MILLIS_KEY: &str = "expires_at_millis";
32
33pub const REFRESH_OFFSET_MILLIS_KEY: &str = "refresh_offset_millis";
35
36const DEFAULT_REFRESH_OFFSET_MILLIS: u64 = 60_000;
38
39#[async_trait]
72pub trait StorageOptionsProvider: Send + Sync + fmt::Debug {
73 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>>;
81
82 fn provider_id(&self) -> String;
93}
94
95pub struct LanceNamespaceStorageOptionsProvider {
97 namespace_client: Arc<dyn LanceNamespace>,
98 table_id: Vec<String>,
99}
100
101impl fmt::Debug for LanceNamespaceStorageOptionsProvider {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 write!(f, "{}", self.provider_id())
104 }
105}
106
107impl fmt::Display for LanceNamespaceStorageOptionsProvider {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 write!(f, "{}", self.provider_id())
110 }
111}
112
113impl LanceNamespaceStorageOptionsProvider {
114 pub fn new(namespace_client: Arc<dyn LanceNamespace>, table_id: Vec<String>) -> Self {
120 Self {
121 namespace_client,
122 table_id,
123 }
124 }
125}
126
127#[async_trait]
128impl StorageOptionsProvider for LanceNamespaceStorageOptionsProvider {
129 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
130 let request = DescribeTableRequest {
131 id: Some(self.table_id.clone()),
132 ..Default::default()
133 };
134
135 let response = self
136 .namespace_client
137 .describe_table(request)
138 .await
139 .map_err(|e| {
140 Error::io_source(Box::new(std::io::Error::other(format!(
141 "Failed to fetch storage options: {}",
142 e
143 ))))
144 })?;
145
146 Ok(response.storage_options)
147 }
148
149 fn provider_id(&self) -> String {
150 format!(
151 "LanceNamespaceStorageOptionsProvider {{ namespace_client: {}, table_id: {:?} }}",
152 self.namespace_client.namespace_id(),
153 self.table_id
154 )
155 }
156}
157
158pub struct StorageOptionsAccessor {
176 initial_options: Option<HashMap<String, String>>,
178
179 provider: Option<Arc<dyn StorageOptionsProvider>>,
181
182 cache: Arc<RwLock<Option<CachedStorageOptions>>>,
184
185 refresh_offset: Duration,
187}
188
189impl fmt::Debug for StorageOptionsAccessor {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_struct("StorageOptionsAccessor")
192 .field("has_initial_options", &self.initial_options.is_some())
193 .field("has_provider", &self.provider.is_some())
194 .field("refresh_offset", &self.refresh_offset)
195 .finish()
196 }
197}
198
199#[derive(Debug, Clone)]
200struct CachedStorageOptions {
201 options: HashMap<String, String>,
202 expires_at_millis: Option<u64>,
203}
204
205impl StorageOptionsAccessor {
206 fn extract_refresh_offset(options: &HashMap<String, String>) -> Duration {
208 options
209 .get(REFRESH_OFFSET_MILLIS_KEY)
210 .and_then(|s| s.parse::<u64>().ok())
211 .map(Duration::from_millis)
212 .unwrap_or(Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS))
213 }
214
215 pub fn with_static_options(options: HashMap<String, String>) -> Self {
220 let expires_at_millis = options
221 .get(EXPIRES_AT_MILLIS_KEY)
222 .and_then(|s| s.parse::<u64>().ok());
223 let refresh_offset = Self::extract_refresh_offset(&options);
224
225 Self {
226 initial_options: Some(options.clone()),
227 provider: None,
228 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
229 options,
230 expires_at_millis,
231 }))),
232 refresh_offset,
233 }
234 }
235
236 pub fn with_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
245 Self {
246 initial_options: None,
247 provider: Some(provider),
248 cache: Arc::new(RwLock::new(None)),
249 refresh_offset: Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS),
250 }
251 }
252
253 pub fn with_initial_and_provider(
263 initial_options: HashMap<String, String>,
264 provider: Arc<dyn StorageOptionsProvider>,
265 ) -> Self {
266 let expires_at_millis = initial_options
267 .get(EXPIRES_AT_MILLIS_KEY)
268 .and_then(|s| s.parse::<u64>().ok());
269 let refresh_offset = Self::extract_refresh_offset(&initial_options);
270
271 Self {
272 initial_options: Some(initial_options.clone()),
273 provider: Some(provider),
274 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
275 options: initial_options,
276 expires_at_millis,
277 }))),
278 refresh_offset,
279 }
280 }
281
282 pub async fn get_storage_options(&self) -> Result<super::StorageOptions> {
294 loop {
295 match self.do_get_storage_options().await? {
296 Some(options) => return Ok(options),
297 None => {
298 tokio::time::sleep(Duration::from_millis(10)).await;
300 continue;
301 }
302 }
303 }
304 }
305
306 pub(crate) async fn refresh_storage_options(&self) -> Result<super::StorageOptions> {
311 let Some(provider) = &self.provider else {
312 return self.get_storage_options().await;
313 };
314
315 log::debug!(
316 "Refreshing storage options from provider: {}",
317 provider.provider_id()
318 );
319
320 let storage_options_map = provider.fetch_storage_options().await.map_err(|e| {
321 Error::io_source(Box::new(std::io::Error::other(format!(
322 "Failed to fetch storage options: {}",
323 e
324 ))))
325 })?;
326
327 let Some(options) = storage_options_map else {
328 if let Some(initial) = &self.initial_options {
329 return Ok(super::StorageOptions(initial.clone()));
330 }
331 log::debug!(
332 "Provider {} returned no storage options, using default credentials",
333 provider.provider_id()
334 );
335 return Ok(super::StorageOptions(HashMap::new()));
336 };
337
338 let expires_at_millis = options
339 .get(EXPIRES_AT_MILLIS_KEY)
340 .and_then(|s| s.parse::<u64>().ok());
341
342 let mut cache = self.cache.write().await;
343 *cache = Some(CachedStorageOptions {
344 options: options.clone(),
345 expires_at_millis,
346 });
347
348 Ok(super::StorageOptions(options))
349 }
350
351 async fn do_get_storage_options(&self) -> Result<Option<super::StorageOptions>> {
352 {
354 let cached = self.cache.read().await;
355 if !self.needs_refresh(&cached)
356 && let Some(cached_opts) = &*cached
357 {
358 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
359 }
360 }
361
362 let Some(provider) = &self.provider else {
364 return if let Some(initial) = &self.initial_options {
365 Ok(Some(super::StorageOptions(initial.clone())))
366 } else {
367 Ok(Some(super::StorageOptions(HashMap::new())))
369 };
370 };
371
372 let Ok(mut cache) = self.cache.try_write() else {
374 return Ok(None);
375 };
376
377 if !self.needs_refresh(&cache)
380 && let Some(cached_opts) = &*cache
381 {
382 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
383 }
384 log::debug!(
385 "Refreshing storage options from provider: {}",
386 provider.provider_id()
387 );
388
389 let storage_options_map = provider.fetch_storage_options().await.map_err(|e| {
390 Error::io_source(Box::new(std::io::Error::other(format!(
391 "Failed to fetch storage options: {}",
392 e
393 ))))
394 })?;
395
396 let Some(options) = storage_options_map else {
397 if let Some(initial) = &self.initial_options {
399 return Ok(Some(super::StorageOptions(initial.clone())));
400 }
401 log::debug!(
405 "Provider {} returned no storage options, using default credentials",
406 provider.provider_id()
407 );
408 return Ok(Some(super::StorageOptions(HashMap::new())));
409 };
410
411 let expires_at_millis = options
412 .get(EXPIRES_AT_MILLIS_KEY)
413 .and_then(|s| s.parse::<u64>().ok());
414
415 if let Some(expires_at) = expires_at_millis {
416 let now_ms = SystemTime::now()
417 .duration_since(UNIX_EPOCH)
418 .unwrap_or(Duration::from_secs(0))
419 .as_millis() as u64;
420 let expires_in_secs = (expires_at.saturating_sub(now_ms)) / 1000;
421 log::debug!(
422 "Successfully refreshed storage options from provider: {}, options expire in {} seconds",
423 provider.provider_id(),
424 expires_in_secs
425 );
426 } else {
427 log::debug!(
428 "Successfully refreshed storage options from provider: {} (no expiration)",
429 provider.provider_id()
430 );
431 }
432
433 *cache = Some(CachedStorageOptions {
434 options: options.clone(),
435 expires_at_millis,
436 });
437
438 Ok(Some(super::StorageOptions(options)))
439 }
440
441 fn needs_refresh(&self, cached: &Option<CachedStorageOptions>) -> bool {
442 match cached {
443 None => true,
444 Some(cached_opts) => {
445 if let Some(expires_at_millis) = cached_opts.expires_at_millis {
446 let now_ms = SystemTime::now()
447 .duration_since(UNIX_EPOCH)
448 .unwrap_or(Duration::from_secs(0))
449 .as_millis() as u64;
450
451 let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
453 now_ms + refresh_offset_millis >= expires_at_millis
454 } else {
455 false
457 }
458 }
459 }
460 }
461
462 pub fn initial_storage_options(&self) -> Option<&HashMap<String, String>> {
467 self.initial_options.as_ref()
468 }
469
470 pub fn accessor_id(&self) -> String {
475 if let Some(provider) = &self.provider {
476 provider.provider_id()
477 } else if let Some(initial) = &self.initial_options {
478 use std::collections::hash_map::DefaultHasher;
480 use std::hash::{Hash, Hasher};
481
482 let mut hasher = DefaultHasher::new();
483 let mut keys: Vec<_> = initial.keys().collect();
484 keys.sort();
485 for key in keys {
486 key.hash(&mut hasher);
487 initial.get(key).hash(&mut hasher);
488 }
489 format!("static_options_{:x}", hasher.finish())
490 } else {
491 "empty_accessor".to_string()
492 }
493 }
494
495 pub fn has_provider(&self) -> bool {
497 self.provider.is_some()
498 }
499
500 pub fn refresh_offset(&self) -> Duration {
502 self.refresh_offset
503 }
504
505 pub fn provider(&self) -> Option<&Arc<dyn StorageOptionsProvider>> {
507 self.provider.as_ref()
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use mock_instant::thread_local::MockClock;
515
516 #[derive(Debug)]
517 struct MockStorageOptionsProvider {
518 call_count: Arc<RwLock<usize>>,
519 expires_in_millis: Option<u64>,
520 }
521
522 impl MockStorageOptionsProvider {
523 fn new(expires_in_millis: Option<u64>) -> Self {
524 Self {
525 call_count: Arc::new(RwLock::new(0)),
526 expires_in_millis,
527 }
528 }
529
530 async fn get_call_count(&self) -> usize {
531 *self.call_count.read().await
532 }
533 }
534
535 #[async_trait]
536 impl StorageOptionsProvider for MockStorageOptionsProvider {
537 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
538 let count = {
539 let mut c = self.call_count.write().await;
540 *c += 1;
541 *c
542 };
543
544 let mut options = HashMap::from([
545 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
546 (
547 "aws_secret_access_key".to_string(),
548 format!("SECRET_{}", count),
549 ),
550 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
551 ]);
552
553 if let Some(expires_in) = self.expires_in_millis {
554 let now_ms = SystemTime::now()
555 .duration_since(UNIX_EPOCH)
556 .unwrap()
557 .as_millis() as u64;
558 let expires_at = now_ms + expires_in;
559 options.insert(EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string());
560 }
561
562 Ok(Some(options))
563 }
564
565 fn provider_id(&self) -> String {
566 let ptr = Arc::as_ptr(&self.call_count) as usize;
567 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
568 }
569 }
570
571 #[tokio::test]
572 async fn test_static_options_only() {
573 let options = HashMap::from([
574 ("key1".to_string(), "value1".to_string()),
575 ("key2".to_string(), "value2".to_string()),
576 ]);
577 let accessor = StorageOptionsAccessor::with_static_options(options.clone());
578
579 let result = accessor.get_storage_options().await.unwrap();
580 assert_eq!(result.0, options);
581 assert!(!accessor.has_provider());
582 assert_eq!(accessor.initial_storage_options(), Some(&options));
583 }
584
585 #[tokio::test]
586 async fn test_provider_only() {
587 MockClock::set_system_time(Duration::from_secs(100_000));
588
589 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
590 let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
591
592 let result = accessor.get_storage_options().await.unwrap();
593 assert!(result.0.contains_key("aws_access_key_id"));
594 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
595 assert!(accessor.has_provider());
596 assert_eq!(accessor.initial_storage_options(), None);
597 assert_eq!(mock_provider.get_call_count().await, 1);
598 }
599
600 #[tokio::test]
601 async fn test_initial_and_provider_uses_initial_first() {
602 MockClock::set_system_time(Duration::from_secs(100_000));
603
604 let now_ms = MockClock::system_time().as_millis() as u64;
605 let expires_at = now_ms + 600_000; let initial = HashMap::from([
608 ("aws_access_key_id".to_string(), "INITIAL_KEY".to_string()),
609 (
610 "aws_secret_access_key".to_string(),
611 "INITIAL_SECRET".to_string(),
612 ),
613 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
614 ]);
615 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
616
617 let accessor = StorageOptionsAccessor::with_initial_and_provider(
618 initial.clone(),
619 mock_provider.clone(),
620 );
621
622 let result = accessor.get_storage_options().await.unwrap();
624 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "INITIAL_KEY");
625 assert_eq!(mock_provider.get_call_count().await, 0); }
627
628 #[tokio::test]
629 async fn test_caching_and_refresh() {
630 MockClock::set_system_time(Duration::from_secs(100_000));
631
632 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000))); let now_ms = MockClock::system_time().as_millis() as u64;
635 let expires_at = now_ms + 600_000; let initial = HashMap::from([
637 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
638 (REFRESH_OFFSET_MILLIS_KEY.to_string(), "300000".to_string()), ]);
640 let accessor =
641 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
642
643 let result = accessor.get_storage_options().await.unwrap();
645 assert!(result.0.contains_key(EXPIRES_AT_MILLIS_KEY));
646 assert_eq!(mock_provider.get_call_count().await, 0);
647
648 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
650 let result = accessor.get_storage_options().await.unwrap();
651 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
652 assert_eq!(mock_provider.get_call_count().await, 1);
653 }
654
655 #[tokio::test]
656 async fn test_expired_initial_triggers_refresh() {
657 MockClock::set_system_time(Duration::from_secs(100_000));
658
659 let now_ms = MockClock::system_time().as_millis() as u64;
660 let expired_time = now_ms - 1_000; let initial = HashMap::from([
663 ("aws_access_key_id".to_string(), "EXPIRED_KEY".to_string()),
664 (EXPIRES_AT_MILLIS_KEY.to_string(), expired_time.to_string()),
665 ]);
666 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
667
668 let accessor =
669 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
670
671 let result = accessor.get_storage_options().await.unwrap();
673 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
674 assert_eq!(mock_provider.get_call_count().await, 1);
675 }
676
677 #[tokio::test]
678 async fn test_accessor_id_with_provider() {
679 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None));
680 let accessor = StorageOptionsAccessor::with_provider(mock_provider);
681
682 let id = accessor.accessor_id();
683 assert!(id.starts_with("MockStorageOptionsProvider"));
684 }
685
686 #[tokio::test]
687 async fn test_accessor_id_static() {
688 let options = HashMap::from([("key".to_string(), "value".to_string())]);
689 let accessor = StorageOptionsAccessor::with_static_options(options);
690
691 let id = accessor.accessor_id();
692 assert!(id.starts_with("static_options_"));
693 }
694
695 #[tokio::test]
696 async fn test_concurrent_access() {
697 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
699
700 let accessor = Arc::new(StorageOptionsAccessor::with_provider(mock_provider.clone()));
701
702 let mut handles = vec![];
704 for i in 0..10 {
705 let acc = accessor.clone();
706 let handle = tokio::spawn(async move {
707 let result = acc.get_storage_options().await.unwrap();
708 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
709 i
710 });
711 handles.push(handle);
712 }
713
714 let results: Vec<_> = futures::future::join_all(handles)
716 .await
717 .into_iter()
718 .map(|r| r.unwrap())
719 .collect();
720
721 assert_eq!(results.len(), 10);
723
724 let call_count = mock_provider.get_call_count().await;
726 assert_eq!(
727 call_count, 1,
728 "Provider should be called exactly once despite concurrent access"
729 );
730 }
731
732 #[tokio::test]
733 async fn test_no_expiration_never_refreshes() {
734 MockClock::set_system_time(Duration::from_secs(100_000));
735
736 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None)); let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
738
739 accessor.get_storage_options().await.unwrap();
741 assert_eq!(mock_provider.get_call_count().await, 1);
742
743 MockClock::set_system_time(Duration::from_secs(200_000));
745
746 accessor.get_storage_options().await.unwrap();
748 assert_eq!(mock_provider.get_call_count().await, 1);
749 }
750}