use super::error::BillingError;
use super::storage::BillingStore;
use crate::error::Result;
use url::Url;
pub struct PortalManager<S: BillingStore, C: StripePortalClient> {
store: S,
client: C,
config: PortalConfig,
}
impl<S: BillingStore, C: StripePortalClient> PortalManager<S, C> {
#[must_use]
pub fn new(store: S, client: C, config: PortalConfig) -> Self {
Self {
store,
client,
config,
}
}
pub async fn create_portal_session(
&self,
billable_id: &str,
return_url: &str,
) -> Result<PortalSession> {
self.config.validate_return_url(return_url)?;
let customer_id = self
.store
.get_stripe_customer_id(billable_id)
.await?
.ok_or_else(|| {
crate::error::TidewayError::NotFound("No Stripe customer found".to_string())
})?;
let session = self
.client
.create_portal_session(CreatePortalSessionRequest {
customer_id,
return_url: return_url.to_string(),
configuration_id: self.config.configuration_id.clone(),
})
.await?;
Ok(session)
}
pub async fn create_portal_session_with_flow(
&self,
billable_id: &str,
return_url: &str,
flow: PortalFlow,
) -> Result<PortalSession> {
self.config.validate_return_url(return_url)?;
let customer_id = self
.store
.get_stripe_customer_id(billable_id)
.await?
.ok_or_else(|| {
crate::error::TidewayError::NotFound("No Stripe customer found".to_string())
})?;
let session = self
.client
.create_portal_session_with_flow(
CreatePortalSessionRequest {
customer_id,
return_url: return_url.to_string(),
configuration_id: self.config.configuration_id.clone(),
},
flow,
)
.await?;
Ok(session)
}
}
#[derive(Debug, Clone, Default)]
pub struct PortalConfig {
pub configuration_id: Option<String>,
pub allowed_return_domains: Vec<String>,
pub allow_localhost_http: bool,
}
impl PortalConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn configuration_id(mut self, id: impl Into<String>) -> Self {
self.configuration_id = Some(id.into());
self
}
#[must_use]
pub fn allowed_return_domains<I, S>(mut self, domains: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_return_domains = domains.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn allow_localhost_http(mut self, allow: bool) -> Self {
self.allow_localhost_http = allow;
self
}
fn is_localhost(host: &str) -> bool {
matches!(host, "localhost" | "127.0.0.1" | "[::1]" | "::1")
}
pub fn validate_return_url(&self, url: &str) -> Result<()> {
let parsed = Url::parse(url).map_err(|e| BillingError::InvalidRedirectUrl {
url: url.to_string(),
reason: format!("invalid URL: {}", e),
})?;
let is_https = parsed.scheme() == "https";
let is_localhost_http = self.allow_localhost_http
&& parsed.scheme() == "http"
&& parsed.host_str().map(Self::is_localhost).unwrap_or(false);
if !is_https && !is_localhost_http {
return Err(BillingError::InvalidRedirectUrl {
url: url.to_string(),
reason: "return URL must use HTTPS".to_string(),
}
.into());
}
if !self.allowed_return_domains.is_empty() {
let host = parsed
.host_str()
.ok_or_else(|| BillingError::InvalidRedirectUrl {
url: url.to_string(),
reason: "return URL must have a host".to_string(),
})?;
let domain_allowed = self.allowed_return_domains.iter().any(|allowed| {
host == allowed || host.ends_with(&format!(".{}", allowed))
});
if !domain_allowed {
return Err(BillingError::RedirectDomainNotAllowed {
domain: host.to_string(),
}
.into());
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[must_use]
pub struct PortalSession {
pub id: String,
pub url: String,
}
#[derive(Debug, Clone)]
pub struct CreatePortalSessionRequest {
pub customer_id: String,
pub return_url: String,
pub configuration_id: Option<String>,
}
#[derive(Debug, Clone)]
pub enum PortalFlow {
PaymentMethodUpdate,
SubscriptionUpdate {
subscription_id: String,
},
SubscriptionCancel {
subscription_id: String,
},
}
impl PortalFlow {
#[must_use]
pub fn flow_type(&self) -> &'static str {
match self {
Self::PaymentMethodUpdate => "payment_method_update",
Self::SubscriptionUpdate { .. } => "subscription_update",
Self::SubscriptionCancel { .. } => "subscription_cancel",
}
}
}
#[allow(async_fn_in_trait)]
pub trait StripePortalClient: Send + Sync {
async fn create_portal_session(
&self,
request: CreatePortalSessionRequest,
) -> Result<PortalSession>;
async fn create_portal_session_with_flow(
&self,
request: CreatePortalSessionRequest,
flow: PortalFlow,
) -> Result<PortalSession>;
}
#[cfg(any(test, feature = "test-billing"))]
pub mod test {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Default)]
pub struct MockStripePortalClient {
session_counter: AtomicU64,
}
impl MockStripePortalClient {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl StripePortalClient for MockStripePortalClient {
async fn create_portal_session(
&self,
_request: CreatePortalSessionRequest,
) -> Result<PortalSession> {
let id = format!(
"bps_test_{}",
self.session_counter.fetch_add(1, Ordering::SeqCst)
);
Ok(PortalSession {
id: id.clone(),
url: format!("https://billing.stripe.com/p/session/{}", id),
})
}
async fn create_portal_session_with_flow(
&self,
_request: CreatePortalSessionRequest,
_flow: PortalFlow,
) -> Result<PortalSession> {
let id = format!(
"bps_test_{}",
self.session_counter.fetch_add(1, Ordering::SeqCst)
);
Ok(PortalSession {
id: id.clone(),
url: format!("https://billing.stripe.com/p/session/{}", id),
})
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockStripePortalClient;
use super::*;
use crate::billing::storage::test::InMemoryBillingStore;
#[tokio::test]
async fn test_create_portal_session() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_123", "org", "cus_123")
.await
.unwrap();
let client = MockStripePortalClient::new();
let config = PortalConfig::new();
let manager = PortalManager::new(store, client, config);
let session = manager
.create_portal_session("org_123", "https://example.com/billing")
.await
.unwrap();
assert!(session.id.starts_with("bps_test_"));
assert!(session.url.contains("billing.stripe.com"));
}
#[tokio::test]
async fn test_create_portal_session_no_customer() {
let store = InMemoryBillingStore::new();
let client = MockStripePortalClient::new();
let config = PortalConfig::new();
let manager = PortalManager::new(store, client, config);
let result = manager
.create_portal_session("nonexistent", "https://example.com/billing")
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_create_portal_session_with_flow() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_456", "org", "cus_456")
.await
.unwrap();
let client = MockStripePortalClient::new();
let config = PortalConfig::new();
let manager = PortalManager::new(store, client, config);
let session = manager
.create_portal_session_with_flow(
"org_456",
"https://example.com/billing",
PortalFlow::PaymentMethodUpdate,
)
.await
.unwrap();
assert!(session.id.starts_with("bps_test_"));
}
#[tokio::test]
async fn test_portal_flow_types() {
assert_eq!(
PortalFlow::PaymentMethodUpdate.flow_type(),
"payment_method_update"
);
assert_eq!(
PortalFlow::SubscriptionUpdate {
subscription_id: "sub_123".to_string()
}
.flow_type(),
"subscription_update"
);
assert_eq!(
PortalFlow::SubscriptionCancel {
subscription_id: "sub_123".to_string()
}
.flow_type(),
"subscription_cancel"
);
}
#[test]
fn test_portal_url_validation_https_required() {
let config = PortalConfig::new();
assert!(
config
.validate_return_url("https://example.com/billing")
.is_ok()
);
let result = config.validate_return_url("http://example.com/billing");
assert!(result.is_err());
}
#[test]
fn test_portal_url_validation_invalid_url() {
let config = PortalConfig::new();
assert!(config.validate_return_url("not-a-url").is_err());
assert!(config.validate_return_url("").is_err());
}
#[test]
fn test_portal_url_validation_allowed_domains() {
let config = PortalConfig::new().allowed_return_domains(["example.com", "myapp.io"]);
assert!(
config
.validate_return_url("https://example.com/billing")
.is_ok()
);
assert!(
config
.validate_return_url("https://myapp.io/settings")
.is_ok()
);
assert!(
config
.validate_return_url("https://app.example.com/billing")
.is_ok()
);
assert!(
config
.validate_return_url("https://evil.com/billing")
.is_err()
);
}
#[tokio::test]
async fn test_portal_rejects_invalid_url() {
let store = InMemoryBillingStore::new();
store
.set_stripe_customer_id("org_url", "org", "cus_url")
.await
.unwrap();
let client = MockStripePortalClient::new();
let config = PortalConfig::new().allowed_return_domains(["example.com"]);
let manager = PortalManager::new(store, client, config);
let result = manager
.create_portal_session("org_url", "https://example.com/billing")
.await;
assert!(result.is_ok());
let result = manager
.create_portal_session("org_url", "https://evil.com/billing")
.await;
assert!(result.is_err());
}
#[test]
fn test_portal_url_validation_localhost_http() {
let config = PortalConfig::new();
assert!(
config
.validate_return_url("http://localhost:5173/billing")
.is_err()
);
assert!(
config
.validate_return_url("http://127.0.0.1:3000/billing")
.is_err()
);
let config = PortalConfig::new().allow_localhost_http(true);
assert!(
config
.validate_return_url("http://localhost:5173/billing")
.is_ok()
);
assert!(
config
.validate_return_url("http://127.0.0.1:3000/billing")
.is_ok()
);
assert!(
config
.validate_return_url("http://[::1]:8080/billing")
.is_ok()
);
assert!(
config
.validate_return_url("http://example.com/billing")
.is_err()
);
assert!(
config
.validate_return_url("https://example.com/billing")
.is_ok()
);
}
}