use crate::{Error, Result};
use log::warn;
use std::{path::Path, sync::Arc};
use tokio::sync::Mutex;
const USED_SPACE_FILENAME: &str = "used_space";
#[derive(Debug, Clone)]
pub struct UsedSpace {
inner: Arc<Mutex<inner::UsedSpace>>,
}
pub type StoreId = u64;
impl UsedSpace {
pub fn new(max_capacity: u64) -> Self {
Self {
inner: Arc::new(Mutex::new(inner::UsedSpace::new(max_capacity))),
}
}
pub async fn reset(&self) {
inner::UsedSpace::reset(self.inner.clone()).await
}
pub async fn max_capacity(&self) -> u64 {
inner::UsedSpace::max_capacity(self.inner.clone()).await
}
pub async fn total(&self) -> u64 {
inner::UsedSpace::total(self.inner.clone()).await
}
#[allow(unused)]
pub async fn local(&self, id: StoreId) -> u64 {
inner::UsedSpace::local(self.inner.clone(), id).await
}
pub async fn add_local_store<T: AsRef<Path>>(&self, dir: T) -> Result<StoreId> {
inner::UsedSpace::add_local_store(self.inner.clone(), dir).await
}
pub async fn increase(&self, id: StoreId, consumed: u64) -> Result<()> {
inner::UsedSpace::increase(self.inner.clone(), id, consumed).await
}
pub async fn decrease(&self, id: StoreId, released: u64) -> Result<()> {
inner::UsedSpace::decrease(self.inner.clone(), id, released).await
}
}
mod inner {
use super::*;
use std::{collections::HashMap, io::SeekFrom};
use tokio::{
fs::{File, OpenOptions},
io::{AsyncReadExt, AsyncWriteExt},
};
#[derive(Debug)]
pub struct UsedSpace {
max_capacity: u64,
total_value: u64,
local_stores: HashMap<StoreId, LocalUsedSpace>,
next_id: StoreId,
}
#[derive(Debug)]
struct LocalUsedSpace {
pub local_value: u64,
pub local_record: File,
}
impl UsedSpace {
pub fn new(max_capacity: u64) -> Self {
Self {
max_capacity,
total_value: 0u64,
local_stores: HashMap::new(),
next_id: 0u64,
}
}
pub async fn reset(used_space: Arc<Mutex<UsedSpace>>) {
let mut used_space_lock = used_space.lock().await;
used_space_lock.total_value = 0;
for (_id, local_used_space) in used_space_lock.local_stores.iter_mut() {
local_used_space.local_value = 0;
if let Err(err) =
Self::write_local_to_file(&mut local_used_space.local_record, 0).await
{
warn!("Error updating used_space file on disk: {}", err);
}
}
}
pub async fn max_capacity(used_space: Arc<Mutex<UsedSpace>>) -> u64 {
let used_space_lock = used_space.lock().await;
used_space_lock.max_capacity
}
pub async fn total(used_space: Arc<Mutex<UsedSpace>>) -> u64 {
let used_space_lock = used_space.lock().await;
used_space_lock.total_value
}
pub async fn local(used_space: Arc<Mutex<UsedSpace>>, id: StoreId) -> u64 {
let used_space_lock = used_space.lock().await;
used_space_lock
.local_stores
.get(&id)
.map_or(0, |res| res.local_value)
}
pub async fn add_local_store<T: AsRef<Path>>(
used_space: Arc<Mutex<UsedSpace>>,
dir: T,
) -> Result<StoreId> {
let mut local_record = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(dir.as_ref().join(USED_SPACE_FILENAME))
.await?;
let mut buffer = vec![];
let could_read = local_record.read_to_end(&mut buffer).await.is_ok();
let has_value = !buffer.is_empty();
let local_value = if could_read && has_value {
bincode::deserialize::<u64>(&buffer)?
} else {
let mut bytes = Vec::<u8>::new();
bincode::serialize_into(&mut bytes, &0_u64)?;
local_record.write_all(&bytes).await?;
0
};
let local_store = LocalUsedSpace {
local_value,
local_record,
};
let mut used_space_lock = used_space.lock().await;
let id = used_space_lock.next_id;
used_space_lock.next_id += 1;
let _ = used_space_lock.local_stores.insert(id, local_store);
Ok(id)
}
pub async fn increase(
used_space: Arc<Mutex<UsedSpace>>,
id: StoreId,
consumed: u64,
) -> Result<()> {
let mut used_space_lock = used_space.lock().await;
let new_total = used_space_lock
.total_value
.checked_add(consumed)
.ok_or(Error::NotEnoughSpace)?;
if new_total > used_space_lock.max_capacity {
return Err(Error::NotEnoughSpace);
}
let new_local = used_space_lock
.local_stores
.get(&id)
.ok_or(Error::NoStoreId)?
.local_value
.checked_add(consumed)
.ok_or(Error::NotEnoughSpace)?;
{
let record = &mut used_space_lock
.local_stores
.get_mut(&id)
.ok_or(Error::NoStoreId)?
.local_record;
Self::write_local_to_file(record, new_local).await?;
}
used_space_lock.total_value = new_total;
used_space_lock
.local_stores
.get_mut(&id)
.ok_or(Error::NoStoreId)?
.local_value = new_local;
Ok(())
}
pub async fn decrease(
used_space: Arc<Mutex<UsedSpace>>,
id: StoreId,
released: u64,
) -> Result<()> {
let mut used_space_lock = used_space.lock().await;
let new_local = used_space_lock
.local_stores
.get_mut(&id)
.ok_or(Error::NoStoreId)?
.local_value
.saturating_sub(released);
let new_total = used_space_lock.total_value.saturating_sub(released);
{
let record = &mut used_space_lock
.local_stores
.get_mut(&id)
.ok_or(Error::NoStoreId)?
.local_record;
Self::write_local_to_file(record, new_local).await?;
}
used_space_lock.total_value = new_total;
used_space_lock
.local_stores
.get_mut(&id)
.ok_or(Error::NoStoreId)?
.local_value = new_local;
Ok(())
}
async fn write_local_to_file(record: &mut File, local: u64) -> Result<()> {
record.set_len(0).await?;
let _ = record.seek(SeekFrom::Start(0)).await?;
let mut contents = Vec::<u8>::new();
bincode::serialize_into(&mut contents, &local)?;
record.write_all(&contents).await?;
record.sync_all().await?;
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::{Error, Result, UsedSpace};
use tempdir::TempDir;
const TEST_STORE_MAX_SIZE: u64 = u64::MAX;
fn create_temp_root() -> Result<TempDir> {
TempDir::new(&"temp_store_root").map_err(|e| Error::TempDirCreationFailed(e.to_string()))
}
fn create_temp_store(temp_root: &TempDir) -> Result<TempDir> {
let path_str = temp_root.path().join(&"temp_store");
let path_str = path_str.to_str().ok_or_else(|| {
Error::TempDirCreationFailed("Could not parse path to string".to_string())
})?;
TempDir::new(path_str).map_err(|e| Error::TempDirCreationFailed(e.to_string()))
}
#[tokio::test]
async fn used_space_multiwriter_test() -> Result<()> {
const NUMS_TO_ADD: usize = 128;
let root_dir = create_temp_root()?;
let store_dir = create_temp_store(&root_dir)?;
let used_space = UsedSpace::new(TEST_STORE_MAX_SIZE);
let id = used_space.add_local_store(&store_dir).await?;
let mut rng = rand::thread_rng();
let bytes = crate::utils::random_vec(&mut rng, std::mem::size_of::<u32>() * NUMS_TO_ADD);
let mut nums = Vec::new();
for chunk in bytes.as_slice().chunks_exact(std::mem::size_of::<u32>()) {
let mut num = 0u32;
for (i, component) in chunk.iter().enumerate() {
num |= (*component as u32) << (i * 8);
}
nums.push(num as u64);
}
let total: u64 = nums.iter().sum();
let mut tasks = Vec::new();
for n in nums.iter() {
tasks.push(used_space.increase(id, *n));
}
let _ = futures::future::try_join_all(tasks.into_iter()).await?;
assert_eq!(total, used_space.total().await);
assert_eq!(total, used_space.local(id).await);
let mut tasks = Vec::new();
for n in nums.iter() {
tasks.push(used_space.decrease(id, *n));
}
let _ = futures::future::try_join_all(tasks.into_iter()).await?;
assert_eq!(0, used_space.total().await);
assert_eq!(0, used_space.local(id).await);
Ok(())
}
}