use std::any::Any;
use std::sync::{Arc, Weak};
use crate::object_storage::{get_object_store, AwsOptions, GcpOptions};
use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider};
use datafusion::common::plan_datafusion_err;
use datafusion::datasource::listing::ListingTableUrl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
use datafusion::execution::session_state::SessionStateBuilder;
use async_trait::async_trait;
use dirs::home_dir;
use parking_lot::RwLock;
#[derive(Debug)]
pub struct DynamicObjectStoreCatalog {
inner: Arc<dyn CatalogProviderList>,
state: Weak<RwLock<SessionState>>,
}
impl DynamicObjectStoreCatalog {
pub fn new(
inner: Arc<dyn CatalogProviderList>,
state: Weak<RwLock<SessionState>>,
) -> Self {
Self { inner, state }
}
}
impl CatalogProviderList for DynamicObjectStoreCatalog {
fn as_any(&self) -> &dyn Any {
self
}
fn register_catalog(
&self,
name: String,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
self.inner.register_catalog(name, catalog)
}
fn catalog_names(&self) -> Vec<String> {
self.inner.catalog_names()
}
fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
let state = self.state.clone();
self.inner.catalog(name).map(|catalog| {
Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _
})
}
}
#[derive(Debug)]
struct DynamicObjectStoreCatalogProvider {
inner: Arc<dyn CatalogProvider>,
state: Weak<RwLock<SessionState>>,
}
impl DynamicObjectStoreCatalogProvider {
pub fn new(
inner: Arc<dyn CatalogProvider>,
state: Weak<RwLock<SessionState>>,
) -> Self {
Self { inner, state }
}
}
impl CatalogProvider for DynamicObjectStoreCatalogProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema_names(&self) -> Vec<String> {
self.inner.schema_names()
}
fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
let state = self.state.clone();
self.inner.schema(name).map(|schema| {
Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _
})
}
fn register_schema(
&self,
name: &str,
schema: Arc<dyn SchemaProvider>,
) -> Result<Option<Arc<dyn SchemaProvider>>> {
self.inner.register_schema(name, schema)
}
}
#[derive(Debug)]
struct DynamicObjectStoreSchemaProvider {
inner: Arc<dyn SchemaProvider>,
state: Weak<RwLock<SessionState>>,
}
impl DynamicObjectStoreSchemaProvider {
pub fn new(
inner: Arc<dyn SchemaProvider>,
state: Weak<RwLock<SessionState>>,
) -> Self {
Self { inner, state }
}
}
#[async_trait]
impl SchemaProvider for DynamicObjectStoreSchemaProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn table_names(&self) -> Vec<String> {
self.inner.table_names()
}
fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.register_table(name, table)
}
async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
let inner_table = self.inner.table(name).await;
if inner_table.is_ok() {
if let Some(inner_table) = inner_table? {
return Ok(Some(inner_table));
}
}
let mut state = self
.state
.upgrade()
.ok_or_else(|| plan_datafusion_err!("locking error"))?
.read()
.clone();
let mut builder = SessionStateBuilder::from(state.clone());
let optimized_name = substitute_tilde(name.to_owned());
let table_url = ListingTableUrl::parse(optimized_name.as_str())?;
let scheme = table_url.scheme();
let url = table_url.as_ref();
match state.runtime_env().object_store_registry.get_store(url) {
Ok(_) => { }
Err(_) => {
match scheme {
"s3" | "oss" | "cos" => {
if let Some(table_options) = builder.table_options() {
table_options.extensions.insert(AwsOptions::default())
}
}
"gs" | "gcs" => {
if let Some(table_options) = builder.table_options() {
table_options.extensions.insert(GcpOptions::default())
}
}
_ => {}
};
state = builder.build();
let store = get_object_store(
&state,
table_url.scheme(),
url,
&state.default_table_options(),
false,
)
.await?;
state.runtime_env().register_object_store(url, store);
}
}
self.inner.table(name).await
}
fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.deregister_table(name)
}
fn table_exist(&self, name: &str) -> bool {
self.inner.table_exist(name)
}
}
pub fn substitute_tilde(cur: String) -> String {
if let Some(usr_dir_path) = home_dir() {
if let Some(usr_dir) = usr_dir_path.to_str() {
if cur.starts_with('~') && !usr_dir.is_empty() {
return cur.replacen('~', usr_dir, 1);
}
}
}
cur
}
#[cfg(test)]
mod tests {
use std::{env, vec};
use super::*;
use datafusion::catalog::SchemaProvider;
use datafusion::prelude::SessionContext;
fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
let ctx = SessionContext::new();
ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new(
ctx.state().catalog_list().clone(),
ctx.state_weak_ref(),
)));
let provider = &DynamicObjectStoreCatalog::new(
ctx.state().catalog_list().clone(),
ctx.state_weak_ref(),
) as &dyn CatalogProviderList;
let catalog = provider
.catalog(provider.catalog_names().first().unwrap())
.unwrap();
let schema = catalog
.schema(catalog.schema_names().first().unwrap())
.unwrap();
(ctx, schema)
}
#[tokio::test]
async fn query_http_location_test() -> Result<()> {
let domain = "example.com";
let location = format!("http://{domain}/file.parquet");
let (ctx, schema) = setup_context();
let table = schema.table(&location).await?;
assert!(table.is_none());
let store = ctx
.runtime_env()
.object_store(ListingTableUrl::parse(location)?)?;
assert_eq!(format!("{store}"), "HttpStore");
let expected_domain = format!("Domain(\"{domain}\")");
assert!(format!("{store:?}").contains(&expected_domain));
Ok(())
}
#[tokio::test]
async fn query_s3_location_test() -> Result<()> {
let aws_envs = vec![
"AWS_ENDPOINT",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_ALLOW_HTTP",
];
for aws_env in aws_envs {
if env::var(aws_env).is_err() {
eprint!("aws envs not set, skipping s3 test");
return Ok(());
}
}
let bucket = "examples3bucket";
let location = format!("s3://{bucket}/file.parquet");
let (ctx, schema) = setup_context();
let table = schema.table(&location).await?;
assert!(table.is_none());
let store = ctx
.runtime_env()
.object_store(ListingTableUrl::parse(location)?)?;
assert_eq!(format!("{store}"), format!("AmazonS3({bucket})"));
let expected_bucket = format!("bucket: \"{bucket}\"");
assert!(format!("{store:?}").contains(&expected_bucket));
Ok(())
}
#[tokio::test]
async fn query_gs_location_test() -> Result<()> {
let bucket = "examplegsbucket";
let location = format!("gs://{bucket}/file.parquet");
let (ctx, schema) = setup_context();
let table = schema.table(&location).await?;
assert!(table.is_none());
let store = ctx
.runtime_env()
.object_store(ListingTableUrl::parse(location)?)?;
assert_eq!(format!("{store}"), format!("GoogleCloudStorage({bucket})"));
let expected_bucket = format!("bucket_name_encoded: \"{bucket}\"");
assert!(format!("{store:?}").contains(&expected_bucket));
Ok(())
}
#[tokio::test]
async fn query_invalid_location_test() {
let location = "ts://file.parquet";
let (_ctx, schema) = setup_context();
assert!(schema.table(location).await.is_err());
}
#[cfg(not(target_os = "windows"))]
#[test]
fn test_substitute_tilde() {
use std::{env, path::PathBuf};
let original_home = home_dir();
let test_home_path = if cfg!(windows) {
"C:\\Users\\user"
} else {
"/home/user"
};
env::set_var(
if cfg!(windows) { "USERPROFILE" } else { "HOME" },
test_home_path,
);
let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet";
let expected = PathBuf::from(test_home_path)
.join("Code")
.join("datafusion")
.join("benchmarks")
.join("data")
.join("tpch_sf1")
.join("part")
.join("part-0.parquet")
.to_string_lossy()
.to_string();
let actual = substitute_tilde(input.to_string());
assert_eq!(actual, expected);
match original_home {
Some(home_path) => env::set_var(
if cfg!(windows) { "USERPROFILE" } else { "HOME" },
home_path.to_str().unwrap(),
),
None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }),
}
}
}