use csv::{Reader, Writer};
use serde::{Deserialize, Serialize};
use std::{path::Path, sync::Arc};
use thiserror::Error;
use tokio::{task, task::JoinError};
struct Config<PA> {
path: PA,
extension: String,
}
#[derive(Debug, Error)]
pub enum DbError {
#[error("CSV error")]
Csv(#[from] csv::Error),
#[error("I/O error")]
Io(#[from] std::io::Error),
#[error("Task join error")]
Join(#[from] tokio::task::JoinError),
#[error("No document matched the predicate")]
NoMatch,
}
pub struct Database<PA> {
config: Arc<Config<PA>>,
}
impl<PA> Database<PA>
where
PA: AsRef<Path> + Send + Sync + Clone + 'static,
{
pub fn new(path: PA, extension: Option<&str>) -> Self {
Self {
config: Arc::new(Config {
path,
extension: String::from(extension.unwrap_or("csv")),
}),
}
}
pub async fn find<T, P>(&self, collection: &str, predicate: P) -> Result<Vec<T>, DbError>
where
T: Serialize + for<'de> Deserialize<'de> + Send + 'static,
P: FnMut(&T) -> bool,
{
let collection = collection.to_string();
let config = self.config.clone();
let results = task::spawn_blocking(move || {
let mut rdr = match Reader::from_path(
config
.path
.as_ref()
.join(format!("{}.{}", collection, config.extension)),
) {
Ok(rdr) => rdr,
Err(_) => return Ok(Vec::new()),
};
rdr.deserialize().collect::<Result<Vec<T>, csv::Error>>()
})
.await??;
Ok(results.into_iter().filter(predicate).collect())
}
pub async fn insert<T>(&self, collection: &str, document: T) -> Result<(), DbError>
where
T: Serialize + for<'de> Deserialize<'de> + Send + 'static,
{
let mut documents: Vec<T> = self.find(collection, |_| true).await?;
documents.push(document);
Ok(self.write(collection, documents).await??)
}
pub async fn delete<T, P>(&self, collection: &str, mut predicate: P) -> Result<(), DbError>
where
T: Serialize + for<'de> Deserialize<'de> + PartialEq + Send + 'static,
P: FnMut(&&T) -> bool,
{
let mut documents: Vec<T> = self.find(collection, |_| true).await?;
documents.retain(|d| !predicate(&d));
Ok(self.write(collection, documents).await??)
}
pub async fn update<T, P>(
&self,
collection: &str,
document: T,
mut predicate: P,
) -> Result<(), DbError>
where
T: Serialize + for<'de> Deserialize<'de> + PartialEq + Send + 'static,
P: FnMut(&&T) -> bool,
{
let mut documents: Vec<T> = self.find(collection, |_| true).await?;
let original_len = documents.len();
documents.retain(|d| !predicate(&d));
if documents.len() == original_len {
return Err(DbError::NoMatch);
}
documents.push(document);
Ok(self.write(collection, documents).await??)
}
async fn write<T>(
&self,
collection: &str,
documents: Vec<T>,
) -> Result<Result<(), csv::Error>, JoinError>
where
T: Serialize + Send + 'static,
{
let collection = collection.to_string();
let config = self.config.clone();
let result = task::spawn_blocking(move || {
let path = config
.path
.as_ref()
.join(format!("{}.{}", collection, config.extension));
if let Some(parent_path) = path.parent() {
std::fs::create_dir_all(parent_path)?
}
let mut wrt = match Writer::from_path(&path) {
Ok(wrt) => wrt,
Err(error) => match error.kind() {
csv::ErrorKind::Io(_) => match std::fs::File::create(&path) {
Ok(_) => csv::Writer::from_path(&path)?,
Err(_) => return Err(error),
},
_ => return Err(error),
},
};
for document in documents {
wrt.serialize(document)?;
}
Ok(())
})
.await;
result
}
}