use super::customer::{CustomerManager, StripeClient};
use super::plans::Plans;
use super::storage::{BillableEntity, BillingStore};
use crate::error::Result;
use url::Url;
pub struct CheckoutManager<S: BillingStore, C: StripeClient + StripeCheckoutClient> {
customer_manager: CustomerManager<S, C>,
client: C,
plans: Plans,
config: CheckoutConfig,
}
impl<S: BillingStore + Clone, C: StripeClient + StripeCheckoutClient + Clone>
CheckoutManager<S, C>
{
#[must_use]
pub fn new(store: S, client: C, plans: Plans, config: CheckoutConfig) -> Self {
Self {
customer_manager: CustomerManager::new(store, client.clone()),
client,
plans,
config,
}
}
pub async fn create_checkout_session(
&self,
entity: &impl BillableEntity,
request: CheckoutRequest,
) -> Result<CheckoutSession> {
self.config.validate_redirect_url(&request.success_url)?;
self.config.validate_redirect_url(&request.cancel_url)?;
let allow_promos = request
.allow_promotion_codes
.unwrap_or(self.config.allow_promotion_codes);
if request.coupon.is_some() && allow_promos {
return Err(crate::error::TidewayError::BadRequest(
"Cannot use both a coupon and allow_promotion_codes. \
Either apply a specific coupon or let users enter promotion codes, not both."
.to_string(),
));
}
let plan = self.plans.get(&request.plan_id).ok_or_else(|| {
crate::error::TidewayError::BadRequest(format!("Unknown plan: {}", request.plan_id))
})?;
let customer_id = self.customer_manager.get_or_create_customer(entity).await?;
let mut line_items = vec![CheckoutLineItem {
price_id: plan.stripe_price_id.clone(),
quantity: 1,
}];
if let Some(extra_seats) = request.extra_seats {
if extra_seats > 0 {
let seat_price = plan.extra_seat_price_id.as_ref().ok_or_else(|| {
crate::error::TidewayError::BadRequest(
"Plan does not support extra seats".to_string(),
)
})?;
line_items.push(CheckoutLineItem {
price_id: seat_price.clone(),
quantity: extra_seats,
});
}
}
let trial_days = request.trial_days.or(plan.trial_days);
let session = self
.client
.create_checkout_session(CreateCheckoutSessionRequest {
customer_id,
line_items,
success_url: request.success_url,
cancel_url: request.cancel_url,
mode: CheckoutMode::Subscription,
allow_promotion_codes: request
.allow_promotion_codes
.unwrap_or(self.config.allow_promotion_codes),
trial_period_days: trial_days,
metadata: CheckoutMetadata {
billable_id: entity.billable_id().to_string(),
billable_type: entity.billable_type().to_string(),
plan_id: request.plan_id,
},
tax_id_collection: self.config.collect_tax_id,
billing_address_collection: self.config.collect_billing_address,
coupon: request.coupon,
payment_method_collection: request.payment_method_collection,
})
.await?;
Ok(session)
}
#[deprecated(note = "Use SeatManager::add_seats to update subscriptions with proration.")]
pub async fn create_seat_checkout_session(
&self,
entity: &impl BillableEntity,
request: SeatCheckoutRequest,
) -> Result<CheckoutSession> {
let _ = (entity, request);
Err(crate::error::TidewayError::BadRequest(
"Seat add-ons via Checkout are not supported. Use SeatManager::add_seats instead."
.to_string(),
))
}
}
#[derive(Debug, Clone)]
pub struct CheckoutConfig {
pub allow_promotion_codes: bool,
pub collect_tax_id: bool,
pub collect_billing_address: bool,
pub allowed_redirect_domains: Vec<String>,
pub allow_localhost_http: bool,
}
impl Default for CheckoutConfig {
fn default() -> Self {
Self {
allow_promotion_codes: true,
collect_tax_id: false,
collect_billing_address: false,
allowed_redirect_domains: Vec::new(),
allow_localhost_http: false,
}
}
}
impl CheckoutConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn allow_promotion_codes(mut self, allow: bool) -> Self {
self.allow_promotion_codes = allow;
self
}
#[must_use]
pub fn collect_tax_id(mut self, collect: bool) -> Self {
self.collect_tax_id = collect;
self
}
#[must_use]
pub fn collect_billing_address(mut self, collect: bool) -> Self {
self.collect_billing_address = collect;
self
}
#[must_use]
pub fn allowed_redirect_domains<I, S>(mut self, domains: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_redirect_domains = domains.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn add_allowed_domain(mut self, domain: impl Into<String>) -> Self {
self.allowed_redirect_domains.push(domain.into());
self
}
#[must_use]
pub fn allow_localhost_http(mut self, allow: bool) -> Self {
self.allow_localhost_http = allow;
self
}
fn is_localhost(host: &str) -> bool {
matches!(host, "localhost" | "127.0.0.1" | "[::1]" | "::1")
}
pub fn validate_redirect_url(&self, url: &str) -> Result<()> {
let parsed = Url::parse(url).map_err(|e| {
crate::error::TidewayError::BadRequest(format!("Invalid redirect URL: {}", e))
})?;
let is_https = parsed.scheme() == "https";
let is_localhost_http = self.allow_localhost_http
&& parsed.scheme() == "http"
&& parsed.host_str().map(Self::is_localhost).unwrap_or(false);
if !is_https && !is_localhost_http {
return Err(crate::error::TidewayError::BadRequest(
"Redirect URL must use HTTPS".to_string(),
));
}
if !self.allowed_redirect_domains.is_empty() {
let host = parsed.host_str().ok_or_else(|| {
crate::error::TidewayError::BadRequest("Redirect URL must have a host".to_string())
})?;
let domain_allowed = self.allowed_redirect_domains.iter().any(|allowed| {
host == allowed || host.ends_with(&format!(".{}", allowed))
});
if !domain_allowed {
return Err(crate::error::TidewayError::BadRequest(format!(
"Redirect URL domain '{}' is not allowed",
host
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PaymentMethodCollection {
#[default]
Always,
IfRequired,
}
impl PaymentMethodCollection {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Always => "always",
Self::IfRequired => "if_required",
}
}
}
#[derive(Debug, Clone)]
pub struct CheckoutRequest {
pub plan_id: String,
pub success_url: String,
pub cancel_url: String,
pub extra_seats: Option<u32>,
pub trial_days: Option<u32>,
pub allow_promotion_codes: Option<bool>,
pub coupon: Option<String>,
pub payment_method_collection: Option<PaymentMethodCollection>,
}
impl CheckoutRequest {
#[must_use]
pub fn new(
plan_id: impl Into<String>,
success_url: impl Into<String>,
cancel_url: impl Into<String>,
) -> Self {
Self {
plan_id: plan_id.into(),
success_url: success_url.into(),
cancel_url: cancel_url.into(),
extra_seats: None,
trial_days: None,
allow_promotion_codes: None,
coupon: None,
payment_method_collection: None,
}
}
#[must_use]
pub fn with_extra_seats(mut self, seats: u32) -> Self {
self.extra_seats = Some(seats);
self
}
#[must_use]
pub fn with_trial_days(mut self, days: u32) -> Self {
self.trial_days = Some(days);
self
}
#[must_use]
pub fn with_promotion_codes(mut self, allow: bool) -> Self {
self.allow_promotion_codes = Some(allow);
self
}
#[must_use]
pub fn with_coupon(mut self, coupon: impl Into<String>) -> Self {
self.coupon = Some(coupon.into());
self
}
#[must_use]
pub fn with_payment_method_collection(mut self, collection: PaymentMethodCollection) -> Self {
self.payment_method_collection = Some(collection);
self
}
#[must_use]
pub fn skip_payment_collection(self) -> Self {
self.with_payment_method_collection(PaymentMethodCollection::IfRequired)
}
}
#[derive(Debug, Clone)]
pub struct SeatCheckoutRequest {
pub plan_id: String,
pub seats: u32,
pub success_url: String,
pub cancel_url: String,
}
#[derive(Debug, Clone)]
#[must_use]
pub struct CheckoutSession {
pub id: String,
pub url: String,
}
#[derive(Debug, Clone)]
pub struct CheckoutLineItem {
pub price_id: String,
pub quantity: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CheckoutMode {
Payment,
Subscription,
Setup,
}
impl CheckoutMode {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Payment => "payment",
Self::Subscription => "subscription",
Self::Setup => "setup",
}
}
}
#[derive(Debug, Clone)]
pub struct CheckoutMetadata {
pub billable_id: String,
pub billable_type: String,
pub plan_id: String,
}
#[derive(Debug, Clone)]
pub struct CreateCheckoutSessionRequest {
pub customer_id: String,
pub line_items: Vec<CheckoutLineItem>,
pub success_url: String,
pub cancel_url: String,
pub mode: CheckoutMode,
pub allow_promotion_codes: bool,
pub trial_period_days: Option<u32>,
pub metadata: CheckoutMetadata,
pub tax_id_collection: bool,
pub billing_address_collection: bool,
pub coupon: Option<String>,
pub payment_method_collection: Option<PaymentMethodCollection>,
}
#[allow(async_fn_in_trait)]
pub trait StripeCheckoutClient: Send + Sync {
async fn create_checkout_session(
&self,
request: CreateCheckoutSessionRequest,
) -> Result<CheckoutSession>;
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Default)]
pub struct MockStripeCheckoutClient {
session_counter: AtomicU64,
}
impl MockStripeCheckoutClient {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl StripeCheckoutClient for MockStripeCheckoutClient {
async fn create_checkout_session(
&self,
_request: CreateCheckoutSessionRequest,
) -> Result<CheckoutSession> {
let id = format!(
"cs_test_{}",
self.session_counter.fetch_add(1, Ordering::SeqCst)
);
Ok(CheckoutSession {
id: id.clone(),
url: format!("https://checkout.stripe.com/c/pay/{}", id),
})
}
}
#[derive(Default)]
pub struct MockFullStripeClient {
pub customer: super::super::customer::test::MockStripeClient,
pub checkout: MockStripeCheckoutClient,
}
impl MockFullStripeClient {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl super::super::customer::StripeClient for MockFullStripeClient {
async fn create_customer(
&self,
request: super::super::customer::CreateCustomerRequest,
) -> Result<String> {
self.customer.create_customer(request).await
}
async fn update_customer(
&self,
customer_id: &str,
request: super::super::customer::UpdateCustomerRequest,
) -> Result<()> {
self.customer.update_customer(customer_id, request).await
}
async fn delete_customer(&self, customer_id: &str) -> Result<()> {
self.customer.delete_customer(customer_id).await
}
async fn get_default_payment_method(&self, customer_id: &str) -> Result<Option<String>> {
self.customer.get_default_payment_method(customer_id).await
}
}
impl StripeCheckoutClient for MockFullStripeClient {
async fn create_checkout_session(
&self,
request: CreateCheckoutSessionRequest,
) -> Result<CheckoutSession> {
self.checkout.create_checkout_session(request).await
}
}
impl Clone for MockFullStripeClient {
fn clone(&self) -> Self {
Self::new()
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockFullStripeClient;
use super::*;
use crate::billing::storage::BillableEntity;
use crate::billing::storage::test::InMemoryBillingStore;
struct TestEntity {
id: String,
email: String,
}
impl BillableEntity for TestEntity {
fn billable_id(&self) -> &str {
&self.id
}
fn billable_type(&self) -> &str {
"org"
}
fn email(&self) -> &str {
&self.email
}
fn name(&self) -> Option<&str> {
None
}
}
fn create_test_plans() -> Plans {
Plans::builder()
.plan("starter")
.stripe_price("price_starter")
.extra_seat_price("price_seat")
.included_seats(3)
.trial_days(14)
.done()
.unwrap()
.plan("pro")
.stripe_price("price_pro")
.included_seats(5)
.done()
.unwrap()
.build()
.unwrap()
}
#[tokio::test]
async fn test_create_checkout_session() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::default();
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_123".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"starter",
"https://example.com/success",
"https://example.com/cancel",
);
let session = manager
.create_checkout_session(&entity, request)
.await
.unwrap();
assert!(session.id.starts_with("cs_test_"));
assert!(session.url.contains("checkout.stripe.com"));
}
#[tokio::test]
async fn test_create_checkout_session_with_extra_seats() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::default();
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_456".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"starter",
"https://example.com/success",
"https://example.com/cancel",
)
.with_extra_seats(5);
let session = manager
.create_checkout_session(&entity, request)
.await
.unwrap();
assert!(session.id.starts_with("cs_test_"));
}
#[tokio::test]
async fn test_create_checkout_session_invalid_plan() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::default();
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_789".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"nonexistent",
"https://example.com/success",
"https://example.com/cancel",
);
let result = manager.create_checkout_session(&entity, request).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_create_checkout_session_no_seat_support() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::default();
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_abc".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"pro",
"https://example.com/success",
"https://example.com/cancel",
)
.with_extra_seats(5);
let result = manager.create_checkout_session(&entity, request).await;
assert!(result.is_err());
}
#[test]
fn test_url_validation_https_required() {
let config = CheckoutConfig::new();
assert!(
config
.validate_redirect_url("https://example.com/success")
.is_ok()
);
let result = config.validate_redirect_url("http://example.com/success");
assert!(result.is_err());
}
#[test]
fn test_url_validation_invalid_url() {
let config = CheckoutConfig::new();
let result = config.validate_redirect_url("not-a-url");
assert!(result.is_err());
let result = config.validate_redirect_url("");
assert!(result.is_err());
}
#[test]
fn test_url_validation_allowed_domains() {
let config =
CheckoutConfig::new().allowed_redirect_domains(["example.com", "app.mysite.com"]);
assert!(
config
.validate_redirect_url("https://example.com/success")
.is_ok()
);
assert!(
config
.validate_redirect_url("https://app.mysite.com/cancel")
.is_ok()
);
assert!(
config
.validate_redirect_url("https://app.example.com/success")
.is_ok()
);
assert!(
config
.validate_redirect_url("https://staging.app.mysite.com/success")
.is_ok()
);
let result = config.validate_redirect_url("https://evil.com/redirect");
assert!(result.is_err());
let result = config.validate_redirect_url("https://notexample.com/success");
assert!(result.is_err());
}
#[test]
fn test_url_validation_empty_allowed_list() {
let config = CheckoutConfig::new();
assert!(
config
.validate_redirect_url("https://example.com/success")
.is_ok()
);
assert!(
config
.validate_redirect_url("https://any-domain.com/path")
.is_ok()
);
}
#[tokio::test]
async fn test_checkout_rejects_invalid_url() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::new().allowed_redirect_domains(["example.com"]);
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_url_test".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"starter",
"https://example.com/success",
"https://example.com/cancel",
);
assert!(
manager
.create_checkout_session(&entity, request)
.await
.is_ok()
);
let request = CheckoutRequest::new(
"starter",
"https://evil.com/success",
"https://example.com/cancel",
);
assert!(
manager
.create_checkout_session(&entity, request)
.await
.is_err()
);
}
#[test]
fn test_checkout_request_with_coupon() {
let request = CheckoutRequest::new(
"pro",
"https://example.com/success",
"https://example.com/cancel",
)
.with_coupon("SAVE20")
.with_promotion_codes(false);
assert_eq!(request.plan_id, "pro");
assert_eq!(request.coupon, Some("SAVE20".to_string()));
assert_eq!(request.allow_promotion_codes, Some(false));
}
#[tokio::test]
async fn test_create_checkout_session_with_coupon() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::default();
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_coupon_test".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"starter",
"https://example.com/success",
"https://example.com/cancel",
)
.with_coupon("WELCOME50")
.with_promotion_codes(false);
let session = manager
.create_checkout_session(&entity, request)
.await
.unwrap();
assert!(session.id.starts_with("cs_test_"));
}
#[tokio::test]
async fn test_checkout_rejects_coupon_with_promotion_codes() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::default();
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_conflict_test".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"starter",
"https://example.com/success",
"https://example.com/cancel",
)
.with_coupon("SAVE20")
.with_promotion_codes(true);
let result = manager.create_checkout_session(&entity, request).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("coupon") && err.contains("promotion"));
}
#[tokio::test]
async fn test_checkout_allows_coupon_without_promotion_codes() {
let store = InMemoryBillingStore::new();
let client = MockFullStripeClient::new();
let plans = create_test_plans();
let config = CheckoutConfig::default();
let manager = CheckoutManager::new(store, client, plans, config);
let entity = TestEntity {
id: "org_coupon_only".to_string(),
email: "test@example.com".to_string(),
};
let request = CheckoutRequest::new(
"starter",
"https://example.com/success",
"https://example.com/cancel",
)
.with_coupon("SAVE20")
.with_promotion_codes(false);
let result = manager.create_checkout_session(&entity, request).await;
assert!(result.is_ok());
}
#[test]
fn test_url_validation_localhost_http() {
let config = CheckoutConfig::new();
assert!(
config
.validate_redirect_url("http://localhost:5173/success")
.is_err()
);
assert!(
config
.validate_redirect_url("http://127.0.0.1:3000/cancel")
.is_err()
);
let config = CheckoutConfig::new().allow_localhost_http(true);
assert!(
config
.validate_redirect_url("http://localhost:5173/success")
.is_ok()
);
assert!(
config
.validate_redirect_url("http://127.0.0.1:3000/cancel")
.is_ok()
);
assert!(
config
.validate_redirect_url("http://[::1]:8080/success")
.is_ok()
);
assert!(
config
.validate_redirect_url("http://example.com/success")
.is_err()
);
assert!(
config
.validate_redirect_url("https://example.com/success")
.is_ok()
);
}
}