use crate::error::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecord {
pub subscription_item_id: String,
pub quantity: u64,
pub timestamp: Option<i64>,
pub action: UsageAction,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum UsageAction {
#[default]
Increment,
Set,
}
impl UsageAction {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Increment => "increment",
Self::Set => "set",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecordResult {
pub id: String,
pub quantity: u64,
pub timestamp: i64,
pub subscription_item_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageSummary {
pub total_usage: u64,
pub items: Vec<UsageItemSummary>,
pub period_start: i64,
pub period_end: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageItemSummary {
pub subscription_item_id: String,
pub total_usage: u64,
}
#[async_trait]
pub trait StripeUsageClient: Send + Sync {
async fn create_usage_record(
&self,
subscription_item_id: &str,
quantity: u64,
timestamp: Option<i64>,
action: UsageAction,
) -> Result<UsageRecordResult>;
async fn list_usage_records(
&self,
subscription_item_id: &str,
) -> Result<Vec<UsageRecordSummary>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecordSummary {
pub id: String,
pub total_usage: u64,
pub period_start: i64,
pub period_end: i64,
pub invoice: Option<String>,
}
pub struct UsageManager<C: StripeUsageClient> {
client: C,
}
impl<C: StripeUsageClient> UsageManager<C> {
#[must_use]
pub fn new(client: C) -> Self {
Self { client }
}
pub async fn report_usage(&self, record: UsageRecord) -> Result<UsageRecordResult> {
self.client
.create_usage_record(
&record.subscription_item_id,
record.quantity,
record.timestamp,
record.action,
)
.await
}
pub async fn report_usage_batch(
&self,
records: Vec<UsageRecord>,
) -> Result<Vec<UsageRecordResult>> {
let futures: Vec<_> = records.into_iter().map(|r| self.report_usage(r)).collect();
let results = futures::future::try_join_all(futures).await?;
Ok(results)
}
pub async fn get_usage_records(
&self,
subscription_item_id: &str,
) -> Result<Vec<UsageRecordSummary>> {
self.client.list_usage_records(subscription_item_id).await
}
}
#[derive(Debug, Default)]
pub struct UsageTracker {
usage: std::sync::RwLock<std::collections::HashMap<String, u64>>,
}
impl UsageTracker {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn track(&self, subscription_item_id: &str, quantity: u64) {
if let Ok(mut usage) = self.usage.write() {
*usage.entry(subscription_item_id.to_string()).or_default() += quantity;
}
}
#[must_use]
pub fn current(&self) -> std::collections::HashMap<String, u64> {
self.usage.read().map(|u| u.clone()).unwrap_or_default()
}
pub fn flush(&self) -> Vec<UsageRecord> {
let mut usage = match self.usage.write() {
Ok(u) => u,
Err(_) => return vec![],
};
let records: Vec<UsageRecord> = usage
.drain()
.filter(|(_, qty)| *qty > 0)
.map(|(item_id, quantity)| UsageRecord {
subscription_item_id: item_id,
quantity,
timestamp: None,
action: UsageAction::Increment,
})
.collect();
records
}
#[must_use]
pub fn has_usage(&self) -> bool {
self.usage.read().map(|u| !u.is_empty()).unwrap_or(false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageThreshold {
pub subscription_item_id: String,
pub warning_threshold: Option<u64>,
pub hard_limit: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UsageCheckResult {
Ok,
Warning { current: u64, threshold: u64 },
Exceeded { current: u64, limit: u64 },
}
impl UsageCheckResult {
#[must_use]
pub fn is_allowed(&self) -> bool {
!matches!(self, Self::Exceeded { .. })
}
#[must_use]
pub fn is_warning(&self) -> bool {
matches!(self, Self::Warning { .. })
}
}
#[must_use]
pub fn check_usage(current: u64, threshold: &UsageThreshold) -> UsageCheckResult {
if let Some(limit) = threshold.hard_limit {
if current >= limit {
return UsageCheckResult::Exceeded { current, limit };
}
}
if let Some(warning) = threshold.warning_threshold {
if current >= warning {
return UsageCheckResult::Warning {
current,
threshold: warning,
};
}
}
UsageCheckResult::Ok
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::sync::{Arc, RwLock};
#[derive(Default, Clone)]
pub struct MockStripeUsageClient {
records: Arc<RwLock<Vec<UsageRecordResult>>>,
summaries: Arc<RwLock<Vec<UsageRecordSummary>>>,
}
impl MockStripeUsageClient {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn get_records(&self) -> Vec<UsageRecordResult> {
self.records.read().unwrap().clone()
}
pub fn set_summaries(&self, summaries: Vec<UsageRecordSummary>) {
*self.summaries.write().unwrap() = summaries;
}
}
#[async_trait]
impl StripeUsageClient for MockStripeUsageClient {
async fn create_usage_record(
&self,
subscription_item_id: &str,
quantity: u64,
timestamp: Option<i64>,
_action: UsageAction,
) -> Result<UsageRecordResult> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let result = UsageRecordResult {
id: format!("mbur_{}", uuid::Uuid::new_v4()),
quantity,
timestamp: timestamp.unwrap_or(now),
subscription_item_id: subscription_item_id.to_string(),
};
self.records.write().unwrap().push(result.clone());
Ok(result)
}
async fn list_usage_records(
&self,
_subscription_item_id: &str,
) -> Result<Vec<UsageRecordSummary>> {
Ok(self.summaries.read().unwrap().clone())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use test::MockStripeUsageClient;
#[tokio::test]
async fn test_report_usage() {
let client = MockStripeUsageClient::new();
let manager = UsageManager::new(client.clone());
let result = manager
.report_usage(UsageRecord {
subscription_item_id: "si_test".to_string(),
quantity: 100,
timestamp: None,
action: UsageAction::Increment,
})
.await
.unwrap();
assert_eq!(result.quantity, 100);
assert_eq!(result.subscription_item_id, "si_test");
let records = client.get_records();
assert_eq!(records.len(), 1);
}
#[tokio::test]
async fn test_report_usage_batch() {
let client = MockStripeUsageClient::new();
let manager = UsageManager::new(client.clone());
let records = vec![
UsageRecord {
subscription_item_id: "si_api".to_string(),
quantity: 50,
timestamp: None,
action: UsageAction::Increment,
},
UsageRecord {
subscription_item_id: "si_storage".to_string(),
quantity: 1024,
timestamp: None,
action: UsageAction::Set,
},
];
let results = manager.report_usage_batch(records).await.unwrap();
assert_eq!(results.len(), 2);
let stored = client.get_records();
assert_eq!(stored.len(), 2);
}
#[test]
fn test_usage_tracker() {
let tracker = UsageTracker::new();
tracker.track("si_api", 10);
tracker.track("si_api", 5);
tracker.track("si_storage", 100);
let current = tracker.current();
assert_eq!(current.get("si_api"), Some(&15));
assert_eq!(current.get("si_storage"), Some(&100));
let records = tracker.flush();
assert_eq!(records.len(), 2);
assert!(!tracker.has_usage());
assert!(tracker.current().is_empty());
}
#[test]
fn test_usage_check() {
let threshold = UsageThreshold {
subscription_item_id: "si_test".to_string(),
warning_threshold: Some(80),
hard_limit: Some(100),
};
assert_eq!(check_usage(50, &threshold), UsageCheckResult::Ok);
assert!(check_usage(50, &threshold).is_allowed());
let warning = check_usage(85, &threshold);
assert!(matches!(
warning,
UsageCheckResult::Warning {
current: 85,
threshold: 80
}
));
assert!(warning.is_allowed());
assert!(warning.is_warning());
let exceeded = check_usage(100, &threshold);
assert!(matches!(
exceeded,
UsageCheckResult::Exceeded {
current: 100,
limit: 100
}
));
assert!(!exceeded.is_allowed());
}
#[test]
fn test_usage_action() {
assert_eq!(UsageAction::Increment.as_str(), "increment");
assert_eq!(UsageAction::Set.as_str(), "set");
assert_eq!(UsageAction::default(), UsageAction::Increment);
}
}