use common::{
DakeraError, NamespaceId, QuotaCheckResult, QuotaConfig, QuotaEnforcement, QuotaStatus,
QuotaUsage, Result, Vector,
};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
pub struct QuotaManager {
configs: RwLock<HashMap<NamespaceId, QuotaConfig>>,
usage: RwLock<HashMap<NamespaceId, QuotaUsage>>,
default_config: RwLock<Option<QuotaConfig>>,
}
impl QuotaManager {
pub fn new() -> Self {
Self {
configs: RwLock::new(HashMap::new()),
usage: RwLock::new(HashMap::new()),
default_config: RwLock::new(None),
}
}
pub fn with_default(default_config: QuotaConfig) -> Self {
Self {
configs: RwLock::new(HashMap::new()),
usage: RwLock::new(HashMap::new()),
default_config: RwLock::new(Some(default_config)),
}
}
pub fn set_default_config(&self, config: Option<QuotaConfig>) {
*self.default_config.write() = config;
}
pub fn get_default_config(&self) -> Option<QuotaConfig> {
self.default_config.read().clone()
}
pub fn set_quota(&self, namespace: &NamespaceId, config: QuotaConfig) {
self.configs.write().insert(namespace.clone(), config);
tracing::info!(namespace = %namespace, "Quota configuration updated");
}
pub fn get_quota(&self, namespace: &NamespaceId) -> Option<QuotaConfig> {
self.configs
.read()
.get(namespace)
.cloned()
.or_else(|| self.default_config.read().clone())
}
pub fn remove_quota(&self, namespace: &NamespaceId) -> Option<QuotaConfig> {
self.configs.write().remove(namespace)
}
pub fn init_usage(&self, namespace: &NamespaceId, vector_count: u64, storage_bytes: u64) {
let usage = QuotaUsage::new(vector_count, storage_bytes);
self.usage.write().insert(namespace.clone(), usage);
if !self.configs.read().contains_key(namespace) {
if let Some(default) = self.default_config.read().clone() {
self.configs.write().insert(namespace.clone(), default);
}
}
}
pub fn get_usage(&self, namespace: &NamespaceId) -> Option<QuotaUsage> {
self.usage.read().get(namespace).cloned()
}
pub fn update_usage(&self, namespace: &NamespaceId, vector_count: u64, storage_bytes: u64) {
let mut usage_map = self.usage.write();
if let Some(usage) = usage_map.get_mut(namespace) {
usage.vector_count = vector_count;
usage.storage_bytes = storage_bytes;
usage.touch();
} else {
usage_map.insert(
namespace.clone(),
QuotaUsage::new(vector_count, storage_bytes),
);
}
}
pub fn increment_usage(&self, namespace: &NamespaceId, added_vectors: u64, added_bytes: u64) {
let mut usage_map = self.usage.write();
if let Some(usage) = usage_map.get_mut(namespace) {
usage.vector_count += added_vectors;
usage.storage_bytes += added_bytes;
usage.touch();
} else {
usage_map.insert(
namespace.clone(),
QuotaUsage::new(added_vectors, added_bytes),
);
}
}
pub fn decrement_usage(
&self,
namespace: &NamespaceId,
removed_vectors: u64,
removed_bytes: u64,
) {
let mut usage_map = self.usage.write();
if let Some(usage) = usage_map.get_mut(namespace) {
usage.vector_count = usage.vector_count.saturating_sub(removed_vectors);
usage.storage_bytes = usage.storage_bytes.saturating_sub(removed_bytes);
usage.touch();
}
}
pub fn remove_usage(&self, namespace: &NamespaceId) {
self.usage.write().remove(namespace);
}
pub fn get_status(&self, namespace: &NamespaceId) -> Option<QuotaStatus> {
let config = self.get_quota(namespace)?;
let usage = self.get_usage(namespace).unwrap_or_default();
Some(QuotaStatus::new(namespace.clone(), config, usage))
}
pub fn get_all_status(&self) -> Vec<QuotaStatus> {
let configs = self.configs.read();
let usage = self.usage.read();
let mut namespaces: std::collections::HashSet<_> = configs.keys().cloned().collect();
namespaces.extend(usage.keys().cloned());
namespaces
.into_iter()
.filter_map(|ns| {
let config = configs
.get(&ns)
.cloned()
.or_else(|| self.default_config.read().clone())?;
let usage = usage.get(&ns).cloned().unwrap_or_default();
Some(QuotaStatus::new(ns, config, usage))
})
.collect()
}
pub fn estimate_storage_size(vectors: &[Vector]) -> u64 {
vectors
.iter()
.map(|v| {
let id_size = v.id.len() as u64;
let values_size = (v.values.len() * std::mem::size_of::<f32>()) as u64;
let metadata_size = v
.metadata
.as_ref()
.map(|m| serde_json::to_string(m).map(|s| s.len()).unwrap_or(0))
.unwrap_or(0) as u64;
let ttl_size = 16u64;
id_size + values_size + metadata_size + ttl_size + 64 })
.sum()
}
pub fn check_upsert(&self, namespace: &NamespaceId, vectors: &[Vector]) -> QuotaCheckResult {
let config = match self.get_quota(namespace) {
Some(c) => c,
None => {
return QuotaCheckResult {
allowed: true,
reason: None,
usage: self.get_usage(namespace).unwrap_or_default(),
exceeded_quota: None,
};
}
};
if config.enforcement == QuotaEnforcement::None {
return QuotaCheckResult {
allowed: true,
reason: None,
usage: self.get_usage(namespace).unwrap_or_default(),
exceeded_quota: None,
};
}
let usage = self.get_usage(namespace).unwrap_or_default();
let new_vectors = vectors.len() as u64;
let new_bytes = Self::estimate_storage_size(vectors);
if let Some(max_vectors) = config.max_vectors {
let projected = usage.vector_count + new_vectors;
if projected > max_vectors {
let reason = format!(
"Would exceed max vectors: {} + {} = {} > {}",
usage.vector_count, new_vectors, projected, max_vectors
);
if config.enforcement == QuotaEnforcement::Hard {
return QuotaCheckResult {
allowed: false,
reason: Some(reason),
usage,
exceeded_quota: Some("max_vectors".to_string()),
};
} else {
tracing::warn!(
namespace = %namespace,
reason = %reason,
"Soft quota exceeded"
);
}
}
}
if let Some(max_storage) = config.max_storage_bytes {
let projected = usage.storage_bytes + new_bytes;
if projected > max_storage {
let reason = format!(
"Would exceed max storage: {} + {} = {} > {} bytes",
usage.storage_bytes, new_bytes, projected, max_storage
);
if config.enforcement == QuotaEnforcement::Hard {
return QuotaCheckResult {
allowed: false,
reason: Some(reason),
usage,
exceeded_quota: Some("max_storage_bytes".to_string()),
};
} else {
tracing::warn!(
namespace = %namespace,
reason = %reason,
"Soft quota exceeded"
);
}
}
}
if let Some(max_dim) = config.max_dimensions {
for v in vectors {
if v.values.len() > max_dim {
let reason = format!(
"Vector '{}' exceeds max dimensions: {} > {}",
v.id,
v.values.len(),
max_dim
);
if config.enforcement == QuotaEnforcement::Hard {
return QuotaCheckResult {
allowed: false,
reason: Some(reason),
usage,
exceeded_quota: Some("max_dimensions".to_string()),
};
} else {
tracing::warn!(
namespace = %namespace,
vector_id = %v.id,
reason = %reason,
"Soft quota exceeded"
);
}
}
}
}
if let Some(max_meta) = config.max_metadata_bytes {
for v in vectors {
if let Some(meta) = &v.metadata {
let meta_size = serde_json::to_string(meta).map(|s| s.len()).unwrap_or(0);
if meta_size > max_meta {
let reason = format!(
"Vector '{}' exceeds max metadata size: {} > {} bytes",
v.id, meta_size, max_meta
);
if config.enforcement == QuotaEnforcement::Hard {
return QuotaCheckResult {
allowed: false,
reason: Some(reason),
usage,
exceeded_quota: Some("max_metadata_bytes".to_string()),
};
} else {
tracing::warn!(
namespace = %namespace,
vector_id = %v.id,
reason = %reason,
"Soft quota exceeded"
);
}
}
}
}
}
QuotaCheckResult {
allowed: true,
reason: None,
usage,
exceeded_quota: None,
}
}
pub fn enforce_upsert(&self, namespace: &NamespaceId, vectors: &[Vector]) -> Result<()> {
let check = self.check_upsert(namespace, vectors);
if !check.allowed {
return Err(DakeraError::QuotaExceeded {
namespace: namespace.clone(),
reason: check.reason.unwrap_or_else(|| "Quota exceeded".to_string()),
});
}
Ok(())
}
}
impl Default for QuotaManager {
fn default() -> Self {
Self::new()
}
}
pub type SharedQuotaManager = Arc<QuotaManager>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quota_manager_basic() {
let manager = QuotaManager::new();
let namespace = "test".to_string();
let vectors = vec![Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0, 3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}];
let check = manager.check_upsert(&namespace, &vectors);
assert!(check.allowed);
}
#[test]
fn test_quota_enforcement_hard() {
let manager = QuotaManager::new();
let namespace = "test".to_string();
manager.set_quota(
&namespace,
QuotaConfig {
max_vectors: Some(2),
max_storage_bytes: None,
max_dimensions: None,
max_metadata_bytes: None,
enforcement: QuotaEnforcement::Hard,
},
);
manager.init_usage(&namespace, 1, 100);
let vectors = vec![Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}];
let check = manager.check_upsert(&namespace, &vectors);
assert!(check.allowed);
let vectors = vec![
Vector {
id: "v2".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v3".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
];
let check = manager.check_upsert(&namespace, &vectors);
assert!(!check.allowed);
assert_eq!(check.exceeded_quota, Some("max_vectors".to_string()));
}
#[test]
fn test_quota_enforcement_soft() {
let manager = QuotaManager::new();
let namespace = "test".to_string();
manager.set_quota(
&namespace,
QuotaConfig {
max_vectors: Some(1),
max_storage_bytes: None,
max_dimensions: None,
max_metadata_bytes: None,
enforcement: QuotaEnforcement::Soft,
},
);
manager.init_usage(&namespace, 1, 100);
let vectors = vec![Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}];
let check = manager.check_upsert(&namespace, &vectors);
assert!(check.allowed); }
#[test]
fn test_dimension_quota() {
let manager = QuotaManager::new();
let namespace = "test".to_string();
manager.set_quota(
&namespace,
QuotaConfig {
max_vectors: None,
max_storage_bytes: None,
max_dimensions: Some(3),
max_metadata_bytes: None,
enforcement: QuotaEnforcement::Hard,
},
);
let vectors = vec![Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0, 3.0, 4.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}];
let check = manager.check_upsert(&namespace, &vectors);
assert!(!check.allowed);
assert_eq!(check.exceeded_quota, Some("max_dimensions".to_string()));
}
#[test]
fn test_quota_status() {
let manager = QuotaManager::new();
let namespace = "test".to_string();
manager.set_quota(
&namespace,
QuotaConfig {
max_vectors: Some(100),
max_storage_bytes: Some(1_000_000),
max_dimensions: None,
max_metadata_bytes: None,
enforcement: QuotaEnforcement::Hard,
},
);
manager.init_usage(&namespace, 50, 500_000);
let status = manager.get_status(&namespace).unwrap();
assert_eq!(status.namespace, namespace);
assert!(!status.is_exceeded);
assert_eq!(status.vector_usage_percent, Some(50.0));
assert_eq!(status.storage_usage_percent, Some(50.0));
}
}