use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
#[cfg(test)]
use mock_instant::thread_local::{SystemTime, UNIX_EPOCH};
#[cfg(not(test))]
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use lance_namespace::LanceNamespace;
use lance_namespace::models::DescribeTableRequest;
use tokio::sync::RwLock;
use crate::{Error, Result};
pub const EXPIRES_AT_MILLIS_KEY: &str = "expires_at_millis";
pub const REFRESH_OFFSET_MILLIS_KEY: &str = "refresh_offset_millis";
const DEFAULT_REFRESH_OFFSET_MILLIS: u64 = 60_000;
#[async_trait]
pub trait StorageOptionsProvider: Send + Sync + fmt::Debug {
async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>>;
fn provider_id(&self) -> String;
}
pub struct LanceNamespaceStorageOptionsProvider {
namespace: Arc<dyn LanceNamespace>,
table_id: Vec<String>,
}
impl fmt::Debug for LanceNamespaceStorageOptionsProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.provider_id())
}
}
impl fmt::Display for LanceNamespaceStorageOptionsProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.provider_id())
}
}
impl LanceNamespaceStorageOptionsProvider {
pub fn new(namespace: Arc<dyn LanceNamespace>, table_id: Vec<String>) -> Self {
Self {
namespace,
table_id,
}
}
}
#[async_trait]
impl StorageOptionsProvider for LanceNamespaceStorageOptionsProvider {
async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
let request = DescribeTableRequest {
id: Some(self.table_id.clone()),
..Default::default()
};
let response = self.namespace.describe_table(request).await.map_err(|e| {
Error::io_source(Box::new(std::io::Error::other(format!(
"Failed to fetch storage options: {}",
e
))))
})?;
Ok(response.storage_options)
}
fn provider_id(&self) -> String {
format!(
"LanceNamespaceStorageOptionsProvider {{ namespace: {}, table_id: {:?} }}",
self.namespace.namespace_id(),
self.table_id
)
}
}
pub struct StorageOptionsAccessor {
initial_options: Option<HashMap<String, String>>,
provider: Option<Arc<dyn StorageOptionsProvider>>,
cache: Arc<RwLock<Option<CachedStorageOptions>>>,
refresh_offset: Duration,
}
impl fmt::Debug for StorageOptionsAccessor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StorageOptionsAccessor")
.field("has_initial_options", &self.initial_options.is_some())
.field("has_provider", &self.provider.is_some())
.field("refresh_offset", &self.refresh_offset)
.finish()
}
}
#[derive(Debug, Clone)]
struct CachedStorageOptions {
options: HashMap<String, String>,
expires_at_millis: Option<u64>,
}
impl StorageOptionsAccessor {
fn extract_refresh_offset(options: &HashMap<String, String>) -> Duration {
options
.get(REFRESH_OFFSET_MILLIS_KEY)
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS))
}
pub fn with_static_options(options: HashMap<String, String>) -> Self {
let expires_at_millis = options
.get(EXPIRES_AT_MILLIS_KEY)
.and_then(|s| s.parse::<u64>().ok());
let refresh_offset = Self::extract_refresh_offset(&options);
Self {
initial_options: Some(options.clone()),
provider: None,
cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
options,
expires_at_millis,
}))),
refresh_offset,
}
}
pub fn with_provider(provider: Arc<dyn StorageOptionsProvider>) -> Self {
Self {
initial_options: None,
provider: Some(provider),
cache: Arc::new(RwLock::new(None)),
refresh_offset: Duration::from_millis(DEFAULT_REFRESH_OFFSET_MILLIS),
}
}
pub fn with_initial_and_provider(
initial_options: HashMap<String, String>,
provider: Arc<dyn StorageOptionsProvider>,
) -> Self {
let expires_at_millis = initial_options
.get(EXPIRES_AT_MILLIS_KEY)
.and_then(|s| s.parse::<u64>().ok());
let refresh_offset = Self::extract_refresh_offset(&initial_options);
Self {
initial_options: Some(initial_options.clone()),
provider: Some(provider),
cache: Arc::new(RwLock::new(Some(CachedStorageOptions {
options: initial_options,
expires_at_millis,
}))),
refresh_offset,
}
}
pub async fn get_storage_options(&self) -> Result<super::StorageOptions> {
loop {
match self.do_get_storage_options().await? {
Some(options) => return Ok(options),
None => {
tokio::time::sleep(Duration::from_millis(10)).await;
continue;
}
}
}
}
async fn do_get_storage_options(&self) -> Result<Option<super::StorageOptions>> {
{
let cached = self.cache.read().await;
if !self.needs_refresh(&cached)
&& let Some(cached_opts) = &*cached
{
return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
}
}
let Some(provider) = &self.provider else {
return if let Some(initial) = &self.initial_options {
Ok(Some(super::StorageOptions(initial.clone())))
} else {
Err(Error::io_source(Box::new(std::io::Error::other(
"No storage options available",
))))
};
};
let Ok(mut cache) = self.cache.try_write() else {
return Ok(None);
};
if !self.needs_refresh(&cache)
&& let Some(cached_opts) = &*cache
{
return Ok(Some(super::StorageOptions(cached_opts.options.clone())));
}
log::debug!(
"Refreshing storage options from provider: {}",
provider.provider_id()
);
let storage_options_map = provider.fetch_storage_options().await.map_err(|e| {
Error::io_source(Box::new(std::io::Error::other(format!(
"Failed to fetch storage options: {}",
e
))))
})?;
let Some(options) = storage_options_map else {
if let Some(initial) = &self.initial_options {
return Ok(Some(super::StorageOptions(initial.clone())));
}
return Err(Error::io_source(Box::new(std::io::Error::other(
"Provider returned no storage options",
))));
};
let expires_at_millis = options
.get(EXPIRES_AT_MILLIS_KEY)
.and_then(|s| s.parse::<u64>().ok());
if let Some(expires_at) = expires_at_millis {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_millis() as u64;
let expires_in_secs = (expires_at.saturating_sub(now_ms)) / 1000;
log::debug!(
"Successfully refreshed storage options from provider: {}, options expire in {} seconds",
provider.provider_id(),
expires_in_secs
);
} else {
log::debug!(
"Successfully refreshed storage options from provider: {} (no expiration)",
provider.provider_id()
);
}
*cache = Some(CachedStorageOptions {
options: options.clone(),
expires_at_millis,
});
Ok(Some(super::StorageOptions(options)))
}
fn needs_refresh(&self, cached: &Option<CachedStorageOptions>) -> bool {
match cached {
None => true,
Some(cached_opts) => {
if let Some(expires_at_millis) = cached_opts.expires_at_millis {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_millis() as u64;
let refresh_offset_millis = self.refresh_offset.as_millis() as u64;
now_ms + refresh_offset_millis >= expires_at_millis
} else {
false
}
}
}
}
pub fn initial_storage_options(&self) -> Option<&HashMap<String, String>> {
self.initial_options.as_ref()
}
pub fn accessor_id(&self) -> String {
if let Some(provider) = &self.provider {
provider.provider_id()
} else if let Some(initial) = &self.initial_options {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let mut keys: Vec<_> = initial.keys().collect();
keys.sort();
for key in keys {
key.hash(&mut hasher);
initial.get(key).hash(&mut hasher);
}
format!("static_options_{:x}", hasher.finish())
} else {
"empty_accessor".to_string()
}
}
pub fn has_provider(&self) -> bool {
self.provider.is_some()
}
pub fn refresh_offset(&self) -> Duration {
self.refresh_offset
}
pub fn provider(&self) -> Option<&Arc<dyn StorageOptionsProvider>> {
self.provider.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use mock_instant::thread_local::MockClock;
#[derive(Debug)]
struct MockStorageOptionsProvider {
call_count: Arc<RwLock<usize>>,
expires_in_millis: Option<u64>,
}
impl MockStorageOptionsProvider {
fn new(expires_in_millis: Option<u64>) -> Self {
Self {
call_count: Arc::new(RwLock::new(0)),
expires_in_millis,
}
}
async fn get_call_count(&self) -> usize {
*self.call_count.read().await
}
}
#[async_trait]
impl StorageOptionsProvider for MockStorageOptionsProvider {
async fn fetch_storage_options(&self) -> Result<Option<HashMap<String, String>>> {
let count = {
let mut c = self.call_count.write().await;
*c += 1;
*c
};
let mut options = HashMap::from([
("aws_access_key_id".to_string(), format!("AKID_{}", count)),
(
"aws_secret_access_key".to_string(),
format!("SECRET_{}", count),
),
("aws_session_token".to_string(), format!("TOKEN_{}", count)),
]);
if let Some(expires_in) = self.expires_in_millis {
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let expires_at = now_ms + expires_in;
options.insert(EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string());
}
Ok(Some(options))
}
fn provider_id(&self) -> String {
let ptr = Arc::as_ptr(&self.call_count) as usize;
format!("MockStorageOptionsProvider {{ id: {} }}", ptr)
}
}
#[tokio::test]
async fn test_static_options_only() {
let options = HashMap::from([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string()),
]);
let accessor = StorageOptionsAccessor::with_static_options(options.clone());
let result = accessor.get_storage_options().await.unwrap();
assert_eq!(result.0, options);
assert!(!accessor.has_provider());
assert_eq!(accessor.initial_storage_options(), Some(&options));
}
#[tokio::test]
async fn test_provider_only() {
MockClock::set_system_time(Duration::from_secs(100_000));
let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
let result = accessor.get_storage_options().await.unwrap();
assert!(result.0.contains_key("aws_access_key_id"));
assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
assert!(accessor.has_provider());
assert_eq!(accessor.initial_storage_options(), None);
assert_eq!(mock_provider.get_call_count().await, 1);
}
#[tokio::test]
async fn test_initial_and_provider_uses_initial_first() {
MockClock::set_system_time(Duration::from_secs(100_000));
let now_ms = MockClock::system_time().as_millis() as u64;
let expires_at = now_ms + 600_000;
let initial = HashMap::from([
("aws_access_key_id".to_string(), "INITIAL_KEY".to_string()),
(
"aws_secret_access_key".to_string(),
"INITIAL_SECRET".to_string(),
),
(EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
]);
let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
let accessor = StorageOptionsAccessor::with_initial_and_provider(
initial.clone(),
mock_provider.clone(),
);
let result = accessor.get_storage_options().await.unwrap();
assert_eq!(result.0.get("aws_access_key_id").unwrap(), "INITIAL_KEY");
assert_eq!(mock_provider.get_call_count().await, 0); }
#[tokio::test]
async fn test_caching_and_refresh() {
MockClock::set_system_time(Duration::from_secs(100_000));
let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000))); let now_ms = MockClock::system_time().as_millis() as u64;
let expires_at = now_ms + 600_000; let initial = HashMap::from([
(EXPIRES_AT_MILLIS_KEY.to_string(), expires_at.to_string()),
(REFRESH_OFFSET_MILLIS_KEY.to_string(), "300000".to_string()), ]);
let accessor =
StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
let result = accessor.get_storage_options().await.unwrap();
assert!(result.0.contains_key(EXPIRES_AT_MILLIS_KEY));
assert_eq!(mock_provider.get_call_count().await, 0);
MockClock::set_system_time(Duration::from_secs(100_000 + 360));
let result = accessor.get_storage_options().await.unwrap();
assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
assert_eq!(mock_provider.get_call_count().await, 1);
}
#[tokio::test]
async fn test_expired_initial_triggers_refresh() {
MockClock::set_system_time(Duration::from_secs(100_000));
let now_ms = MockClock::system_time().as_millis() as u64;
let expired_time = now_ms - 1_000;
let initial = HashMap::from([
("aws_access_key_id".to_string(), "EXPIRED_KEY".to_string()),
(EXPIRES_AT_MILLIS_KEY.to_string(), expired_time.to_string()),
]);
let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(600_000)));
let accessor =
StorageOptionsAccessor::with_initial_and_provider(initial, mock_provider.clone());
let result = accessor.get_storage_options().await.unwrap();
assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
assert_eq!(mock_provider.get_call_count().await, 1);
}
#[tokio::test]
async fn test_accessor_id_with_provider() {
let mock_provider = Arc::new(MockStorageOptionsProvider::new(None));
let accessor = StorageOptionsAccessor::with_provider(mock_provider);
let id = accessor.accessor_id();
assert!(id.starts_with("MockStorageOptionsProvider"));
}
#[tokio::test]
async fn test_accessor_id_static() {
let options = HashMap::from([("key".to_string(), "value".to_string())]);
let accessor = StorageOptionsAccessor::with_static_options(options);
let id = accessor.accessor_id();
assert!(id.starts_with("static_options_"));
}
#[tokio::test]
async fn test_concurrent_access() {
let mock_provider = Arc::new(MockStorageOptionsProvider::new(Some(9999999999999)));
let accessor = Arc::new(StorageOptionsAccessor::with_provider(mock_provider.clone()));
let mut handles = vec![];
for i in 0..10 {
let acc = accessor.clone();
let handle = tokio::spawn(async move {
let result = acc.get_storage_options().await.unwrap();
assert_eq!(result.0.get("aws_access_key_id").unwrap(), "AKID_1");
i
});
handles.push(handle);
}
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
assert_eq!(results.len(), 10);
let call_count = mock_provider.get_call_count().await;
assert_eq!(
call_count, 1,
"Provider should be called exactly once despite concurrent access"
);
}
#[tokio::test]
async fn test_no_expiration_never_refreshes() {
MockClock::set_system_time(Duration::from_secs(100_000));
let mock_provider = Arc::new(MockStorageOptionsProvider::new(None)); let accessor = StorageOptionsAccessor::with_provider(mock_provider.clone());
accessor.get_storage_options().await.unwrap();
assert_eq!(mock_provider.get_call_count().await, 1);
MockClock::set_system_time(Duration::from_secs(200_000));
accessor.get_storage_options().await.unwrap();
assert_eq!(mock_provider.get_call_count().await, 1);
}
}