use std::collections::HashMap;
use std::path::Path as StdPath;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use ::object_store::{
aws::AmazonS3Builder, azure::MicrosoftAzureBuilder, gcp::GoogleCloudStorageBuilder,
local::LocalFileSystem, memory::InMemory, path::Path, ClientOptions, CredentialProvider,
ObjectStore as OSObjectStore, Result as ObjectStoreResult,
};
use async_trait::async_trait;
use aws_config::default_provider::credentials::DefaultCredentialsChain;
use aws_credential_types::provider::ProvideCredentials;
use futures::{StreamExt, TryStreamExt};
use object_store::aws::AwsCredential as ObjectStoreAwsCredential;
use reqwest::header::{HeaderMap, CACHE_CONTROL};
use shellexpand::tilde;
use tokio::sync::RwLock;
use url::Url;
use crate::error::{Error, Result};
use crate::io::object_reader::CloudObjectReader;
use crate::io::object_writer::ObjectWriter;
use super::commit::{CommitHandler, RenameCommitHandler, UnsafeCommitHandler};
use super::local::LocalObjectReader;
use super::object_reader::ObjectReader;
#[derive(Debug, Clone)]
pub struct ObjectStore {
pub inner: Arc<dyn OSObjectStore>,
scheme: String,
base_path: Path,
block_size: usize,
pub commit_handler: Arc<dyn CommitHandler>,
}
impl std::fmt::Display for ObjectStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ObjectStore({})", self.scheme)
}
}
const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";
#[derive(Debug)]
struct AwsCredentialAdapter {
pub inner: Arc<dyn ProvideCredentials>,
cache: Arc<RwLock<HashMap<String, Arc<aws_credential_types::Credentials>>>>,
credentials_refresh_offset: Duration,
}
impl AwsCredentialAdapter {
fn new(provider: Arc<dyn ProvideCredentials>, credentials_refresh_offset: Duration) -> Self {
Self {
inner: provider,
cache: Arc::new(RwLock::new(HashMap::new())),
credentials_refresh_offset,
}
}
}
#[async_trait]
impl CredentialProvider for AwsCredentialAdapter {
type Credential = ObjectStoreAwsCredential;
async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
let cached_creds = {
let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
let expired = cache_value
.clone()
.map(|cred| {
cred.expiry()
.map(|exp| {
exp.checked_sub(self.credentials_refresh_offset)
.expect("this time should always be valid")
< SystemTime::now()
})
.unwrap_or(false)
})
.unwrap_or(true); if expired {
None
} else {
cache_value.clone()
}
};
if let Some(creds) = cached_creds {
Ok(Arc::new(Self::Credential {
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(|s| s.to_string()),
}))
} else {
let refreshed_creds = Arc::new(self.inner.provide_credentials().await.unwrap());
self.cache
.write()
.await
.insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
Ok(Arc::new(Self::Credential {
key_id: refreshed_creds.access_key_id().to_string(),
secret_key: refreshed_creds.secret_access_key().to_string(),
token: refreshed_creds.session_token().map(|s| s.to_string()),
}))
}
}
}
async fn build_s3_object_store(
uri: &str,
credentials_refresh_offset: Duration,
credentials: Option<Arc<dyn CredentialProvider<Credential = ObjectStoreAwsCredential>>>,
) -> Result<Arc<dyn OSObjectStore>> {
use aws_config::meta::region::RegionProviderChain;
const DEFAULT_REGION: &str = "us-west-2";
let region_provider = RegionProviderChain::default_provider().or_else(DEFAULT_REGION);
let creds = match credentials {
Some(creds) => creds,
None => {
let credentials_provider = DefaultCredentialsChain::builder()
.region(region_provider.region().await)
.build()
.await;
Arc::new(AwsCredentialAdapter::new(
Arc::new(credentials_provider),
credentials_refresh_offset,
))
}
};
Ok(Arc::new(
AmazonS3Builder::from_env() .with_url(uri)
.with_credentials(creds)
.with_region(
region_provider
.region()
.await
.map(|r| r.as_ref().to_string())
.unwrap_or(DEFAULT_REGION.to_string()),
)
.build()?,
))
}
async fn build_gcs_object_store(uri: &str) -> Result<Arc<dyn OSObjectStore>> {
let mut headers = HeaderMap::new();
headers.insert(CACHE_CONTROL, "no-cache".parse().unwrap());
Ok(Arc::new(
GoogleCloudStorageBuilder::from_env()
.with_url(uri)
.with_client_options(ClientOptions::new().with_default_headers(headers))
.build()?,
))
}
async fn build_azure_object_store(uri: &str) -> Result<Arc<dyn OSObjectStore>> {
Ok(Arc::new(
MicrosoftAzureBuilder::from_env().with_url(uri).build()?,
))
}
pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync {
fn wrap(&self, original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore>;
}
#[derive(Debug, Clone)]
pub struct ObjectStoreParams {
pub object_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
pub s3_credentials_refresh_offset: Duration,
pub aws_credentials: Option<Arc<dyn CredentialProvider<Credential = ObjectStoreAwsCredential>>>,
pub commit_handler: Option<Arc<dyn CommitHandler>>,
}
impl Default for ObjectStoreParams {
fn default() -> Self {
Self {
object_store_wrapper: None,
s3_credentials_refresh_offset: Duration::from_secs(60),
aws_credentials: None,
commit_handler: None,
}
}
}
impl ObjectStore {
pub async fn from_uri(uri: &str) -> Result<(Self, Path)> {
Self::from_uri_and_params(uri, ObjectStoreParams::default()).await
}
pub async fn from_uri_and_params(uri: &str, params: ObjectStoreParams) -> Result<(Self, Path)> {
let (object_store, base_path) = match Url::parse(uri) {
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
Self::new_from_path(uri, ¶ms)
}
Ok(url) => {
let store = Self::new_from_url(url.clone(), ¶ms).await?;
let path = Path::from(url.path());
Ok((store, path))
}
Err(_) => Self::new_from_path(uri, ¶ms),
}?;
Ok((
Self {
inner: params
.object_store_wrapper
.map(|w| w.wrap(object_store.inner.clone()))
.unwrap_or(object_store.inner),
..object_store
},
base_path,
))
}
fn new_from_path(str_path: &str, params: &ObjectStoreParams) -> Result<(Self, Path)> {
let expanded = tilde(str_path).to_string();
let expanded_path = StdPath::new(&expanded);
if !expanded_path.try_exists()? {
std::fs::create_dir_all(expanded_path)?;
} else if !expanded_path.is_dir() {
return Err(Error::IO {
message: format!("{} is not a lance directory", str_path),
});
}
let expanded_path = expanded_path.canonicalize()?;
Ok((
Self {
inner: Arc::new(LocalFileSystem::new()),
scheme: String::from("file"),
base_path: Path::from_absolute_path(&expanded_path)?,
block_size: 4 * 1024, commit_handler: params
.commit_handler
.clone()
.unwrap_or_else(|| Arc::new(RenameCommitHandler)),
},
Path::from_filesystem_path(&expanded_path)?,
))
}
async fn new_from_url(url: Url, params: &ObjectStoreParams) -> Result<Self> {
match url.scheme() {
"s3" => Ok(Self {
inner: build_s3_object_store(
url.to_string().as_str(),
params.s3_credentials_refresh_offset,
params.aws_credentials.clone(),
)
.await?,
scheme: String::from("s3"),
base_path: Path::from(url.path()),
block_size: 64 * 1024,
commit_handler: params
.commit_handler
.clone()
.unwrap_or_else(|| Arc::new(UnsafeCommitHandler)),
}),
"gs" => Ok(Self {
inner: build_gcs_object_store(url.to_string().as_str()).await?,
scheme: String::from("gs"),
base_path: Path::from(url.path()),
block_size: 64 * 1024,
commit_handler: params
.commit_handler
.clone()
.unwrap_or_else(|| Arc::new(RenameCommitHandler)),
}),
"az" => Ok(Self {
inner: build_azure_object_store(url.to_string().as_str()).await?,
scheme: String::from("az"),
base_path: Path::from(url.path()),
block_size: 64 * 1024,
commit_handler: params
.commit_handler
.clone()
.unwrap_or_else(|| Arc::new(RenameCommitHandler)),
}),
"file" => Ok(Self::new_from_path(url.path(), params)?.0),
"memory" => Ok(Self {
inner: Arc::new(InMemory::new()),
scheme: String::from("memory"),
base_path: Path::from(url.path()),
block_size: 64 * 1024,
commit_handler: params
.commit_handler
.clone()
.unwrap_or_else(|| Arc::new(RenameCommitHandler)),
}),
s => Err(Error::IO {
message: format!("Unsupported URI scheme: {}", s),
}),
}
}
pub fn memory() -> Self {
Self {
inner: Arc::new(InMemory::new()),
scheme: String::from("memory"),
base_path: Path::from("/"),
block_size: 64 * 1024,
commit_handler: Arc::new(RenameCommitHandler),
}
}
pub fn is_local(&self) -> bool {
self.scheme == "file"
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn set_block_size(&mut self, new_size: usize) {
self.block_size = new_size;
}
pub fn base_path(&self) -> &Path {
&self.base_path
}
pub async fn open(&self, path: &Path) -> Result<Box<dyn ObjectReader>> {
match self.scheme.as_str() {
"file" => LocalObjectReader::open(path, self.block_size),
_ => Ok(Box::new(CloudObjectReader::new(
self,
path.clone(),
self.block_size,
)?)),
}
}
pub async fn create(&self, path: &Path) -> Result<ObjectWriter> {
ObjectWriter::new(self, path).await
}
pub async fn read_dir(&self, dir_path: impl Into<Path>) -> Result<Vec<String>> {
let path = dir_path.into();
let path = Path::parse(&path)?;
let output = self.inner.list_with_delimiter(Some(&path)).await?;
Ok(output
.common_prefixes
.iter()
.chain(output.objects.iter().map(|o| &o.location))
.map(|s| s.filename().unwrap().to_string())
.collect())
}
pub async fn remove_dir_all(&self, dir_path: impl Into<Path>) -> Result<()> {
let path = dir_path.into();
let path = Path::parse(&path)?;
if self.is_local() {
return super::local::remove_dir_all(&path);
}
let sub_entries = self
.inner
.list(Some(&path))
.await?
.map(|m| m.map(|meta| meta.location))
.boxed();
self.inner
.delete_stream(sub_entries)
.try_collect::<Vec<_>>()
.await?;
Ok(())
}
pub async fn exists(&self, path: &Path) -> Result<bool> {
match self.inner.head(path).await {
Ok(_) => Ok(true),
Err(object_store::Error::NotFound { path: _, source: _ }) => Ok(false),
Err(e) => Err(e.into()),
}
}
pub async fn size(&self, path: &Path) -> Result<usize> {
Ok(self.inner.head(path).await?.size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env::set_current_dir;
use std::fs::{create_dir_all, write};
use std::sync::atomic::{AtomicBool, Ordering};
fn write_to_file(path_str: &str, contents: &str) -> std::io::Result<()> {
let expanded = tilde(path_str).to_string();
let path = StdPath::new(&expanded);
std::fs::create_dir_all(path.parent().unwrap())?;
write(path, contents)
}
async fn read_from_store(store: ObjectStore, path: &Path) -> Result<String> {
let test_file_store = store.open(path).await.unwrap();
let size = test_file_store.size().await.unwrap();
let bytes = test_file_store.get_range(0..size).await.unwrap();
let contents = String::from_utf8(bytes.to_vec()).unwrap();
Ok(contents)
}
#[tokio::test]
async fn test_absolute_paths() {
let tmp_dir = tempfile::tempdir().unwrap();
let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
write_to_file(
&format!("{tmp_path}/bar/foo.lance/test_file"),
"TEST_CONTENT",
)
.unwrap();
for uri in &[
format!("{tmp_path}/bar/foo.lance"),
format!("{tmp_path}/./bar/foo.lance"),
format!("{tmp_path}/bar/foo.lance/../foo.lance"),
] {
let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
let contents = read_from_store(store, &path.child("test_file"))
.await
.unwrap();
assert_eq!(contents, "TEST_CONTENT");
}
}
#[tokio::test]
async fn test_relative_paths() {
let tmp_dir = tempfile::tempdir().unwrap();
let tmp_path = tmp_dir.path().to_str().unwrap().to_owned();
write_to_file(
&format!("{tmp_path}/bar/foo.lance/test_file"),
"RELATIVE_URL",
)
.unwrap();
set_current_dir(StdPath::new(&tmp_path)).expect("Error changing current dir");
let (store, path) = ObjectStore::from_uri("./bar/foo.lance").await.unwrap();
let contents = read_from_store(store, &path.child("test_file"))
.await
.unwrap();
assert_eq!(contents, "RELATIVE_URL");
}
#[tokio::test]
async fn test_tilde_expansion() {
let uri = "~/foo.lance";
write_to_file(&format!("{uri}/test_file"), "TILDE").unwrap();
let (store, path) = ObjectStore::from_uri(uri).await.unwrap();
let contents = read_from_store(store, &path.child("test_file"))
.await
.unwrap();
assert_eq!(contents, "TILDE");
}
#[tokio::test]
async fn test_read_directory() {
let tmp_dir = tempfile::tempdir().unwrap();
let path = tmp_dir.path();
create_dir_all(path.join("foo").join("bar")).unwrap();
create_dir_all(path.join("foo").join("zoo")).unwrap();
create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
write_to_file(
path.join("foo").join("test_file").to_str().unwrap(),
"read_dir",
)
.unwrap();
let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
let sub_dirs = store.read_dir(base.child("foo")).await.unwrap();
assert_eq!(sub_dirs, vec!["bar", "zoo", "test_file"]);
}
#[tokio::test]
async fn test_delete_directory() {
let tmp_dir = tempfile::tempdir().unwrap();
let path = tmp_dir.path();
create_dir_all(path.join("foo").join("bar")).unwrap();
create_dir_all(path.join("foo").join("zoo")).unwrap();
create_dir_all(path.join("foo").join("zoo").join("abc")).unwrap();
write_to_file(
path.join("foo")
.join("bar")
.join("test_file")
.to_str()
.unwrap(),
"delete",
)
.unwrap();
write_to_file(path.join("foo").join("top").to_str().unwrap(), "delete_top").unwrap();
let (store, base) = ObjectStore::from_uri(path.to_str().unwrap()).await.unwrap();
store.remove_dir_all(base.child("foo")).await.unwrap();
assert!(!path.join("foo").exists());
}
#[derive(Debug)]
struct TestWrapper {
called: AtomicBool,
return_value: Arc<dyn OSObjectStore>,
}
impl WrappingObjectStore for TestWrapper {
fn wrap(&self, _original: Arc<dyn OSObjectStore>) -> Arc<dyn OSObjectStore> {
self.called.store(true, Ordering::Relaxed);
self.return_value.clone()
}
}
impl TestWrapper {
fn called(&self) -> bool {
self.called.load(Ordering::Relaxed)
}
}
#[tokio::test]
async fn test_wrapping_object_store_option_is_used() {
let mock_inner_store: Arc<dyn OSObjectStore> = Arc::new(InMemory::new());
assert_eq!(Arc::strong_count(&mock_inner_store), 1);
let wrapper = Arc::new(TestWrapper {
called: AtomicBool::new(false),
return_value: mock_inner_store.clone(),
});
let params = ObjectStoreParams {
object_store_wrapper: Some(wrapper.clone()),
..ObjectStoreParams::default()
};
assert!(!wrapper.called());
let _ = ObjectStore::from_uri_and_params("memory:///", params)
.await
.unwrap();
assert!(wrapper.called());
assert_eq!(Arc::strong_count(&mock_inner_store), 2);
}
#[derive(Debug, Default)]
struct MockAwsCredentialsProvider {
called: AtomicBool,
}
#[async_trait]
impl CredentialProvider for MockAwsCredentialsProvider {
type Credential = ObjectStoreAwsCredential;
async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
self.called.store(true, Ordering::Relaxed);
Ok(Arc::new(Self::Credential {
key_id: "".to_string(),
secret_key: "".to_string(),
token: None,
}))
}
}
#[tokio::test]
async fn test_injected_aws_creds_option_is_used() {
let mock_provider = Arc::new(MockAwsCredentialsProvider::default());
let params = ObjectStoreParams {
aws_credentials: Some(mock_provider.clone()
as Arc<dyn CredentialProvider<Credential = ObjectStoreAwsCredential>>),
..ObjectStoreParams::default()
};
assert!(!mock_provider.called.load(Ordering::Relaxed));
let (store, _) = ObjectStore::from_uri_and_params("s3://not-a-bucket", params)
.await
.unwrap();
let _ = store
.open(&Path::parse("/").unwrap())
.await
.unwrap()
.get_range(0..1)
.await;
assert!(mock_provider.called.load(Ordering::Relaxed));
}
#[tokio::test]
#[cfg(windows)]
async fn test_windows_paths() {
use std::path::Component;
use std::path::Prefix;
use std::path::Prefix::*;
fn get_path_prefix(path: &StdPath) -> Prefix {
match path.components().next().unwrap() {
Component::Prefix(prefix_component) => prefix_component.kind(),
_ => panic!(),
}
}
fn get_drive_letter(prefix: Prefix) -> String {
match prefix {
Disk(bytes) => String::from_utf8(vec![bytes]).unwrap(),
_ => panic!(),
}
}
let tmp_dir = tempfile::tempdir().unwrap();
let tmp_path = tmp_dir.path();
let prefix = get_path_prefix(tmp_path);
let drive_letter = get_drive_letter(prefix);
write_to_file(
&(format!("{drive_letter}:/test_folder/test.lance") + "/test_file"),
"WINDOWS",
)
.unwrap();
for uri in &[
format!("{drive_letter}:/test_folder/test.lance"),
format!("{drive_letter}:\\test_folder\\test.lance"),
] {
let (store, base) = ObjectStore::from_uri(uri).await.unwrap();
let contents = read_from_store(store, &base.child("test_file"))
.await
.unwrap();
assert_eq!(contents, "WINDOWS");
}
}
}