use super::error::BillingError;
use super::storage::StoredPlan;
use crate::error::Result;
use async_trait::async_trait;
const MAX_BILLABLE_ID_LENGTH: usize = 256;
const MAX_PLAN_ID_LENGTH: usize = 64;
pub fn validate_billable_id(id: &str) -> Result<()> {
if id.is_empty() {
return Err(BillingError::InvalidBillableId {
id: id.to_string(),
reason: "billable_id cannot be empty".to_string(),
}
.into());
}
if id.len() > MAX_BILLABLE_ID_LENGTH {
return Err(BillingError::InvalidBillableId {
id: truncate_for_error(id),
reason: format!(
"billable_id exceeds maximum length of {}",
MAX_BILLABLE_ID_LENGTH
),
}
.into());
}
if !id
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
return Err(BillingError::InvalidBillableId {
id: sanitize_for_error(id),
reason: "billable_id contains invalid characters (only alphanumeric, underscore, and hyphen allowed)".to_string(),
}.into());
}
Ok(())
}
pub fn validate_plan_id(id: &str) -> Result<()> {
if id.is_empty() {
return Err(BillingError::InvalidPlanId {
id: id.to_string(),
reason: "plan_id cannot be empty".to_string(),
}
.into());
}
if id.len() > MAX_PLAN_ID_LENGTH {
return Err(BillingError::InvalidPlanId {
id: truncate_for_error(id),
reason: format!("plan_id exceeds maximum length of {}", MAX_PLAN_ID_LENGTH),
}
.into());
}
if !id
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
return Err(BillingError::InvalidPlanId {
id: sanitize_for_error(id),
reason: "plan_id contains invalid characters".to_string(),
}
.into());
}
Ok(())
}
const VALID_CURRENCIES: &[&str] = &[
"usd", "eur", "gbp", "cad", "aud", "jpy", "chf", "sek", "nok", "dkk", "nzd", "sgd", "hkd",
"inr", "brl", "mxn", "pln", "czk", "huf", "ron",
];
const MAX_PLAN_NAME_LENGTH: usize = 128;
const MAX_PLAN_DESCRIPTION_LENGTH: usize = 1024;
const MAX_STRIPE_PRICE_ID_LENGTH: usize = 256;
pub fn validate_plan(plan: &StoredPlan) -> Result<()> {
validate_plan_id(&plan.id)?;
if plan.name.is_empty() {
return Err(BillingError::InvalidPlanId {
id: plan.id.clone(),
reason: "plan name cannot be empty".to_string(),
}
.into());
}
if plan.name.len() > MAX_PLAN_NAME_LENGTH {
return Err(BillingError::InvalidPlanId {
id: plan.id.clone(),
reason: format!(
"plan name exceeds maximum length of {}",
MAX_PLAN_NAME_LENGTH
),
}
.into());
}
if let Some(ref desc) = plan.description {
if desc.len() > MAX_PLAN_DESCRIPTION_LENGTH {
return Err(BillingError::InvalidPlanId {
id: plan.id.clone(),
reason: format!(
"plan description exceeds maximum length of {}",
MAX_PLAN_DESCRIPTION_LENGTH
),
}
.into());
}
}
validate_stripe_price_id(&plan.stripe_price_id, &plan.id)?;
if let Some(ref seat_price_id) = plan.stripe_seat_price_id {
validate_stripe_price_id(seat_price_id, &plan.id)?;
}
if plan.price_cents < 0 {
return Err(BillingError::InvalidPlanId {
id: plan.id.clone(),
reason: "price_cents cannot be negative".to_string(),
}
.into());
}
let currency_lower = plan.currency.to_lowercase();
if !VALID_CURRENCIES.contains(¤cy_lower.as_str()) {
return Err(BillingError::InvalidPlanId {
id: plan.id.clone(),
reason: format!(
"invalid currency '{}', must be a valid ISO 4217 code",
plan.currency
),
}
.into());
}
if plan.included_seats == 0 {
return Err(BillingError::InvalidPlanId {
id: plan.id.clone(),
reason: "included_seats must be at least 1".to_string(),
}
.into());
}
Ok(())
}
fn validate_stripe_price_id(price_id: &str, plan_id: &str) -> Result<()> {
if price_id.is_empty() {
return Err(BillingError::InvalidPlanId {
id: plan_id.to_string(),
reason: "stripe_price_id cannot be empty".to_string(),
}
.into());
}
if price_id.len() > MAX_STRIPE_PRICE_ID_LENGTH {
return Err(BillingError::InvalidPlanId {
id: plan_id.to_string(),
reason: format!(
"stripe_price_id exceeds maximum length of {}",
MAX_STRIPE_PRICE_ID_LENGTH
),
}
.into());
}
if !price_id.starts_with("price_") {
return Err(BillingError::InvalidPlanId {
id: plan_id.to_string(),
reason: "stripe_price_id should start with 'price_'".to_string(),
}
.into());
}
Ok(())
}
fn truncate_for_error(s: &str) -> String {
if s.len() <= 50 {
s.to_string()
} else {
format!("{}...", &s[..47])
}
}
fn sanitize_for_error(s: &str) -> String {
let sanitized: String = s
.chars()
.take(50)
.map(|c| {
if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
c
} else {
'?'
}
})
.collect();
if s.len() > 50 {
format!("{}...", sanitized)
} else {
sanitized
}
}
#[derive(Debug, Clone)]
pub struct StripePrice {
pub id: String,
pub active: bool,
pub currency: String,
pub unit_amount: Option<i64>,
pub interval: Option<String>,
pub product_id: String,
}
#[async_trait]
pub trait StripePriceValidator: Send + Sync {
async fn get_price(&self, price_id: &str) -> Result<Option<StripePrice>>;
}
pub async fn validate_plan_with_stripe<V: StripePriceValidator>(
plan: &StoredPlan,
validator: &V,
) -> Result<()> {
validate_plan(plan)?;
let price = validator.get_price(&plan.stripe_price_id).await?;
match price {
None => {
return Err(BillingError::InvalidStripePrice {
price_id: plan.stripe_price_id.clone(),
reason: "price does not exist in Stripe".to_string(),
}
.into());
}
Some(stripe_price) => {
if !stripe_price.active {
return Err(BillingError::InvalidStripePrice {
price_id: plan.stripe_price_id.clone(),
reason: "price is not active in Stripe".to_string(),
}
.into());
}
if stripe_price.currency.to_lowercase() != plan.currency.to_lowercase() {
return Err(BillingError::InvalidStripePrice {
price_id: plan.stripe_price_id.clone(),
reason: format!(
"currency mismatch: plan has '{}' but Stripe price has '{}'",
plan.currency, stripe_price.currency
),
}
.into());
}
}
}
if let Some(ref seat_price_id) = plan.stripe_seat_price_id {
let seat_price = validator.get_price(seat_price_id).await?;
match seat_price {
None => {
return Err(BillingError::InvalidStripePrice {
price_id: seat_price_id.clone(),
reason: "seat price does not exist in Stripe".to_string(),
}
.into());
}
Some(stripe_price) => {
if !stripe_price.active {
return Err(BillingError::InvalidStripePrice {
price_id: seat_price_id.clone(),
reason: "seat price is not active in Stripe".to_string(),
}
.into());
}
if stripe_price.currency.to_lowercase() != plan.currency.to_lowercase() {
return Err(BillingError::InvalidStripePrice {
price_id: seat_price_id.clone(),
reason: format!(
"seat price currency mismatch: plan has '{}' but Stripe price has '{}'",
plan.currency, stripe_price.currency
),
}
.into());
}
}
}
}
Ok(())
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::collections::HashMap;
use std::sync::RwLock;
#[derive(Default)]
pub struct MockPriceValidator {
prices: RwLock<HashMap<String, StripePrice>>,
}
impl MockPriceValidator {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_price(&self, price: StripePrice) {
self.prices.write().unwrap().insert(price.id.clone(), price);
}
pub fn add_active_price(&self, price_id: &str, currency: &str, amount: i64) {
self.add_price(StripePrice {
id: price_id.to_string(),
active: true,
currency: currency.to_string(),
unit_amount: Some(amount),
interval: Some("month".to_string()),
product_id: format!("prod_{}", price_id.replace("price_", "")),
});
}
}
#[async_trait]
impl StripePriceValidator for MockPriceValidator {
async fn get_price(&self, price_id: &str) -> Result<Option<StripePrice>> {
Ok(self.prices.read().unwrap().get(price_id).cloned())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_billable_id_valid() {
assert!(validate_billable_id("org_123").is_ok());
assert!(validate_billable_id("user-456").is_ok());
assert!(validate_billable_id("ABC123").is_ok());
assert!(validate_billable_id("a").is_ok());
}
#[test]
fn test_validate_billable_id_empty() {
let result = validate_billable_id("");
assert!(result.is_err());
}
#[test]
fn test_validate_billable_id_too_long() {
let long_id = "a".repeat(300);
let result = validate_billable_id(&long_id);
assert!(result.is_err());
}
#[test]
fn test_validate_billable_id_invalid_chars() {
assert!(validate_billable_id("org<script>").is_err());
assert!(validate_billable_id("org 123").is_err());
assert!(validate_billable_id("org/123").is_err());
assert!(validate_billable_id("org\n123").is_err());
assert!(validate_billable_id("org;DROP TABLE").is_err());
}
#[test]
fn test_validate_plan_id_valid() {
assert!(validate_plan_id("starter").is_ok());
assert!(validate_plan_id("pro-monthly").is_ok());
assert!(validate_plan_id("enterprise_annual").is_ok());
}
#[test]
fn test_validate_plan_id_invalid() {
assert!(validate_plan_id("").is_err());
assert!(validate_plan_id("plan with spaces").is_err());
assert!(validate_plan_id(&"a".repeat(100)).is_err());
}
#[test]
fn test_sanitize_for_error() {
assert_eq!(sanitize_for_error("valid_id"), "valid_id");
assert_eq!(sanitize_for_error("has<script>chars"), "has?script?chars");
let long = "a".repeat(100);
let result = sanitize_for_error(&long);
assert!(result.ends_with("..."));
assert!(result.len() <= 53); }
use super::super::storage::PlanInterval;
fn make_valid_plan() -> StoredPlan {
StoredPlan {
id: "starter".to_string(),
name: "Starter Plan".to_string(),
description: Some("A great starter plan".to_string()),
stripe_price_id: "price_abc123".to_string(),
stripe_seat_price_id: None,
price_cents: 999,
currency: "usd".to_string(),
interval: PlanInterval::Monthly,
included_seats: 1,
features: serde_json::json!({}),
limits: serde_json::json!({}),
trial_days: Some(14),
is_active: true,
sort_order: 0,
created_at: 0,
updated_at: 0,
}
}
#[test]
fn test_validate_plan_valid() {
let plan = make_valid_plan();
assert!(validate_plan(&plan).is_ok());
}
#[test]
fn test_validate_plan_empty_name() {
let mut plan = make_valid_plan();
plan.name = "".to_string();
assert!(validate_plan(&plan).is_err());
}
#[test]
fn test_validate_plan_invalid_id() {
let mut plan = make_valid_plan();
plan.id = "plan with spaces".to_string();
assert!(validate_plan(&plan).is_err());
}
#[test]
fn test_validate_plan_invalid_stripe_price() {
let mut plan = make_valid_plan();
plan.stripe_price_id = "invalid".to_string();
assert!(validate_plan(&plan).is_err());
plan.stripe_price_id = "".to_string();
assert!(validate_plan(&plan).is_err());
}
#[test]
fn test_validate_plan_negative_price() {
let mut plan = make_valid_plan();
plan.price_cents = -100;
assert!(validate_plan(&plan).is_err());
}
#[test]
fn test_validate_plan_invalid_currency() {
let mut plan = make_valid_plan();
plan.currency = "xyz".to_string();
assert!(validate_plan(&plan).is_err());
}
#[test]
fn test_validate_plan_zero_seats() {
let mut plan = make_valid_plan();
plan.included_seats = 0;
assert!(validate_plan(&plan).is_err());
}
#[test]
fn test_validate_plan_currencies() {
let mut plan = make_valid_plan();
for currency in &["usd", "EUR", "GBP", "cad", "aud"] {
plan.currency = currency.to_string();
assert!(
validate_plan(&plan).is_ok(),
"Currency {} should be valid",
currency
);
}
}
use super::test::MockPriceValidator;
#[tokio::test]
async fn test_validate_plan_with_stripe_success() {
let validator = MockPriceValidator::new();
validator.add_active_price("price_abc123", "usd", 999);
let plan = make_valid_plan();
assert!(validate_plan_with_stripe(&plan, &validator).await.is_ok());
}
#[tokio::test]
async fn test_validate_plan_with_stripe_price_not_found() {
let validator = MockPriceValidator::new();
let plan = make_valid_plan();
let result = validate_plan_with_stripe(&plan, &validator).await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(err_str.contains("does not exist"));
}
#[tokio::test]
async fn test_validate_plan_with_stripe_price_inactive() {
let validator = MockPriceValidator::new();
validator.add_price(StripePrice {
id: "price_abc123".to_string(),
active: false,
currency: "usd".to_string(),
unit_amount: Some(999),
interval: Some("month".to_string()),
product_id: "prod_abc123".to_string(),
});
let plan = make_valid_plan();
let result = validate_plan_with_stripe(&plan, &validator).await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(err_str.contains("not active"));
}
#[tokio::test]
async fn test_validate_plan_with_stripe_currency_mismatch() {
let validator = MockPriceValidator::new();
validator.add_active_price("price_abc123", "eur", 999);
let plan = make_valid_plan(); let result = validate_plan_with_stripe(&plan, &validator).await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(err_str.contains("currency mismatch"));
}
#[tokio::test]
async fn test_validate_plan_with_stripe_seat_price() {
let validator = MockPriceValidator::new();
validator.add_active_price("price_abc123", "usd", 999);
validator.add_active_price("price_seat123", "usd", 500);
let mut plan = make_valid_plan();
plan.stripe_seat_price_id = Some("price_seat123".to_string());
assert!(validate_plan_with_stripe(&plan, &validator).await.is_ok());
}
#[tokio::test]
async fn test_validate_plan_with_stripe_seat_price_not_found() {
let validator = MockPriceValidator::new();
validator.add_active_price("price_abc123", "usd", 999);
let mut plan = make_valid_plan();
plan.stripe_seat_price_id = Some("price_seat123".to_string());
let result = validate_plan_with_stripe(&plan, &validator).await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(err_str.contains("seat price does not exist"));
}
#[tokio::test]
async fn test_validate_plan_with_stripe_seat_price_currency_mismatch() {
let validator = MockPriceValidator::new();
validator.add_active_price("price_abc123", "usd", 999);
validator.add_active_price("price_seat123", "eur", 500);
let mut plan = make_valid_plan();
plan.stripe_seat_price_id = Some("price_seat123".to_string());
let result = validate_plan_with_stripe(&plan, &validator).await;
assert!(result.is_err());
let err_str = result.unwrap_err().to_string();
assert!(err_str.contains("seat price currency mismatch"));
}
}