use crate::error::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[async_trait]
pub trait BillingStore: Send + Sync {
async fn get_stripe_customer_id(&self, billable_id: &str) -> Result<Option<String>>;
async fn set_stripe_customer_id(
&self,
billable_id: &str,
billable_type: &str,
customer_id: &str,
) -> Result<()>;
async fn get_subscription(&self, billable_id: &str) -> Result<Option<StoredSubscription>>;
async fn save_subscription(
&self,
billable_id: &str,
subscription: &StoredSubscription,
) -> Result<()>;
async fn compare_and_save_subscription(
&self,
billable_id: &str,
subscription: &StoredSubscription,
expected_version: u64,
) -> Result<bool> {
#[cfg(debug_assertions)]
{
static WARNED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
if !WARNED.swap(true, std::sync::atomic::Ordering::Relaxed) {
tracing::warn!(
target: "tideway::billing",
"Using default non-atomic compare_and_save_subscription implementation. \
This is NOT safe for production use with concurrent requests. \
Override this method with an atomic compare-and-swap operation."
);
}
}
if let Some(current) = self.get_subscription(billable_id).await? {
if current.updated_at != expected_version {
return Ok(false);
}
}
self.save_subscription(billable_id, subscription).await?;
Ok(true)
}
async fn delete_subscription(&self, billable_id: &str) -> Result<()>;
async fn get_subscription_by_stripe_id(
&self,
stripe_subscription_id: &str,
) -> Result<Option<(String, StoredSubscription)>>;
async fn is_event_processed(&self, event_id: &str) -> Result<bool>;
async fn mark_event_processed(&self, event_id: &str) -> Result<()>;
async fn cleanup_old_events(&self, _older_than_days: u32) -> Result<usize> {
Ok(0)
}
async fn count_subscriptions_by_plan(&self, plan_id: &str) -> Result<u32>;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct StoredSubscription {
pub stripe_subscription_id: String,
pub stripe_customer_id: String,
pub plan_id: String,
pub status: SubscriptionStatus,
pub current_period_start: u64,
pub current_period_end: u64,
pub extra_seats: u32,
pub trial_end: Option<u64>,
pub cancel_at_period_end: bool,
pub base_item_id: Option<String>,
pub seat_item_id: Option<String>,
pub updated_at: u64,
}
impl StoredSubscription {
#[must_use]
pub fn is_active(&self) -> bool {
matches!(
self.status,
SubscriptionStatus::Active | SubscriptionStatus::Trialing
)
}
#[must_use]
pub fn is_trialing(&self) -> bool {
self.status == SubscriptionStatus::Trialing
}
#[must_use]
pub fn is_past_due(&self) -> bool {
self.status == SubscriptionStatus::PastDue
}
#[must_use]
pub fn is_canceled(&self) -> bool {
self.status == SubscriptionStatus::Canceled
}
#[must_use]
pub fn will_cancel(&self) -> bool {
self.cancel_at_period_end
}
#[must_use]
pub fn trial_days_remaining(&self) -> Option<u32> {
self.trial_end.and_then(|end| {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
if end > now {
Some(((end - now) / 86400) as u32)
} else {
None
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SubscriptionStatus {
Active,
Trialing,
PastDue,
Canceled,
Incomplete,
IncompleteExpired,
Paused,
Unpaid,
}
impl SubscriptionStatus {
#[must_use]
pub fn from_stripe(status: &str) -> Self {
match status {
"active" => Self::Active,
"trialing" => Self::Trialing,
"past_due" => Self::PastDue,
"canceled" => Self::Canceled,
"incomplete" => Self::Incomplete,
"incomplete_expired" => Self::IncompleteExpired,
"paused" => Self::Paused,
"unpaid" => Self::Unpaid,
_ => Self::Canceled, }
}
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Active => "active",
Self::Trialing => "trialing",
Self::PastDue => "past_due",
Self::Canceled => "canceled",
Self::Incomplete => "incomplete",
Self::IncompleteExpired => "incomplete_expired",
Self::Paused => "paused",
Self::Unpaid => "unpaid",
}
}
}
impl std::fmt::Display for SubscriptionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct StoredPlan {
pub id: String,
pub name: String,
pub description: Option<String>,
pub stripe_price_id: String,
pub stripe_seat_price_id: Option<String>,
pub price_cents: i64,
pub currency: String,
pub interval: PlanInterval,
pub included_seats: u32,
pub features: serde_json::Value,
pub limits: serde_json::Value,
pub trial_days: Option<u32>,
pub is_active: bool,
pub sort_order: i32,
pub created_at: u64,
pub updated_at: u64,
}
impl StoredPlan {
#[must_use]
pub fn new(id: impl Into<String>, stripe_price_id: impl Into<String>) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Self {
id: id.into(),
name: String::new(),
description: None,
stripe_price_id: stripe_price_id.into(),
stripe_seat_price_id: None,
price_cents: 0,
currency: "usd".to_string(),
interval: PlanInterval::Monthly,
included_seats: 1,
features: serde_json::json!({}),
limits: serde_json::json!({}),
trial_days: None,
is_active: true,
sort_order: 0,
created_at: now,
updated_at: now,
}
}
#[must_use]
pub fn has_feature(&self, feature: &str) -> bool {
self.features
.get(feature)
.and_then(|v| v.as_bool())
.unwrap_or(false)
}
#[must_use]
pub fn get_limit(&self, limit: &str) -> Option<i64> {
self.limits.get(limit).and_then(|v| v.as_i64())
}
#[must_use]
pub fn check_limit(&self, resource: &str, current: i64) -> bool {
match self.get_limit(resource) {
None => true, Some(max) => current < max,
}
}
#[must_use]
pub fn formatted_price(&self) -> String {
let symbol = match self.currency.as_str() {
"usd" => "$",
"gbp" => "£",
"eur" => "€",
_ => &self.currency,
};
let dollars = self.price_cents as f64 / 100.0;
format!("{}{:.2}", symbol, dollars)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PlanInterval {
Monthly,
Yearly,
OneTime,
}
impl PlanInterval {
#[allow(clippy::should_implement_trait)] #[must_use]
pub fn from_str(s: &str) -> Self {
match s {
"monthly" | "month" => Self::Monthly,
"yearly" | "year" | "annual" => Self::Yearly,
"one_time" | "onetime" | "lifetime" => Self::OneTime,
_ => Self::Monthly,
}
}
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Monthly => "monthly",
Self::Yearly => "yearly",
Self::OneTime => "one_time",
}
}
}
impl std::fmt::Display for PlanInterval {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[async_trait]
pub trait PlanStore: Send + Sync {
async fn list_plans(&self) -> Result<Vec<StoredPlan>>;
async fn list_all_plans(&self) -> Result<Vec<StoredPlan>>;
async fn get_plan(&self, plan_id: &str) -> Result<Option<StoredPlan>>;
async fn get_plan_by_stripe_price(&self, stripe_price_id: &str) -> Result<Option<StoredPlan>>;
async fn create_plan(&self, plan: &StoredPlan) -> Result<()>;
async fn update_plan(&self, plan: &StoredPlan) -> Result<()>;
async fn delete_plan(&self, plan_id: &str) -> Result<()>;
async fn set_plan_active(&self, plan_id: &str, is_active: bool) -> Result<()>;
}
pub trait BillableEntity: Send + Sync {
fn billable_id(&self) -> &str;
fn billable_type(&self) -> &str;
fn email(&self) -> &str;
fn name(&self) -> Option<&str>;
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Default, Clone)]
pub struct InMemoryBillingStore {
inner: Arc<InMemoryBillingStoreInner>,
}
#[derive(Default)]
struct InMemoryBillingStoreInner {
customers: RwLock<HashMap<String, CustomerRecord>>,
subscriptions: RwLock<HashMap<String, StoredSubscription>>,
processed_events: RwLock<HashMap<String, u64>>,
plans: RwLock<HashMap<String, StoredPlan>>,
}
#[derive(Clone)]
struct CustomerRecord {
#[allow(dead_code)]
billable_type: String,
stripe_customer_id: String,
}
impl InMemoryBillingStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn get_all_subscriptions(&self) -> HashMap<String, StoredSubscription> {
self.inner.subscriptions.read().unwrap().clone()
}
pub fn get_processed_events(&self) -> Vec<String> {
self.inner
.processed_events
.read()
.unwrap()
.keys()
.cloned()
.collect()
}
pub fn get_all_plans(&self) -> HashMap<String, StoredPlan> {
self.inner.plans.read().unwrap().clone()
}
pub fn seed_plans(&self, plans: Vec<StoredPlan>) {
let mut store = self.inner.plans.write().unwrap();
for plan in plans {
store.insert(plan.id.clone(), plan);
}
}
}
#[async_trait]
impl PlanStore for InMemoryBillingStore {
async fn list_plans(&self) -> Result<Vec<StoredPlan>> {
let plans = self.inner.plans.read().unwrap();
let mut active: Vec<StoredPlan> =
plans.values().filter(|p| p.is_active).cloned().collect();
active.sort_by_key(|p| p.sort_order);
Ok(active)
}
async fn list_all_plans(&self) -> Result<Vec<StoredPlan>> {
let plans = self.inner.plans.read().unwrap();
let mut all: Vec<StoredPlan> = plans.values().cloned().collect();
all.sort_by_key(|p| p.sort_order);
Ok(all)
}
async fn get_plan(&self, plan_id: &str) -> Result<Option<StoredPlan>> {
Ok(self.inner.plans.read().unwrap().get(plan_id).cloned())
}
async fn get_plan_by_stripe_price(
&self,
stripe_price_id: &str,
) -> Result<Option<StoredPlan>> {
let plans = self.inner.plans.read().unwrap();
Ok(plans
.values()
.find(|p| p.stripe_price_id == stripe_price_id)
.cloned())
}
async fn create_plan(&self, plan: &StoredPlan) -> Result<()> {
self.inner
.plans
.write()
.unwrap()
.insert(plan.id.clone(), plan.clone());
Ok(())
}
async fn update_plan(&self, plan: &StoredPlan) -> Result<()> {
let mut plans = self.inner.plans.write().unwrap();
if plans.contains_key(&plan.id) {
plans.insert(plan.id.clone(), plan.clone());
}
Ok(())
}
async fn delete_plan(&self, plan_id: &str) -> Result<()> {
self.inner.plans.write().unwrap().remove(plan_id);
Ok(())
}
async fn set_plan_active(&self, plan_id: &str, is_active: bool) -> Result<()> {
let mut plans = self.inner.plans.write().unwrap();
if let Some(plan) = plans.get_mut(plan_id) {
plan.is_active = is_active;
}
Ok(())
}
}
#[async_trait]
impl BillingStore for InMemoryBillingStore {
async fn get_stripe_customer_id(&self, billable_id: &str) -> Result<Option<String>> {
Ok(self
.inner
.customers
.read()
.unwrap()
.get(billable_id)
.map(|r| r.stripe_customer_id.clone()))
}
async fn set_stripe_customer_id(
&self,
billable_id: &str,
billable_type: &str,
customer_id: &str,
) -> Result<()> {
self.inner.customers.write().unwrap().insert(
billable_id.to_string(),
CustomerRecord {
billable_type: billable_type.to_string(),
stripe_customer_id: customer_id.to_string(),
},
);
Ok(())
}
async fn get_subscription(&self, billable_id: &str) -> Result<Option<StoredSubscription>> {
Ok(self
.inner
.subscriptions
.read()
.unwrap()
.get(billable_id)
.cloned())
}
async fn save_subscription(
&self,
billable_id: &str,
subscription: &StoredSubscription,
) -> Result<()> {
self.inner
.subscriptions
.write()
.unwrap()
.insert(billable_id.to_string(), subscription.clone());
Ok(())
}
async fn compare_and_save_subscription(
&self,
billable_id: &str,
subscription: &StoredSubscription,
expected_version: u64,
) -> Result<bool> {
let mut subs = self.inner.subscriptions.write().unwrap();
if let Some(current) = subs.get(billable_id) {
if current.updated_at != expected_version {
return Ok(false);
}
}
subs.insert(billable_id.to_string(), subscription.clone());
Ok(true)
}
async fn delete_subscription(&self, billable_id: &str) -> Result<()> {
self.inner
.subscriptions
.write()
.unwrap()
.remove(billable_id);
Ok(())
}
async fn get_subscription_by_stripe_id(
&self,
stripe_subscription_id: &str,
) -> Result<Option<(String, StoredSubscription)>> {
let subs = self.inner.subscriptions.read().unwrap();
for (billable_id, sub) in subs.iter() {
if sub.stripe_subscription_id == stripe_subscription_id {
return Ok(Some((billable_id.clone(), sub.clone())));
}
}
Ok(None)
}
async fn is_event_processed(&self, event_id: &str) -> Result<bool> {
Ok(self
.inner
.processed_events
.read()
.unwrap()
.contains_key(event_id))
}
async fn mark_event_processed(&self, event_id: &str) -> Result<()> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
self.inner
.processed_events
.write()
.unwrap()
.insert(event_id.to_string(), now);
Ok(())
}
async fn cleanup_old_events(&self, older_than_days: u32) -> Result<usize> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let cutoff = now - (older_than_days as u64 * 86400);
let mut events = self.inner.processed_events.write().unwrap();
let initial_len = events.len();
events.retain(|_, &mut timestamp| timestamp >= cutoff);
Ok(initial_len - events.len())
}
async fn count_subscriptions_by_plan(&self, plan_id: &str) -> Result<u32> {
let subs = self.inner.subscriptions.read().unwrap();
let count = subs
.values()
.filter(|s| {
s.plan_id == plan_id
&& matches!(
s.status,
SubscriptionStatus::Active | SubscriptionStatus::Trialing
)
})
.count();
Ok(count as u32)
}
}
}
pub struct CachedPlanStore<S: PlanStore> {
inner: S,
cache: std::sync::Arc<std::sync::RwLock<PlanCache>>,
ttl: std::time::Duration,
}
struct PlanCache {
plans: std::collections::HashMap<String, CachedPlan>,
active_plans: Option<CachedPlanList>,
all_plans: Option<CachedPlanList>,
}
struct CachedPlan {
plan: Option<StoredPlan>,
expires_at: std::time::Instant,
}
struct CachedPlanList {
plans: Vec<StoredPlan>,
expires_at: std::time::Instant,
}
impl<S: PlanStore> CachedPlanStore<S> {
#[must_use]
pub fn new(inner: S, ttl: std::time::Duration) -> Self {
Self {
inner,
cache: std::sync::Arc::new(std::sync::RwLock::new(PlanCache {
plans: std::collections::HashMap::new(),
active_plans: None,
all_plans: None,
})),
ttl,
}
}
pub fn invalidate(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.plans.clear();
cache.active_plans = None;
cache.all_plans = None;
}
}
pub fn invalidate_plan(&self, plan_id: &str) {
if let Ok(mut cache) = self.cache.write() {
cache.plans.remove(plan_id);
cache.active_plans = None;
cache.all_plans = None;
}
}
#[must_use]
pub fn cache_size(&self) -> usize {
self.cache.read().map(|c| c.plans.len()).unwrap_or(0)
}
}
#[async_trait::async_trait]
impl<S: PlanStore + Send + Sync> PlanStore for CachedPlanStore<S> {
async fn list_plans(&self) -> Result<Vec<StoredPlan>> {
if let Ok(cache) = self.cache.read() {
if let Some(ref cached) = cache.active_plans {
if cached.expires_at > std::time::Instant::now() {
return Ok(cached.plans.clone());
}
}
}
let plans = self.inner.list_plans().await?;
if let Ok(mut cache) = self.cache.write() {
cache.active_plans = Some(CachedPlanList {
plans: plans.clone(),
expires_at: std::time::Instant::now() + self.ttl,
});
}
Ok(plans)
}
async fn list_all_plans(&self) -> Result<Vec<StoredPlan>> {
if let Ok(cache) = self.cache.read() {
if let Some(ref cached) = cache.all_plans {
if cached.expires_at > std::time::Instant::now() {
return Ok(cached.plans.clone());
}
}
}
let plans = self.inner.list_all_plans().await?;
if let Ok(mut cache) = self.cache.write() {
cache.all_plans = Some(CachedPlanList {
plans: plans.clone(),
expires_at: std::time::Instant::now() + self.ttl,
});
}
Ok(plans)
}
async fn get_plan(&self, plan_id: &str) -> Result<Option<StoredPlan>> {
if let Ok(cache) = self.cache.read() {
if let Some(cached) = cache.plans.get(plan_id) {
if cached.expires_at > std::time::Instant::now() {
return Ok(cached.plan.clone());
}
}
}
let plan = self.inner.get_plan(plan_id).await?;
if let Ok(mut cache) = self.cache.write() {
cache.plans.insert(
plan_id.to_string(),
CachedPlan {
plan: plan.clone(),
expires_at: std::time::Instant::now() + self.ttl,
},
);
}
Ok(plan)
}
async fn get_plan_by_stripe_price(&self, stripe_price_id: &str) -> Result<Option<StoredPlan>> {
if let Ok(cache) = self.cache.read() {
if let Some(ref cached) = cache.all_plans {
if cached.expires_at > std::time::Instant::now() {
return Ok(cached
.plans
.iter()
.find(|p| p.stripe_price_id == stripe_price_id)
.cloned());
}
}
}
self.inner.get_plan_by_stripe_price(stripe_price_id).await
}
async fn create_plan(&self, plan: &StoredPlan) -> Result<()> {
let result = self.inner.create_plan(plan).await;
if result.is_ok() {
self.invalidate();
}
result
}
async fn update_plan(&self, plan: &StoredPlan) -> Result<()> {
let result = self.inner.update_plan(plan).await;
if result.is_ok() {
self.invalidate_plan(&plan.id);
}
result
}
async fn delete_plan(&self, plan_id: &str) -> Result<()> {
let result = self.inner.delete_plan(plan_id).await;
if result.is_ok() {
self.invalidate_plan(plan_id);
}
result
}
async fn set_plan_active(&self, plan_id: &str, is_active: bool) -> Result<()> {
let result = self.inner.set_plan_active(plan_id, is_active).await;
if result.is_ok() {
self.invalidate_plan(plan_id);
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscription_status_from_stripe() {
assert_eq!(
SubscriptionStatus::from_stripe("active"),
SubscriptionStatus::Active
);
assert_eq!(
SubscriptionStatus::from_stripe("trialing"),
SubscriptionStatus::Trialing
);
assert_eq!(
SubscriptionStatus::from_stripe("past_due"),
SubscriptionStatus::PastDue
);
assert_eq!(
SubscriptionStatus::from_stripe("canceled"),
SubscriptionStatus::Canceled
);
assert_eq!(
SubscriptionStatus::from_stripe("unknown"),
SubscriptionStatus::Canceled
);
}
#[test]
fn test_subscription_is_active() {
let sub = StoredSubscription {
stripe_subscription_id: "sub_123".to_string(),
stripe_customer_id: "cus_123".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Active,
current_period_start: 0,
current_period_end: 0,
extra_seats: 0,
trial_end: None,
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 0,
};
assert!(sub.is_active());
assert!(!sub.is_trialing());
assert!(!sub.is_past_due());
}
#[test]
fn test_subscription_trialing() {
let sub = StoredSubscription {
stripe_subscription_id: "sub_123".to_string(),
stripe_customer_id: "cus_123".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Trialing,
current_period_start: 0,
current_period_end: 0,
extra_seats: 0,
trial_end: Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 86400 * 7,
), cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 0,
};
assert!(sub.is_active());
assert!(sub.is_trialing());
assert!(sub.trial_days_remaining().unwrap() >= 6);
}
#[tokio::test]
async fn test_in_memory_store() {
use test::InMemoryBillingStore;
let store = InMemoryBillingStore::new();
assert!(
store
.get_stripe_customer_id("org_123")
.await
.unwrap()
.is_none()
);
store
.set_stripe_customer_id("org_123", "org", "cus_abc")
.await
.unwrap();
assert_eq!(
store
.get_stripe_customer_id("org_123")
.await
.unwrap()
.unwrap(),
"cus_abc"
);
let sub = StoredSubscription {
stripe_subscription_id: "sub_123".to_string(),
stripe_customer_id: "cus_abc".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Active,
current_period_start: 0,
current_period_end: 0,
extra_seats: 2,
trial_end: None,
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 0,
};
store.save_subscription("org_123", &sub).await.unwrap();
let loaded = store.get_subscription("org_123").await.unwrap().unwrap();
assert_eq!(loaded.plan_id, "starter");
assert_eq!(loaded.extra_seats, 2);
assert!(!store.is_event_processed("evt_123").await.unwrap());
store.mark_event_processed("evt_123").await.unwrap();
assert!(store.is_event_processed("evt_123").await.unwrap());
}
fn create_test_plan(
id: &str,
price_cents: i64,
is_active: bool,
sort_order: i32,
) -> StoredPlan {
StoredPlan {
id: id.to_string(),
name: format!("{} Plan", id),
description: Some(format!("Description for {}", id)),
stripe_price_id: format!("price_{}", id),
stripe_seat_price_id: None,
price_cents,
currency: "usd".to_string(),
interval: PlanInterval::Monthly,
included_seats: 1,
features: serde_json::json!({"basic": true}),
limits: serde_json::json!({"projects": 10}),
trial_days: Some(14),
is_active,
sort_order,
created_at: 0,
updated_at: 0,
}
}
#[tokio::test]
async fn test_in_memory_plan_store() {
use test::InMemoryBillingStore;
let store = InMemoryBillingStore::new();
assert!(store.list_plans().await.unwrap().is_empty());
assert!(store.list_all_plans().await.unwrap().is_empty());
let starter = create_test_plan("starter", 999, true, 1);
let pro = create_test_plan("pro", 2999, true, 2);
let inactive = create_test_plan("legacy", 499, false, 0);
store.create_plan(&starter).await.unwrap();
store.create_plan(&pro).await.unwrap();
store.create_plan(&inactive).await.unwrap();
let active = store.list_plans().await.unwrap();
assert_eq!(active.len(), 2);
assert_eq!(active[0].id, "starter");
assert_eq!(active[1].id, "pro");
let all = store.list_all_plans().await.unwrap();
assert_eq!(all.len(), 3);
let plan = store.get_plan("starter").await.unwrap().unwrap();
assert_eq!(plan.price_cents, 999);
let plan = store
.get_plan_by_stripe_price("price_pro")
.await
.unwrap()
.unwrap();
assert_eq!(plan.id, "pro");
let mut updated = starter.clone();
updated.price_cents = 1499;
store.update_plan(&updated).await.unwrap();
let plan = store.get_plan("starter").await.unwrap().unwrap();
assert_eq!(plan.price_cents, 1499);
store.set_plan_active("starter", false).await.unwrap();
let active = store.list_plans().await.unwrap();
assert_eq!(active.len(), 1);
assert_eq!(active[0].id, "pro");
store.delete_plan("pro").await.unwrap();
assert!(store.get_plan("pro").await.unwrap().is_none());
}
#[test]
fn test_stored_plan_helpers() {
let plan = StoredPlan {
id: "test".to_string(),
name: "Test".to_string(),
description: None,
stripe_price_id: "price_test".to_string(),
stripe_seat_price_id: None,
price_cents: 1999,
currency: "usd".to_string(),
interval: PlanInterval::Monthly,
included_seats: 5,
features: serde_json::json!({"api_access": true, "support": false}),
limits: serde_json::json!({"projects": 50, "storage_mb": 1000}),
trial_days: Some(14),
is_active: true,
sort_order: 0,
created_at: 0,
updated_at: 0,
};
assert!(plan.has_feature("api_access"));
assert!(!plan.has_feature("support"));
assert!(!plan.has_feature("nonexistent"));
assert_eq!(plan.get_limit("projects"), Some(50));
assert_eq!(plan.get_limit("storage_mb"), Some(1000));
assert_eq!(plan.get_limit("nonexistent"), None);
assert!(plan.check_limit("projects", 49)); assert!(!plan.check_limit("projects", 50)); assert!(plan.check_limit("nonexistent", 9999));
assert_eq!(plan.formatted_price(), "$19.99");
}
#[test]
fn test_plan_interval() {
assert_eq!(PlanInterval::from_str("monthly"), PlanInterval::Monthly);
assert_eq!(PlanInterval::from_str("month"), PlanInterval::Monthly);
assert_eq!(PlanInterval::from_str("yearly"), PlanInterval::Yearly);
assert_eq!(PlanInterval::from_str("year"), PlanInterval::Yearly);
assert_eq!(PlanInterval::from_str("annual"), PlanInterval::Yearly);
assert_eq!(PlanInterval::from_str("one_time"), PlanInterval::OneTime);
assert_eq!(PlanInterval::from_str("lifetime"), PlanInterval::OneTime);
assert_eq!(PlanInterval::from_str("unknown"), PlanInterval::Monthly);
assert_eq!(PlanInterval::Monthly.as_str(), "monthly");
assert_eq!(PlanInterval::Yearly.as_str(), "yearly");
assert_eq!(PlanInterval::OneTime.as_str(), "one_time");
}
#[tokio::test]
async fn test_count_subscriptions_by_plan() {
use test::InMemoryBillingStore;
let store = InMemoryBillingStore::new();
assert_eq!(
store.count_subscriptions_by_plan("starter").await.unwrap(),
0
);
let sub1 = StoredSubscription {
stripe_subscription_id: "sub_1".to_string(),
stripe_customer_id: "cus_1".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Active,
current_period_start: 0,
current_period_end: 0,
extra_seats: 0,
trial_end: None,
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 0,
};
store.save_subscription("org_1", &sub1).await.unwrap();
assert_eq!(
store.count_subscriptions_by_plan("starter").await.unwrap(),
1
);
let sub2 = StoredSubscription {
stripe_subscription_id: "sub_2".to_string(),
stripe_customer_id: "cus_2".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Trialing,
current_period_start: 0,
current_period_end: 0,
extra_seats: 0,
trial_end: Some(99999999),
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 0,
};
store.save_subscription("org_2", &sub2).await.unwrap();
assert_eq!(
store.count_subscriptions_by_plan("starter").await.unwrap(),
2
);
let sub3 = StoredSubscription {
stripe_subscription_id: "sub_3".to_string(),
stripe_customer_id: "cus_3".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Canceled,
current_period_start: 0,
current_period_end: 0,
extra_seats: 0,
trial_end: None,
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 0,
};
store.save_subscription("org_3", &sub3).await.unwrap();
assert_eq!(
store.count_subscriptions_by_plan("starter").await.unwrap(),
2
);
let sub4 = StoredSubscription {
stripe_subscription_id: "sub_4".to_string(),
stripe_customer_id: "cus_4".to_string(),
plan_id: "pro".to_string(),
status: SubscriptionStatus::Active,
current_period_start: 0,
current_period_end: 0,
extra_seats: 0,
trial_end: None,
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 0,
};
store.save_subscription("org_4", &sub4).await.unwrap();
assert_eq!(
store.count_subscriptions_by_plan("starter").await.unwrap(),
2
);
assert_eq!(store.count_subscriptions_by_plan("pro").await.unwrap(), 1);
}
#[tokio::test]
async fn test_cached_plan_store_caches_results() {
use std::time::Duration;
use test::InMemoryBillingStore;
let inner = InMemoryBillingStore::new();
let starter = create_test_plan("starter", 999, true, 1);
inner.create_plan(&starter).await.unwrap();
let cached = CachedPlanStore::new(inner.clone(), Duration::from_secs(1));
let plans = cached.list_plans().await.unwrap();
assert_eq!(plans.len(), 1);
assert_eq!(plans[0].id, "starter");
let pro = create_test_plan("pro", 2999, true, 2);
inner.create_plan(&pro).await.unwrap();
let plans = cached.list_plans().await.unwrap();
assert_eq!(plans.len(), 1);
cached.invalidate();
let plans = cached.list_plans().await.unwrap();
assert_eq!(plans.len(), 2);
}
#[tokio::test]
async fn test_cached_plan_store_get_plan() {
use std::time::Duration;
use test::InMemoryBillingStore;
let inner = InMemoryBillingStore::new();
let starter = create_test_plan("starter", 999, true, 1);
inner.create_plan(&starter).await.unwrap();
let cached = CachedPlanStore::new(inner.clone(), Duration::from_secs(60));
let plan = cached.get_plan("starter").await.unwrap().unwrap();
assert_eq!(plan.price_cents, 999);
assert_eq!(cached.cache_size(), 1);
let mut updated = starter.clone();
updated.price_cents = 1499;
inner.update_plan(&updated).await.unwrap();
let plan = cached.get_plan("starter").await.unwrap().unwrap();
assert_eq!(plan.price_cents, 999);
cached.invalidate_plan("starter");
let plan = cached.get_plan("starter").await.unwrap().unwrap();
assert_eq!(plan.price_cents, 1499);
}
#[tokio::test]
async fn test_cached_plan_store_write_operations_invalidate() {
use std::time::Duration;
use test::InMemoryBillingStore;
let inner = InMemoryBillingStore::new();
let cached = CachedPlanStore::new(inner, Duration::from_secs(60));
let starter = create_test_plan("starter", 999, true, 1);
cached.create_plan(&starter).await.unwrap();
let plans = cached.list_plans().await.unwrap();
assert_eq!(plans.len(), 1);
let mut updated = starter.clone();
updated.price_cents = 1499;
cached.update_plan(&updated).await.unwrap();
let plan = cached.get_plan("starter").await.unwrap().unwrap();
assert_eq!(plan.price_cents, 1499);
cached.delete_plan("starter").await.unwrap();
assert!(cached.get_plan("starter").await.unwrap().is_none());
assert!(cached.list_plans().await.unwrap().is_empty());
}
#[tokio::test]
async fn test_cached_plan_store_set_active_invalidates() {
use std::time::Duration;
use test::InMemoryBillingStore;
let inner = InMemoryBillingStore::new();
let cached = CachedPlanStore::new(inner, Duration::from_secs(60));
let starter = create_test_plan("starter", 999, true, 1);
cached.create_plan(&starter).await.unwrap();
let plans = cached.list_plans().await.unwrap();
assert_eq!(plans.len(), 1);
cached.set_plan_active("starter", false).await.unwrap();
let active_plans = cached.list_plans().await.unwrap();
assert!(active_plans.is_empty());
let all_plans = cached.list_all_plans().await.unwrap();
assert_eq!(all_plans.len(), 1);
assert!(!all_plans[0].is_active);
}
}