1use std::collections::HashMap;
13use std::fmt;
14use std::sync::Arc;
15use std::time::Duration;
16
17#[cfg(test)]
18use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
19
20#[cfg(not(test))]
21use std::time::{SystemTime, UNIX_EPOCH};
22
23use async_trait::async_trait;
24use lance_namespace::models::DescribeTableRequest;
25use lance_namespace::LanceNamespace;
26use snafu::location;
27use tokio::sync::RwLock;
28
29use crate::{Error, Result};
30
31pub const EXPIRES_AT_MILLIS_KEY: &str = "expires_at_millis";
33
34pub const REFRESH_OFFSET_MILLIS_KEY: &str = "refresh_offset_millis";
36
37const DEFAULT_REFRESH_OFFSET_MILLIS: u64 = 60_000;
39
40#[async_trait]
73pub trait StorageOptionsProvider: Send + Sync + fmt::Debug {
74 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>>;
82
83 fn provider_id(&self) -> String;
94}
95
96pub struct LanceNamespaceStorageOptionsProvider {
98 namespace: Arc<dyn LanceNamespace>,
99 table_id: Vec<String>,
100}
101
102impl fmt::Debug for LanceNamespaceStorageOptionsProvider {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 write!(f, "{}", self.provider_id())
105 }
106}
107
108impl fmt::Display for LanceNamespaceStorageOptionsProvider {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 write!(f, "{}", self.provider_id())
111 }
112}
113
114impl LanceNamespaceStorageOptionsProvider {
115 pub fn new(namespace: Arc<dyn LanceNamespace>, table_id: Vec<String>) -> Self {
121 Self {
122 namespace,
123 table_id,
124 }
125 }
126}
127
128#[async_trait]
129impl StorageOptionsProvider for LanceNamespaceStorageOptionsProvider {
130 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
131 let request = DescribeTableRequest {
132 id: Some(self.table_id.clone()),
133 ..Default::default()
134 };
135
136 let response = self
137 .namespace
138 .describe_table(request)
139 .await
140 .map_err(|e| Error::IO {
141 source: Box::new(std::io::Error::other(format!(
142 "Failed to fetch storage options: {}",
143 e
144 ))),
145 location: location!(),
146 })?;
147
148 Ok(response.storage_options)
149 }
150
151 fn provider_id(&self) -> String {
152 format!(
153 "LanceNamespaceStorageOptionsProvider {{ namespace: {}, table_id: {:?} }}",
154 self.namespace.namespace_id(),
155 self.table_id
156 )
157 }
158}
159
160pub struct StorageOptionsAccessor {
178 initial_options: Option<HashMap<String, String>>,
180
181 provider: Option<Arc<dyn StorageOptionsProvider>>,
183
184 cache: Arc<RwLock<Option<CachedStorageOptions>>>,
186
187 refresh_offset: Duration,
189}
190
191impl fmt::Debug for StorageOptionsAccessor {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 f.debug_struct("StorageOptionsAccessor")
194 .field("has_initial_options", &self.initial_options.is_some())
195 .field("has_provider", &self.provider.is_some())
196 .field("refresh_offset", &self.refresh_offset)
197 .finish()
198 }
199}
200
201#[derive(Debug, Clone)]
202struct CachedStorageOptions {
203 options: HashMap<String, String>,
204 expires_at_millis: Option<u64>,
205}
206
207impl StorageOptionsAccessor {
208 fn extract_refresh_offset(options: &HashMap<String, String>) -> Duration {
210 options
211 .get(REFRESH_OFFSET_MILLIS_KEY)
212 .and_then(|s| s.parse::<u64>().ok())
213 .map(Duration::from_millis)
214 .unwrap_or(Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS))
215 }
216
217 pub fn with_static_options(options: HashMap<String, String>) -> Self {
222 let expires_at_millis = options
223 .get(EXPIRES_AT_MILLIS_KEY)
224 .and_then(|s| s.parse::<u64>().ok());
225 let refresh_offset = Self::extract_refresh_offset(&options);
226
227 Self {
228 initial_options: Some(options.clone()),
229 provider: None,
230 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
231 options,
232 expires_at_millis,
233 }))),
234 refresh_offset,
235 }
236 }
237
238 pub fn with_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
247 Self {
248 initial_options: None,
249 provider: Some(provider),
250 cache: Arc::new(RwLock::new(None)),
251 refresh_offset: Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS),
252 }
253 }
254
255 pub fn with_initial_and_provider(
265 initial_options: HashMap<String, String>,
266 provider: Arc<dyn StorageOptionsProvider>,
267 ) -> Self {
268 let expires_at_millis = initial_options
269 .get(EXPIRES_AT_MILLIS_KEY)
270 .and_then(|s| s.parse::<u64>().ok());
271 let refresh_offset = Self::extract_refresh_offset(&initial_options);
272
273 Self {
274 initial_options: Some(initial_options.clone()),
275 provider: Some(provider),
276 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
277 options: initial_options,
278 expires_at_millis,
279 }))),
280 refresh_offset,
281 }
282 }
283
284 pub async fn get_storage_options(&self) -> Result<super::StorageOptions> {
296 loop {
297 match self.do_get_storage_options().await? {
298 Some(options) => return Ok(options),
299 None => {
300 tokio::time::sleep(Duration::from_millis(10)).await;
302 continue;
303 }
304 }
305 }
306 }
307
308 async fn do_get_storage_options(&self) -> Result<Option<super::StorageOptions>> {
309 {
311 let cached = self.cache.read().await;
312 if !self.needs_refresh(&cached) {
313 if let Some(cached_opts) = &*cached {
314 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
315 }
316 }
317 }
318
319 let Some(provider) = &self.provider else {
321 return if let Some(initial) = &self.initial_options {
322 Ok(Some(super::StorageOptions(initial.clone())))
323 } else {
324 Err(Error::IO {
325 source: Box::new(std::io::Error::other("No storage options available")),
326 location: location!(),
327 })
328 };
329 };
330
331 let Ok(mut cache) = self.cache.try_write() else {
333 return Ok(None);
334 };
335
336 if !self.needs_refresh(&cache) {
339 if let Some(cached_opts) = &*cache {
340 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
341 }
342 }
343
344 log::debug!(
345 "Refreshing storage options from provider: {}",
346 provider.provider_id()
347 );
348
349 let storage_options_map =
350 provider
351 .fetch_storage_options()
352 .await
353 .map_err(|e| Error::IO {
354 source: Box::new(std::io::Error::other(format!(
355 "Failed to fetch storage options: {}",
356 e
357 ))),
358 location: location!(),
359 })?;
360
361 let Some(options) = storage_options_map else {
362 if let Some(initial) = &self.initial_options {
364 return Ok(Some(super::StorageOptions(initial.clone())));
365 }
366 return Err(Error::IO {
367 source: Box::new(std::io::Error::other(
368 "Provider returned no storage options",
369 )),
370 location: location!(),
371 });
372 };
373
374 let expires_at_millis = options
375 .get(EXPIRES_AT_MILLIS_KEY)
376 .and_then(|s| s.parse::<u64>().ok());
377
378 if let Some(expires_at) = expires_at_millis {
379 let now_ms = SystemTime::now()
380 .duration_since(UNIX_EPOCH)
381 .unwrap_or(Duration::from_secs(0))
382 .as_millis() as u64;
383 let expires_in_secs = (expires_at.saturating_sub(now_ms)) / 1000;
384 log::debug!(
385 "Successfully refreshed storage options from provider: {}, options expire in {} seconds",
386 provider.provider_id(),
387 expires_in_secs
388 );
389 } else {
390 log::debug!(
391 "Successfully refreshed storage options from provider: {} (no expiration)",
392 provider.provider_id()
393 );
394 }
395
396 *cache = Some(CachedStorageOptions {
397 options: options.clone(),
398 expires_at_millis,
399 });
400
401 Ok(Some(super::StorageOptions(options)))
402 }
403
404 fn needs_refresh(&self, cached: &Option<CachedStorageOptions>) -> bool {
405 match cached {
406 None => true,
407 Some(cached_opts) => {
408 if let Some(expires_at_millis) = cached_opts.expires_at_millis {
409 let now_ms = SystemTime::now()
410 .duration_since(UNIX_EPOCH)
411 .unwrap_or(Duration::from_secs(0))
412 .as_millis() as u64;
413
414 let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
416 now_ms + refresh_offset_millis >= expires_at_millis
417 } else {
418 false
420 }
421 }
422 }
423 }
424
425 pub fn initial_storage_options(&self) -> Option<&HashMap<String, String>> {
430 self.initial_options.as_ref()
431 }
432
433 pub fn accessor_id(&self) -> String {
438 if let Some(provider) = &self.provider {
439 provider.provider_id()
440 } else if let Some(initial) = &self.initial_options {
441 use std::collections::hash_map::DefaultHasher;
443 use std::hash::{Hash, Hasher};
444
445 let mut hasher = DefaultHasher::new();
446 let mut keys: Vec<_> = initial.keys().collect();
447 keys.sort();
448 for key in keys {
449 key.hash(&mut hasher);
450 initial.get(key).hash(&mut hasher);
451 }
452 format!("static_options_{:x}", hasher.finish())
453 } else {
454 "empty_accessor".to_string()
455 }
456 }
457
458 pub fn has_provider(&self) -> bool {
460 self.provider.is_some()
461 }
462
463 pub fn refresh_offset(&self) -> Duration {
465 self.refresh_offset
466 }
467
468 pub fn provider(&self) -> Option<&Arc<dyn StorageOptionsProvider>> {
470 self.provider.as_ref()
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use mock_instant::thread_local::MockClock;
478
479 #[derive(Debug)]
480 struct MockStorageOptionsProvider {
481 call_count: Arc<RwLock<usize>>,
482 expires_in_millis: Option<u64>,
483 }
484
485 impl MockStorageOptionsProvider {
486 fn new(expires_in_millis: Option<u64>) -> Self {
487 Self {
488 call_count: Arc::new(RwLock::new(0)),
489 expires_in_millis,
490 }
491 }
492
493 async fn get_call_count(&self) -> usize {
494 *self.call_count.read().await
495 }
496 }
497
498 #[async_trait]
499 impl StorageOptionsProvider for MockStorageOptionsProvider {
500 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
501 let count = {
502 let mut c = self.call_count.write().await;
503 *c += 1;
504 *c
505 };
506
507 let mut options = HashMap::from([
508 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
509 (
510 "aws_secret_access_key".to_string(),
511 format!("SECRET_{}", count),
512 ),
513 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
514 ]);
515
516 if let Some(expires_in) = self.expires_in_millis {
517 let now_ms = SystemTime::now()
518 .duration_since(UNIX_EPOCH)
519 .unwrap()
520 .as_millis() as u64;
521 let expires_at = now_ms + expires_in;
522 options.insert(EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string());
523 }
524
525 Ok(Some(options))
526 }
527
528 fn provider_id(&self) -> String {
529 let ptr = Arc::as_ptr(&self.call_count) as usize;
530 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
531 }
532 }
533
534 #[tokio::test]
535 async fn test_static_options_only() {
536 let options = HashMap::from([
537 ("key1".to_string(), "value1".to_string()),
538 ("key2".to_string(), "value2".to_string()),
539 ]);
540 let accessor = StorageOptionsAccessor::with_static_options(options.clone());
541
542 let result = accessor.get_storage_options().await.unwrap();
543 assert_eq!(result.0, options);
544 assert!(!accessor.has_provider());
545 assert_eq!(accessor.initial_storage_options(), Some(&options));
546 }
547
548 #[tokio::test]
549 async fn test_provider_only() {
550 MockClock::set_system_time(Duration::from_secs(100_000));
551
552 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
553 let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
554
555 let result = accessor.get_storage_options().await.unwrap();
556 assert!(result.0.contains_key("aws_access_key_id"));
557 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
558 assert!(accessor.has_provider());
559 assert_eq!(accessor.initial_storage_options(), None);
560 assert_eq!(mock_provider.get_call_count().await, 1);
561 }
562
563 #[tokio::test]
564 async fn test_initial_and_provider_uses_initial_first() {
565 MockClock::set_system_time(Duration::from_secs(100_000));
566
567 let now_ms = MockClock::system_time().as_millis() as u64;
568 let expires_at = now_ms + 600_000; let initial = HashMap::from([
571 ("aws_access_key_id".to_string(), "INITIAL_KEY".to_string()),
572 (
573 "aws_secret_access_key".to_string(),
574 "INITIAL_SECRET".to_string(),
575 ),
576 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
577 ]);
578 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
579
580 let accessor = StorageOptionsAccessor::with_initial_and_provider(
581 initial.clone(),
582 mock_provider.clone(),
583 );
584
585 let result = accessor.get_storage_options().await.unwrap();
587 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "INITIAL_KEY");
588 assert_eq!(mock_provider.get_call_count().await, 0); }
590
591 #[tokio::test]
592 async fn test_caching_and_refresh() {
593 MockClock::set_system_time(Duration::from_secs(100_000));
594
595 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000))); let now_ms = MockClock::system_time().as_millis() as u64;
598 let expires_at = now_ms + 600_000; let initial = HashMap::from([
600 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
601 (REFRESH_OFFSET_MILLIS_KEY.to_string(), "300000".to_string()), ]);
603 let accessor =
604 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
605
606 let result = accessor.get_storage_options().await.unwrap();
608 assert!(result.0.contains_key(EXPIRES_AT_MILLIS_KEY));
609 assert_eq!(mock_provider.get_call_count().await, 0);
610
611 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
613 let result = accessor.get_storage_options().await.unwrap();
614 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
615 assert_eq!(mock_provider.get_call_count().await, 1);
616 }
617
618 #[tokio::test]
619 async fn test_expired_initial_triggers_refresh() {
620 MockClock::set_system_time(Duration::from_secs(100_000));
621
622 let now_ms = MockClock::system_time().as_millis() as u64;
623 let expired_time = now_ms - 1_000; let initial = HashMap::from([
626 ("aws_access_key_id".to_string(), "EXPIRED_KEY".to_string()),
627 (EXPIRES_AT_MILLIS_KEY.to_string(), expired_time.to_string()),
628 ]);
629 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
630
631 let accessor =
632 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
633
634 let result = accessor.get_storage_options().await.unwrap();
636 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
637 assert_eq!(mock_provider.get_call_count().await, 1);
638 }
639
640 #[tokio::test]
641 async fn test_accessor_id_with_provider() {
642 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None));
643 let accessor = StorageOptionsAccessor::with_provider(mock_provider);
644
645 let id = accessor.accessor_id();
646 assert!(id.starts_with("MockStorageOptionsProvider"));
647 }
648
649 #[tokio::test]
650 async fn test_accessor_id_static() {
651 let options = HashMap::from([("key".to_string(), "value".to_string())]);
652 let accessor = StorageOptionsAccessor::with_static_options(options);
653
654 let id = accessor.accessor_id();
655 assert!(id.starts_with("static_options_"));
656 }
657
658 #[tokio::test]
659 async fn test_concurrent_access() {
660 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
662
663 let accessor = Arc::new(StorageOptionsAccessor::with_provider(mock_provider.clone()));
664
665 let mut handles = vec![];
667 for i in 0..10 {
668 let acc = accessor.clone();
669 let handle = tokio::spawn(async move {
670 let result = acc.get_storage_options().await.unwrap();
671 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
672 i
673 });
674 handles.push(handle);
675 }
676
677 let results: Vec<_> = futures::future::join_all(handles)
679 .await
680 .into_iter()
681 .map(|r| r.unwrap())
682 .collect();
683
684 assert_eq!(results.len(), 10);
686
687 let call_count = mock_provider.get_call_count().await;
689 assert_eq!(
690 call_count, 1,
691 "Provider should be called exactly once despite concurrent access"
692 );
693 }
694
695 #[tokio::test]
696 async fn test_no_expiration_never_refreshes() {
697 MockClock::set_system_time(Duration::from_secs(100_000));
698
699 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None)); let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
701
702 accessor.get_storage_options().await.unwrap();
704 assert_eq!(mock_provider.get_call_count().await, 1);
705
706 MockClock::set_system_time(Duration::from_secs(200_000));
708
709 accessor.get_storage_options().await.unwrap();
711 assert_eq!(mock_provider.get_call_count().await, 1);
712 }
713}