use http::{Method, Request, StatusCode};
use rustauth_core::db::{Create, DbAdapter, DbValue, Where};
use rustauth_stripe::options::{StripeOptions, StripePlan, SubscriptionOptions};
use rustauth_stripe::stripe;
use rustauth_stripe::stripe_api::{StripeClient, StripeRequest, StripeResponse, StripeTransport};
use serde_json::json;
use std::sync::{Arc, Mutex};
use crate::common::webhook::sign_webhook_payload;
#[derive(Default)]
struct CheckoutWebhookTransport {
requests: Mutex<Vec<StripeRequest>>,
}
impl StripeTransport for CheckoutWebhookTransport {
fn send<'a>(
&'a self,
request: StripeRequest,
) -> rustauth_stripe::stripe_api::StripeTransportFuture<'a> {
let body = match request.path.as_str() {
"/v1/subscriptions/stripe_sub_checkout" => json!({
"id": "stripe_sub_checkout",
"object": "subscription",
"status": "active",
"customer": "cus_123",
"items": {
"data": [{
"id": "si_checkout",
"price": {
"id": "price_pro",
"recurring": { "interval": "month" }
},
"quantity": 1,
"current_period_start": 1700000000,
"current_period_end": 1702592000
}]
}
}),
_ => json!({ "id": "ok" }),
};
if let Err(error) = self
.requests
.lock()
.map(|mut requests| requests.push(request))
{
let message = error.to_string();
return Box::pin(async move {
Err(rustauth_stripe::stripe_api::StripeApiError::Transport(
message,
))
});
}
Box::pin(async move { Ok(StripeResponse { status: 200, body }) })
}
}
#[tokio::test]
async fn completes_without_subscription_id_metadata_when_client_reference_set(
) -> Result<(), Box<dyn std::error::Error>> {
let transport = Arc::new(CheckoutWebhookTransport::default());
let secret = "whsec_test".to_owned();
let options = StripeOptions::new(
StripeClient::with_transport("sk_test", transport),
secret.clone(),
)
.subscription(SubscriptionOptions::enabled(vec![
StripePlan::new("pro").price_id("price_pro")
]));
let plugin = stripe(options).unwrap();
let endpoint = plugin
.endpoints
.iter()
.find(|endpoint| endpoint.path == "/stripe/webhook")
.ok_or("webhook endpoint")?;
let adapter = rustauth_core::db::MemoryAdapter::new();
adapter
.create(
Create::new("subscription")
.data("id", DbValue::String("sub_local".to_owned()))
.data("plan", DbValue::String("pro".to_owned()))
.data("reference_id", DbValue::String("user_1".to_owned()))
.data("status", DbValue::String("incomplete".to_owned()))
.data("stripe_customer_id", DbValue::Null)
.data("stripe_subscription_id", DbValue::Null)
.data("cancel_at_period_end", DbValue::Boolean(false))
.data("seats", DbValue::Number(1))
.data("billing_interval", DbValue::String("month".to_owned()))
.force_allow_id(),
)
.await?;
let adapter_arc: Arc<dyn rustauth_core::db::DbAdapter> = Arc::new(adapter.clone());
let context = rustauth_core::context::create_auth_context_with_adapter(
rustauth_core::options::RustAuthOptions {
secret: Some("secret-a-at-least-32-chars-long!!".to_owned()),
base_url: Some("http://localhost:3000".to_owned()),
..rustauth_core::options::RustAuthOptions::default()
},
adapter_arc,
)?;
let payload = br#"{"id":"evt_checkout_ref","type":"checkout.session.completed","data":{"object":{"id":"cs_ref","mode":"subscription","customer":"cus_123","subscription":"stripe_sub_checkout","client_reference_id":"user_1","metadata":{"userId":"user_1","referenceId":"user_1"}}}}"#;
let timestamp = time::OffsetDateTime::now_utc().unix_timestamp();
let signature = sign_webhook_payload(&secret, payload, timestamp)?;
let request = Request::builder()
.method(Method::POST)
.uri("http://localhost:3000/api/auth/stripe/webhook")
.header("stripe-signature", signature)
.body(payload.to_vec())?;
let response = (endpoint.handler)(&context, request).await?;
assert_eq!(response.status(), StatusCode::OK);
let subscription = adapter
.find_one(
rustauth_core::db::FindOne::new("subscription")
.where_clause(Where::new("id", DbValue::String("sub_local".to_owned()))),
)
.await?
.ok_or("subscription")?;
assert_eq!(
subscription.get("stripe_subscription_id"),
Some(&DbValue::String("stripe_sub_checkout".to_owned()))
);
assert_eq!(
subscription.get("status"),
Some(&DbValue::String("active".to_owned()))
);
Ok(())
}
#[tokio::test]
async fn checkout_completed_ignores_injected_stripe_customer_id_metadata(
) -> Result<(), Box<dyn std::error::Error>> {
let transport = Arc::new(CheckoutWebhookTransport::default());
let secret = "whsec_test".to_owned();
let options = StripeOptions::new(
StripeClient::with_transport("sk_test", transport),
secret.clone(),
)
.subscription(SubscriptionOptions::enabled(vec![
StripePlan::new("pro").price_id("price_pro")
]));
let plugin = stripe(options).unwrap();
let endpoint = plugin
.endpoints
.iter()
.find(|endpoint| endpoint.path == "/stripe/webhook")
.ok_or("webhook endpoint")?;
let adapter = rustauth_core::db::MemoryAdapter::new();
adapter
.create(
Create::new("subscription")
.data("id", DbValue::String("sub_local".to_owned()))
.data("plan", DbValue::String("pro".to_owned()))
.data("reference_id", DbValue::String("user_1".to_owned()))
.data("status", DbValue::String("incomplete".to_owned()))
.data("stripe_customer_id", DbValue::Null)
.data("stripe_subscription_id", DbValue::Null)
.data("cancel_at_period_end", DbValue::Boolean(false))
.data("seats", DbValue::Number(1))
.data("billing_interval", DbValue::String("month".to_owned()))
.force_allow_id(),
)
.await?;
let adapter_arc: Arc<dyn rustauth_core::db::DbAdapter> = Arc::new(adapter.clone());
let context = rustauth_core::context::create_auth_context_with_adapter(
rustauth_core::options::RustAuthOptions {
secret: Some("secret-a-at-least-32-chars-long!!".to_owned()),
base_url: Some("http://localhost:3000".to_owned()),
..rustauth_core::options::RustAuthOptions::default()
},
adapter_arc,
)?;
let payload = br#"{"id":"evt_checkout_cus_spoof","type":"checkout.session.completed","data":{"object":{"id":"cs_spoof","mode":"subscription","customer":"cus_123","subscription":"stripe_sub_checkout","client_reference_id":"user_1","metadata":{"userId":"user_1","referenceId":"user_1","subscriptionId":"sub_local","stripeCustomerId":"cus_victim"}}}}"#;
let timestamp = time::OffsetDateTime::now_utc().unix_timestamp();
let signature = sign_webhook_payload(&secret, payload, timestamp)?;
let request = Request::builder()
.method(Method::POST)
.uri("http://localhost:3000/api/auth/stripe/webhook")
.header("stripe-signature", signature)
.body(payload.to_vec())?;
let response = (endpoint.handler)(&context, request).await?;
assert_eq!(response.status(), StatusCode::OK);
let subscription = adapter
.find_one(
rustauth_core::db::FindOne::new("subscription")
.where_clause(Where::new("id", DbValue::String("sub_local".to_owned()))),
)
.await?
.ok_or("subscription")?;
assert_eq!(
subscription.get("stripe_customer_id"),
Some(&DbValue::String("cus_123".to_owned()))
);
Ok(())
}
#[tokio::test]
async fn no_op_when_neither_metadata_nor_reference() -> Result<(), Box<dyn std::error::Error>> {
let transport = Arc::new(CheckoutWebhookTransport::default());
let secret = "whsec_test".to_owned();
let options = StripeOptions::new(
StripeClient::with_transport("sk_test", transport),
secret.clone(),
)
.subscription(SubscriptionOptions::enabled(vec![
StripePlan::new("pro").price_id("price_pro")
]));
let plugin = stripe(options).unwrap();
let endpoint = plugin
.endpoints
.iter()
.find(|endpoint| endpoint.path == "/stripe/webhook")
.ok_or("webhook endpoint")?;
let adapter = rustauth_core::db::MemoryAdapter::new();
let adapter_arc: Arc<dyn rustauth_core::db::DbAdapter> = Arc::new(adapter);
let context = rustauth_core::context::create_auth_context_with_adapter(
rustauth_core::options::RustAuthOptions {
secret: Some("secret-a-at-least-32-chars-long!!".to_owned()),
base_url: Some("http://localhost:3000".to_owned()),
..rustauth_core::options::RustAuthOptions::default()
},
adapter_arc,
)?;
let payload = br#"{"id":"evt_noop","type":"checkout.session.completed","data":{"object":{"id":"cs_noop","mode":"subscription","customer":"cus_123","subscription":"stripe_sub_checkout","metadata":{}}}}"#;
let timestamp = time::OffsetDateTime::now_utc().unix_timestamp();
let signature = sign_webhook_payload(&secret, payload, timestamp)?;
let request = Request::builder()
.method(Method::POST)
.uri("http://localhost:3000/api/auth/stripe/webhook")
.header("stripe-signature", signature)
.body(payload.to_vec())?;
let response = (endpoint.handler)(&context, request).await?;
assert_eq!(response.status(), StatusCode::OK);
Ok(())
}