use std::time::Instant;
use serde::{Deserialize, Serialize};
use super::isolation::{get_isolation_manager, QueryRoute};
use super::quotas::{get_quota_manager, QuotaResult};
use super::registry::{get_registry, TenantConfig, TenantError};
use super::validation::{escape_string_literal, validate_ip_address, validate_tenant_id};
#[derive(Debug, Clone)]
pub enum OperationResult<T> {
Success(T),
QuotaDenied(QuotaResult),
TenantError(TenantError),
Error(String),
}
impl<T> OperationResult<T> {
pub fn is_success(&self) -> bool {
matches!(self, Self::Success(_))
}
pub fn unwrap(self) -> T {
match self {
Self::Success(v) => v,
Self::QuotaDenied(q) => panic!("Quota denied: {:?}", q),
Self::TenantError(e) => panic!("Tenant error: {}", e),
Self::Error(e) => panic!("Operation error: {}", e),
}
}
pub fn into_result(self) -> Result<T, String> {
match self {
Self::Success(v) => Ok(v),
Self::QuotaDenied(q) => Err(q
.error_message()
.unwrap_or_else(|| "Quota denied".to_string())),
Self::TenantError(e) => Err(e.to_string()),
Self::Error(e) => Err(e),
}
}
}
#[derive(Debug, Clone)]
pub struct TenantContext {
pub tenant_id: String,
pub config: TenantConfig,
pub route: QueryRoute,
pub is_admin: bool,
}
#[derive(Debug, Clone)]
pub struct ValidatedTenantId(String);
impl ValidatedTenantId {
pub fn new(tenant_id: &str) -> Result<Self, TenantError> {
validate_tenant_id(tenant_id).map_err(|e| TenantError::InvalidId(format!("{}", e)))?;
Ok(Self(tenant_id.to_string()))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl TenantContext {
pub fn current() -> Result<Self, TenantError> {
let tenant_id = get_current_tenant_id()?;
if tenant_id == "*" {
return Ok(Self {
tenant_id: "*".to_string(),
config: TenantConfig::new("*".to_string()),
route: QueryRoute::SharedWithFilter {
table: "".to_string(),
filter: "true".to_string(), tenant_param: None, },
is_admin: true,
});
}
let config = get_registry().validate_context(&tenant_id)?;
let route = get_isolation_manager().route_query(&tenant_id, "embeddings");
Ok(Self {
tenant_id,
config,
route,
is_admin: false,
})
}
pub fn for_tenant(tenant_id: &str) -> Result<Self, TenantError> {
let config = get_registry().validate_context(tenant_id)?;
let route = get_isolation_manager().route_query(tenant_id, "embeddings");
Ok(Self {
tenant_id: tenant_id.to_string(),
config,
route,
is_admin: false,
})
}
pub fn table_ref(&self, base_table: &str) -> String {
let route = get_isolation_manager().route_query(&self.tenant_id, base_table);
route.table_reference()
}
pub fn where_clause(&self, base_table: &str) -> Option<String> {
let route = get_isolation_manager().route_query(&self.tenant_id, base_table);
route.where_clause()
}
}
pub fn get_current_tenant_id() -> Result<String, TenantError> {
#[cfg(feature = "pg_test")]
{
thread_local! {
static MOCK_TENANT_ID: std::cell::RefCell<String> = std::cell::RefCell::new(String::new());
}
MOCK_TENANT_ID.with(|id| {
let tenant_id = id.borrow().clone();
if tenant_id.is_empty() {
Err(TenantError::NoContext)
} else {
Ok(tenant_id)
}
})
}
#[cfg(not(feature = "pg_test"))]
{
Err(TenantError::NoContext)
}
}
#[cfg(feature = "pg_test")]
pub fn set_mock_tenant_id(tenant_id: &str) {
thread_local! {
static MOCK_TENANT_ID: std::cell::RefCell<String> = std::cell::RefCell::new(String::new());
}
MOCK_TENANT_ID.with(|id| {
*id.borrow_mut() = tenant_id.to_string();
});
}
pub struct TenantVectorInsert<'a> {
ctx: &'a TenantContext,
vectors: Vec<(Vec<f32>, Option<serde_json::Value>)>,
table_name: String,
estimated_bytes_per_vector: usize,
}
impl<'a> TenantVectorInsert<'a> {
pub fn new(ctx: &'a TenantContext, table_name: &str) -> Self {
Self {
ctx,
vectors: Vec::new(),
table_name: table_name.to_string(),
estimated_bytes_per_vector: 4 * 1536 + 100, }
}
pub fn add(&mut self, vector: Vec<f32>, metadata: Option<serde_json::Value>) -> &mut Self {
self.vectors.push((vector, metadata));
self
}
pub fn add_batch(&mut self, vectors: Vec<(Vec<f32>, Option<serde_json::Value>)>) -> &mut Self {
self.vectors.extend(vectors);
self
}
pub fn execute(self) -> OperationResult<InsertResult> {
let quota_manager = get_quota_manager();
let total_bytes = self.vectors.len() as u64 * self.estimated_bytes_per_vector as u64;
let quota_check = quota_manager.check_vector_insert(
&self.ctx.tenant_id,
self.vectors.len() as u64,
total_bytes,
);
if !quota_check.is_allowed() {
return OperationResult::QuotaDenied(quota_check);
}
let table_ref = self.ctx.table_ref(&self.table_name);
let start = Instant::now();
let inserted_count = self.vectors.len();
quota_manager.record_vector_insert(&self.ctx.tenant_id, inserted_count as u64, total_bytes);
OperationResult::Success(InsertResult {
inserted_count,
table_used: table_ref,
duration_ms: start.elapsed().as_millis() as u64,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InsertResult {
pub inserted_count: usize,
pub table_used: String,
pub duration_ms: u64,
}
pub struct TenantVectorSearch<'a> {
ctx: &'a TenantContext,
query: Vec<f32>,
k: usize,
table_name: String,
ef_search: Option<usize>,
filter: Option<String>,
}
impl<'a> TenantVectorSearch<'a> {
pub fn new(ctx: &'a TenantContext, query: Vec<f32>, k: usize, table_name: &str) -> Self {
Self {
ctx,
query,
k,
table_name: table_name.to_string(),
ef_search: None,
filter: None,
}
}
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = Some(ef);
self
}
pub fn with_filter(mut self, filter: &str) -> Self {
self.filter = Some(filter.to_string());
self
}
pub fn execute(self) -> OperationResult<SearchResult> {
let quota_manager = get_quota_manager();
let rate_check = quota_manager.check_query(&self.ctx.tenant_id);
if !rate_check.is_allowed() {
return OperationResult::QuotaDenied(rate_check);
}
quota_manager.start_query(&self.ctx.tenant_id);
let start = Instant::now();
let table_ref = self.ctx.table_ref(&self.table_name);
let tenant_filter = self.ctx.where_clause(&self.table_name);
let combined_filter = match (&tenant_filter, &self.filter) {
(Some(tf), Some(f)) => Some(format!("({}) AND ({})", tf, f)),
(Some(tf), None) => Some(tf.clone()),
(None, Some(f)) => Some(f.clone()),
(None, None) => None,
};
let results: Vec<(i64, f32)> = Vec::new();
quota_manager.end_query(&self.ctx.tenant_id);
OperationResult::Success(SearchResult {
results,
k: self.k,
table_used: table_ref,
filter_applied: combined_filter,
duration_ms: start.elapsed().as_millis() as u64,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub results: Vec<(i64, f32)>,
pub k: usize,
pub table_used: String,
pub filter_applied: Option<String>,
pub duration_ms: u64,
}
pub struct TenantVectorDelete<'a> {
ctx: &'a TenantContext,
ids: Vec<i64>,
table_name: String,
}
impl<'a> TenantVectorDelete<'a> {
pub fn new(ctx: &'a TenantContext, ids: Vec<i64>, table_name: &str) -> Self {
Self {
ctx,
ids,
table_name: table_name.to_string(),
}
}
pub fn execute(self) -> OperationResult<DeleteResult> {
let quota_manager = get_quota_manager();
let start = Instant::now();
let table_ref = self.ctx.table_ref(&self.table_name);
let deleted_count = self.ids.len();
let deleted_bytes = (deleted_count * 4 * 1536) as u64;
quota_manager.record_vector_delete(
&self.ctx.tenant_id,
deleted_count as u64,
deleted_bytes,
);
OperationResult::Success(DeleteResult {
deleted_count,
table_used: table_ref,
duration_ms: start.elapsed().as_millis() as u64,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeleteResult {
pub deleted_count: usize,
pub table_used: String,
pub duration_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantStats {
pub tenant_id: String,
pub vector_count: u64,
pub storage_bytes: u64,
pub collection_count: u32,
pub isolation_level: String,
pub integrity_state: String,
pub lambda_cut: f32,
pub is_suspended: bool,
pub quota_usage_percent: f32,
}
pub fn get_tenant_stats(tenant_id: &str) -> Result<TenantStats, TenantError> {
let config = get_registry()
.get(tenant_id)
.ok_or_else(|| TenantError::NotFound(tenant_id.to_string()))?;
let usage = get_quota_manager().get_usage(tenant_id).unwrap_or_default();
let shared_state = get_registry().get_shared_state(tenant_id);
let (integrity_state, lambda_cut) = match shared_state {
Some(state) => {
let integrity = match state
.integrity_state
.load(std::sync::atomic::Ordering::Relaxed)
{
0 => "normal",
1 => "stress",
2 => "critical",
_ => "unknown",
};
(integrity.to_string(), state.lambda_cut())
}
None => ("unknown".to_string(), 1.0),
};
Ok(TenantStats {
tenant_id: tenant_id.to_string(),
vector_count: usage.vector_count,
storage_bytes: usage.storage_bytes,
collection_count: usage.collection_count,
isolation_level: config.isolation_level.as_str().to_string(),
integrity_state,
lambda_cut,
is_suspended: config.is_suspended(),
quota_usage_percent: (usage.vector_count as f64 / config.quota.max_vectors as f64 * 100.0)
as f32,
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLogEntry {
pub tenant_id: String,
pub operation: String,
pub user_id: Option<String>,
pub details: serde_json::Value,
pub timestamp: i64,
pub ip_address: Option<String>,
pub success: bool,
pub error: Option<String>,
}
impl AuditLogEntry {
pub fn new(tenant_id: &str, operation: &str) -> Self {
Self {
tenant_id: tenant_id.to_string(),
operation: operation.to_string(),
user_id: None,
details: serde_json::json!({}),
timestamp: chrono_now_millis(),
ip_address: None,
success: true,
error: None,
}
}
pub fn with_user(mut self, user_id: &str) -> Self {
self.user_id = Some(user_id.to_string());
self
}
pub fn with_details(mut self, details: serde_json::Value) -> Self {
self.details = details;
self
}
pub fn failed(mut self, error: &str) -> Self {
self.success = false;
self.error = Some(error.to_string());
self
}
pub fn insert_sql_parameterized(&self) -> (String, Vec<Option<String>>) {
let sql = r#"
INSERT INTO ruvector.tenant_audit_log (tenant_id, operation, user_id, details, ip_address, success, error)
VALUES ($1, $2, $3, $4, $5, $6, $7)
"#.to_string();
let params = vec![
Some(self.tenant_id.clone()),
Some(self.operation.clone()),
self.user_id.clone(),
Some(serde_json::to_string(&self.details).unwrap_or_else(|_| "{}".to_string())),
self.ip_address.as_ref().and_then(|ip| {
if validate_ip_address(ip) {
Some(ip.clone())
} else {
None
}
}),
Some(self.success.to_string()),
self.error.clone(),
];
(sql, params)
}
pub fn insert_sql(&self) -> String {
if validate_tenant_id(&self.tenant_id).is_err() {
return "SELECT 1 WHERE false".to_string(); }
let escaped_tenant_id = escape_string_literal(&self.tenant_id);
let escaped_operation = escape_string_literal(&self.operation);
let escaped_user_id = self
.user_id
.as_ref()
.map(|u| format!("'{}'", escape_string_literal(u)))
.unwrap_or_else(|| "NULL".to_string());
let escaped_details = escape_string_literal(
&serde_json::to_string(&self.details).unwrap_or_else(|_| "{}".to_string()),
);
let escaped_ip = self
.ip_address
.as_ref()
.and_then(|ip| {
if validate_ip_address(ip) {
Some(format!("'{}'", escape_string_literal(ip)))
} else {
None
}
})
.unwrap_or_else(|| "NULL".to_string());
let escaped_error = self
.error
.as_ref()
.map(|e| format!("'{}'", escape_string_literal(e)))
.unwrap_or_else(|| "NULL".to_string());
format!(
r#"
INSERT INTO ruvector.tenant_audit_log (tenant_id, operation, user_id, details, ip_address, success, error)
VALUES ('{}', '{}', {}, '{}', {}, {}, {})
"#,
escaped_tenant_id,
escaped_operation,
escaped_user_id,
escaped_details,
escaped_ip,
self.success,
escaped_error
)
}
}
fn chrono_now_millis() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as i64)
.unwrap_or(0)
}
pub fn validate_cross_tenant(
context_tenant: &str,
request_tenant: Option<&str>,
) -> Result<(), TenantError> {
if let Some(req_tenant) = request_tenant {
if req_tenant != context_tenant && context_tenant != "*" {
let entry = AuditLogEntry::new(context_tenant, "cross_tenant_attempt")
.with_details(serde_json::json!({
"requested_tenant": req_tenant,
"context_tenant": context_tenant
}))
.failed("Cross-tenant access denied");
return Err(TenantError::TenantMismatch {
context: context_tenant.to_string(),
request: req_tenant.to_string(),
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::super::registry::TenantConfig;
use super::*;
fn setup_test_tenant(id: &str) {
let registry = get_registry();
let config = TenantConfig::new(id.to_string());
let _ = registry.register(config);
}
#[test]
fn test_operation_result() {
let success: OperationResult<i32> = OperationResult::Success(42);
assert!(success.is_success());
assert_eq!(success.unwrap(), 42);
let denied = OperationResult::<i32>::QuotaDenied(QuotaResult::RateLimited {
retry_after_ms: 100,
});
assert!(!denied.is_success());
assert!(denied.into_result().is_err());
}
#[test]
fn test_cross_tenant_validation() {
assert!(validate_cross_tenant("tenant-a", Some("tenant-a")).is_ok());
assert!(validate_cross_tenant("tenant-a", Some("tenant-b")).is_err());
assert!(validate_cross_tenant("*", Some("tenant-b")).is_ok());
assert!(validate_cross_tenant("tenant-a", None).is_ok());
}
#[test]
fn test_audit_log_entry() {
let entry = AuditLogEntry::new("acme-corp", "vector_insert")
.with_user("user123")
.with_details(serde_json::json!({"count": 100}));
assert_eq!(entry.tenant_id, "acme-corp");
assert_eq!(entry.operation, "vector_insert");
assert!(entry.success);
let failed_entry =
AuditLogEntry::new("acme-corp", "vector_insert").failed("Quota exceeded");
assert!(!failed_entry.success);
assert!(failed_entry.error.is_some());
}
#[test]
fn test_insert_result() {
let result = InsertResult {
inserted_count: 100,
table_used: "embeddings".to_string(),
duration_ms: 50,
};
assert_eq!(result.inserted_count, 100);
}
#[test]
fn test_search_result() {
let result = SearchResult {
results: vec![(1, 0.1), (2, 0.2)],
k: 10,
table_used: "embeddings".to_string(),
filter_applied: Some("category = 'test'".to_string()),
duration_ms: 25,
};
assert_eq!(result.results.len(), 2);
assert!(result.filter_applied.is_some());
}
}