rustauth-stripe 0.3.0

Stripe integration for RustAuth.
Documentation
#![allow(clippy::unwrap_used)]

use http::{Method, Request, StatusCode};

#[path = "common/mod.rs"]
mod common;
use rustauth_core::context::create_auth_context_with_adapter;
use rustauth_core::db::{Create, DbAdapter, DbValue, FindOne, MemoryAdapter, Where};
use rustauth_core::options::RustAuthOptions;
use rustauth_stripe::options::{FreeTrialOptions, StripeOptions, StripePlan, SubscriptionOptions};
use rustauth_stripe::stripe;
use rustauth_stripe::stripe_api::{
    StripeClient, StripeRequest, StripeResponse, StripeTransport, StripeTransportFuture,
};
use std::sync::{
    atomic::{AtomicUsize, Ordering},
    Arc,
};
use time::OffsetDateTime;

struct RetrieveSubscriptionTransport;

impl StripeTransport for RetrieveSubscriptionTransport {
    fn send<'a>(&'a self, request: StripeRequest) -> StripeTransportFuture<'a> {
        Box::pin(async move {
            if request.path == "/v1/subscriptions/stripe_sub_complete" {
                Ok(StripeResponse {
                    status: 200,
                    body: serde_json::json!({
                        "id": "stripe_sub_complete",
                        "object": "subscription",
                        "customer": "cus_123",
                        "status": "trialing",
                        "cancel_at_period_end": false,
                        "trial_start": 1700000000,
                        "trial_end": 1700604800,
                        "items": {
                            "data": [{
                                "id": "si_complete",
                                "price": {
                                    "id": "price_pro",
                                    "recurring": {
                                        "interval": "month",
                                        "usage_type": "licensed"
                                    }
                                },
                                "quantity": 2,
                                "current_period_start": 1700000000,
                                "current_period_end": 1702592000
                            }]
                        }
                    }),
                })
            } else {
                Ok(StripeResponse {
                    status: 404,
                    body: serde_json::json!({
                        "error": { "message": "not found" }
                    }),
                })
            }
        })
    }
}

#[tokio::test]
async fn checkout_completed_hooks_run_without_failing_webhook(
) -> Result<(), Box<dyn std::error::Error>> {
    let trial_start_calls = Arc::new(AtomicUsize::new(0));
    let complete_calls = Arc::new(AtomicUsize::new(0));
    let trial_start_for_options = Arc::clone(&trial_start_calls);
    let complete_for_options = Arc::clone(&complete_calls);
    let plugin = stripe(
        StripeOptions::new(
            StripeClient::with_transport("sk_test", Arc::new(RetrieveSubscriptionTransport)),
            "whsec_test",
        )
        .subscription(
            SubscriptionOptions::enabled(vec![StripePlan::new("pro")
                .price_id("price_pro")
                .free_trial(
                    FreeTrialOptions::new(14).on_trial_start(move |subscription| {
                        let trial_start_calls = Arc::clone(&trial_start_for_options);
                        Box::pin(async move {
                            assert_eq!(subscription.id, "sub_local");
                            trial_start_calls.fetch_add(1, Ordering::SeqCst);
                            Err(rustauth_core::error::RustAuthError::Api(
                                "trial hook failed".to_owned(),
                            ))
                        })
                    }),
                )])
            .on_subscription_complete(move |input| {
                let complete_calls = Arc::clone(&complete_for_options);
                Box::pin(async move {
                    assert_eq!(input.event.event_type, "checkout.session.completed");
                    assert_eq!(input.subscription.id, "sub_local");
                    assert_eq!(
                        input
                            .stripe_subscription
                            .as_ref()
                            .map(|sub| sub.id.as_str()),
                        Some("stripe_sub_complete")
                    );
                    assert_eq!(input.plan.as_ref().map(|plan| plan.name()), Some("pro"));
                    complete_calls.fetch_add(1, Ordering::SeqCst);
                    Err(rustauth_core::error::RustAuthError::Api(
                        "complete hook failed".to_owned(),
                    ))
                })
            }),
        ),
    )
    .unwrap();
    let (context, adapter) = context_with_user_customer(plugin).await?;
    create_local_subscription(&adapter, "sub_local", "stripe_sub_complete", "incomplete").await?;
    let endpoint = stripe_webhook_endpoint(&context)?;
    let payload = br#"{"id":"evt_complete","type":"checkout.session.completed","data":{"object":{"id":"cs_complete","mode":"subscription","customer":"cus_123","subscription":"stripe_sub_complete","client_reference_id":"user_1","metadata":{"userId":"user_1","referenceId":"user_1","subscriptionId":"sub_local"}}}}"#;

    let response = (endpoint.handler)(&context, signed_webhook_request(payload)?).await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(trial_start_calls.load(Ordering::SeqCst), 1);
    assert_eq!(complete_calls.load(Ordering::SeqCst), 1);
    Ok(())
}

