use crate::multi_tenancy::types::{MultiTenancyError, MultiTenancyResult, TenantOperation};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BillingPeriod {
Hourly,
Daily,
Monthly,
Annual,
}
impl BillingPeriod {
pub fn duration_secs(&self) -> i64 {
match self {
Self::Hourly => 3600,
Self::Daily => 86400,
Self::Monthly => 2592000, Self::Annual => 31536000, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PricingModel {
PerRequest {
cost_per_request: f64,
},
PerVector {
cost_per_1k_vectors: f64,
},
PerStorage {
cost_per_gb: f64,
},
PerComputeUnit {
cost_per_unit: f64,
},
Subscription {
monthly_fee: f64,
included_requests: u64,
overage_cost: f64,
},
Custom {
base_fee: f64,
operation_costs: HashMap<String, f64>,
},
}
impl PricingModel {
pub fn calculate_cost(&self, operation: TenantOperation, count: u64) -> f64 {
match self {
Self::PerRequest { cost_per_request } => *cost_per_request * count as f64,
Self::PerComputeUnit { cost_per_unit } => {
*cost_per_unit * operation.default_cost_weight() * count as f64
}
Self::Custom {
operation_costs, ..
} => {
let op_cost = operation_costs
.get(operation.name())
.copied()
.unwrap_or(0.01);
op_cost * count as f64
}
_ => 0.0, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecord {
pub tenant_id: String,
pub operation: TenantOperation,
pub count: u64,
pub timestamp: DateTime<Utc>,
pub cost: f64,
pub metadata: HashMap<String, String>,
}
impl UsageRecord {
pub fn new(tenant_id: impl Into<String>, operation: TenantOperation, count: u64) -> Self {
Self {
tenant_id: tenant_id.into(),
operation,
count,
timestamp: Utc::now(),
cost: 0.0,
metadata: HashMap::new(),
}
}
pub fn calculate_cost(&mut self, pricing: &PricingModel) {
self.cost = pricing.calculate_cost(self.operation, self.count);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BillingMetrics {
pub tenant_id: String,
pub period_start: DateTime<Utc>,
pub period_end: DateTime<Utc>,
pub total_cost: f64,
pub total_requests: u64,
pub avg_request_cost: f64,
pub cost_by_operation: HashMap<String, f64>,
pub requests_by_operation: HashMap<String, u64>,
pub peak_daily_cost: f64,
pub estimated_monthly_cost: f64,
}
impl BillingMetrics {
pub fn new(tenant_id: impl Into<String>, period: BillingPeriod) -> Self {
let now = Utc::now();
let period_end = now + Duration::seconds(period.duration_secs());
Self {
tenant_id: tenant_id.into(),
period_start: now,
period_end,
total_cost: 0.0,
total_requests: 0,
avg_request_cost: 0.0,
cost_by_operation: HashMap::new(),
requests_by_operation: HashMap::new(),
peak_daily_cost: 0.0,
estimated_monthly_cost: 0.0,
}
}
pub fn record_usage(&mut self, record: &UsageRecord) {
self.total_cost += record.cost;
self.total_requests += record.count;
let op_name = record.operation.name().to_string();
*self.cost_by_operation.entry(op_name.clone()).or_insert(0.0) += record.cost;
*self.requests_by_operation.entry(op_name).or_insert(0) += record.count;
if self.total_requests > 0 {
self.avg_request_cost = self.total_cost / self.total_requests as f64;
}
let elapsed_secs = (Utc::now() - self.period_start).num_seconds() as f64;
if elapsed_secs > 0.0 {
let monthly_secs = 2592000.0; self.estimated_monthly_cost = self.total_cost * (monthly_secs / elapsed_secs);
}
}
pub fn reset(&mut self, period: BillingPeriod) {
self.period_start = Utc::now();
self.period_end = self.period_start + Duration::seconds(period.duration_secs());
self.total_cost = 0.0;
self.total_requests = 0;
self.avg_request_cost = 0.0;
self.cost_by_operation.clear();
self.requests_by_operation.clear();
}
}
pub struct BillingEngine {
pricing: Arc<Mutex<HashMap<String, PricingModel>>>,
usage_history: Arc<Mutex<Vec<UsageRecord>>>,
metrics: Arc<Mutex<HashMap<String, BillingMetrics>>>,
period: BillingPeriod,
}
impl BillingEngine {
pub fn new(period: BillingPeriod) -> Self {
Self {
pricing: Arc::new(Mutex::new(HashMap::new())),
usage_history: Arc::new(Mutex::new(Vec::new())),
metrics: Arc::new(Mutex::new(HashMap::new())),
period,
}
}
pub fn set_pricing(
&self,
tenant_id: impl Into<String>,
pricing: PricingModel,
) -> MultiTenancyResult<()> {
let tenant_id = tenant_id.into();
self.pricing
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?
.insert(tenant_id.clone(), pricing);
self.metrics
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?
.entry(tenant_id.clone())
.or_insert_with(|| BillingMetrics::new(tenant_id, self.period));
Ok(())
}
pub fn record_usage(
&self,
tenant_id: &str,
operation: TenantOperation,
count: u64,
) -> MultiTenancyResult<f64> {
let mut record = UsageRecord::new(tenant_id, operation, count);
let pricing = self
.pricing
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?
.get(tenant_id)
.cloned()
.ok_or_else(|| MultiTenancyError::BillingError {
message: format!("No pricing model for tenant: {}", tenant_id),
})?;
record.calculate_cost(&pricing);
let cost = record.cost;
let mut metrics = self
.metrics
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?;
metrics
.entry(tenant_id.to_string())
.or_insert_with(|| BillingMetrics::new(tenant_id, self.period))
.record_usage(&record);
self.usage_history
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?
.push(record);
Ok(cost)
}
pub fn get_metrics(&self, tenant_id: &str) -> MultiTenancyResult<BillingMetrics> {
self.metrics
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?
.get(tenant_id)
.cloned()
.ok_or_else(|| MultiTenancyError::TenantNotFound {
tenant_id: tenant_id.to_string(),
})
}
pub fn get_usage_history(
&self,
tenant_id: &str,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> MultiTenancyResult<Vec<UsageRecord>> {
let history = self
.usage_history
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?;
Ok(history
.iter()
.filter(|r| r.tenant_id == tenant_id && r.timestamp >= start && r.timestamp <= end)
.cloned()
.collect())
}
pub fn reset_period(&self, tenant_id: &str) -> MultiTenancyResult<()> {
let mut metrics = self
.metrics
.lock()
.map_err(|e| MultiTenancyError::InternalError {
message: format!("Lock error: {}", e),
})?;
metrics
.get_mut(tenant_id)
.ok_or_else(|| MultiTenancyError::TenantNotFound {
tenant_id: tenant_id.to_string(),
})?
.reset(self.period);
Ok(())
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
#[test]
fn test_billing_period() {
assert_eq!(BillingPeriod::Hourly.duration_secs(), 3600);
assert_eq!(BillingPeriod::Daily.duration_secs(), 86400);
assert_eq!(BillingPeriod::Monthly.duration_secs(), 2592000);
}
#[test]
fn test_pricing_models() {
let model = PricingModel::PerRequest {
cost_per_request: 0.01,
};
assert_eq!(
model.calculate_cost(TenantOperation::VectorSearch, 100),
1.0
);
let model = PricingModel::PerComputeUnit { cost_per_unit: 0.1 };
let cost = model.calculate_cost(TenantOperation::IndexBuild, 1);
assert!(cost > 0.0); }
#[test]
fn test_usage_record() {
let mut record = UsageRecord::new("tenant1", TenantOperation::VectorSearch, 100);
assert_eq!(record.count, 100);
assert_eq!(record.cost, 0.0);
let pricing = PricingModel::PerRequest {
cost_per_request: 0.01,
};
record.calculate_cost(&pricing);
assert_eq!(record.cost, 1.0);
}
#[test]
fn test_billing_metrics() {
let mut metrics = BillingMetrics::new("tenant1", BillingPeriod::Daily);
assert_eq!(metrics.total_cost, 0.0);
assert_eq!(metrics.total_requests, 0);
let mut record = UsageRecord::new("tenant1", TenantOperation::VectorSearch, 100);
record.cost = 1.0;
metrics.record_usage(&record);
assert_eq!(metrics.total_cost, 1.0);
assert_eq!(metrics.total_requests, 100);
assert!((metrics.avg_request_cost - 0.01).abs() < 0.001);
}
#[test]
fn test_billing_engine() -> Result<()> {
let engine = BillingEngine::new(BillingPeriod::Daily);
let pricing = PricingModel::PerRequest {
cost_per_request: 0.01,
};
engine.set_pricing("tenant1", pricing)?;
let cost = engine.record_usage("tenant1", TenantOperation::VectorSearch, 100)?;
assert_eq!(cost, 1.0);
let metrics = engine.get_metrics("tenant1")?;
assert_eq!(metrics.total_cost, 1.0);
assert_eq!(metrics.total_requests, 100);
engine.record_usage("tenant1", TenantOperation::VectorInsert, 50)?;
let metrics = engine.get_metrics("tenant1")?;
assert_eq!(metrics.total_cost, 1.5);
assert_eq!(metrics.total_requests, 150);
Ok(())
}
#[test]
fn test_usage_history() -> Result<()> {
let engine = BillingEngine::new(BillingPeriod::Daily);
let pricing = PricingModel::PerRequest {
cost_per_request: 0.01,
};
engine.set_pricing("tenant1", pricing)?;
engine.record_usage("tenant1", TenantOperation::VectorSearch, 100)?;
engine.record_usage("tenant1", TenantOperation::VectorInsert, 50)?;
let start = Utc::now() - Duration::hours(1);
let end = Utc::now() + Duration::hours(1);
let history = engine.get_usage_history("tenant1", start, end)?;
assert_eq!(history.len(), 2);
assert_eq!(history[0].count, 100);
assert_eq!(history[1].count, 50);
Ok(())
}
#[test]
fn test_subscription_pricing() {
let pricing = PricingModel::Subscription {
monthly_fee: 100.0,
included_requests: 10000,
overage_cost: 0.02,
};
match pricing {
PricingModel::Subscription {
monthly_fee,
included_requests,
overage_cost,
} => {
assert_eq!(monthly_fee, 100.0);
assert_eq!(included_requests, 10000);
assert_eq!(overage_cost, 0.02);
}
_ => panic!("Expected subscription pricing"),
}
}
}