use std::{fs, io, path::PathBuf};
use rustolio_utils::{
bytes::encoding::{decode_from_std, encode_to_std},
crypto::signature::PublicKey,
};
use tokio::sync::{mpsc, oneshot};
use crate::{Error, Key, KeyType, Result, Value};
pub enum Action {
Get(
Key,
Option<PublicKey>,
oneshot::Sender<Result<Option<Value>>>,
),
Set(Key, Value),
Shutdown(oneshot::Sender<()>),
}
#[derive(Debug)]
pub struct Store {
path: PathBuf,
}
impl Store {
pub fn init(dir: PathBuf, mut disk_channel: mpsc::Receiver<Action>) {
tokio::spawn(async move {
let mut shutdown_tx = None;
let mut disk_store = Self::new(dir);
loop {
match disk_channel.recv().await {
Some(action) => match action {
Action::Shutdown(tx) => {
shutdown_tx = Some(tx);
disk_channel.close();
}
Action::Get(k, s, tx) => {
let value = disk_store.get(k, s).await;
tx.send(value).unwrap();
}
Action::Set(k, v) => {
disk_store.set(k, v).await.unwrap();
}
},
None => {
shutdown_tx.unwrap().send(()).unwrap();
break;
}
}
}
});
}
fn new(path: PathBuf) -> Self {
fs::create_dir_all(&path).unwrap();
Self { path }
}
async fn get(&self, key: Key, signer: Option<PublicKey>) -> Result<Option<Value>> {
let file = self.file(key);
tokio::task::spawn_blocking(move || read(&file, key.ty(), signer))
.await
.unwrap()
}
async fn set(&mut self, key: Key, value: Value) -> Result<()> {
let file = self.file(key);
tokio::task::spawn_blocking(move || write(&file, &value));
Ok(())
}
fn file(&self, key: Key) -> PathBuf {
self.path.join(key.hash().to_string())
}
}
fn read(dir: &PathBuf, ty: KeyType, signer: Option<PublicKey>) -> Result<Option<Value>> {
let mut file = match fs::File::open(dir) {
Ok(f) => f,
Err(e) => {
if e.kind() == io::ErrorKind::NotFound {
return Ok(None);
}
return Err(Error::FileError(e));
}
};
let value: Value = decode_from_std(&mut file).map_err(Error::EncodingError)?;
match ty {
KeyType::ReadWrite | KeyType::ReadSecureWrite => {}
KeyType::SecureReadWrite => {
if value.signer() != signer {
return Err(crate::Error::NotAllowed);
}
}
}
Ok(Some(value))
}
fn write(dir: &PathBuf, value: &crate::Value) -> Result<()> {
if let Ok(mut file) = fs::File::open(dir) {
let current_signer = decode_from_std(&mut file).map_err(Error::EncodingError)?;
if value.signer() != current_signer {
return Err(crate::Error::NotAllowed);
}
}
let mut file = fs::File::create(dir).map_err(Error::FileError)?;
encode_to_std(value, &mut file).map_err(Error::EncodingError)?;
Ok(())
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::{test_utils::file::generate_dir, KeyType};
#[tokio::test]
async fn test_read_write() {
let (dir, _guard) = generate_dir("read_write");
let file = dir.join("filename");
let res = read(&file, KeyType::ReadWrite, None).unwrap();
assert!(res.is_none());
let value: Value = Value::from_value(&"foo").unwrap();
write(&file, &value).unwrap();
let res = read(&file, KeyType::ReadWrite, None).unwrap();
assert_eq!(res.unwrap(), value);
}
#[tokio::test]
async fn test_disk_store() {
let (dir, _guard) = generate_dir("disk_store");
let mut disk_store = Store::new(dir.clone());
let key = Key::from_value(&0, KeyType::ReadWrite);
let value = Value::from_value(&"foo").unwrap();
let res = disk_store.get(key, None).await.unwrap();
assert!(res.is_none());
disk_store.set(key, value.clone()).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let res = disk_store.get(key, None).await.unwrap();
assert_eq!(res.unwrap(), value);
}
#[tokio::test]
async fn test_disk_store_channel() {
let (dir, _guard) = generate_dir("disk_store_channel");
let (disk_channel, rx) = mpsc::channel(5);
Store::init(dir.clone(), rx);
let key = Key::from_value(&0, KeyType::ReadWrite);
let value = Value::from_value(&"foo").unwrap();
let (tx, rx) = oneshot::channel();
disk_channel.send(Action::Get(key, None, tx)).await.unwrap();
let res = rx.await.unwrap().unwrap();
assert!(res.is_none());
disk_channel
.send(Action::Set(key, value.clone()))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let (tx, rx) = oneshot::channel();
disk_channel.send(Action::Get(key, None, tx)).await.unwrap();
let res = rx.await.unwrap().unwrap();
assert_eq!(res.unwrap(), value);
}
}