use std::collections::HashMap;
use std::fmt;
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_trait::async_trait;
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
use crate::{
add_prefix_to_storage, OwnedBytes, Storage, StorageErrorKind, StorageFactory, StorageResult,
};
#[derive(Default, Clone)]
pub struct RamStorage {
files: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
}
impl fmt::Debug for RamStorage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RamStorage")
}
}
impl RamStorage {
pub fn builder() -> RamStorageBuilder {
RamStorageBuilder::default()
}
async fn put_data(&self, path: &Path, payload: OwnedBytes) {
self.files.write().await.insert(path.to_path_buf(), payload);
}
async fn get_data(&self, path: &Path) -> Option<OwnedBytes> {
self.files.read().await.get(path).cloned()
}
pub async fn list_files(&self) -> Vec<PathBuf> {
self.files.read().await.keys().cloned().collect()
}
}
#[async_trait]
impl Storage for RamStorage {
async fn check(&self) -> anyhow::Result<()> {
Ok(())
}
async fn put(
&self,
path: &Path,
payload: Box<dyn crate::PutPayload>,
) -> crate::StorageResult<()> {
let payload_bytes = payload.read_all().await?;
self.put_data(path, payload_bytes).await;
Ok(())
}
async fn copy_to_file(&self, path: &Path, output_path: &Path) -> StorageResult<()> {
let payload_bytes = self.get_data(path).await.ok_or_else(|| {
StorageErrorKind::DoesNotExist
.with_error(anyhow::anyhow!("Failed to find dest_path {:?}", path))
})?;
let mut file = File::create(output_path).await?;
file.write_all(&payload_bytes).await?;
file.flush().await?;
Ok(())
}
async fn get_slice(&self, path: &Path, range: Range<usize>) -> StorageResult<OwnedBytes> {
let payload_bytes = self.get_data(path).await.ok_or_else(|| {
StorageErrorKind::DoesNotExist
.with_error(anyhow::anyhow!("Failed to find dest_path {:?}", path))
})?;
Ok(payload_bytes.slice(range.start as usize..range.end as usize))
}
async fn delete(&self, path: &Path) -> StorageResult<()> {
self.files.write().await.remove(path);
Ok(())
}
async fn get_all(&self, path: &Path) -> StorageResult<OwnedBytes> {
let payload_bytes = self.get_data(path).await.ok_or_else(|| {
StorageErrorKind::DoesNotExist
.with_error(anyhow::anyhow!("Failed to find dest_path {:?}", path))
})?;
Ok(payload_bytes)
}
fn uri(&self) -> String {
"ram://".to_string()
}
async fn file_num_bytes(&self, path: &Path) -> StorageResult<u64> {
if let Some(file_bytes) = self.files.read().await.get(path) {
Ok(file_bytes.len() as u64)
} else {
let err = anyhow::anyhow!("Missing file `{}`", path.display());
Err(StorageErrorKind::DoesNotExist.with_error(err))
}
}
}
#[derive(Default)]
pub struct RamStorageBuilder {
files: HashMap<PathBuf, OwnedBytes>,
}
impl RamStorageBuilder {
pub fn put(mut self, path: &str, payload: &[u8]) -> Self {
self.files
.insert(PathBuf::from(path), OwnedBytes::new(payload.to_vec()));
self
}
pub fn build(self) -> RamStorage {
RamStorage {
files: Arc::new(RwLock::new(self.files)),
}
}
}
pub struct RamStorageFactory {
ram_storage: Arc<dyn Storage>,
}
impl Default for RamStorageFactory {
fn default() -> Self {
RamStorageFactory {
ram_storage: Arc::new(RamStorage::default()),
}
}
}
impl StorageFactory for RamStorageFactory {
fn protocol(&self) -> String {
"ram".to_string()
}
fn resolve(&self, uri: &str) -> crate::StorageResult<Arc<dyn Storage>> {
if !uri.starts_with("ram://") {
let err_msg = anyhow::anyhow!(
"{:?} is an invalid ram storage uri. Only ram:// is accepted.",
uri
);
return Err(StorageErrorKind::DoesNotExist.with_error(err_msg));
}
let prefix = uri.split("://").nth(1).ok_or_else(|| {
StorageErrorKind::DoesNotExist
.with_error(anyhow::anyhow!("Invalid prefix path: {}", uri))
})?;
Ok(add_prefix_to_storage(self.ram_storage.clone(), prefix))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_suite::storage_test_suite;
#[tokio::test]
async fn test_storage() -> anyhow::Result<()> {
let mut ram_storage = RamStorage::default();
storage_test_suite(&mut ram_storage).await?;
Ok(())
}
#[test]
fn test_ram_storage_factory() {
let ram_storage_factory = RamStorageFactory::default();
let err = ram_storage_factory.resolve("rom://toto").err().unwrap();
assert_eq!(err.kind(), StorageErrorKind::DoesNotExist);
let data_result = ram_storage_factory.resolve("ram://data").ok().unwrap();
let home_result = ram_storage_factory.resolve("ram://home/data").ok().unwrap();
assert_ne!(data_result.uri(), home_result.uri());
let data_result_two = ram_storage_factory.resolve("ram://data").ok().unwrap();
assert_eq!(data_result.uri(), data_result_two.uri());
}
#[tokio::test]
async fn test_ram_storage_builder() -> anyhow::Result<()> {
let storage = RamStorage::builder()
.put("path1", b"path1_payload")
.put("path2", b"path2_payload")
.put("path1", b"path1_payloadb")
.build();
assert_eq!(
&storage.get_all(Path::new("path1")).await?,
&b"path1_payloadb"[..]
);
assert_eq!(
&storage.get_all(Path::new("path2")).await?,
&b"path2_payload"[..]
);
Ok(())
}
}