#[tokio::test]
async fn subscription_created_hook_runs_without_failing_webhook(
) -> Result<(), Box<dyn std::error::Error>> {
    let hook_calls = Arc::new(AtomicUsize::new(0));
    let hook_calls_for_options = Arc::clone(&hook_calls);
    let plugin = stripe(
        StripeOptions::new(StripeClient::new("sk_test"), "whsec_test").subscription(
            SubscriptionOptions::enabled(vec![StripePlan::new("pro").price_id("price_pro")])
                .on_subscription_created(move |input| {
                    let hook_calls = Arc::clone(&hook_calls_for_options);
                    Box::pin(async move {
                        assert_eq!(input.event.event_type, "customer.subscription.created");
                        assert_eq!(input.subscription.reference_id, "user_1");
                        assert_eq!(
                            input.subscription.stripe_subscription_id.as_deref(),
                            Some("stripe_sub_created")
                        );
                        assert_eq!(
                            input
                                .stripe_subscription
                                .as_ref()
                                .map(|sub| sub.id.as_str()),
                            Some("stripe_sub_created")
                        );
                        assert_eq!(input.plan.as_ref().map(|plan| plan.name()), Some("pro"));
                        hook_calls.fetch_add(1, Ordering::SeqCst);
                        Err(rustauth_core::error::RustAuthError::Api(
                            "hook failed".to_owned(),
                        ))
                    })
                }),
        ),
    )
    .unwrap();
    let (context, adapter) = context_with_user_customer(plugin).await?;
    let endpoint = stripe_webhook_endpoint(&context)?;
    let payload = br#"{"id":"evt_created","type":"customer.subscription.created","data":{"object":{"id":"stripe_sub_created","customer":"cus_123","status":"trialing","metadata":{},"cancel_at_period_end":false,"trial_start":1700000000,"trial_end":1700604800,"items":{"data":[{"id":"si_created","price":{"id":"price_pro","recurring":{"interval":"month","usage_type":"licensed"}},"quantity":3,"current_period_start":1700000000,"current_period_end":1702592000}]}}}}"#;

    let response = (endpoint.handler)(&context, signed_webhook_request(payload)?).await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(hook_calls.load(Ordering::SeqCst), 1);
    assert!(adapter
        .find_one(FindOne::new("subscription").where_clause(Where::new(
            "stripe_subscription_id",
            DbValue::String("stripe_sub_created".to_owned()),
        )))
        .await?
        .is_some());
    Ok(())
}

#[tokio::test]
async fn subscription_deleted_hook_runs_without_failing_webhook(
) -> Result<(), Box<dyn std::error::Error>> {
    let hook_calls = Arc::new(AtomicUsize::new(0));
    let hook_calls_for_options = Arc::clone(&hook_calls);
    let plugin = stripe(
        StripeOptions::new(StripeClient::new("sk_test"), "whsec_test").subscription(
            SubscriptionOptions::enabled(vec![StripePlan::new("pro").price_id("price_pro")])
                .on_subscription_deleted(move |input| {
                    let hook_calls = Arc::clone(&hook_calls_for_options);
                    Box::pin(async move {
                        assert_eq!(input.event.event_type, "customer.subscription.deleted");
                        assert_eq!(input.subscription.id, "sub_local");
                        assert_eq!(
                            input
                                .stripe_subscription
                                .as_ref()
                                .map(|sub| sub.id.as_str()),
                            Some("stripe_sub_deleted")
                        );
                        hook_calls.fetch_add(1, Ordering::SeqCst);
                        Err(rustauth_core::error::RustAuthError::Api(
                            "hook failed".to_owned(),
                        ))
                    })
                }),
        ),
    )
    .unwrap();
    let (context, adapter) = context_with_user_customer(plugin).await?;
    create_local_subscription(&adapter, "sub_local", "stripe_sub_deleted", "active").await?;
    let endpoint = stripe_webhook_endpoint(&context)?;
    let payload = br#"{"id":"evt_deleted","type":"customer.subscription.deleted","data":{"object":{"id":"stripe_sub_deleted","customer":"cus_123","status":"canceled","metadata":{},"cancel_at_period_end":false,"canceled_at":1700100000,"ended_at":1700200000,"items":{"data":[{"id":"si_deleted","price":{"id":"price_pro","recurring":{"interval":"month","usage_type":"licensed"}},"quantity":1,"current_period_start":1700000000,"current_period_end":1702592000}]}}}}"#;

    let response = (endpoint.handler)(&context, signed_webhook_request(payload)?).await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(hook_calls.load(Ordering::SeqCst), 1);
    Ok(())
}

