use std::any::Any;
use std::collections::HashSet;
use std::path::Path;
use std::sync::{Arc, Mutex};
use crate::{SchemaProvider, TableProvider, TableProviderFactory};
use crate::Session;
use datafusion_common::{DFSchema, DataFusionError, HashMap, TableReference};
use datafusion_expr::CreateExternalTable;
use async_trait::async_trait;
use futures::TryStreamExt;
use itertools::Itertools;
use object_store::ObjectStore;
#[derive(Debug)]
pub struct ListingSchemaProvider {
authority: String,
path: object_store::path::Path,
factory: Arc<dyn TableProviderFactory>,
store: Arc<dyn ObjectStore>,
tables: Arc<Mutex<HashMap<String, Arc<dyn TableProvider>>>>,
format: String,
}
impl ListingSchemaProvider {
pub fn new(
authority: String,
path: object_store::path::Path,
factory: Arc<dyn TableProviderFactory>,
store: Arc<dyn ObjectStore>,
format: String,
) -> Self {
Self {
authority,
path,
factory,
store,
tables: Arc::new(Mutex::new(HashMap::new())),
format,
}
}
pub async fn refresh(&self, state: &dyn Session) -> datafusion_common::Result<()> {
let entries: Vec<_> = self.store.list(Some(&self.path)).try_collect().await?;
let base = Path::new(self.path.as_ref());
let mut tables = HashSet::new();
for file in entries.iter() {
let mut is_dir = false;
let mut parent = Path::new(file.location.as_ref());
while let Some(p) = parent.parent() {
if p == base {
tables.insert(TablePath {
is_dir,
path: parent,
});
}
parent = p;
is_dir = true;
}
}
for table in tables.iter() {
let file_name = table
.path
.file_name()
.ok_or_else(|| {
DataFusionError::Internal("Cannot parse file name!".to_string())
})?
.to_str()
.ok_or_else(|| {
DataFusionError::Internal("Cannot parse file name!".to_string())
})?;
let table_name = file_name.split('.').collect_vec()[0];
let table_path = table.to_string().ok_or_else(|| {
DataFusionError::Internal("Cannot parse file name!".to_string())
})?;
if !self.table_exist(table_name) {
let table_url = format!("{}/{}", self.authority, table_path);
let name = TableReference::bare(table_name);
let provider = self
.factory
.create(
state,
&CreateExternalTable {
schema: Arc::new(DFSchema::empty()),
name,
location: table_url,
file_type: self.format.clone(),
table_partition_cols: vec![],
if_not_exists: false,
temporary: false,
definition: None,
order_exprs: vec![],
unbounded: false,
options: Default::default(),
constraints: Default::default(),
column_defaults: Default::default(),
},
)
.await?;
let _ =
self.register_table(table_name.to_string(), Arc::clone(&provider))?;
}
}
Ok(())
}
}
#[async_trait]
impl SchemaProvider for ListingSchemaProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn table_names(&self) -> Vec<String> {
self.tables
.lock()
.expect("Can't lock tables")
.keys()
.map(|it| it.to_string())
.collect()
}
async fn table(
&self,
name: &str,
) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
Ok(self
.tables
.lock()
.expect("Can't lock tables")
.get(name)
.cloned())
}
fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> datafusion_common::Result<Option<Arc<dyn TableProvider>>> {
self.tables
.lock()
.expect("Can't lock tables")
.insert(name, Arc::clone(&table));
Ok(Some(table))
}
fn deregister_table(
&self,
name: &str,
) -> datafusion_common::Result<Option<Arc<dyn TableProvider>>> {
Ok(self.tables.lock().expect("Can't lock tables").remove(name))
}
fn table_exist(&self, name: &str) -> bool {
self.tables
.lock()
.expect("Can't lock tables")
.contains_key(name)
}
}
#[derive(Eq, PartialEq, Hash, Debug)]
struct TablePath<'a> {
path: &'a Path,
is_dir: bool,
}
impl TablePath<'_> {
fn to_string(&self) -> Option<String> {
self.path.to_str().map(|path| {
if self.is_dir {
format!("{path}/")
} else {
path.to_string()
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn table_path_ends_with_slash_when_is_dir() {
let table_path = TablePath {
path: Path::new("/file"),
is_dir: true,
};
assert!(table_path.to_string().expect("table path").ends_with('/'));
}
#[test]
fn dir_table_path_str_does_not_end_with_slash_when_not_is_dir() {
let table_path = TablePath {
path: Path::new("/file"),
is_dir: false,
};
assert!(!table_path.to_string().expect("table_path").ends_with('/'));
}
}