use std::fs::create_dir_all;
use std::path::Path;
use arrow_array::RecordBatchReader;
use lance::io::object_store::ObjectStore;
use snafu::prelude::*;
use crate::error::{CreateDirSnafu, Result};
use crate::table::Table;
pub struct Database {
object_store: ObjectStore,
pub(crate) uri: String,
}
const LANCE_EXTENSION: &str = "lance";
impl Database {
pub async fn connect(uri: &str) -> Result<Database> {
let (object_store, _) = ObjectStore::from_uri(uri).await?;
if object_store.is_local() {
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
}
Ok(Database {
uri: uri.to_string(),
object_store,
})
}
fn try_create_dir(path: &str) -> core::result::Result<(), std::io::Error> {
let path = Path::new(path);
if !path.try_exists()? {
create_dir_all(&path)?;
}
Ok(())
}
pub async fn table_names(&self) -> Result<Vec<String>> {
let f = self
.object_store
.read_dir(self.uri.as_str())
.await?
.iter()
.map(|fname| Path::new(fname))
.filter(|path| {
let is_lance = path
.extension()
.map(|e| e.to_str().map(|e| e == LANCE_EXTENSION))
.flatten();
is_lance.unwrap_or(false)
})
.map(|p| {
p.file_stem()
.map(|s| s.to_str().map(|s| String::from(s)))
.flatten()
})
.flatten()
.collect();
Ok(f)
}
pub async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader>,
) -> Result<Table> {
Table::create(&self.uri, name, batches).await
}
pub async fn open_table(&self, name: &str) -> Result<Table> {
Table::open(&self.uri, name).await
}
pub async fn drop_table(&self, name: &str) -> Result<()> {
let dir_name = format!("{}/{}.{}", self.uri, name, LANCE_EXTENSION);
self.object_store.remove_dir_all(dir_name).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::fs::create_dir_all;
use tempfile::tempdir;
use crate::database::Database;
#[tokio::test]
async fn test_connect() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
assert_eq!(db.uri, uri);
}
#[tokio::test]
async fn test_table_names() {
let tmp_dir = tempdir().unwrap();
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
create_dir_all(tmp_dir.path().join("table2.lance")).unwrap();
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 2);
assert!(tables.contains(&String::from("table1")));
assert!(tables.contains(&String::from("table2")));
}
#[tokio::test]
async fn test_connect_s3() {
}
#[tokio::test]
async fn drop_table() {
let tmp_dir = tempdir().unwrap();
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
db.drop_table("table1").await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 0);
}
}