1use crate::multi_tenancy::types::{MultiTenancyError, MultiTenancyResult, TenantOperation};
4use chrono::{DateTime, Duration, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum BillingPeriod {
12 Hourly,
13 Daily,
14 Monthly,
15 Annual,
16}
17
18impl BillingPeriod {
19 pub fn duration_secs(&self) -> i64 {
21 match self {
22 Self::Hourly => 3600,
23 Self::Daily => 86400,
24 Self::Monthly => 2592000, Self::Annual => 31536000, }
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum PricingModel {
33 PerRequest {
35 cost_per_request: f64,
37 },
38 PerVector {
40 cost_per_1k_vectors: f64,
42 },
43 PerStorage {
45 cost_per_gb: f64,
47 },
48 PerComputeUnit {
50 cost_per_unit: f64,
52 },
53 Subscription {
55 monthly_fee: f64,
57 included_requests: u64,
59 overage_cost: f64,
61 },
62 Custom {
64 base_fee: f64,
66 operation_costs: HashMap<String, f64>,
68 },
69}
70
71impl PricingModel {
72 pub fn calculate_cost(&self, operation: TenantOperation, count: u64) -> f64 {
74 match self {
75 Self::PerRequest { cost_per_request } => *cost_per_request * count as f64,
76 Self::PerComputeUnit { cost_per_unit } => {
77 *cost_per_unit * operation.default_cost_weight() * count as f64
78 }
79 Self::Custom {
80 operation_costs, ..
81 } => {
82 let op_cost = operation_costs
83 .get(operation.name())
84 .copied()
85 .unwrap_or(0.01);
86 op_cost * count as f64
87 }
88 _ => 0.0, }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct UsageRecord {
96 pub tenant_id: String,
98 pub operation: TenantOperation,
100 pub count: u64,
102 pub timestamp: DateTime<Utc>,
104 pub cost: f64,
106 pub metadata: HashMap<String, String>,
108}
109
110impl UsageRecord {
111 pub fn new(tenant_id: impl Into<String>, operation: TenantOperation, count: u64) -> Self {
113 Self {
114 tenant_id: tenant_id.into(),
115 operation,
116 count,
117 timestamp: Utc::now(),
118 cost: 0.0,
119 metadata: HashMap::new(),
120 }
121 }
122
123 pub fn calculate_cost(&mut self, pricing: &PricingModel) {
125 self.cost = pricing.calculate_cost(self.operation, self.count);
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct BillingMetrics {
132 pub tenant_id: String,
134
135 pub period_start: DateTime<Utc>,
137
138 pub period_end: DateTime<Utc>,
140
141 pub total_cost: f64,
143
144 pub total_requests: u64,
146
147 pub avg_request_cost: f64,
149
150 pub cost_by_operation: HashMap<String, f64>,
152
153 pub requests_by_operation: HashMap<String, u64>,
155
156 pub peak_daily_cost: f64,
158
159 pub estimated_monthly_cost: f64,
161}
162
163impl BillingMetrics {
164 pub fn new(tenant_id: impl Into<String>, period: BillingPeriod) -> Self {
166 let now = Utc::now();
167 let period_end = now + Duration::seconds(period.duration_secs());
168
169 Self {
170 tenant_id: tenant_id.into(),
171 period_start: now,
172 period_end,
173 total_cost: 0.0,
174 total_requests: 0,
175 avg_request_cost: 0.0,
176 cost_by_operation: HashMap::new(),
177 requests_by_operation: HashMap::new(),
178 peak_daily_cost: 0.0,
179 estimated_monthly_cost: 0.0,
180 }
181 }
182
183 pub fn record_usage(&mut self, record: &UsageRecord) {
185 self.total_cost += record.cost;
186 self.total_requests += record.count;
187
188 let op_name = record.operation.name().to_string();
189 *self.cost_by_operation.entry(op_name.clone()).or_insert(0.0) += record.cost;
190 *self.requests_by_operation.entry(op_name).or_insert(0) += record.count;
191
192 if self.total_requests > 0 {
194 self.avg_request_cost = self.total_cost / self.total_requests as f64;
195 }
196
197 let elapsed_secs = (Utc::now() - self.period_start).num_seconds() as f64;
199 if elapsed_secs > 0.0 {
200 let monthly_secs = 2592000.0; self.estimated_monthly_cost = self.total_cost * (monthly_secs / elapsed_secs);
202 }
203 }
204
205 pub fn reset(&mut self, period: BillingPeriod) {
207 self.period_start = Utc::now();
208 self.period_end = self.period_start + Duration::seconds(period.duration_secs());
209 self.total_cost = 0.0;
210 self.total_requests = 0;
211 self.avg_request_cost = 0.0;
212 self.cost_by_operation.clear();
213 self.requests_by_operation.clear();
214 }
215}
216
217pub struct BillingEngine {
219 pricing: Arc<Mutex<HashMap<String, PricingModel>>>,
221
222 usage_history: Arc<Mutex<Vec<UsageRecord>>>,
224
225 metrics: Arc<Mutex<HashMap<String, BillingMetrics>>>,
227
228 period: BillingPeriod,
230}
231
232impl BillingEngine {
233 pub fn new(period: BillingPeriod) -> Self {
235 Self {
236 pricing: Arc::new(Mutex::new(HashMap::new())),
237 usage_history: Arc::new(Mutex::new(Vec::new())),
238 metrics: Arc::new(Mutex::new(HashMap::new())),
239 period,
240 }
241 }
242
243 pub fn set_pricing(
245 &self,
246 tenant_id: impl Into<String>,
247 pricing: PricingModel,
248 ) -> MultiTenancyResult<()> {
249 let tenant_id = tenant_id.into();
250
251 self.pricing
252 .lock()
253 .map_err(|e| MultiTenancyError::InternalError {
254 message: format!("Lock error: {}", e),
255 })?
256 .insert(tenant_id.clone(), pricing);
257
258 self.metrics
260 .lock()
261 .map_err(|e| MultiTenancyError::InternalError {
262 message: format!("Lock error: {}", e),
263 })?
264 .entry(tenant_id.clone())
265 .or_insert_with(|| BillingMetrics::new(tenant_id, self.period));
266
267 Ok(())
268 }
269
270 pub fn record_usage(
272 &self,
273 tenant_id: &str,
274 operation: TenantOperation,
275 count: u64,
276 ) -> MultiTenancyResult<f64> {
277 let mut record = UsageRecord::new(tenant_id, operation, count);
278
279 let pricing = self
281 .pricing
282 .lock()
283 .map_err(|e| MultiTenancyError::InternalError {
284 message: format!("Lock error: {}", e),
285 })?
286 .get(tenant_id)
287 .cloned()
288 .ok_or_else(|| MultiTenancyError::BillingError {
289 message: format!("No pricing model for tenant: {}", tenant_id),
290 })?;
291
292 record.calculate_cost(&pricing);
293 let cost = record.cost;
294
295 let mut metrics = self
297 .metrics
298 .lock()
299 .map_err(|e| MultiTenancyError::InternalError {
300 message: format!("Lock error: {}", e),
301 })?;
302
303 metrics
304 .entry(tenant_id.to_string())
305 .or_insert_with(|| BillingMetrics::new(tenant_id, self.period))
306 .record_usage(&record);
307
308 self.usage_history
310 .lock()
311 .map_err(|e| MultiTenancyError::InternalError {
312 message: format!("Lock error: {}", e),
313 })?
314 .push(record);
315
316 Ok(cost)
317 }
318
319 pub fn get_metrics(&self, tenant_id: &str) -> MultiTenancyResult<BillingMetrics> {
321 self.metrics
322 .lock()
323 .map_err(|e| MultiTenancyError::InternalError {
324 message: format!("Lock error: {}", e),
325 })?
326 .get(tenant_id)
327 .cloned()
328 .ok_or_else(|| MultiTenancyError::TenantNotFound {
329 tenant_id: tenant_id.to_string(),
330 })
331 }
332
333 pub fn get_usage_history(
335 &self,
336 tenant_id: &str,
337 start: DateTime<Utc>,
338 end: DateTime<Utc>,
339 ) -> MultiTenancyResult<Vec<UsageRecord>> {
340 let history = self
341 .usage_history
342 .lock()
343 .map_err(|e| MultiTenancyError::InternalError {
344 message: format!("Lock error: {}", e),
345 })?;
346
347 Ok(history
348 .iter()
349 .filter(|r| r.tenant_id == tenant_id && r.timestamp >= start && r.timestamp <= end)
350 .cloned()
351 .collect())
352 }
353
354 pub fn reset_period(&self, tenant_id: &str) -> MultiTenancyResult<()> {
356 let mut metrics = self
357 .metrics
358 .lock()
359 .map_err(|e| MultiTenancyError::InternalError {
360 message: format!("Lock error: {}", e),
361 })?;
362
363 metrics
364 .get_mut(tenant_id)
365 .ok_or_else(|| MultiTenancyError::TenantNotFound {
366 tenant_id: tenant_id.to_string(),
367 })?
368 .reset(self.period);
369
370 Ok(())
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_billing_period() {
380 assert_eq!(BillingPeriod::Hourly.duration_secs(), 3600);
381 assert_eq!(BillingPeriod::Daily.duration_secs(), 86400);
382 assert_eq!(BillingPeriod::Monthly.duration_secs(), 2592000);
383 }
384
385 #[test]
386 fn test_pricing_models() {
387 let model = PricingModel::PerRequest {
388 cost_per_request: 0.01,
389 };
390 assert_eq!(
391 model.calculate_cost(TenantOperation::VectorSearch, 100),
392 1.0
393 );
394
395 let model = PricingModel::PerComputeUnit { cost_per_unit: 0.1 };
396 let cost = model.calculate_cost(TenantOperation::IndexBuild, 1);
397 assert!(cost > 0.0); }
399
400 #[test]
401 fn test_usage_record() {
402 let mut record = UsageRecord::new("tenant1", TenantOperation::VectorSearch, 100);
403 assert_eq!(record.count, 100);
404 assert_eq!(record.cost, 0.0);
405
406 let pricing = PricingModel::PerRequest {
407 cost_per_request: 0.01,
408 };
409 record.calculate_cost(&pricing);
410 assert_eq!(record.cost, 1.0);
411 }
412
413 #[test]
414 fn test_billing_metrics() {
415 let mut metrics = BillingMetrics::new("tenant1", BillingPeriod::Daily);
416 assert_eq!(metrics.total_cost, 0.0);
417 assert_eq!(metrics.total_requests, 0);
418
419 let mut record = UsageRecord::new("tenant1", TenantOperation::VectorSearch, 100);
420 record.cost = 1.0;
421 metrics.record_usage(&record);
422
423 assert_eq!(metrics.total_cost, 1.0);
424 assert_eq!(metrics.total_requests, 100);
425 assert!((metrics.avg_request_cost - 0.01).abs() < 0.001);
426 }
427
428 #[test]
429 fn test_billing_engine() {
430 let engine = BillingEngine::new(BillingPeriod::Daily);
431
432 let pricing = PricingModel::PerRequest {
434 cost_per_request: 0.01,
435 };
436 engine.set_pricing("tenant1", pricing).unwrap();
437
438 let cost = engine
440 .record_usage("tenant1", TenantOperation::VectorSearch, 100)
441 .unwrap();
442 assert_eq!(cost, 1.0);
443
444 let metrics = engine.get_metrics("tenant1").unwrap();
446 assert_eq!(metrics.total_cost, 1.0);
447 assert_eq!(metrics.total_requests, 100);
448
449 engine
451 .record_usage("tenant1", TenantOperation::VectorInsert, 50)
452 .unwrap();
453
454 let metrics = engine.get_metrics("tenant1").unwrap();
455 assert_eq!(metrics.total_cost, 1.5);
456 assert_eq!(metrics.total_requests, 150);
457 }
458
459 #[test]
460 fn test_usage_history() {
461 let engine = BillingEngine::new(BillingPeriod::Daily);
462
463 let pricing = PricingModel::PerRequest {
464 cost_per_request: 0.01,
465 };
466 engine.set_pricing("tenant1", pricing).unwrap();
467
468 engine
470 .record_usage("tenant1", TenantOperation::VectorSearch, 100)
471 .unwrap();
472 engine
473 .record_usage("tenant1", TenantOperation::VectorInsert, 50)
474 .unwrap();
475
476 let start = Utc::now() - Duration::hours(1);
478 let end = Utc::now() + Duration::hours(1);
479 let history = engine.get_usage_history("tenant1", start, end).unwrap();
480
481 assert_eq!(history.len(), 2);
482 assert_eq!(history[0].count, 100);
483 assert_eq!(history[1].count, 50);
484 }
485
486 #[test]
487 fn test_subscription_pricing() {
488 let pricing = PricingModel::Subscription {
489 monthly_fee: 100.0,
490 included_requests: 10000,
491 overage_cost: 0.02,
492 };
493
494 match pricing {
496 PricingModel::Subscription {
497 monthly_fee,
498 included_requests,
499 overage_cost,
500 } => {
501 assert_eq!(monthly_fee, 100.0);
502 assert_eq!(included_requests, 10000);
503 assert_eq!(overage_cost, 0.02);
504 }
505 _ => panic!("Expected subscription pricing"),
506 }
507 }
508}