Skip to main content

pylon_auth/
stripe.rs

1//! Stripe billing — minimal surface area focused on what auth /
2//! orgs actually need:
3//!   - Create / retrieve a Stripe customer for a user (or org)
4//!   - Create a Checkout Session (subscription or one-time)
5//!   - Verify webhook signatures (avoid trusting unauthenticated POSTs)
6//!   - Map webhook events to a `BillingEvent` enum the host app can match on
7//!
8//! Out of scope (apps can call the Stripe API directly):
9//!   - Invoice / refund / payment-intent management
10//!   - Customer portal (one-line redirect; app handles)
11//!   - Discounts / promo codes (set in Checkout config)
12//!
13//! Stripe API docs: <https://docs.stripe.com/api>
14//! Webhook signing: <https://docs.stripe.com/webhooks/signatures>
15
16use hmac::{Hmac, Mac};
17use serde::{Deserialize, Serialize};
18use sha2::Sha256;
19
20type HmacSha256 = Hmac<Sha256>;
21
22#[derive(Debug, Clone)]
23pub struct StripeConfig {
24    /// Server-side secret key (`sk_live_…` or `sk_test_…`).
25    pub api_key: String,
26    /// Webhook signing secret (`whsec_…`) for the configured endpoint.
27    /// Apps with multiple webhooks should run a separate verifier per
28    /// endpoint with each endpoint's own secret.
29    pub webhook_secret: Option<String>,
30}
31
32impl StripeConfig {
33    pub fn from_env() -> Option<Self> {
34        let api_key = std::env::var("PYLON_STRIPE_API_KEY").ok()?;
35        let webhook_secret = std::env::var("PYLON_STRIPE_WEBHOOK_SECRET").ok();
36        Some(Self {
37            api_key,
38            webhook_secret,
39        })
40    }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct StripeCustomer {
45    pub id: String,
46    #[serde(default)]
47    pub email: Option<String>,
48    #[serde(default)]
49    pub name: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CheckoutSession {
54    pub id: String,
55    /// Hosted checkout URL — what you 302 the user to.
56    pub url: String,
57    #[serde(default)]
58    pub customer: Option<String>,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum CheckoutMode {
63    /// Subscription billing — recurring price.
64    Subscription,
65    /// One-time payment.
66    Payment,
67}
68
69impl CheckoutMode {
70    fn as_str(&self) -> &'static str {
71        match self {
72            Self::Subscription => "subscription",
73            Self::Payment => "payment",
74        }
75    }
76}
77
78impl StripeConfig {
79    /// Create or retrieve a Stripe customer for the given email. If
80    /// you already store `stripeCustomerId` on the user/org row, pass
81    /// it instead — `retrieve_or_create` does the lookup-or-create
82    /// dance based on which field is populated.
83    pub fn create_customer(
84        &self,
85        email: &str,
86        name: Option<&str>,
87    ) -> Result<StripeCustomer, String> {
88        let mut body = format!("email={}", url_encode(email));
89        if let Some(n) = name {
90            body.push_str("&name=");
91            body.push_str(&url_encode(n));
92        }
93        self.post("https://api.stripe.com/v1/customers", &body)
94    }
95
96    /// Create a Checkout Session — the standard hosted-payment flow.
97    /// `price_ids` are the Stripe Price ids the customer is buying
98    /// (1 for subscriptions, N for cart-style one-time payments).
99    pub fn create_checkout(
100        &self,
101        customer_id: Option<&str>,
102        price_ids: &[&str],
103        mode: CheckoutMode,
104        success_url: &str,
105        cancel_url: &str,
106    ) -> Result<CheckoutSession, String> {
107        let mut body = format!(
108            "mode={}&success_url={}&cancel_url={}",
109            mode.as_str(),
110            url_encode(success_url),
111            url_encode(cancel_url),
112        );
113        if let Some(cid) = customer_id {
114            body.push_str("&customer=");
115            body.push_str(&url_encode(cid));
116        }
117        for (i, pid) in price_ids.iter().enumerate() {
118            body.push_str(&format!(
119                "&line_items[{i}][price]={}&line_items[{i}][quantity]=1",
120                url_encode(pid)
121            ));
122        }
123        self.post("https://api.stripe.com/v1/checkout/sessions", &body)
124    }
125
126    fn post<T: for<'de> Deserialize<'de>>(&self, url: &str, body: &str) -> Result<T, String> {
127        let agent = ureq::AgentBuilder::new()
128            .timeout_connect(std::time::Duration::from_secs(10))
129            .timeout_read(std::time::Duration::from_secs(10))
130            .user_agent("pylon-auth/0.1")
131            .build();
132        let resp = agent
133            .post(url)
134            .set("Authorization", &format!("Bearer {}", self.api_key))
135            .set("Content-Type", "application/x-www-form-urlencoded")
136            .send_string(body)
137            .map_err(|e| match e {
138                ureq::Error::Status(code, r) => {
139                    let body = r.into_string().unwrap_or_default();
140                    format!("stripe HTTP {code}: {body}")
141                }
142                e => format!("stripe network: {e}"),
143            })?;
144        let txt = resp
145            .into_string()
146            .map_err(|e| format!("stripe body: {e}"))?;
147        serde_json::from_str(&txt).map_err(|e| format!("stripe JSON: {e}"))
148    }
149}
150
151// ---------------------------------------------------------------------------
152// Webhook signature verification + event parsing
153// ---------------------------------------------------------------------------
154
155/// Subset of Stripe webhook events pylon directly supports. Apps
156/// receiving any other event get the raw `event_type` string in
157/// [`BillingEvent::Other`] and can match on it themselves.
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum BillingEvent {
160    /// `checkout.session.completed` — customer finished checkout.
161    /// Attach `customer_id` + `subscription_id` to your user/org row.
162    CheckoutCompleted {
163        customer_id: Option<String>,
164        subscription_id: Option<String>,
165        client_reference_id: Option<String>,
166    },
167    /// `customer.subscription.updated` / `created` — subscription
168    /// state changed (renewed, plan changed, paused). Map `status`
169    /// to your app's "is this org allowed to use the paid feature"
170    /// gate.
171    SubscriptionChanged {
172        subscription_id: String,
173        customer_id: String,
174        status: String,
175        current_period_end: u64,
176    },
177    /// `customer.subscription.deleted` — subscription canceled or
178    /// ended. Revoke paid access.
179    SubscriptionDeleted {
180        subscription_id: String,
181        customer_id: String,
182    },
183    /// `invoice.payment_failed` — billing problem; usually pylon
184    /// surfaces this to the org's billing email.
185    PaymentFailed {
186        customer_id: String,
187        invoice_id: String,
188    },
189    /// Any event pylon doesn't model. Carries the raw event type
190    /// + the full JSON body for the app to parse.
191    Other {
192        event_type: String,
193        body: serde_json::Value,
194    },
195}
196
197#[derive(Debug, Clone, PartialEq, Eq)]
198pub enum WebhookError {
199    /// `Stripe-Signature` header missing or malformed.
200    MissingSignature,
201    /// Timestamp older than 5 min or newer than 5 min — replay
202    /// protection per Stripe's docs.
203    StaleTimestamp,
204    /// HMAC-SHA256 mismatch — payload was tampered with or the
205    /// secret is wrong.
206    BadSignature,
207    /// Body wasn't valid JSON.
208    BadJson,
209}
210
211impl std::fmt::Display for WebhookError {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.write_str(match self {
214            Self::MissingSignature => "Stripe-Signature header missing",
215            Self::StaleTimestamp => "webhook timestamp outside ±5min tolerance",
216            Self::BadSignature => "webhook signature mismatch",
217            Self::BadJson => "webhook body not valid JSON",
218        })
219    }
220}
221
222/// Verify a webhook payload + parse it into a `BillingEvent`.
223///
224/// `signature_header` is the raw `Stripe-Signature` header value,
225/// shaped `t=<unix_ts>,v1=<hex_sig>[,v0=<old>]`. We accept any v1
226/// matching the configured secret; v0 is the deprecated scheme and
227/// we ignore it.
228pub fn verify_webhook(
229    secret: &str,
230    body: &[u8],
231    signature_header: &str,
232    now_secs: u64,
233) -> Result<BillingEvent, WebhookError> {
234    let mut t: Option<u64> = None;
235    let mut v1_sigs: Vec<&str> = Vec::new();
236    // P3-6 (codex Wave-4 review): cap v1 sigs at 8 to prevent
237    // header-amplification DoS — an attacker who can submit a
238    // 10MB Stripe-Signature header would otherwise force 10000
239    // constant-time HMAC comparisons.
240    const MAX_V1_SIGS: usize = 8;
241    for kv in signature_header.split(',') {
242        let kv = kv.trim();
243        if let Some(v) = kv.strip_prefix("t=") {
244            t = v.parse().ok();
245        } else if let Some(v) = kv.strip_prefix("v1=") {
246            if v1_sigs.len() < MAX_V1_SIGS {
247                v1_sigs.push(v);
248            }
249        }
250    }
251    let ts = t.ok_or(WebhookError::MissingSignature)?;
252    if v1_sigs.is_empty() {
253        return Err(WebhookError::MissingSignature);
254    }
255    // ±5min tolerance, Stripe's documented default.
256    let diff = if now_secs > ts {
257        now_secs - ts
258    } else {
259        ts - now_secs
260    };
261    if diff > 5 * 60 {
262        return Err(WebhookError::StaleTimestamp);
263    }
264
265    // Signed payload = "<ts>." + body
266    let mut mac =
267        HmacSha256::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
268    mac.update(format!("{ts}.").as_bytes());
269    mac.update(body);
270    let expected = mac.finalize().into_bytes();
271    let expected_hex = bytes_to_hex(&expected);
272
273    let any_match = v1_sigs
274        .iter()
275        .any(|s| crate::constant_time_eq(s.as_bytes(), expected_hex.as_bytes()));
276    if !any_match {
277        return Err(WebhookError::BadSignature);
278    }
279
280    let body_json: serde_json::Value =
281        serde_json::from_slice(body).map_err(|_| WebhookError::BadJson)?;
282    Ok(parse_event(body_json))
283}
284
285fn parse_event(body: serde_json::Value) -> BillingEvent {
286    let event_type = body
287        .get("type")
288        .and_then(|v| v.as_str())
289        .unwrap_or("")
290        .to_string();
291    let object = body.pointer("/data/object").cloned().unwrap_or_default();
292    match event_type.as_str() {
293        "checkout.session.completed" => BillingEvent::CheckoutCompleted {
294            customer_id: object
295                .get("customer")
296                .and_then(|v| v.as_str())
297                .map(String::from),
298            subscription_id: object
299                .get("subscription")
300                .and_then(|v| v.as_str())
301                .map(String::from),
302            client_reference_id: object
303                .get("client_reference_id")
304                .and_then(|v| v.as_str())
305                .map(String::from),
306        },
307        "customer.subscription.updated" | "customer.subscription.created" => {
308            BillingEvent::SubscriptionChanged {
309                subscription_id: object
310                    .get("id")
311                    .and_then(|v| v.as_str())
312                    .unwrap_or("")
313                    .to_string(),
314                customer_id: object
315                    .get("customer")
316                    .and_then(|v| v.as_str())
317                    .unwrap_or("")
318                    .to_string(),
319                status: object
320                    .get("status")
321                    .and_then(|v| v.as_str())
322                    .unwrap_or("")
323                    .to_string(),
324                current_period_end: object
325                    .get("current_period_end")
326                    .and_then(|v| v.as_u64())
327                    .unwrap_or(0),
328            }
329        }
330        "customer.subscription.deleted" => BillingEvent::SubscriptionDeleted {
331            subscription_id: object
332                .get("id")
333                .and_then(|v| v.as_str())
334                .unwrap_or("")
335                .to_string(),
336            customer_id: object
337                .get("customer")
338                .and_then(|v| v.as_str())
339                .unwrap_or("")
340                .to_string(),
341        },
342        "invoice.payment_failed" => BillingEvent::PaymentFailed {
343            customer_id: object
344                .get("customer")
345                .and_then(|v| v.as_str())
346                .unwrap_or("")
347                .to_string(),
348            invoice_id: object
349                .get("id")
350                .and_then(|v| v.as_str())
351                .unwrap_or("")
352                .to_string(),
353        },
354        _ => BillingEvent::Other { event_type, body },
355    }
356}
357
358fn url_encode(s: &str) -> String {
359    let mut out = String::with_capacity(s.len());
360    for b in s.bytes() {
361        match b {
362            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
363                out.push(b as char)
364            }
365            _ => out.push_str(&format!("%{b:02X}")),
366        }
367    }
368    out
369}
370
371fn bytes_to_hex(bytes: &[u8]) -> String {
372    use std::fmt::Write;
373    let mut s = String::with_capacity(bytes.len() * 2);
374    for b in bytes {
375        let _ = write!(s, "{b:02x}");
376    }
377    s
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use sha2::Sha256;
384
385    fn sign(secret: &str, ts: u64, body: &[u8]) -> String {
386        let mut mac =
387            Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
388        mac.update(format!("{ts}.").as_bytes());
389        mac.update(body);
390        bytes_to_hex(&mac.finalize().into_bytes())
391    }
392
393    #[test]
394    fn verify_webhook_round_trip_checkout_completed() {
395        let secret = "whsec_test_secret";
396        let body = br#"{
397            "type": "checkout.session.completed",
398            "data": { "object": {
399                "customer": "cus_xyz",
400                "subscription": "sub_abc",
401                "client_reference_id": "user_123"
402            }}
403        }"#;
404        let ts = 1_700_000_000;
405        let sig = sign(secret, ts, body);
406        let header = format!("t={ts},v1={sig}");
407        let event = verify_webhook(secret, body, &header, ts).unwrap();
408        match event {
409            BillingEvent::CheckoutCompleted {
410                customer_id,
411                subscription_id,
412                client_reference_id,
413            } => {
414                assert_eq!(customer_id.as_deref(), Some("cus_xyz"));
415                assert_eq!(subscription_id.as_deref(), Some("sub_abc"));
416                assert_eq!(client_reference_id.as_deref(), Some("user_123"));
417            }
418            other => panic!("expected CheckoutCompleted, got {other:?}"),
419        }
420    }
421
422    #[test]
423    fn verify_webhook_rejects_bad_signature() {
424        let body = b"{}";
425        let ts = 1_700_000_000;
426        let header = format!("t={ts},v1=deadbeefdeadbeef");
427        assert_eq!(
428            verify_webhook("secret", body, &header, ts),
429            Err(WebhookError::BadSignature)
430        );
431    }
432
433    #[test]
434    fn verify_webhook_rejects_stale_timestamp() {
435        let secret = "s";
436        let body = b"{}";
437        let ts = 1_700_000_000;
438        let sig = sign(secret, ts, body);
439        let header = format!("t={ts},v1={sig}");
440        // 6 minutes later — outside Stripe's ±5min tolerance.
441        let now = ts + 6 * 60;
442        assert_eq!(
443            verify_webhook(secret, body, &header, now),
444            Err(WebhookError::StaleTimestamp)
445        );
446    }
447
448    #[test]
449    fn verify_webhook_missing_signature_header() {
450        let body = b"{}";
451        assert_eq!(
452            verify_webhook("s", body, "", 0),
453            Err(WebhookError::MissingSignature)
454        );
455        // Has timestamp but no v1 sig.
456        assert_eq!(
457            verify_webhook("s", body, "t=100", 100),
458            Err(WebhookError::MissingSignature)
459        );
460    }
461
462    #[test]
463    fn parse_subscription_changed() {
464        let body = serde_json::json!({
465            "type": "customer.subscription.updated",
466            "data": { "object": {
467                "id": "sub_xyz",
468                "customer": "cus_abc",
469                "status": "active",
470                "current_period_end": 9_999_999_999u64
471            }}
472        });
473        match parse_event(body) {
474            BillingEvent::SubscriptionChanged {
475                subscription_id,
476                customer_id,
477                status,
478                current_period_end,
479            } => {
480                assert_eq!(subscription_id, "sub_xyz");
481                assert_eq!(customer_id, "cus_abc");
482                assert_eq!(status, "active");
483                assert_eq!(current_period_end, 9_999_999_999);
484            }
485            other => panic!("expected SubscriptionChanged, got {other:?}"),
486        }
487    }
488
489    #[test]
490    fn unknown_event_falls_through_to_other() {
491        let body = serde_json::json!({"type": "some.weird.event", "data": {}});
492        match parse_event(body) {
493            BillingEvent::Other { event_type, .. } => {
494                assert_eq!(event_type, "some.weird.event");
495            }
496            other => panic!("expected Other, got {other:?}"),
497        }
498    }
499
500    #[test]
501    fn webhook_accepts_multiple_v1_sigs() {
502        // Stripe rotates webhook secrets via "endpoint signing secret
503        // rotation"; during the rotation window, a single Stripe-Signature
504        // header carries v1 sigs for both old + new secrets. Verifier
505        // must accept any matching one.
506        let secret = "new_secret";
507        let body = br#"{"type":"x"}"#;
508        let ts = 1_700_000_000;
509        let sig_new = sign(secret, ts, body);
510        let header = format!("t={ts},v1=deadbeef,v1={sig_new}");
511        assert!(verify_webhook(secret, body, &header, ts).is_ok());
512    }
513}