1use std::collections::HashMap;
13use std::fmt;
14use std::sync::Arc;
15use std::time::Duration;
16
17#[cfg(test)]
18use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
19
20#[cfg(not(test))]
21use std::time::{SystemTime, UNIX_EPOCH};
22
23use async_trait::async_trait;
24use lance_namespace::LanceNamespace;
25use lance_namespace::models::DescribeTableRequest;
26use tokio::sync::RwLock;
27
28use crate::{Error, Result};
29
30pub const EXPIRES_AT_MILLIS_KEY: &str = "expires_at_millis";
32
33pub const REFRESH_OFFSET_MILLIS_KEY: &str = "refresh_offset_millis";
35
36const DEFAULT_REFRESH_OFFSET_MILLIS: u64 = 60_000;
38
39#[async_trait]
72pub trait StorageOptionsProvider: Send + Sync + fmt::Debug {
73 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>>;
81
82 fn provider_id(&self) -> String;
93}
94
95pub struct LanceNamespaceStorageOptionsProvider {
97 namespace: Arc<dyn LanceNamespace>,
98 table_id: Vec<String>,
99}
100
101impl fmt::Debug for LanceNamespaceStorageOptionsProvider {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 write!(f, "{}", self.provider_id())
104 }
105}
106
107impl fmt::Display for LanceNamespaceStorageOptionsProvider {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 write!(f, "{}", self.provider_id())
110 }
111}
112
113impl LanceNamespaceStorageOptionsProvider {
114 pub fn new(namespace: Arc<dyn LanceNamespace>, table_id: Vec<String>) -> Self {
120 Self {
121 namespace,
122 table_id,
123 }
124 }
125}
126
127#[async_trait]
128impl StorageOptionsProvider for LanceNamespaceStorageOptionsProvider {
129 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
130 let request = DescribeTableRequest {
131 id: Some(self.table_id.clone()),
132 ..Default::default()
133 };
134
135 let response = self.namespace.describe_table(request).await.map_err(|e| {
136 Error::io_source(Box::new(std::io::Error::other(format!(
137 "Failed to fetch storage options: {}",
138 e
139 ))))
140 })?;
141
142 Ok(response.storage_options)
143 }
144
145 fn provider_id(&self) -> String {
146 format!(
147 "LanceNamespaceStorageOptionsProvider {{ namespace: {}, table_id: {:?} }}",
148 self.namespace.namespace_id(),
149 self.table_id
150 )
151 }
152}
153
154pub struct StorageOptionsAccessor {
172 initial_options: Option<HashMap<String, String>>,
174
175 provider: Option<Arc<dyn StorageOptionsProvider>>,
177
178 cache: Arc<RwLock<Option<CachedStorageOptions>>>,
180
181 refresh_offset: Duration,
183}
184
185impl fmt::Debug for StorageOptionsAccessor {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 f.debug_struct("StorageOptionsAccessor")
188 .field("has_initial_options", &self.initial_options.is_some())
189 .field("has_provider", &self.provider.is_some())
190 .field("refresh_offset", &self.refresh_offset)
191 .finish()
192 }
193}
194
195#[derive(Debug, Clone)]
196struct CachedStorageOptions {
197 options: HashMap<String, String>,
198 expires_at_millis: Option<u64>,
199}
200
201impl StorageOptionsAccessor {
202 fn extract_refresh_offset(options: &HashMap<String, String>) -> Duration {
204 options
205 .get(REFRESH_OFFSET_MILLIS_KEY)
206 .and_then(|s| s.parse::<u64>().ok())
207 .map(Duration::from_millis)
208 .unwrap_or(Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS))
209 }
210
211 pub fn with_static_options(options: HashMap<String, String>) -> Self {
216 let expires_at_millis = options
217 .get(EXPIRES_AT_MILLIS_KEY)
218 .and_then(|s| s.parse::<u64>().ok());
219 let refresh_offset = Self::extract_refresh_offset(&options);
220
221 Self {
222 initial_options: Some(options.clone()),
223 provider: None,
224 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
225 options,
226 expires_at_millis,
227 }))),
228 refresh_offset,
229 }
230 }
231
232 pub fn with_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
241 Self {
242 initial_options: None,
243 provider: Some(provider),
244 cache: Arc::new(RwLock::new(None)),
245 refresh_offset: Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS),
246 }
247 }
248
249 pub fn with_initial_and_provider(
259 initial_options: HashMap<String, String>,
260 provider: Arc<dyn StorageOptionsProvider>,
261 ) -> Self {
262 let expires_at_millis = initial_options
263 .get(EXPIRES_AT_MILLIS_KEY)
264 .and_then(|s| s.parse::<u64>().ok());
265 let refresh_offset = Self::extract_refresh_offset(&initial_options);
266
267 Self {
268 initial_options: Some(initial_options.clone()),
269 provider: Some(provider),
270 cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
271 options: initial_options,
272 expires_at_millis,
273 }))),
274 refresh_offset,
275 }
276 }
277
278 pub async fn get_storage_options(&self) -> Result<super::StorageOptions> {
290 loop {
291 match self.do_get_storage_options().await? {
292 Some(options) => return Ok(options),
293 None => {
294 tokio::time::sleep(Duration::from_millis(10)).await;
296 continue;
297 }
298 }
299 }
300 }
301
302 async fn do_get_storage_options(&self) -> Result<Option<super::StorageOptions>> {
303 {
305 let cached = self.cache.read().await;
306 if !self.needs_refresh(&cached)
307 && let Some(cached_opts) = &*cached
308 {
309 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
310 }
311 }
312
313 let Some(provider) = &self.provider else {
315 return if let Some(initial) = &self.initial_options {
316 Ok(Some(super::StorageOptions(initial.clone())))
317 } else {
318 Err(Error::io_source(Box::new(std::io::Error::other(
319 "No storage options available",
320 ))))
321 };
322 };
323
324 let Ok(mut cache) = self.cache.try_write() else {
326 return Ok(None);
327 };
328
329 if !self.needs_refresh(&cache)
332 && let Some(cached_opts) = &*cache
333 {
334 return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
335 }
336
337 log::debug!(
338 "Refreshing storage options from provider: {}",
339 provider.provider_id()
340 );
341
342 let storage_options_map = provider.fetch_storage_options().await.map_err(|e| {
343 Error::io_source(Box::new(std::io::Error::other(format!(
344 "Failed to fetch storage options: {}",
345 e
346 ))))
347 })?;
348
349 let Some(options) = storage_options_map else {
350 if let Some(initial) = &self.initial_options {
352 return Ok(Some(super::StorageOptions(initial.clone())));
353 }
354 return Err(Error::io_source(Box::new(std::io::Error::other(
355 "Provider returned no storage options",
356 ))));
357 };
358
359 let expires_at_millis = options
360 .get(EXPIRES_AT_MILLIS_KEY)
361 .and_then(|s| s.parse::<u64>().ok());
362
363 if let Some(expires_at) = expires_at_millis {
364 let now_ms = SystemTime::now()
365 .duration_since(UNIX_EPOCH)
366 .unwrap_or(Duration::from_secs(0))
367 .as_millis() as u64;
368 let expires_in_secs = (expires_at.saturating_sub(now_ms)) / 1000;
369 log::debug!(
370 "Successfully refreshed storage options from provider: {}, options expire in {} seconds",
371 provider.provider_id(),
372 expires_in_secs
373 );
374 } else {
375 log::debug!(
376 "Successfully refreshed storage options from provider: {} (no expiration)",
377 provider.provider_id()
378 );
379 }
380
381 *cache = Some(CachedStorageOptions {
382 options: options.clone(),
383 expires_at_millis,
384 });
385
386 Ok(Some(super::StorageOptions(options)))
387 }
388
389 fn needs_refresh(&self, cached: &Option<CachedStorageOptions>) -> bool {
390 match cached {
391 None => true,
392 Some(cached_opts) => {
393 if let Some(expires_at_millis) = cached_opts.expires_at_millis {
394 let now_ms = SystemTime::now()
395 .duration_since(UNIX_EPOCH)
396 .unwrap_or(Duration::from_secs(0))
397 .as_millis() as u64;
398
399 let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
401 now_ms + refresh_offset_millis >= expires_at_millis
402 } else {
403 false
405 }
406 }
407 }
408 }
409
410 pub fn initial_storage_options(&self) -> Option<&HashMap<String, String>> {
415 self.initial_options.as_ref()
416 }
417
418 pub fn accessor_id(&self) -> String {
423 if let Some(provider) = &self.provider {
424 provider.provider_id()
425 } else if let Some(initial) = &self.initial_options {
426 use std::collections::hash_map::DefaultHasher;
428 use std::hash::{Hash, Hasher};
429
430 let mut hasher = DefaultHasher::new();
431 let mut keys: Vec<_> = initial.keys().collect();
432 keys.sort();
433 for key in keys {
434 key.hash(&mut hasher);
435 initial.get(key).hash(&mut hasher);
436 }
437 format!("static_options_{:x}", hasher.finish())
438 } else {
439 "empty_accessor".to_string()
440 }
441 }
442
443 pub fn has_provider(&self) -> bool {
445 self.provider.is_some()
446 }
447
448 pub fn refresh_offset(&self) -> Duration {
450 self.refresh_offset
451 }
452
453 pub fn provider(&self) -> Option<&Arc<dyn StorageOptionsProvider>> {
455 self.provider.as_ref()
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use mock_instant::thread_local::MockClock;
463
464 #[derive(Debug)]
465 struct MockStorageOptionsProvider {
466 call_count: Arc<RwLock<usize>>,
467 expires_in_millis: Option<u64>,
468 }
469
470 impl MockStorageOptionsProvider {
471 fn new(expires_in_millis: Option<u64>) -> Self {
472 Self {
473 call_count: Arc::new(RwLock::new(0)),
474 expires_in_millis,
475 }
476 }
477
478 async fn get_call_count(&self) -> usize {
479 *self.call_count.read().await
480 }
481 }
482
483 #[async_trait]
484 impl StorageOptionsProvider for MockStorageOptionsProvider {
485 async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
486 let count = {
487 let mut c = self.call_count.write().await;
488 *c += 1;
489 *c
490 };
491
492 let mut options = HashMap::from([
493 ("aws_access_key_id".to_string(), format!("AKID_{}", count)),
494 (
495 "aws_secret_access_key".to_string(),
496 format!("SECRET_{}", count),
497 ),
498 ("aws_session_token".to_string(), format!("TOKEN_{}", count)),
499 ]);
500
501 if let Some(expires_in) = self.expires_in_millis {
502 let now_ms = SystemTime::now()
503 .duration_since(UNIX_EPOCH)
504 .unwrap()
505 .as_millis() as u64;
506 let expires_at = now_ms + expires_in;
507 options.insert(EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string());
508 }
509
510 Ok(Some(options))
511 }
512
513 fn provider_id(&self) -> String {
514 let ptr = Arc::as_ptr(&self.call_count) as usize;
515 format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
516 }
517 }
518
519 #[tokio::test]
520 async fn test_static_options_only() {
521 let options = HashMap::from([
522 ("key1".to_string(), "value1".to_string()),
523 ("key2".to_string(), "value2".to_string()),
524 ]);
525 let accessor = StorageOptionsAccessor::with_static_options(options.clone());
526
527 let result = accessor.get_storage_options().await.unwrap();
528 assert_eq!(result.0, options);
529 assert!(!accessor.has_provider());
530 assert_eq!(accessor.initial_storage_options(), Some(&options));
531 }
532
533 #[tokio::test]
534 async fn test_provider_only() {
535 MockClock::set_system_time(Duration::from_secs(100_000));
536
537 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
538 let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
539
540 let result = accessor.get_storage_options().await.unwrap();
541 assert!(result.0.contains_key("aws_access_key_id"));
542 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
543 assert!(accessor.has_provider());
544 assert_eq!(accessor.initial_storage_options(), None);
545 assert_eq!(mock_provider.get_call_count().await, 1);
546 }
547
548 #[tokio::test]
549 async fn test_initial_and_provider_uses_initial_first() {
550 MockClock::set_system_time(Duration::from_secs(100_000));
551
552 let now_ms = MockClock::system_time().as_millis() as u64;
553 let expires_at = now_ms + 600_000; let initial = HashMap::from([
556 ("aws_access_key_id".to_string(), "INITIAL_KEY".to_string()),
557 (
558 "aws_secret_access_key".to_string(),
559 "INITIAL_SECRET".to_string(),
560 ),
561 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
562 ]);
563 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
564
565 let accessor = StorageOptionsAccessor::with_initial_and_provider(
566 initial.clone(),
567 mock_provider.clone(),
568 );
569
570 let result = accessor.get_storage_options().await.unwrap();
572 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "INITIAL_KEY");
573 assert_eq!(mock_provider.get_call_count().await, 0); }
575
576 #[tokio::test]
577 async fn test_caching_and_refresh() {
578 MockClock::set_system_time(Duration::from_secs(100_000));
579
580 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000))); let now_ms = MockClock::system_time().as_millis() as u64;
583 let expires_at = now_ms + 600_000; let initial = HashMap::from([
585 (EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
586 (REFRESH_OFFSET_MILLIS_KEY.to_string(), "300000".to_string()), ]);
588 let accessor =
589 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
590
591 let result = accessor.get_storage_options().await.unwrap();
593 assert!(result.0.contains_key(EXPIRES_AT_MILLIS_KEY));
594 assert_eq!(mock_provider.get_call_count().await, 0);
595
596 MockClock::set_system_time(Duration::from_secs(100_000 + 360));
598 let result = accessor.get_storage_options().await.unwrap();
599 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
600 assert_eq!(mock_provider.get_call_count().await, 1);
601 }
602
603 #[tokio::test]
604 async fn test_expired_initial_triggers_refresh() {
605 MockClock::set_system_time(Duration::from_secs(100_000));
606
607 let now_ms = MockClock::system_time().as_millis() as u64;
608 let expired_time = now_ms - 1_000; let initial = HashMap::from([
611 ("aws_access_key_id".to_string(), "EXPIRED_KEY".to_string()),
612 (EXPIRES_AT_MILLIS_KEY.to_string(), expired_time.to_string()),
613 ]);
614 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
615
616 let accessor =
617 StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
618
619 let result = accessor.get_storage_options().await.unwrap();
621 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
622 assert_eq!(mock_provider.get_call_count().await, 1);
623 }
624
625 #[tokio::test]
626 async fn test_accessor_id_with_provider() {
627 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None));
628 let accessor = StorageOptionsAccessor::with_provider(mock_provider);
629
630 let id = accessor.accessor_id();
631 assert!(id.starts_with("MockStorageOptionsProvider"));
632 }
633
634 #[tokio::test]
635 async fn test_accessor_id_static() {
636 let options = HashMap::from([("key".to_string(), "value".to_string())]);
637 let accessor = StorageOptionsAccessor::with_static_options(options);
638
639 let id = accessor.accessor_id();
640 assert!(id.starts_with("static_options_"));
641 }
642
643 #[tokio::test]
644 async fn test_concurrent_access() {
645 let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
647
648 let accessor = Arc::new(StorageOptionsAccessor::with_provider(mock_provider.clone()));
649
650 let mut handles = vec![];
652 for i in 0..10 {
653 let acc = accessor.clone();
654 let handle = tokio::spawn(async move {
655 let result = acc.get_storage_options().await.unwrap();
656 assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
657 i
658 });
659 handles.push(handle);
660 }
661
662 let results: Vec<_> = futures::future::join_all(handles)
664 .await
665 .into_iter()
666 .map(|r| r.unwrap())
667 .collect();
668
669 assert_eq!(results.len(), 10);
671
672 let call_count = mock_provider.get_call_count().await;
674 assert_eq!(
675 call_count, 1,
676 "Provider should be called exactly once despite concurrent access"
677 );
678 }
679
680 #[tokio::test]
681 async fn test_no_expiration_never_refreshes() {
682 MockClock::set_system_time(Duration::from_secs(100_000));
683
684 let mock_provider = Arc::new(MockStorageOptionsProvider::new(None)); let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
686
687 accessor.get_storage_options().await.unwrap();
689 assert_eq!(mock_provider.get_call_count().await, 1);
690
691 MockClock::set_system_time(Duration::from_secs(200_000));
693
694 accessor.get_storage_options().await.unwrap();
696 assert_eq!(mock_provider.get_call_count().await, 1);
697 }
698}