#[tokio::test]
async fn subscription_cancel_hook_runs_only_for_new_pending_cancel(
) -> Result<(), Box<dyn std::error::Error>> {
    let hook_calls = Arc::new(AtomicUsize::new(0));
    let hook_calls_for_options = Arc::clone(&hook_calls);
    let plugin = stripe(
        StripeOptions::new(StripeClient::new("sk_test"), "whsec_test").subscription(
            SubscriptionOptions::enabled(vec![StripePlan::new("pro").price_id("price_pro")])
                .on_subscription_cancel(move |input| {
                    let hook_calls = Arc::clone(&hook_calls_for_options);
                    Box::pin(async move {
                        assert_eq!(input.subscription.id, "sub_local");
                        assert_eq!(
                            input
                                .cancellation_details
                                .as_ref()
                                .and_then(|details| details.get("reason"))
                                .and_then(serde_json::Value::as_str),
                            Some("cancellation_requested")
                        );
                        hook_calls.fetch_add(1, Ordering::SeqCst);
                        Ok(())
                    })
                }),
        ),
    )
    .unwrap();
    let (context, adapter) = context_with_user_customer(plugin).await?;
    create_local_subscription(&adapter, "sub_local", "stripe_sub_cancel", "active").await?;
    let endpoint = stripe_webhook_endpoint(&context)?;
    let payload = br#"{"id":"evt_cancel","type":"customer.subscription.updated","data":{"object":{"id":"stripe_sub_cancel","customer":"cus_123","status":"active","metadata":{},"cancel_at_period_end":true,"cancellation_details":{"reason":"cancellation_requested"},"items":{"data":[{"id":"si_cancel","price":{"id":"price_pro","recurring":{"interval":"month","usage_type":"licensed"}},"quantity":1,"current_period_start":1700000000,"current_period_end":1702592000}]}}}}"#;

    let response = (endpoint.handler)(&context, signed_webhook_request(payload)?).await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(hook_calls.load(Ordering::SeqCst), 1);
    let second_response = (endpoint.handler)(&context, signed_webhook_request(payload)?).await?;
    assert_eq!(second_response.status(), StatusCode::OK);
    assert_eq!(hook_calls.load(Ordering::SeqCst), 1);
    Ok(())
}

#[tokio::test]
async fn trial_transition_hooks_run_for_trial_end_and_expiration(
) -> Result<(), Box<dyn std::error::Error>> {
    let trial_end_calls = Arc::new(AtomicUsize::new(0));
    let trial_expired_calls = Arc::new(AtomicUsize::new(0));
    let trial_end_for_options = Arc::clone(&trial_end_calls);
    let trial_expired_for_options = Arc::clone(&trial_expired_calls);
    let plugin = stripe(
        StripeOptions::new(StripeClient::new("sk_test"), "whsec_test").subscription(
            SubscriptionOptions::enabled(vec![StripePlan::new("pro")
                .price_id("price_pro")
                .free_trial(
                    FreeTrialOptions::new(14)
                        .on_trial_end(move |subscription, _| {
                            let trial_end_calls = Arc::clone(&trial_end_for_options);
                            Box::pin(async move {
                                assert_eq!(subscription.id, "sub_local");
                                trial_end_calls.fetch_add(1, Ordering::SeqCst);
                                Ok(())
                            })
                        })
                        .on_trial_expired(move |subscription, _| {
                            let trial_expired_calls = Arc::clone(&trial_expired_for_options);
                            Box::pin(async move {
                                assert_eq!(subscription.id, "sub_local");
                                trial_expired_calls.fetch_add(1, Ordering::SeqCst);
                                Ok(())
                            })
                        }),
                )]),
        ),
    )
    .unwrap();
    let (context, adapter) = context_with_user_customer(plugin).await?;
    create_local_subscription(&adapter, "sub_local", "stripe_sub_trial", "trialing").await?;
    let endpoint = stripe_webhook_endpoint(&context)?;
    let ended = br#"{"id":"evt_trial_end","type":"customer.subscription.updated","data":{"object":{"id":"stripe_sub_trial","customer":"cus_123","status":"active","metadata":{},"cancel_at_period_end":false,"items":{"data":[{"id":"si_trial","price":{"id":"price_pro","recurring":{"interval":"month","usage_type":"licensed"}},"quantity":1,"current_period_start":1700000000,"current_period_end":1702592000}]}}}}"#;

    let response = (endpoint.handler)(&context, signed_webhook_request(ended)?).await?;

    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(trial_end_calls.load(Ordering::SeqCst), 1);
    rustauth_core::db::DbAdapter::update(
        &adapter,
        rustauth_core::db::Update::new("subscription")
            .where_clause(Where::new("id", DbValue::String("sub_local".to_owned())))
            .data("status", DbValue::String("trialing".to_owned())),
    )
    .await?;
    let expired = br#"{"id":"evt_trial_expired","type":"customer.subscription.updated","data":{"object":{"id":"stripe_sub_trial","customer":"cus_123","status":"incomplete_expired","metadata":{},"cancel_at_period_end":false,"items":{"data":[{"id":"si_trial","price":{"id":"price_pro","recurring":{"interval":"month","usage_type":"licensed"}},"quantity":1,"current_period_start":1700000000,"current_period_end":1702592000}]}}}}"#;
    let response = (endpoint.handler)(&context, signed_webhook_request(expired)?).await?;
    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(trial_expired_calls.load(Ordering::SeqCst), 1);
    Ok(())
}

