use super::storage::BillingStore;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct PaymentMethod {
pub id: String,
pub card_brand: Option<String>,
pub card_last4: Option<String>,
pub card_exp_month: Option<u32>,
pub card_exp_year: Option<u32>,
pub is_default: bool,
}
#[derive(Debug, Clone)]
pub struct PaymentMethodList {
pub methods: Vec<PaymentMethod>,
pub has_more: bool,
}
#[allow(async_fn_in_trait)]
pub trait StripePaymentMethodClient: Send + Sync {
async fn list_payment_methods(&self, customer_id: &str, limit: u8)
-> Result<PaymentMethodList>;
async fn attach_payment_method(
&self,
payment_method_id: &str,
customer_id: &str,
) -> Result<PaymentMethod>;
async fn detach_payment_method(&self, payment_method_id: &str) -> Result<()>;
async fn set_default_payment_method(
&self,
customer_id: &str,
payment_method_id: &str,
) -> Result<()>;
}
const DEFAULT_PAYMENT_METHOD_LIMIT: u8 = 100;
pub struct PaymentMethodManager<S: BillingStore, C: StripePaymentMethodClient> {
store: S,
client: C,
list_limit: u8,
}
impl<S: BillingStore, C: StripePaymentMethodClient> PaymentMethodManager<S, C> {
#[must_use]
pub fn new(store: S, client: C) -> Self {
Self {
store,
client,
list_limit: DEFAULT_PAYMENT_METHOD_LIMIT,
}
}
#[must_use]
pub fn with_limit(store: S, client: C, list_limit: u8) -> Self {
Self {
store,
client,
list_limit: list_limit.clamp(1, 100),
}
}
pub async fn list_payment_methods(&self, billable_id: &str) -> Result<PaymentMethodList> {
let sub = self
.store
.get_subscription(billable_id)
.await?
.ok_or_else(|| super::error::BillingError::NoSubscription {
billable_id: billable_id.to_string(),
})?;
self.client
.list_payment_methods(&sub.stripe_customer_id, self.list_limit)
.await
}
pub async fn list_payment_methods_with_limit(
&self,
billable_id: &str,
limit: u8,
) -> Result<PaymentMethodList> {
let sub = self
.store
.get_subscription(billable_id)
.await?
.ok_or_else(|| super::error::BillingError::NoSubscription {
billable_id: billable_id.to_string(),
})?;
self.client
.list_payment_methods(&sub.stripe_customer_id, limit.clamp(1, 100))
.await
}
pub async fn set_default(&self, billable_id: &str, payment_method_id: &str) -> Result<()> {
let sub = self
.store
.get_subscription(billable_id)
.await?
.ok_or_else(|| super::error::BillingError::NoSubscription {
billable_id: billable_id.to_string(),
})?;
let methods = self
.client
.list_payment_methods(&sub.stripe_customer_id, self.list_limit)
.await?;
if !methods.methods.iter().any(|m| m.id == payment_method_id) {
return Err(super::error::BillingError::PaymentMethodNotFound {
payment_method_id: payment_method_id.to_string(),
}
.into());
}
self.client
.set_default_payment_method(&sub.stripe_customer_id, payment_method_id)
.await
}
pub async fn remove(&self, billable_id: &str, payment_method_id: &str) -> Result<()> {
let sub = self
.store
.get_subscription(billable_id)
.await?
.ok_or_else(|| super::error::BillingError::NoSubscription {
billable_id: billable_id.to_string(),
})?;
let methods = self
.client
.list_payment_methods(&sub.stripe_customer_id, self.list_limit)
.await?;
if !methods.methods.iter().any(|m| m.id == payment_method_id) {
return Err(super::error::BillingError::PaymentMethodNotFound {
payment_method_id: payment_method_id.to_string(),
}
.into());
}
self.client.detach_payment_method(payment_method_id).await
}
pub async fn attach(
&self,
billable_id: &str,
payment_method_id: &str,
) -> Result<PaymentMethod> {
let sub = self
.store
.get_subscription(billable_id)
.await?
.ok_or_else(|| super::error::BillingError::NoSubscription {
billable_id: billable_id.to_string(),
})?;
self.client
.attach_payment_method(payment_method_id, &sub.stripe_customer_id)
.await
}
pub async fn get_default(&self, billable_id: &str) -> Result<Option<PaymentMethod>> {
let methods = self.list_payment_methods(billable_id).await?;
Ok(methods.methods.into_iter().find(|m| m.is_default))
}
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::collections::HashMap;
use std::sync::RwLock;
#[derive(Default)]
pub struct MockStripePaymentMethodClient {
payment_methods: std::sync::Arc<RwLock<HashMap<String, Vec<PaymentMethod>>>>,
default_methods: std::sync::Arc<RwLock<HashMap<String, String>>>,
}
impl MockStripePaymentMethodClient {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_payment_method(&self, customer_id: &str, method: PaymentMethod) {
let mut methods = self.payment_methods.write().unwrap();
methods
.entry(customer_id.to_string())
.or_default()
.push(method);
}
}
impl StripePaymentMethodClient for MockStripePaymentMethodClient {
async fn list_payment_methods(
&self,
customer_id: &str,
_limit: u8,
) -> Result<PaymentMethodList> {
let methods = self.payment_methods.read().unwrap();
let defaults = self.default_methods.read().unwrap();
let default_id = defaults.get(customer_id);
let customer_methods = methods
.get(customer_id)
.cloned()
.unwrap_or_default()
.into_iter()
.map(|mut m| {
m.is_default = default_id.map(|d| d == &m.id).unwrap_or(false);
m
})
.collect();
Ok(PaymentMethodList {
methods: customer_methods,
has_more: false,
})
}
async fn attach_payment_method(
&self,
payment_method_id: &str,
customer_id: &str,
) -> Result<PaymentMethod> {
let method = PaymentMethod {
id: payment_method_id.to_string(),
card_brand: Some("visa".to_string()),
card_last4: Some("4242".to_string()),
card_exp_month: Some(12),
card_exp_year: Some(2099),
is_default: false,
};
let mut methods = self.payment_methods.write().unwrap();
methods
.entry(customer_id.to_string())
.or_default()
.push(method.clone());
Ok(method)
}
async fn detach_payment_method(&self, payment_method_id: &str) -> Result<()> {
let mut methods = self.payment_methods.write().unwrap();
for customer_methods in methods.values_mut() {
customer_methods.retain(|m| m.id != payment_method_id);
}
Ok(())
}
async fn set_default_payment_method(
&self,
customer_id: &str,
payment_method_id: &str,
) -> Result<()> {
let mut defaults = self.default_methods.write().unwrap();
defaults.insert(customer_id.to_string(), payment_method_id.to_string());
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockStripePaymentMethodClient;
use super::*;
use crate::billing::storage::test::InMemoryBillingStore;
use crate::billing::storage::{StoredSubscription, SubscriptionStatus};
fn create_test_subscription(_billable_id: &str) -> StoredSubscription {
StoredSubscription {
stripe_subscription_id: "sub_123".to_string(),
stripe_customer_id: "cus_123".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Active,
current_period_start: 1700000000,
current_period_end: 1702592000,
extra_seats: 0,
trial_end: None,
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 1700000000,
}
}
#[tokio::test]
async fn test_list_payment_methods() {
let store = InMemoryBillingStore::new();
let client = MockStripePaymentMethodClient::new();
store
.save_subscription("org_123", &create_test_subscription("org_123"))
.await
.unwrap();
client.add_payment_method(
"cus_123",
PaymentMethod {
id: "pm_1".to_string(),
card_brand: Some("visa".to_string()),
card_last4: Some("4242".to_string()),
card_exp_month: Some(12),
card_exp_year: Some(2099),
is_default: false,
},
);
client.add_payment_method(
"cus_123",
PaymentMethod {
id: "pm_2".to_string(),
card_brand: Some("mastercard".to_string()),
card_last4: Some("5555".to_string()),
card_exp_month: Some(6),
card_exp_year: Some(2026),
is_default: false,
},
);
let manager = PaymentMethodManager::new(store, client);
let methods = manager.list_payment_methods("org_123").await.unwrap();
assert_eq!(methods.methods.len(), 2);
assert!(!methods.has_more);
}
#[tokio::test]
async fn test_set_default_payment_method() {
let store = InMemoryBillingStore::new();
let client = MockStripePaymentMethodClient::new();
store
.save_subscription("org_123", &create_test_subscription("org_123"))
.await
.unwrap();
client.add_payment_method(
"cus_123",
PaymentMethod {
id: "pm_1".to_string(),
card_brand: Some("visa".to_string()),
card_last4: Some("4242".to_string()),
card_exp_month: Some(12),
card_exp_year: Some(2099),
is_default: false,
},
);
let manager = PaymentMethodManager::new(store, client);
manager.set_default("org_123", "pm_1").await.unwrap();
let default = manager.get_default("org_123").await.unwrap();
assert!(default.is_some());
assert_eq!(default.unwrap().id, "pm_1");
}
#[tokio::test]
async fn test_attach_payment_method() {
let store = InMemoryBillingStore::new();
let client = MockStripePaymentMethodClient::new();
store
.save_subscription("org_123", &create_test_subscription("org_123"))
.await
.unwrap();
let manager = PaymentMethodManager::new(store, client);
let method = manager.attach("org_123", "pm_new").await.unwrap();
assert_eq!(method.id, "pm_new");
assert_eq!(method.card_brand, Some("visa".to_string()));
let methods = manager.list_payment_methods("org_123").await.unwrap();
assert_eq!(methods.methods.len(), 1);
}
#[tokio::test]
async fn test_remove_payment_method() {
let store = InMemoryBillingStore::new();
let client = MockStripePaymentMethodClient::new();
store
.save_subscription("org_123", &create_test_subscription("org_123"))
.await
.unwrap();
client.add_payment_method(
"cus_123",
PaymentMethod {
id: "pm_1".to_string(),
card_brand: Some("visa".to_string()),
card_last4: Some("4242".to_string()),
card_exp_month: Some(12),
card_exp_year: Some(2099),
is_default: false,
},
);
let manager = PaymentMethodManager::new(store, client);
let methods = manager.list_payment_methods("org_123").await.unwrap();
assert_eq!(methods.methods.len(), 1);
manager.remove("org_123", "pm_1").await.unwrap();
let methods = manager.list_payment_methods("org_123").await.unwrap();
assert_eq!(methods.methods.len(), 0);
}
#[tokio::test]
async fn test_no_subscription_error() {
let store = InMemoryBillingStore::new();
let client = MockStripePaymentMethodClient::new();
let manager = PaymentMethodManager::new(store, client);
let result = manager.list_payment_methods("nonexistent").await;
assert!(result.is_err());
}
}