use super::storage::BillingStore;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct Refund {
pub id: String,
pub amount: i64,
pub currency: String,
pub status: RefundStatus,
pub reason: Option<RefundReason>,
pub created: u64,
pub charge_id: Option<String>,
pub payment_intent_id: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RefundStatus {
Pending,
Succeeded,
Failed,
Canceled,
}
impl RefundStatus {
#[must_use]
pub fn from_stripe(status: &str) -> Self {
match status {
"pending" => Self::Pending,
"succeeded" => Self::Succeeded,
"failed" => Self::Failed,
"canceled" => Self::Canceled,
_ => Self::Pending,
}
}
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Pending => "pending",
Self::Succeeded => "succeeded",
Self::Failed => "failed",
Self::Canceled => "canceled",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RefundReason {
Duplicate,
Fraudulent,
RequestedByCustomer,
}
impl RefundReason {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Duplicate => "duplicate",
Self::Fraudulent => "fraudulent",
Self::RequestedByCustomer => "requested_by_customer",
}
}
#[must_use]
pub fn from_stripe(reason: &str) -> Option<Self> {
match reason {
"duplicate" => Some(Self::Duplicate),
"fraudulent" => Some(Self::Fraudulent),
"requested_by_customer" => Some(Self::RequestedByCustomer),
_ => None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CreateRefundRequest {
pub charge_id: Option<String>,
pub payment_intent_id: Option<String>,
pub amount: Option<i64>,
pub reason: Option<RefundReason>,
}
impl CreateRefundRequest {
#[must_use]
pub fn for_charge(charge_id: impl Into<String>) -> Self {
Self {
charge_id: Some(charge_id.into()),
..Default::default()
}
}
#[must_use]
pub fn for_payment_intent(payment_intent_id: impl Into<String>) -> Self {
Self {
payment_intent_id: Some(payment_intent_id.into()),
..Default::default()
}
}
#[must_use]
pub fn with_amount(mut self, amount: i64) -> Self {
debug_assert!(amount > 0, "Refund amount must be positive, got {}", amount);
self.amount = Some(amount.max(1));
self
}
#[must_use]
pub fn with_reason(mut self, reason: RefundReason) -> Self {
self.reason = Some(reason);
self
}
}
#[allow(async_fn_in_trait)]
pub trait StripeRefundClient: Send + Sync {
async fn create_refund(&self, request: CreateRefundRequest) -> Result<Refund>;
async fn get_refund(&self, refund_id: &str) -> Result<Refund>;
async fn list_refunds(&self, charge_id: &str, limit: u8) -> Result<Vec<Refund>>;
async fn get_charge_customer_id(&self, charge_id: &str) -> Result<String>;
async fn get_payment_intent_customer_id(&self, payment_intent_id: &str) -> Result<String>;
}
pub struct RefundManager<C: StripeRefundClient> {
client: C,
}
impl<C: StripeRefundClient> RefundManager<C> {
#[must_use]
pub fn new(client: C) -> Self {
Self { client }
}
pub async fn refund_charge(&self, charge_id: &str, amount: Option<i64>) -> Result<Refund> {
let mut request = CreateRefundRequest::for_charge(charge_id);
if let Some(amt) = amount {
request = request.with_amount(amt);
}
self.client.create_refund(request).await
}
pub async fn refund_payment_intent(
&self,
payment_intent_id: &str,
amount: Option<i64>,
) -> Result<Refund> {
let mut request = CreateRefundRequest::for_payment_intent(payment_intent_id);
if let Some(amt) = amount {
request = request.with_amount(amt);
}
self.client.create_refund(request).await
}
pub async fn refund_with_reason(
&self,
charge_id: &str,
amount: Option<i64>,
reason: RefundReason,
) -> Result<Refund> {
let mut request = CreateRefundRequest::for_charge(charge_id).with_reason(reason);
if let Some(amt) = amount {
request = request.with_amount(amt);
}
self.client.create_refund(request).await
}
pub async fn get_refund(&self, refund_id: &str) -> Result<Refund> {
self.client.get_refund(refund_id).await
}
pub async fn list_refunds_for_charge(&self, charge_id: &str, limit: u8) -> Result<Vec<Refund>> {
self.client.list_refunds(charge_id, limit).await
}
}
pub struct SecureRefundManager<S: BillingStore, C: StripeRefundClient> {
store: S,
client: C,
}
impl<S: BillingStore, C: StripeRefundClient> SecureRefundManager<S, C> {
#[must_use]
pub fn new(store: S, client: C) -> Self {
Self { store, client }
}
async fn verify_charge_ownership(&self, billable_id: &str, charge_id: &str) -> Result<()> {
let sub = self
.store
.get_subscription(billable_id)
.await?
.ok_or_else(|| super::error::BillingError::NoSubscription {
billable_id: billable_id.to_string(),
})?;
let charge_customer_id = self.client.get_charge_customer_id(charge_id).await?;
if charge_customer_id != sub.stripe_customer_id {
return Err(super::error::BillingError::ChargeNotFound {
charge_id: charge_id.to_string(),
}
.into());
}
Ok(())
}
async fn verify_payment_intent_ownership(
&self,
billable_id: &str,
payment_intent_id: &str,
) -> Result<()> {
let sub = self
.store
.get_subscription(billable_id)
.await?
.ok_or_else(|| super::error::BillingError::NoSubscription {
billable_id: billable_id.to_string(),
})?;
let pi_customer_id = self
.client
.get_payment_intent_customer_id(payment_intent_id)
.await?;
if pi_customer_id != sub.stripe_customer_id {
return Err(super::error::BillingError::RefundFailed {
reason: "Payment intent does not belong to this customer".to_string(),
}
.into());
}
Ok(())
}
pub async fn refund_charge(
&self,
billable_id: &str,
charge_id: &str,
amount: Option<i64>,
) -> Result<Refund> {
self.verify_charge_ownership(billable_id, charge_id).await?;
let mut request = CreateRefundRequest::for_charge(charge_id);
if let Some(amt) = amount {
request = request.with_amount(amt);
}
self.client.create_refund(request).await
}
pub async fn refund_payment_intent(
&self,
billable_id: &str,
payment_intent_id: &str,
amount: Option<i64>,
) -> Result<Refund> {
self.verify_payment_intent_ownership(billable_id, payment_intent_id)
.await?;
let mut request = CreateRefundRequest::for_payment_intent(payment_intent_id);
if let Some(amt) = amount {
request = request.with_amount(amt);
}
self.client.create_refund(request).await
}
pub async fn refund_with_reason(
&self,
billable_id: &str,
charge_id: &str,
amount: Option<i64>,
reason: RefundReason,
) -> Result<Refund> {
self.verify_charge_ownership(billable_id, charge_id).await?;
let mut request = CreateRefundRequest::for_charge(charge_id).with_reason(reason);
if let Some(amt) = amount {
request = request.with_amount(amt);
}
self.client.create_refund(request).await
}
pub async fn get_refund(&self, refund_id: &str) -> Result<Refund> {
self.client.get_refund(refund_id).await
}
pub async fn list_refunds_for_charge(
&self,
billable_id: &str,
charge_id: &str,
limit: u8,
) -> Result<Vec<Refund>> {
self.verify_charge_ownership(billable_id, charge_id).await?;
self.client.list_refunds(charge_id, limit).await
}
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::collections::HashMap;
use std::sync::RwLock;
pub struct MockStripeRefundClient {
refunds: std::sync::Arc<RwLock<HashMap<String, Refund>>>,
charge_refunds: std::sync::Arc<RwLock<HashMap<String, Vec<String>>>>,
charge_customers: std::sync::Arc<RwLock<HashMap<String, String>>>,
payment_intent_customers: std::sync::Arc<RwLock<HashMap<String, String>>>,
refund_counter: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub default_currency: String,
}
impl Default for MockStripeRefundClient {
fn default() -> Self {
Self {
refunds: std::sync::Arc::new(RwLock::new(HashMap::new())),
charge_refunds: std::sync::Arc::new(RwLock::new(HashMap::new())),
charge_customers: std::sync::Arc::new(RwLock::new(HashMap::new())),
payment_intent_customers: std::sync::Arc::new(RwLock::new(HashMap::new())),
refund_counter: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
default_currency: "gbp".to_string(),
}
}
}
impl MockStripeRefundClient {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_currency(currency: impl Into<String>) -> Self {
Self {
default_currency: currency.into().to_lowercase(),
..Self::default()
}
}
pub fn add_charge(&self, charge_id: &str, customer_id: &str) {
self.charge_customers
.write()
.unwrap()
.insert(charge_id.to_string(), customer_id.to_string());
}
pub fn add_payment_intent(&self, payment_intent_id: &str, customer_id: &str) {
self.payment_intent_customers
.write()
.unwrap()
.insert(payment_intent_id.to_string(), customer_id.to_string());
}
}
impl StripeRefundClient for MockStripeRefundClient {
async fn create_refund(&self, request: CreateRefundRequest) -> Result<Refund> {
let id = format!(
"re_mock_{}",
self.refund_counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let refund = Refund {
id: id.clone(),
amount: request.amount.unwrap_or(1000),
currency: self.default_currency.clone(),
status: RefundStatus::Succeeded,
reason: request.reason,
created: now,
charge_id: request.charge_id.clone(),
payment_intent_id: request.payment_intent_id,
};
self.refunds
.write()
.unwrap()
.insert(id.clone(), refund.clone());
if let Some(charge_id) = &request.charge_id {
self.charge_refunds
.write()
.unwrap()
.entry(charge_id.clone())
.or_default()
.push(id);
}
Ok(refund)
}
async fn get_refund(&self, refund_id: &str) -> Result<Refund> {
self.refunds
.read()
.unwrap()
.get(refund_id)
.cloned()
.ok_or_else(|| {
super::super::error::BillingError::RefundNotFound {
refund_id: refund_id.to_string(),
}
.into()
})
}
async fn list_refunds(&self, charge_id: &str, limit: u8) -> Result<Vec<Refund>> {
let charge_refunds = self.charge_refunds.read().unwrap();
let refunds = self.refunds.read().unwrap();
let refund_ids = charge_refunds.get(charge_id).cloned().unwrap_or_default();
let result: Vec<Refund> = refund_ids
.into_iter()
.take(limit as usize)
.filter_map(|id| refunds.get(&id).cloned())
.collect();
Ok(result)
}
async fn get_charge_customer_id(&self, charge_id: &str) -> Result<String> {
self.charge_customers
.read()
.unwrap()
.get(charge_id)
.cloned()
.ok_or_else(|| {
super::super::error::BillingError::ChargeNotFound {
charge_id: charge_id.to_string(),
}
.into()
})
}
async fn get_payment_intent_customer_id(&self, payment_intent_id: &str) -> Result<String> {
self.payment_intent_customers
.read()
.unwrap()
.get(payment_intent_id)
.cloned()
.ok_or_else(|| {
super::super::error::BillingError::RefundFailed {
reason: format!("Payment intent not found: {}", payment_intent_id),
}
.into()
})
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockStripeRefundClient;
use super::*;
#[tokio::test]
async fn test_refund_charge() {
let client = MockStripeRefundClient::new();
let manager = RefundManager::new(client);
let refund = manager.refund_charge("ch_123", None).await.unwrap();
assert!(refund.id.starts_with("re_mock_"));
assert_eq!(refund.status, RefundStatus::Succeeded);
assert_eq!(refund.charge_id, Some("ch_123".to_string()));
}
#[tokio::test]
async fn test_refund_partial() {
let client = MockStripeRefundClient::new();
let manager = RefundManager::new(client);
let refund = manager.refund_charge("ch_123", Some(500)).await.unwrap();
assert_eq!(refund.amount, 500);
}
#[tokio::test]
async fn test_refund_payment_intent() {
let client = MockStripeRefundClient::new();
let manager = RefundManager::new(client);
let refund = manager.refund_payment_intent("pi_123", None).await.unwrap();
assert_eq!(refund.payment_intent_id, Some("pi_123".to_string()));
}
#[tokio::test]
async fn test_refund_with_reason() {
let client = MockStripeRefundClient::new();
let manager = RefundManager::new(client);
let refund = manager
.refund_with_reason("ch_123", None, RefundReason::Duplicate)
.await
.unwrap();
assert_eq!(refund.reason, Some(RefundReason::Duplicate));
}
#[tokio::test]
async fn test_get_refund() {
let client = MockStripeRefundClient::new();
let manager = RefundManager::new(client);
let created = manager.refund_charge("ch_123", None).await.unwrap();
let retrieved = manager.get_refund(&created.id).await.unwrap();
assert_eq!(created.id, retrieved.id);
}
#[tokio::test]
async fn test_list_refunds_for_charge() {
let client = MockStripeRefundClient::new();
let manager = RefundManager::new(client);
manager.refund_charge("ch_123", Some(100)).await.unwrap();
manager.refund_charge("ch_123", Some(200)).await.unwrap();
manager.refund_charge("ch_456", Some(300)).await.unwrap();
let refunds = manager.list_refunds_for_charge("ch_123", 10).await.unwrap();
assert_eq!(refunds.len(), 2);
}
#[test]
fn test_refund_status() {
assert_eq!(RefundStatus::from_stripe("pending"), RefundStatus::Pending);
assert_eq!(
RefundStatus::from_stripe("succeeded"),
RefundStatus::Succeeded
);
assert_eq!(RefundStatus::from_stripe("failed"), RefundStatus::Failed);
assert_eq!(
RefundStatus::from_stripe("canceled"),
RefundStatus::Canceled
);
assert_eq!(RefundStatus::from_stripe("unknown"), RefundStatus::Pending);
}
#[test]
fn test_refund_reason() {
assert_eq!(RefundReason::Duplicate.as_str(), "duplicate");
assert_eq!(RefundReason::Fraudulent.as_str(), "fraudulent");
assert_eq!(
RefundReason::RequestedByCustomer.as_str(),
"requested_by_customer"
);
assert_eq!(
RefundReason::from_stripe("duplicate"),
Some(RefundReason::Duplicate)
);
assert_eq!(RefundReason::from_stripe("unknown"), None);
}
#[test]
fn test_create_refund_request_builder() {
let request = CreateRefundRequest::for_charge("ch_123")
.with_amount(500)
.with_reason(RefundReason::RequestedByCustomer);
assert_eq!(request.charge_id, Some("ch_123".to_string()));
assert_eq!(request.amount, Some(500));
assert_eq!(request.reason, Some(RefundReason::RequestedByCustomer));
}
use crate::billing::storage::test::InMemoryBillingStore;
use crate::billing::storage::{StoredSubscription, SubscriptionStatus};
fn create_test_subscription() -> StoredSubscription {
StoredSubscription {
stripe_subscription_id: "sub_123".to_string(),
stripe_customer_id: "cus_123".to_string(),
plan_id: "starter".to_string(),
status: SubscriptionStatus::Active,
current_period_start: 1700000000,
current_period_end: 1702592000,
extra_seats: 0,
trial_end: None,
cancel_at_period_end: false,
base_item_id: None,
seat_item_id: None,
updated_at: 1700000000,
}
}
#[tokio::test]
async fn test_secure_refund_charge_with_valid_ownership() {
let store = InMemoryBillingStore::new();
let client = MockStripeRefundClient::new();
store
.save_subscription("org_123", &create_test_subscription())
.await
.unwrap();
client.add_charge("ch_test", "cus_123");
let manager = SecureRefundManager::new(store, client);
let refund = manager
.refund_charge("org_123", "ch_test", None)
.await
.unwrap();
assert!(refund.id.starts_with("re_mock_"));
}
#[tokio::test]
async fn test_secure_refund_charge_with_invalid_ownership() {
let store = InMemoryBillingStore::new();
let client = MockStripeRefundClient::new();
store
.save_subscription("org_123", &create_test_subscription())
.await
.unwrap();
client.add_charge("ch_other", "cus_other");
let manager = SecureRefundManager::new(store, client);
let result = manager.refund_charge("org_123", "ch_other", None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_secure_refund_charge_no_subscription() {
let store = InMemoryBillingStore::new();
let client = MockStripeRefundClient::new();
let manager = SecureRefundManager::new(store, client);
let result = manager.refund_charge("nonexistent", "ch_test", None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_secure_refund_payment_intent_with_valid_ownership() {
let store = InMemoryBillingStore::new();
let client = MockStripeRefundClient::new();
store
.save_subscription("org_123", &create_test_subscription())
.await
.unwrap();
client.add_payment_intent("pi_test", "cus_123");
let manager = SecureRefundManager::new(store, client);
let refund = manager
.refund_payment_intent("org_123", "pi_test", None)
.await
.unwrap();
assert!(refund.id.starts_with("re_mock_"));
}
#[tokio::test]
async fn test_secure_refund_payment_intent_with_invalid_ownership() {
let store = InMemoryBillingStore::new();
let client = MockStripeRefundClient::new();
store
.save_subscription("org_123", &create_test_subscription())
.await
.unwrap();
client.add_payment_intent("pi_other", "cus_other");
let manager = SecureRefundManager::new(store, client);
let result = manager
.refund_payment_intent("org_123", "pi_other", None)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_secure_list_refunds_with_ownership() {
let store = InMemoryBillingStore::new();
let client = MockStripeRefundClient::new();
store
.save_subscription("org_123", &create_test_subscription())
.await
.unwrap();
client.add_charge("ch_test", "cus_123");
let manager = SecureRefundManager::new(store, client);
manager
.refund_charge("org_123", "ch_test", Some(100))
.await
.unwrap();
let refunds = manager
.list_refunds_for_charge("org_123", "ch_test", 10)
.await
.unwrap();
assert_eq!(refunds.len(), 1);
}
}