async fn context_with_user_customer(
    plugin: rustauth_core::plugin::AuthPlugin,
) -> Result<(rustauth_core::context::AuthContext, MemoryAdapter), Box<dyn std::error::Error>> {
    let adapter = MemoryAdapter::new();
    let now = OffsetDateTime::now_utc();
    adapter
        .create(
            Create::new("user")
                .data("id", DbValue::String("user_1".to_owned()))
                .data("name", DbValue::String("Ada Lovelace".to_owned()))
                .data("email", DbValue::String("ada@example.com".to_owned()))
                .data("email_verified", DbValue::Boolean(true))
                .data("image", DbValue::Null)
                .data("created_at", DbValue::Timestamp(now))
                .data("updated_at", DbValue::Timestamp(now))
                .data("stripe_customer_id", DbValue::String("cus_123".to_owned()))
                .force_allow_id(),
        )
        .await?;
    let adapter_arc: Arc<dyn DbAdapter> = Arc::new(adapter.clone());
    let context = create_auth_context_with_adapter(
        RustAuthOptions {
            secret: Some("secret-a-at-least-32-chars-long!!".to_owned()),
            plugins: vec![plugin],
            ..RustAuthOptions::default()
        },
        adapter_arc,
    )?;
    Ok((context, adapter))
}

fn stripe_webhook_endpoint(
    context: &rustauth_core::context::AuthContext,
) -> Result<&rustauth_core::api::AsyncAuthEndpoint, Box<dyn std::error::Error>> {
    context
        .plugins
        .iter()
        .find(|plugin| plugin.id == "stripe")
        .and_then(|plugin| {
            plugin
                .endpoints
                .iter()
                .find(|endpoint| endpoint.path == "/stripe/webhook")
        })
        .ok_or_else(|| "stripe webhook endpoint missing".into())
}

async fn create_local_subscription(
    adapter: &MemoryAdapter,
    id: &str,
    stripe_subscription_id: &str,
    status: &str,
) -> Result<(), rustauth_core::error::RustAuthError> {
    adapter
        .create(
            Create::new("subscription")
                .data("id", DbValue::String(id.to_owned()))
                .data("plan", DbValue::String("pro".to_owned()))
                .data("reference_id", DbValue::String("user_1".to_owned()))
                .data("stripe_customer_id", DbValue::String("cus_123".to_owned()))
                .data(
                    "stripe_subscription_id",
                    DbValue::String(stripe_subscription_id.to_owned()),
                )
                .data("status", DbValue::String(status.to_owned()))
                .data("cancel_at_period_end", DbValue::Boolean(false))
                .data("cancel_at", DbValue::Null)
                .force_allow_id(),
        )
        .await?;
    Ok(())
}

fn signed_webhook_request(payload: &[u8]) -> Result<Request<Vec<u8>>, Box<dyn std::error::Error>> {
    let timestamp = time::OffsetDateTime::now_utc().unix_timestamp();
    let signature = common::webhook::sign_webhook_payload("whsec_test", payload, timestamp)?;
    Ok(Request::builder()
        .method(Method::POST)
        .uri("http://localhost:3000/api/auth/stripe/webhook")
        .header("stripe-signature", signature)
        .body(payload.to_vec())?)
}