use std::collections::HashMap;
use std::sync::Arc;
use pubky_common::crypto::PublicKey;
use crate::persistence::files::layer_domain_error::LayerDomainError;
use crate::persistence::sql::uexecutor;
use crate::services::user_service::{UserService, FILE_METADATA_SIZE};
use crate::shared::webdav::EntryPath;
use opendal::raw::*;
use opendal::Result;
#[derive(Clone)]
pub struct UserQuotaLayer {
user_service: UserService,
default_storage_mb: Option<u64>,
}
impl UserQuotaLayer {
pub fn new(user_service: UserService, default_storage_mb: Option<u64>) -> Self {
Self {
user_service,
default_storage_mb,
}
}
}
pub(crate) fn would_exceed_limit(
current_bytes: u64,
bytes_delta: i64,
max_bytes: Option<u64>,
) -> bool {
let Some(max) = max_bytes else {
return false;
};
let new_total = current_bytes as i128 + bytes_delta as i128;
new_total > 0 && new_total > max as i128
}
pub(crate) fn resolve_storage_max_bytes(
user: &crate::persistence::sql::user::UserEntity,
default_storage_mb: Option<u64>,
) -> Option<u64> {
user.quota()
.storage_quota_mb
.resolve_with_default(default_storage_mb)
.map(|mb| mb.saturating_mul(1024 * 1024))
}
impl<A: Access> Layer<A> for UserQuotaLayer {
type LayeredAccess = UserQuotaAccessor<A>;
fn layer(&self, inner: A) -> Self::LayeredAccess {
UserQuotaAccessor {
inner: Arc::new(inner),
user_service: self.user_service.clone(),
default_storage_mb: self.default_storage_mb,
}
}
}
#[derive(Debug, Clone)]
pub struct UserQuotaAccessor<A: Access> {
inner: Arc<A>,
user_service: UserService,
default_storage_mb: Option<u64>,
}
impl<A: Access> LayeredAccess for UserQuotaAccessor<A> {
type Inner = A;
type Reader = A::Reader;
type Writer = WriterWrapper<A::Writer, A>;
type Lister = A::Lister;
type Deleter = DeleterWrapper<A::Deleter, A>;
fn inner(&self) -> &Self::Inner {
&self.inner
}
async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result<RpCreateDir> {
let entry_path = EntryPath::parse_opendal(path)?;
self.inner.create_dir(entry_path.as_str(), args).await
}
async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
self.inner.read(path, args).await
}
async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
let entry_path = EntryPath::parse_opendal(path)?;
let canonical_path = entry_path.to_string();
let (rp, writer) = self.inner.write(&canonical_path, args).await?;
Ok((
rp,
WriterWrapper {
inner: writer,
user_service: self.user_service.clone(),
bytes_count: 0,
entry_path,
inner_accessor: self.inner.clone(),
default_storage_mb: self.default_storage_mb,
},
))
}
async fn copy(&self, from: &str, to: &str, args: OpCopy) -> Result<RpCopy> {
let from = EntryPath::parse_opendal(from)?;
let to = EntryPath::parse_opendal(to)?;
self.inner.copy(from.as_str(), to.as_str(), args).await
}
async fn rename(&self, from: &str, to: &str, args: OpRename) -> Result<RpRename> {
let from = EntryPath::parse_opendal(from)?;
let to = EntryPath::parse_opendal(to)?;
self.inner.rename(from.as_str(), to.as_str(), args).await
}
async fn stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
self.inner.stat(path, args).await
}
async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
let (rp, deleter) = self.inner.delete().await?;
Ok((
rp,
DeleterWrapper {
inner: deleter,
user_service: self.user_service.clone(),
inner_accessor: self.inner.clone(),
path_queue: Vec::new(),
},
))
}
async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
self.inner.list(path, args).await
}
async fn presign(&self, path: &str, args: OpPresign) -> Result<RpPresign> {
let entry_path = EntryPath::parse_opendal(path)?;
self.inner.presign(entry_path.as_str(), args).await
}
}
pub struct WriterWrapper<R, A: Access> {
inner: R,
user_service: UserService,
bytes_count: u64,
entry_path: EntryPath,
inner_accessor: Arc<A>,
default_storage_mb: Option<u64>,
}
impl<R, A: Access> WriterWrapper<R, A> {
async fn get_current_file_size(&self) -> Result<(u64, bool), opendal::Error> {
let stats = match self
.inner_accessor
.stat(self.entry_path.to_string().as_str(), OpStat::default())
.await
{
Ok(stats) => stats,
Err(e) if e.kind() == opendal::ErrorKind::NotFound => {
return Ok((0, false));
}
Err(e) => {
return Err(e);
}
};
let file_size = stats.into_metadata().content_length();
Ok((file_size, true))
}
}
impl<R: oio::Write, A: Access> oio::Write for WriterWrapper<R, A> {
async fn write(&mut self, bs: opendal::Buffer) -> Result<()> {
self.bytes_count += bs.len() as u64;
self.inner.write(bs).await
}
async fn abort(&mut self) -> Result<()> {
self.inner.abort().await
}
async fn close(&mut self) -> Result<opendal::Metadata> {
let mut tx = self
.user_service
.pool()
.begin()
.await
.map_err(|e| opendal::Error::new(opendal::ErrorKind::Unexpected, e.to_string()))?;
let mut user = self
.user_service
.get_for_update(self.entry_path.pubkey(), uexecutor!(tx))
.await
.map_err(|e| {
opendal::Error::new(
opendal::ErrorKind::Unexpected,
format!("Failed to get user {}: {}", self.entry_path.pubkey(), e),
)
})?;
let (current_file_size, file_already_exists) = self.get_current_file_size().await?;
let bytes_delta = if file_already_exists {
self.bytes_count as i64 - current_file_size as i64
} else {
self.bytes_count as i64 - current_file_size as i64 + FILE_METADATA_SIZE as i64
};
let max_bytes = resolve_storage_max_bytes(&user, self.default_storage_mb);
if would_exceed_limit(user.used_bytes, bytes_delta, max_bytes) {
return Err(opendal::Error::new(
opendal::ErrorKind::RateLimited,
"User quota exceeded",
)
.set_source(LayerDomainError::DiskSpaceQuotaExceeded));
}
let metadata = self.inner.close().await?;
user.used_bytes = user.used_bytes.saturating_add_signed(bytes_delta);
self.user_service
.update(&user, uexecutor!(tx))
.await
.map_err(|e| opendal::Error::new(opendal::ErrorKind::Unexpected, e.to_string()))?;
tx.commit()
.await
.map_err(|e| opendal::Error::new(opendal::ErrorKind::Unexpected, e.to_string()))?;
Ok(metadata)
}
}
struct DeletePath {
entry_path: EntryPath,
bytes_count: Option<u64>,
exists: Option<bool>,
}
impl DeletePath {
fn new(path: &str) -> anyhow::Result<Self> {
let entry_path = EntryPath::parse_opendal(path)?;
Ok(Self {
entry_path,
bytes_count: None,
exists: None,
})
}
pub async fn pull_bytes_count<A: Access>(&mut self, operator: &A) -> Result<()> {
if self.bytes_count.is_some() {
return Ok(());
}
let size = match operator
.stat(self.entry_path.as_str(), OpStat::default())
.await
{
Ok(stats) => stats.into_metadata().content_length(),
Err(e) if e.kind() == opendal::ErrorKind::NotFound => {
self.exists = Some(false);
return Ok(());
}
Err(e) => {
return Err(e);
}
};
self.bytes_count = Some(size);
self.exists = Some(true);
Ok(())
}
}
pub struct DeleterWrapper<R, A: Access> {
inner: R,
user_service: UserService,
inner_accessor: Arc<A>,
path_queue: Vec<DeletePath>,
}
impl<R, A: Access> DeleterWrapper<R, A> {
async fn update_user_quota(&self, deleted_paths: Vec<DeletePath>) -> Result<()> {
let mut user_paths: HashMap<PublicKey, Vec<DeletePath>> = HashMap::new();
for path in deleted_paths {
user_paths
.entry(path.entry_path.pubkey().clone())
.or_default()
.push(path);
}
for (user_pubkey, paths) in user_paths {
let mut tx =
self.user_service.pool().begin().await.map_err(|e| {
opendal::Error::new(opendal::ErrorKind::Unexpected, e.to_string())
})?;
let mut user = match self
.user_service
.get_for_update(&user_pubkey, uexecutor!(tx))
.await
{
Ok(user) => user,
Err(sqlx::Error::RowNotFound) => {
continue;
}
Err(e) => {
return Err(opendal::Error::new(
opendal::ErrorKind::Unexpected,
e.to_string(),
));
}
};
let total_bytes: u64 = paths.iter().filter_map(|p| p.bytes_count).sum();
let files_deleted_count =
paths.iter().filter(|p| p.exists.unwrap_or(false)).count() as u64;
let bytes_delta = (total_bytes + files_deleted_count * FILE_METADATA_SIZE) as i64;
user.used_bytes = user.used_bytes.saturating_add_signed(-bytes_delta);
self.user_service
.update(&user, uexecutor!(tx))
.await
.map_err(|e| opendal::Error::new(opendal::ErrorKind::Unexpected, e.to_string()))?;
tx.commit()
.await
.map_err(|e| opendal::Error::new(opendal::ErrorKind::Unexpected, e.to_string()))?;
}
Ok(())
}
}
impl<R: oio::Delete, A: Access> oio::Delete for DeleterWrapper<R, A> {
async fn flush(&mut self) -> Result<usize> {
for path in self.path_queue.iter_mut() {
path.pull_bytes_count(&self.inner_accessor).await?;
}
let deleted_files_count = self.inner.flush().await?;
let deleted_paths = self
.path_queue
.drain(0..deleted_files_count)
.collect::<Vec<_>>();
self.update_user_quota(deleted_paths).await?;
Ok(deleted_files_count)
}
fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> {
let helper = match DeletePath::new(path) {
Ok(helper) => helper,
Err(e) => {
return Err(opendal::Error::new(
opendal::ErrorKind::PermissionDenied,
e.to_string(),
));
}
};
self.inner.delete(helper.entry_path.as_str(), args)?;
self.path_queue.push(helper);
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::persistence::files::opendal::opendal_test_operators::{
get_memory_operator, OpendalTestOperators,
};
use crate::persistence::sql::user::UserRepository;
use crate::persistence::sql::SqlDb;
use crate::shared::user_quota::UserQuota;
use super::*;
fn test_quota_layer(db: &SqlDb, default_quota_mb: Option<u64>) -> UserQuotaLayer {
let user_service = UserService::new(db.clone());
UserQuotaLayer::new(user_service, default_quota_mb)
}
async fn get_user_data_usage(db: &SqlDb, user_pubkey: &PublicKey) -> anyhow::Result<u64> {
let user = UserRepository::get(user_pubkey, &mut db.pool().into())
.await
.map_err(|e| opendal::Error::new(opendal::ErrorKind::Unexpected, e.to_string()))?;
Ok(user.used_bytes)
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_ensure_valid_path() {
for (_scheme, operator) in OpendalTestOperators::new().operators() {
let db = SqlDb::test().await;
let layer = test_quota_layer(&db, None);
let operator = operator.layer(layer);
operator
.write("1234567890/test.txt", vec![0; 10])
.await
.expect_err("Should fail because the path doesn't start with a pubkey");
let pubkey = pubky_common::crypto::Keypair::random().public_key();
let pubkey_raw = pubkey.z32();
UserRepository::create(&pubkey, &mut db.pool().into())
.await
.unwrap();
operator
.write(format!("{}/test.txt", pubkey_raw).as_str(), vec![0; 10])
.await
.expect("Should succeed because the path starts with a pubkey");
operator
.write("test.txt", vec![0; 10])
.await
.expect_err("Should fail because the path doesn't start with a pubkey");
operator
.stat("/")
.await
.expect("stat on root should succeed");
let err = operator
.stat("some_dir/")
.await
.expect_err("should fail because path doesn't exist");
assert_eq!(err.kind(), opendal::ErrorKind::NotFound);
}
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_quota_updated_write_delete() {
let db = SqlDb::test().await;
let layer = test_quota_layer(&db, None);
let operator = get_memory_operator().layer(layer);
let user_pubkey1 = pubky_common::crypto::Keypair::random().public_key();
let user_pubkey1_raw = user_pubkey1.z32();
UserRepository::create_with_quota_mb(&db, &user_pubkey1, 1).await;
operator
.write(
format!("{}/test.txt1", user_pubkey1_raw).as_str(),
vec![0; 10],
)
.await
.unwrap();
let user_usage = get_user_data_usage(&db, &user_pubkey1).await.unwrap();
assert_eq!(user_usage, 10 + FILE_METADATA_SIZE);
operator
.write(
format!("{}/test.txt1", user_pubkey1_raw).as_str(),
vec![0; 12],
)
.await
.unwrap();
let user_usage = get_user_data_usage(&db, &user_pubkey1).await.unwrap();
assert_eq!(user_usage, 12 + FILE_METADATA_SIZE);
operator
.write(
format!("{}/test.txt2", user_pubkey1_raw).as_str(),
vec![0; 5],
)
.await
.unwrap();
let user_usage = get_user_data_usage(&db, &user_pubkey1).await.unwrap();
assert_eq!(user_usage, 17 + 2 * FILE_METADATA_SIZE);
operator
.delete(format!("{}/test.txt1", user_pubkey1_raw).as_str())
.await
.unwrap();
let user_usage = get_user_data_usage(&db, &user_pubkey1).await.unwrap();
assert_eq!(user_usage, 5 + FILE_METADATA_SIZE);
operator
.delete(format!("{}/test.txt2", user_pubkey1_raw).as_str())
.await
.unwrap();
let user_usage = get_user_data_usage(&db, &user_pubkey1).await.unwrap();
assert_eq!(user_usage, 0);
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_quota_rechead() {
use crate::persistence::files::entry::entry_layer::EntryLayer;
use crate::persistence::files::events::{EventRepository, EventsLayer, EventsService};
use crate::persistence::sql::entry::EntryRepository;
use crate::shared::webdav::{EntryPath, WebDavPath};
let db = SqlDb::test().await;
let events_service = EventsService::new(100);
let user_quota_layer = test_quota_layer(&db, None);
let entry_layer = EntryLayer::new(db.clone());
let events_layer = EventsLayer::new(db.clone(), events_service);
let operator = get_memory_operator()
.layer(user_quota_layer)
.layer(entry_layer)
.layer(events_layer);
let user_pubkey1 = pubky_common::crypto::Keypair::random().public_key();
let user_pubkey1_raw = user_pubkey1.z32();
UserRepository::create_with_quota_mb(&db, &user_pubkey1, 1).await;
let one_mb: usize = 1024 * 1024;
let max_content = one_mb - FILE_METADATA_SIZE as usize;
let file_name1 = format!("{}/test1.txt", user_pubkey1_raw);
let entry_path1 =
EntryPath::new(user_pubkey1.clone(), WebDavPath::new("/test1.txt").unwrap());
operator
.write(file_name1.as_str(), vec![0; max_content + 1])
.await
.expect_err("Should fail because the user quota is exceeded");
operator
.read(file_name1.as_str())
.await
.expect_err("Should fail because the file doesn't exist");
let user_usage = get_user_data_usage(&db, &user_pubkey1).await.unwrap();
assert_eq!(user_usage, 0);
EntryRepository::get_by_path(&entry_path1, &mut db.pool().into())
.await
.expect_err("Entry should not exist because quota was exceeded");
let events = crate::persistence::files::events::EventRepository::get_by_cursor(
None,
Some(9999),
&mut db.pool().into(),
)
.await
.expect("Should succeed");
assert_eq!(
events.len(),
0,
"No events should be created when quota is exceeded"
);
operator
.write(file_name1.as_str(), vec![0; max_content])
.await
.expect("Should succeed because the user quota is exactly the limit");
operator
.read(file_name1.as_str())
.await
.expect("Should succeed because the file exists");
let user_usage = get_user_data_usage(&db, &user_pubkey1).await.unwrap();
assert_eq!(user_usage, max_content as u64 + FILE_METADATA_SIZE);
let entry = EntryRepository::get_by_path(&entry_path1, &mut db.pool().into())
.await
.expect("Entry should exist after successful write");
assert_eq!(entry.content_length as usize, max_content);
let events = EventRepository::get_by_cursor(None, Some(9999), &mut db.pool().into())
.await
.expect("Should succeed");
assert_eq!(
events.len(),
1,
"Event should be created after successful write"
);
let file_name2 = format!("{}/test2.txt", user_pubkey1_raw);
operator
.write(file_name2.as_str(), vec![0; 1])
.await
.expect_err("Should fail because the user quota is exceeded");
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_quota_override_variants() {
use crate::shared::user_quota::QuotaOverride;
let db = SqlDb::test().await;
let layer = test_quota_layer(&db, Some(1));
let operator = get_memory_operator().layer(layer);
let pk_value = pubky_common::crypto::Keypair::random().public_key();
let raw_value = pk_value.z32();
UserRepository::create_with_quota_mb(&db, &pk_value, 1).await;
operator
.write(format!("{raw_value}/small.txt").as_str(), vec![0; 31])
.await
.expect("Value(1): small write within 1 MB");
let usage = get_user_data_usage(&db, &pk_value).await.unwrap();
assert_eq!(usage, 31 + FILE_METADATA_SIZE);
operator
.write(format!("{raw_value}/huge.txt").as_str(), vec![0; 1_048_576])
.await
.expect_err("Value(1): >1 MB write should fail");
let pk_zero = pubky_common::crypto::Keypair::random().public_key();
let raw_zero = pk_zero.z32();
UserRepository::create_with_quota_mb(&db, &pk_zero, 0).await;
operator
.write(format!("{raw_zero}/file.txt").as_str(), vec![0; 1])
.await
.expect_err("Value(0): even 1 byte should fail");
let pk_default = pubky_common::crypto::Keypair::random().public_key();
let raw_default = pk_default.z32();
UserRepository::create(&pk_default, &mut db.pool().into())
.await
.unwrap();
operator
.write(format!("{raw_default}/small.txt").as_str(), vec![0; 31])
.await
.expect("Default: small write within 1 MB system default");
operator
.write(
format!("{raw_default}/huge.txt").as_str(),
vec![0; 1_048_576],
)
.await
.expect_err("Default: >1 MB write should fail against system default");
let db2 = SqlDb::test().await;
let layer_no_default = test_quota_layer(&db2, None);
let op_no_default = get_memory_operator().layer(layer_no_default);
let pk_no_default = pubky_common::crypto::Keypair::random().public_key();
let raw_no_default = pk_no_default.z32();
UserRepository::create(&pk_no_default, &mut db2.pool().into())
.await
.unwrap();
op_no_default
.write(
format!("{raw_no_default}/big.txt").as_str(),
vec![0; 2 * 1024 * 1024],
)
.await
.expect("Default + no system default = unlimited");
let pk_unlimited = pubky_common::crypto::Keypair::random().public_key();
let raw_unlimited = pk_unlimited.z32();
let user = UserRepository::create(&pk_unlimited, &mut db.pool().into())
.await
.unwrap();
let config = UserQuota {
storage_quota_mb: QuotaOverride::Unlimited,
..Default::default()
};
UserRepository::set_quota(user.id, &config, &mut db.pool().into())
.await
.unwrap();
operator
.write(
format!("{raw_unlimited}/big.txt").as_str(),
vec![0; 2 * 1024 * 1024],
)
.await
.expect("Unlimited: bypasses 1 MB system default");
}
#[tokio::test]
#[pubky_test_utils::test]
async fn test_storage_quota_change_takes_effect() {
let db = SqlDb::test().await;
let layer = test_quota_layer(&db, None);
let operator = get_memory_operator().layer(layer);
let user_pubkey = pubky_common::crypto::Keypair::random().public_key();
let user_raw = user_pubkey.z32();
let user = UserRepository::create_with_quota_mb(&db, &user_pubkey, 0).await;
operator
.write(format!("{user_raw}/file.txt").as_str(), vec![0; 10])
.await
.expect_err("Should fail: zero quota");
let config = UserQuota {
storage_quota_mb: crate::shared::user_quota::QuotaOverride::Value(1),
..Default::default()
};
UserRepository::set_quota(user.id, &config, &mut db.pool().into())
.await
.unwrap();
operator
.write(format!("{user_raw}/file.txt").as_str(), vec![0; 10])
.await
.expect("Should succeed after quota increase");
}
#[test]
fn no_limit_never_exceeds() {
assert!(!would_exceed_limit(u64::MAX, i64::MAX, None));
}
#[test]
fn exactly_at_limit_does_not_exceed() {
assert!(!would_exceed_limit(500, 500, Some(1000)));
}
#[test]
fn one_byte_over_limit_exceeds() {
assert!(would_exceed_limit(500, 501, Some(1000)));
}
#[test]
fn negative_delta_shrinks_usage() {
assert!(!would_exceed_limit(1000, -500, Some(1000)));
}
#[test]
fn negative_delta_below_zero_does_not_exceed() {
assert!(!would_exceed_limit(100, -200, Some(50)));
}
#[test]
fn zero_limit_any_positive_delta_exceeds() {
assert!(would_exceed_limit(0, 1, Some(0)));
}
#[test]
fn zero_limit_zero_delta_does_not_exceed() {
assert!(!would_exceed_limit(0, 0, Some(0)));
}
#[test]
fn large_current_with_large_negative_delta() {
assert!(!would_exceed_limit(u64::MAX, i64::MIN, Some(u64::MAX)));
}
#[test]
fn large_current_near_limit() {
let max = u64::MAX;
assert!(!would_exceed_limit(max, 0, Some(max)));
assert!(would_exceed_limit(max, 1, Some(max)));
}
}