use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use super::error::BillingError;
use super::storage::BillingStore;
use crate::error::Result;
fn sanitize_invoice_id(invoice_id: &str) -> String {
if invoice_id.len() > 10 {
format!("{}...", &invoice_id[..10])
} else {
invoice_id.to_string()
}
}
#[derive(Debug, Clone)]
pub struct InvoiceConfig {
pub default_limit: u8,
pub include_line_items: bool,
pub max_line_items: u8,
pub default_status_filter: Option<InvoiceStatus>,
}
impl Default for InvoiceConfig {
fn default() -> Self {
Self {
default_limit: 10,
include_line_items: false,
max_line_items: 100,
default_status_filter: None,
}
}
}
impl InvoiceConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_default_limit(mut self, limit: u8) -> Self {
self.default_limit = limit.clamp(1, 100);
self
}
#[must_use]
pub fn with_line_items(mut self, include: bool) -> Self {
self.include_line_items = include;
self
}
#[must_use]
pub fn with_max_line_items(mut self, max: u8) -> Self {
self.max_line_items = max.clamp(1, 100);
self
}
#[must_use]
pub fn with_status_filter(mut self, status: InvoiceStatus) -> Self {
self.default_status_filter = Some(status);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InvoiceStatus {
Draft,
Open,
Paid,
Uncollectible,
Void,
}
impl InvoiceStatus {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Draft => "draft",
Self::Open => "open",
Self::Paid => "paid",
Self::Uncollectible => "uncollectible",
Self::Void => "void",
}
}
}
impl std::str::FromStr for InvoiceStatus {
type Err = InvoiceStatusParseError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"draft" => Ok(Self::Draft),
"open" => Ok(Self::Open),
"paid" => Ok(Self::Paid),
"uncollectible" => Ok(Self::Uncollectible),
"void" => Ok(Self::Void),
_ => Err(InvoiceStatusParseError(s.to_string())),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InvoiceStatusParseError(pub String);
impl std::fmt::Display for InvoiceStatusParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "invalid invoice status: '{}'", self.0)
}
}
impl std::error::Error for InvoiceStatusParseError {}
impl std::fmt::Display for InvoiceStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Invoice {
pub id: String,
pub customer_id: String,
pub subscription_id: Option<String>,
pub status: InvoiceStatus,
pub amount_due: i64,
pub amount_paid: i64,
pub amount_remaining: i64,
pub currency: String,
pub created: u64,
pub due_date: Option<u64>,
pub period_start: u64,
pub period_end: u64,
pub invoice_pdf: Option<String>,
pub hosted_invoice_url: Option<String>,
pub number: Option<String>,
pub paid: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub line_items: Option<Vec<InvoiceLineItem>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvoiceLineItem {
pub id: String,
pub description: Option<String>,
pub amount: i64,
pub currency: String,
pub quantity: Option<u32>,
pub price_id: Option<String>,
pub period_start: u64,
pub period_end: u64,
}
#[derive(Debug, Clone, Default)]
pub struct InvoiceListParams {
pub limit: Option<u8>,
pub starting_after: Option<String>,
pub status: Option<InvoiceStatus>,
pub force_refresh: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvoiceList {
pub invoices: Vec<Invoice>,
pub has_more: bool,
pub next_cursor: Option<String>,
}
#[async_trait]
pub trait StripeInvoiceClient: Send + Sync {
async fn list_invoices(
&self,
customer_id: &str,
limit: u8,
starting_after: Option<&str>,
status: Option<InvoiceStatus>,
) -> Result<InvoiceList>;
async fn get_invoice(&self, invoice_id: &str) -> Result<Invoice>;
async fn get_upcoming_invoice(&self, customer_id: &str) -> Result<Option<Invoice>>;
async fn list_invoice_line_items(
&self,
invoice_id: &str,
limit: u8,
) -> Result<Vec<InvoiceLineItem>>;
}
#[async_trait]
pub trait InvoiceOperations: Send + Sync {
async fn list_invoices(
&self,
billable_id: &str,
params: InvoiceListParams,
) -> Result<InvoiceList>;
async fn get_invoice(&self, billable_id: &str, invoice_id: &str) -> Result<Invoice>;
async fn get_upcoming_invoice(&self, billable_id: &str) -> Result<Option<Invoice>>;
async fn get_invoice_line_items(
&self,
billable_id: &str,
invoice_id: &str,
limit: Option<u8>,
) -> Result<Vec<InvoiceLineItem>>;
}
pub struct InvoiceManager<S, C> {
store: S,
client: C,
config: InvoiceConfig,
}
impl<S: Clone, C: Clone> Clone for InvoiceManager<S, C> {
fn clone(&self) -> Self {
Self {
store: self.store.clone(),
client: self.client.clone(),
config: self.config.clone(),
}
}
}
impl<S: BillingStore, C: StripeInvoiceClient> InvoiceManager<S, C> {
#[must_use]
pub fn new(store: S, client: C) -> Self {
Self::with_config(store, client, InvoiceConfig::default())
}
#[must_use]
pub fn with_config(store: S, client: C, config: InvoiceConfig) -> Self {
Self {
store,
client,
config,
}
}
#[must_use]
pub fn config(&self) -> &InvoiceConfig {
&self.config
}
pub async fn list_invoices(
&self,
billable_id: &str,
params: InvoiceListParams,
) -> Result<InvoiceList> {
tracing::debug!(
billable_id = %billable_id,
limit = ?params.limit,
status = ?params.status,
"listing invoices"
);
let customer_id = self.get_customer_id(billable_id).await?;
let limit = params.limit.unwrap_or(self.config.default_limit);
let status = params.status.or(self.config.default_status_filter);
let result = self
.client
.list_invoices(
&customer_id,
limit,
params.starting_after.as_deref(),
status,
)
.await?;
tracing::debug!(
billable_id = %billable_id,
count = result.invoices.len(),
has_more = result.has_more,
"listed invoices"
);
Ok(result)
}
pub async fn get_invoice(&self, billable_id: &str, invoice_id: &str) -> Result<Invoice> {
tracing::debug!(
billable_id = %billable_id,
invoice_id = %invoice_id,
"fetching invoice"
);
let customer_id = self.get_customer_id(billable_id).await?;
let mut invoice = self.client.get_invoice(invoice_id).await?;
if invoice.customer_id != customer_id {
tracing::warn!(
billable_id = %billable_id,
invoice_id = %invoice_id,
"invoice ownership verification failed"
);
return Err(BillingError::InvoiceNotFound {
invoice_id: sanitize_invoice_id(invoice_id),
}
.into());
}
if self.config.include_line_items && invoice.line_items.is_none() {
tracing::debug!(invoice_id = %invoice_id, "fetching line items");
let line_items = self
.client
.list_invoice_line_items(invoice_id, self.config.max_line_items)
.await?;
invoice.line_items = Some(line_items);
}
tracing::debug!(
billable_id = %billable_id,
invoice_id = %invoice_id,
status = %invoice.status,
"fetched invoice"
);
Ok(invoice)
}
pub async fn get_upcoming_invoice(&self, billable_id: &str) -> Result<Option<Invoice>> {
tracing::debug!(billable_id = %billable_id, "fetching upcoming invoice");
let customer_id = self.get_customer_id(billable_id).await?;
self.client.get_upcoming_invoice(&customer_id).await
}
pub async fn get_invoice_line_items(
&self,
billable_id: &str,
invoice_id: &str,
limit: Option<u8>,
) -> Result<Vec<InvoiceLineItem>> {
tracing::debug!(
billable_id = %billable_id,
invoice_id = %invoice_id,
"fetching invoice line items"
);
let customer_id = self.get_customer_id(billable_id).await?;
let invoice = self.client.get_invoice(invoice_id).await?;
if invoice.customer_id != customer_id {
tracing::warn!(
billable_id = %billable_id,
invoice_id = %invoice_id,
"invoice ownership verification failed for line items"
);
return Err(BillingError::InvoiceNotFound {
invoice_id: sanitize_invoice_id(invoice_id),
}
.into());
}
let limit = limit.unwrap_or(self.config.max_line_items);
self.client.list_invoice_line_items(invoice_id, limit).await
}
async fn get_customer_id(&self, billable_id: &str) -> Result<String> {
self.store
.get_stripe_customer_id(billable_id)
.await?
.ok_or_else(|| {
BillingError::NoCustomer {
billable_id: billable_id.to_string(),
}
.into()
})
}
}
#[async_trait]
impl<S: BillingStore, C: StripeInvoiceClient> InvoiceOperations for InvoiceManager<S, C> {
async fn list_invoices(
&self,
billable_id: &str,
params: InvoiceListParams,
) -> Result<InvoiceList> {
self.list_invoices(billable_id, params).await
}
async fn get_invoice(&self, billable_id: &str, invoice_id: &str) -> Result<Invoice> {
self.get_invoice(billable_id, invoice_id).await
}
async fn get_upcoming_invoice(&self, billable_id: &str) -> Result<Option<Invoice>> {
self.get_upcoming_invoice(billable_id).await
}
async fn get_invoice_line_items(
&self,
billable_id: &str,
invoice_id: &str,
limit: Option<u8>,
) -> Result<Vec<InvoiceLineItem>> {
self.get_invoice_line_items(billable_id, invoice_id, limit)
.await
}
}
const DEFAULT_MAX_CACHE_ENTRIES: usize = 1000;
const CLEANUP_INTERVAL: u64 = 100;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ListCacheKey {
billable_id: String,
limit: u8,
starting_after: Option<String>,
status: Option<InvoiceStatus>,
}
impl ListCacheKey {
fn new(billable_id: &str, params: &InvoiceListParams) -> Self {
Self {
billable_id: billable_id.to_string(),
limit: params.limit.unwrap_or(0),
starting_after: params.starting_after.clone(),
status: params.status,
}
}
fn matches_billable_id(&self, billable_id: &str) -> bool {
self.billable_id == billable_id
}
}
pub struct CachedInvoiceManager<S: BillingStore, C: StripeInvoiceClient> {
inner: InvoiceManager<S, C>,
cache: std::sync::Arc<tokio::sync::RwLock<InvoiceCache>>,
ttl: std::time::Duration,
max_entries: usize,
operation_counter: std::sync::atomic::AtomicU64,
}
struct InvoiceCache {
lists: std::collections::HashMap<ListCacheKey, CacheEntry<std::sync::Arc<InvoiceList>>>,
upcoming: std::collections::HashMap<String, CacheEntry<Option<Invoice>>>,
}
struct CacheEntry<T> {
data: T,
expires_at: std::time::Instant,
last_accessed: std::time::Instant,
}
impl<S: BillingStore, C: StripeInvoiceClient> CachedInvoiceManager<S, C> {
#[must_use]
pub fn new(inner: InvoiceManager<S, C>, ttl: std::time::Duration) -> Self {
Self::with_max_entries(inner, ttl, DEFAULT_MAX_CACHE_ENTRIES)
}
#[must_use]
pub fn with_max_entries(
inner: InvoiceManager<S, C>,
ttl: std::time::Duration,
max_entries: usize,
) -> Self {
Self {
inner,
cache: std::sync::Arc::new(tokio::sync::RwLock::new(InvoiceCache {
lists: std::collections::HashMap::new(),
upcoming: std::collections::HashMap::new(),
})),
ttl,
max_entries,
operation_counter: std::sync::atomic::AtomicU64::new(0),
}
}
#[must_use]
pub fn config(&self) -> &InvoiceConfig {
self.inner.config()
}
pub async fn invalidate(&self, billable_id: &str) {
tracing::debug!(billable_id = %billable_id, "invalidating invoice cache");
let mut cache = self.cache.write().await;
cache
.lists
.retain(|k, _| !k.matches_billable_id(billable_id));
cache.upcoming.remove(billable_id);
}
pub async fn clear(&self) {
tracing::debug!("clearing entire invoice cache");
let mut cache = self.cache.write().await;
cache.lists.clear();
cache.upcoming.clear();
}
#[must_use]
pub fn cache_size(&self) -> usize {
self.cache
.try_read()
.map_or(0, |c| c.lists.len() + c.upcoming.len())
}
async fn enforce_max_entries(&self) {
let mut cache = self.cache.write().await;
let total = cache.lists.len() + cache.upcoming.len();
if total <= self.max_entries {
return;
}
let now = std::time::Instant::now();
cache.lists.retain(|_, v| v.expires_at > now);
cache.upcoming.retain(|_, v| v.expires_at > now);
let total = cache.lists.len() + cache.upcoming.len();
if total > self.max_entries {
let to_remove = total - self.max_entries;
let mut removed = 0;
let sample_size = (to_remove * 5).min(cache.lists.len());
if sample_size > 0 && !cache.lists.is_empty() {
let sample: Vec<_> = cache
.lists
.iter()
.take(sample_size)
.map(|(k, v)| (k.clone(), v.last_accessed))
.collect();
let mut sample = sample;
sample.sort_by_key(|(_, t)| *t);
for (key, _) in sample.into_iter().take(to_remove) {
if cache.lists.remove(&key).is_some() {
removed += 1;
}
if removed >= to_remove {
break;
}
}
}
if removed > 0 {
tracing::debug!(removed = removed, "evicted cache entries via sampling");
}
}
}
pub async fn list_invoices(
&self,
billable_id: &str,
params: InvoiceListParams,
) -> Result<InvoiceList> {
self.maybe_cleanup().await;
let cache_key = ListCacheKey::new(billable_id, ¶ms);
if !params.force_refresh {
let cache = self.cache.read().await;
if let Some(entry) = cache.lists.get(&cache_key) {
if entry.expires_at > std::time::Instant::now() {
tracing::debug!(billable_id = %billable_id, "invoice list cache hit");
return Ok((*entry.data).clone());
}
}
} else {
tracing::debug!(billable_id = %billable_id, "force refresh requested");
}
let result = self.inner.list_invoices(billable_id, params).await?;
{
let mut cache = self.cache.write().await;
let now = std::time::Instant::now();
cache.lists.insert(
cache_key,
CacheEntry {
data: std::sync::Arc::new(result.clone()),
expires_at: now + self.ttl,
last_accessed: now,
},
);
}
Ok(result)
}
pub async fn get_invoice(&self, billable_id: &str, invoice_id: &str) -> Result<Invoice> {
self.inner.get_invoice(billable_id, invoice_id).await
}
pub async fn get_upcoming_invoice(&self, billable_id: &str) -> Result<Option<Invoice>> {
self.maybe_cleanup().await;
{
let cache = self.cache.read().await;
if let Some(entry) = cache.upcoming.get(billable_id) {
if entry.expires_at > std::time::Instant::now() {
tracing::debug!(billable_id = %billable_id, "upcoming invoice cache hit");
return Ok(entry.data.clone());
}
}
}
let result = self.inner.get_upcoming_invoice(billable_id).await?;
{
let mut cache = self.cache.write().await;
let now = std::time::Instant::now();
cache.upcoming.insert(
billable_id.to_string(),
CacheEntry {
data: result.clone(),
expires_at: now + self.ttl,
last_accessed: now,
},
);
}
Ok(result)
}
pub async fn get_invoice_line_items(
&self,
billable_id: &str,
invoice_id: &str,
limit: Option<u8>,
) -> Result<Vec<InvoiceLineItem>> {
self.inner
.get_invoice_line_items(billable_id, invoice_id, limit)
.await
}
async fn maybe_cleanup(&self) {
let count = self
.operation_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if count % CLEANUP_INTERVAL == 0 {
self.enforce_max_entries().await;
}
}
}
#[async_trait]
impl<S: BillingStore, C: StripeInvoiceClient> InvoiceOperations for CachedInvoiceManager<S, C> {
async fn list_invoices(
&self,
billable_id: &str,
params: InvoiceListParams,
) -> Result<InvoiceList> {
self.list_invoices(billable_id, params).await
}
async fn get_invoice(&self, billable_id: &str, invoice_id: &str) -> Result<Invoice> {
self.get_invoice(billable_id, invoice_id).await
}
async fn get_upcoming_invoice(&self, billable_id: &str) -> Result<Option<Invoice>> {
self.get_upcoming_invoice(billable_id).await
}
async fn get_invoice_line_items(
&self,
billable_id: &str,
invoice_id: &str,
limit: Option<u8>,
) -> Result<Vec<InvoiceLineItem>> {
self.get_invoice_line_items(billable_id, invoice_id, limit)
.await
}
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct MockStripeInvoiceClient {
invoices: Arc<RwLock<Vec<Invoice>>>,
upcoming: Arc<RwLock<Option<Invoice>>>,
pub default_currency: String,
}
impl Default for MockStripeInvoiceClient {
fn default() -> Self {
Self {
invoices: Arc::new(RwLock::new(Vec::new())),
upcoming: Arc::new(RwLock::new(None)),
default_currency: "gbp".to_string(),
}
}
}
impl MockStripeInvoiceClient {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_currency(currency: impl Into<String>) -> Self {
Self {
default_currency: currency.into().to_lowercase(),
..Self::default()
}
}
pub fn add_invoice(&self, invoice: Invoice) {
if let Ok(mut invoices) = self.invoices.write() {
invoices.push(invoice);
}
}
pub fn set_upcoming(&self, invoice: Option<Invoice>) {
if let Ok(mut upcoming) = self.upcoming.write() {
*upcoming = invoice;
}
}
#[must_use]
pub fn create_test_invoice(id: &str, customer_id: &str, status: InvoiceStatus) -> Invoice {
Self::create_test_invoice_with_currency(id, customer_id, status, "gbp")
}
#[must_use]
pub fn create_test_invoice_with_currency(
id: &str,
customer_id: &str,
status: InvoiceStatus,
currency: &str,
) -> Invoice {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
Invoice {
id: id.to_string(),
customer_id: customer_id.to_string(),
subscription_id: Some("sub_test".to_string()),
status,
amount_due: 2999,
amount_paid: if status == InvoiceStatus::Paid {
2999
} else {
0
},
amount_remaining: if status == InvoiceStatus::Paid {
0
} else {
2999
},
currency: currency.to_lowercase(),
created: now,
due_date: Some(now + 30 * 24 * 60 * 60),
period_start: now,
period_end: now + 30 * 24 * 60 * 60,
invoice_pdf: Some(format!("https://pay.stripe.com/invoice/{}/pdf", id)),
hosted_invoice_url: Some(format!("https://invoice.stripe.com/{}", id)),
number: Some(format!("INV-{}", id)),
paid: status == InvoiceStatus::Paid,
line_items: None,
}
}
}
#[async_trait]
impl StripeInvoiceClient for MockStripeInvoiceClient {
async fn list_invoices(
&self,
customer_id: &str,
limit: u8,
starting_after: Option<&str>,
status: Option<InvoiceStatus>,
) -> Result<InvoiceList> {
let invoices = self
.invoices
.read()
.map_err(|_| crate::error::TidewayError::Internal("Lock poisoned".to_string()))?;
let mut filtered: Vec<Invoice> = invoices
.iter()
.filter(|inv| inv.customer_id == customer_id)
.filter(|inv| status.is_none_or(|s| inv.status == s))
.cloned()
.collect();
if let Some(after) = starting_after {
if let Some(pos) = filtered.iter().position(|inv| inv.id == after) {
filtered = filtered.into_iter().skip(pos + 1).collect();
}
}
let limit = limit as usize;
let has_more = filtered.len() > limit;
let invoices: Vec<_> = filtered.into_iter().take(limit).collect();
let next_cursor = invoices.last().map(|inv| inv.id.clone());
Ok(InvoiceList {
invoices,
has_more,
next_cursor,
})
}
async fn get_invoice(&self, invoice_id: &str) -> Result<Invoice> {
let invoices = self
.invoices
.read()
.map_err(|_| crate::error::TidewayError::Internal("Lock poisoned".to_string()))?;
invoices
.iter()
.find(|inv| inv.id == invoice_id)
.cloned()
.ok_or_else(|| {
BillingError::InvoiceNotFound {
invoice_id: invoice_id.to_string(),
}
.into()
})
}
async fn get_upcoming_invoice(&self, _customer_id: &str) -> Result<Option<Invoice>> {
let upcoming = self
.upcoming
.read()
.map_err(|_| crate::error::TidewayError::Internal("Lock poisoned".to_string()))?;
Ok(upcoming.clone())
}
async fn list_invoice_line_items(
&self,
invoice_id: &str,
_limit: u8,
) -> Result<Vec<InvoiceLineItem>> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
Ok(vec![InvoiceLineItem {
id: format!("il_{}_1", invoice_id),
description: Some("Pro Plan".to_string()),
amount: 2999,
currency: self.default_currency.clone(),
quantity: Some(1),
price_id: Some("price_pro".to_string()),
period_start: now,
period_end: now + 30 * 24 * 60 * 60,
}])
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockStripeInvoiceClient;
use super::*;
use crate::billing::storage::test::InMemoryBillingStore;
#[test]
fn test_invoice_config_builder() {
let config = InvoiceConfig::new()
.with_default_limit(25)
.with_line_items(true)
.with_max_line_items(50)
.with_status_filter(InvoiceStatus::Paid);
assert_eq!(config.default_limit, 25);
assert!(config.include_line_items);
assert_eq!(config.max_line_items, 50);
assert_eq!(config.default_status_filter, Some(InvoiceStatus::Paid));
}
#[test]
fn test_invoice_config_clamping() {
let config = InvoiceConfig::new()
.with_default_limit(200) .with_max_line_items(0);
assert_eq!(config.default_limit, 100);
assert_eq!(config.max_line_items, 1);
}
#[test]
fn test_invoice_status_conversion() {
use std::str::FromStr;
assert_eq!(InvoiceStatus::from_str("paid"), Ok(InvoiceStatus::Paid));
assert_eq!(InvoiceStatus::from_str("open"), Ok(InvoiceStatus::Open));
assert!(InvoiceStatus::from_str("unknown").is_err());
assert_eq!(InvoiceStatus::Paid.as_str(), "paid");
assert_eq!(InvoiceStatus::Draft.as_str(), "draft");
}
#[tokio::test]
async fn test_list_invoices_no_customer() {
let store = InMemoryBillingStore::new();
let client = MockStripeInvoiceClient::new();
let manager = InvoiceManager::new(store, client);
let result = manager
.list_invoices("unknown_org", Default::default())
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_list_invoices_with_customer() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_123", "org", "cus_test")
.await
.unwrap();
let client = MockStripeInvoiceClient::new();
client.add_invoice(MockStripeInvoiceClient::create_test_invoice(
"in_1",
"cus_test",
InvoiceStatus::Paid,
));
client.add_invoice(MockStripeInvoiceClient::create_test_invoice(
"in_2",
"cus_test",
InvoiceStatus::Open,
));
let manager = InvoiceManager::new(store, client);
let result = manager
.list_invoices("org_123", Default::default())
.await
.unwrap();
assert_eq!(result.invoices.len(), 2);
}
#[tokio::test]
async fn test_list_invoices_with_status_filter() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_123", "org", "cus_test")
.await
.unwrap();
let client = MockStripeInvoiceClient::new();
client.add_invoice(MockStripeInvoiceClient::create_test_invoice(
"in_1",
"cus_test",
InvoiceStatus::Paid,
));
client.add_invoice(MockStripeInvoiceClient::create_test_invoice(
"in_2",
"cus_test",
InvoiceStatus::Open,
));
let manager = InvoiceManager::new(store, client);
let result = manager
.list_invoices(
"org_123",
InvoiceListParams {
status: Some(InvoiceStatus::Paid),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(result.invoices.len(), 1);
assert_eq!(result.invoices[0].status, InvoiceStatus::Paid);
}
#[tokio::test]
async fn test_get_invoice_ownership() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_123", "org", "cus_test")
.await
.unwrap();
let client = MockStripeInvoiceClient::new();
client.add_invoice(MockStripeInvoiceClient::create_test_invoice(
"in_1",
"cus_other",
InvoiceStatus::Paid, ));
let manager = InvoiceManager::new(store, client);
let result = manager.get_invoice("org_123", "in_1").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_cached_manager() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_123", "org", "cus_test")
.await
.unwrap();
let client = MockStripeInvoiceClient::new();
client.add_invoice(MockStripeInvoiceClient::create_test_invoice(
"in_1",
"cus_test",
InvoiceStatus::Paid,
));
let inner = InvoiceManager::new(store, client);
let cached = CachedInvoiceManager::new(inner, std::time::Duration::from_secs(60));
assert_eq!(cached.cache_size(), 0);
let result1 = cached
.list_invoices("org_123", Default::default())
.await
.unwrap();
assert_eq!(result1.invoices.len(), 1);
assert!(cached.cache_size() > 0);
let result2 = cached
.list_invoices("org_123", Default::default())
.await
.unwrap();
assert_eq!(result2.invoices.len(), 1);
}
#[tokio::test]
async fn test_cached_manager_invalidate() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_123", "org", "cus_test")
.await
.unwrap();
let client = MockStripeInvoiceClient::new();
let inner = InvoiceManager::new(store, client);
let cached = CachedInvoiceManager::new(inner, std::time::Duration::from_secs(60));
let _ = cached.list_invoices("org_123", Default::default()).await;
assert!(cached.cache_size() > 0);
cached.invalidate("org_123").await;
assert_eq!(cached.cache_size(), 0);
}
}