Skip to main content

pylon_auth/
stripe.rs

1//! Stripe billing — minimal surface area focused on what auth /
2//! orgs actually need:
3//!   - Create / retrieve a Stripe customer for a user (or org)
4//!   - Create a Checkout Session (subscription or one-time)
5//!   - Verify webhook signatures (avoid trusting unauthenticated POSTs)
6//!   - Map webhook events to a `BillingEvent` enum the host app can match on
7//!
8//! Out of scope (apps can call the Stripe API directly):
9//!   - Invoice / refund / payment-intent management
10//!   - Customer portal (one-line redirect; app handles)
11//!   - Discounts / promo codes (set in Checkout config)
12//!
13//! Stripe API docs: <https://docs.stripe.com/api>
14//! Webhook signing: <https://docs.stripe.com/webhooks/signatures>
15
16use hmac::{Hmac, Mac};
17use serde::{Deserialize, Serialize};
18use sha2::Sha256;
19
20type HmacSha256 = Hmac<Sha256>;
21
22#[derive(Debug, Clone)]
23pub struct StripeConfig {
24    /// Server-side secret key (`sk_live_…` or `sk_test_…`).
25    pub api_key: String,
26    /// Webhook signing secret (`whsec_…`) for the configured endpoint.
27    /// Apps with multiple webhooks should run a separate verifier per
28    /// endpoint with each endpoint's own secret.
29    pub webhook_secret: Option<String>,
30}
31
32impl StripeConfig {
33    pub fn from_env() -> Option<Self> {
34        let api_key = std::env::var("PYLON_STRIPE_API_KEY").ok()?;
35        let webhook_secret = std::env::var("PYLON_STRIPE_WEBHOOK_SECRET").ok();
36        Some(Self {
37            api_key,
38            webhook_secret,
39        })
40    }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct StripeCustomer {
45    pub id: String,
46    #[serde(default)]
47    pub email: Option<String>,
48    #[serde(default)]
49    pub name: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CheckoutSession {
54    pub id: String,
55    /// Hosted checkout URL — what you 302 the user to.
56    pub url: String,
57    #[serde(default)]
58    pub customer: Option<String>,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum CheckoutMode {
63    /// Subscription billing — recurring price.
64    Subscription,
65    /// One-time payment.
66    Payment,
67}
68
69impl CheckoutMode {
70    fn as_str(&self) -> &'static str {
71        match self {
72            Self::Subscription => "subscription",
73            Self::Payment => "payment",
74        }
75    }
76}
77
78impl StripeConfig {
79    /// Create or retrieve a Stripe customer for the given email. If
80    /// you already store `stripeCustomerId` on the user/org row, pass
81    /// it instead — `retrieve_or_create` does the lookup-or-create
82    /// dance based on which field is populated.
83    pub fn create_customer(&self, email: &str, name: Option<&str>) -> Result<StripeCustomer, String> {
84        let mut body = format!("email={}", url_encode(email));
85        if let Some(n) = name {
86            body.push_str("&name=");
87            body.push_str(&url_encode(n));
88        }
89        self.post("https://api.stripe.com/v1/customers", &body)
90    }
91
92    /// Create a Checkout Session — the standard hosted-payment flow.
93    /// `price_ids` are the Stripe Price ids the customer is buying
94    /// (1 for subscriptions, N for cart-style one-time payments).
95    pub fn create_checkout(
96        &self,
97        customer_id: Option<&str>,
98        price_ids: &[&str],
99        mode: CheckoutMode,
100        success_url: &str,
101        cancel_url: &str,
102    ) -> Result<CheckoutSession, String> {
103        let mut body = format!(
104            "mode={}&success_url={}&cancel_url={}",
105            mode.as_str(),
106            url_encode(success_url),
107            url_encode(cancel_url),
108        );
109        if let Some(cid) = customer_id {
110            body.push_str("&customer=");
111            body.push_str(&url_encode(cid));
112        }
113        for (i, pid) in price_ids.iter().enumerate() {
114            body.push_str(&format!(
115                "&line_items[{i}][price]={}&line_items[{i}][quantity]=1",
116                url_encode(pid)
117            ));
118        }
119        self.post("https://api.stripe.com/v1/checkout/sessions", &body)
120    }
121
122    fn post<T: for<'de> Deserialize<'de>>(&self, url: &str, body: &str) -> Result<T, String> {
123        let agent = ureq::AgentBuilder::new()
124            .timeout_connect(std::time::Duration::from_secs(10))
125            .timeout_read(std::time::Duration::from_secs(10))
126            .user_agent("pylon-auth/0.1")
127            .build();
128        let resp = agent
129            .post(url)
130            .set("Authorization", &format!("Bearer {}", self.api_key))
131            .set("Content-Type", "application/x-www-form-urlencoded")
132            .send_string(body)
133            .map_err(|e| match e {
134                ureq::Error::Status(code, r) => {
135                    let body = r.into_string().unwrap_or_default();
136                    format!("stripe HTTP {code}: {body}")
137                }
138                e => format!("stripe network: {e}"),
139            })?;
140        let txt = resp
141            .into_string()
142            .map_err(|e| format!("stripe body: {e}"))?;
143        serde_json::from_str(&txt).map_err(|e| format!("stripe JSON: {e}"))
144    }
145}
146
147// ---------------------------------------------------------------------------
148// Webhook signature verification + event parsing
149// ---------------------------------------------------------------------------
150
151/// Subset of Stripe webhook events pylon directly supports. Apps
152/// receiving any other event get the raw `event_type` string in
153/// [`BillingEvent::Other`] and can match on it themselves.
154#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum BillingEvent {
156    /// `checkout.session.completed` — customer finished checkout.
157    /// Attach `customer_id` + `subscription_id` to your user/org row.
158    CheckoutCompleted {
159        customer_id: Option<String>,
160        subscription_id: Option<String>,
161        client_reference_id: Option<String>,
162    },
163    /// `customer.subscription.updated` / `created` — subscription
164    /// state changed (renewed, plan changed, paused). Map `status`
165    /// to your app's "is this org allowed to use the paid feature"
166    /// gate.
167    SubscriptionChanged {
168        subscription_id: String,
169        customer_id: String,
170        status: String,
171        current_period_end: u64,
172    },
173    /// `customer.subscription.deleted` — subscription canceled or
174    /// ended. Revoke paid access.
175    SubscriptionDeleted {
176        subscription_id: String,
177        customer_id: String,
178    },
179    /// `invoice.payment_failed` — billing problem; usually pylon
180    /// surfaces this to the org's billing email.
181    PaymentFailed {
182        customer_id: String,
183        invoice_id: String,
184    },
185    /// Any event pylon doesn't model. Carries the raw event type
186    /// + the full JSON body for the app to parse.
187    Other {
188        event_type: String,
189        body: serde_json::Value,
190    },
191}
192
193#[derive(Debug, Clone, PartialEq, Eq)]
194pub enum WebhookError {
195    /// `Stripe-Signature` header missing or malformed.
196    MissingSignature,
197    /// Timestamp older than 5 min or newer than 5 min — replay
198    /// protection per Stripe's docs.
199    StaleTimestamp,
200    /// HMAC-SHA256 mismatch — payload was tampered with or the
201    /// secret is wrong.
202    BadSignature,
203    /// Body wasn't valid JSON.
204    BadJson,
205}
206
207impl std::fmt::Display for WebhookError {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        f.write_str(match self {
210            Self::MissingSignature => "Stripe-Signature header missing",
211            Self::StaleTimestamp => "webhook timestamp outside ±5min tolerance",
212            Self::BadSignature => "webhook signature mismatch",
213            Self::BadJson => "webhook body not valid JSON",
214        })
215    }
216}
217
218/// Verify a webhook payload + parse it into a `BillingEvent`.
219///
220/// `signature_header` is the raw `Stripe-Signature` header value,
221/// shaped `t=<unix_ts>,v1=<hex_sig>[,v0=<old>]`. We accept any v1
222/// matching the configured secret; v0 is the deprecated scheme and
223/// we ignore it.
224pub fn verify_webhook(
225    secret: &str,
226    body: &[u8],
227    signature_header: &str,
228    now_secs: u64,
229) -> Result<BillingEvent, WebhookError> {
230    let mut t: Option<u64> = None;
231    let mut v1_sigs: Vec<&str> = Vec::new();
232    // P3-6 (codex Wave-4 review): cap v1 sigs at 8 to prevent
233    // header-amplification DoS — an attacker who can submit a
234    // 10MB Stripe-Signature header would otherwise force 10000
235    // constant-time HMAC comparisons.
236    const MAX_V1_SIGS: usize = 8;
237    for kv in signature_header.split(',') {
238        let kv = kv.trim();
239        if let Some(v) = kv.strip_prefix("t=") {
240            t = v.parse().ok();
241        } else if let Some(v) = kv.strip_prefix("v1=") {
242            if v1_sigs.len() < MAX_V1_SIGS {
243                v1_sigs.push(v);
244            }
245        }
246    }
247    let ts = t.ok_or(WebhookError::MissingSignature)?;
248    if v1_sigs.is_empty() {
249        return Err(WebhookError::MissingSignature);
250    }
251    // ±5min tolerance, Stripe's documented default.
252    let diff = if now_secs > ts { now_secs - ts } else { ts - now_secs };
253    if diff > 5 * 60 {
254        return Err(WebhookError::StaleTimestamp);
255    }
256
257    // Signed payload = "<ts>." + body
258    let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
259        .expect("HMAC accepts any key length");
260    mac.update(format!("{ts}.").as_bytes());
261    mac.update(body);
262    let expected = mac.finalize().into_bytes();
263    let expected_hex = bytes_to_hex(&expected);
264
265    let any_match = v1_sigs
266        .iter()
267        .any(|s| crate::constant_time_eq(s.as_bytes(), expected_hex.as_bytes()));
268    if !any_match {
269        return Err(WebhookError::BadSignature);
270    }
271
272    let body_json: serde_json::Value =
273        serde_json::from_slice(body).map_err(|_| WebhookError::BadJson)?;
274    Ok(parse_event(body_json))
275}
276
277fn parse_event(body: serde_json::Value) -> BillingEvent {
278    let event_type = body
279        .get("type")
280        .and_then(|v| v.as_str())
281        .unwrap_or("")
282        .to_string();
283    let object = body.pointer("/data/object").cloned().unwrap_or_default();
284    match event_type.as_str() {
285        "checkout.session.completed" => BillingEvent::CheckoutCompleted {
286            customer_id: object
287                .get("customer")
288                .and_then(|v| v.as_str())
289                .map(String::from),
290            subscription_id: object
291                .get("subscription")
292                .and_then(|v| v.as_str())
293                .map(String::from),
294            client_reference_id: object
295                .get("client_reference_id")
296                .and_then(|v| v.as_str())
297                .map(String::from),
298        },
299        "customer.subscription.updated" | "customer.subscription.created" => {
300            BillingEvent::SubscriptionChanged {
301                subscription_id: object
302                    .get("id")
303                    .and_then(|v| v.as_str())
304                    .unwrap_or("")
305                    .to_string(),
306                customer_id: object
307                    .get("customer")
308                    .and_then(|v| v.as_str())
309                    .unwrap_or("")
310                    .to_string(),
311                status: object
312                    .get("status")
313                    .and_then(|v| v.as_str())
314                    .unwrap_or("")
315                    .to_string(),
316                current_period_end: object
317                    .get("current_period_end")
318                    .and_then(|v| v.as_u64())
319                    .unwrap_or(0),
320            }
321        }
322        "customer.subscription.deleted" => BillingEvent::SubscriptionDeleted {
323            subscription_id: object
324                .get("id")
325                .and_then(|v| v.as_str())
326                .unwrap_or("")
327                .to_string(),
328            customer_id: object
329                .get("customer")
330                .and_then(|v| v.as_str())
331                .unwrap_or("")
332                .to_string(),
333        },
334        "invoice.payment_failed" => BillingEvent::PaymentFailed {
335            customer_id: object
336                .get("customer")
337                .and_then(|v| v.as_str())
338                .unwrap_or("")
339                .to_string(),
340            invoice_id: object
341                .get("id")
342                .and_then(|v| v.as_str())
343                .unwrap_or("")
344                .to_string(),
345        },
346        _ => BillingEvent::Other {
347            event_type,
348            body,
349        },
350    }
351}
352
353fn url_encode(s: &str) -> String {
354    let mut out = String::with_capacity(s.len());
355    for b in s.bytes() {
356        match b {
357            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
358                out.push(b as char)
359            }
360            _ => out.push_str(&format!("%{b:02X}")),
361        }
362    }
363    out
364}
365
366fn bytes_to_hex(bytes: &[u8]) -> String {
367    use std::fmt::Write;
368    let mut s = String::with_capacity(bytes.len() * 2);
369    for b in bytes {
370        let _ = write!(s, "{b:02x}");
371    }
372    s
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use sha2::Sha256;
379
380    fn sign(secret: &str, ts: u64, body: &[u8]) -> String {
381        let mut mac =
382            Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
383        mac.update(format!("{ts}.").as_bytes());
384        mac.update(body);
385        bytes_to_hex(&mac.finalize().into_bytes())
386    }
387
388    #[test]
389    fn verify_webhook_round_trip_checkout_completed() {
390        let secret = "whsec_test_secret";
391        let body = br#"{
392            "type": "checkout.session.completed",
393            "data": { "object": {
394                "customer": "cus_xyz",
395                "subscription": "sub_abc",
396                "client_reference_id": "user_123"
397            }}
398        }"#;
399        let ts = 1_700_000_000;
400        let sig = sign(secret, ts, body);
401        let header = format!("t={ts},v1={sig}");
402        let event = verify_webhook(secret, body, &header, ts).unwrap();
403        match event {
404            BillingEvent::CheckoutCompleted {
405                customer_id,
406                subscription_id,
407                client_reference_id,
408            } => {
409                assert_eq!(customer_id.as_deref(), Some("cus_xyz"));
410                assert_eq!(subscription_id.as_deref(), Some("sub_abc"));
411                assert_eq!(client_reference_id.as_deref(), Some("user_123"));
412            }
413            other => panic!("expected CheckoutCompleted, got {other:?}"),
414        }
415    }
416
417    #[test]
418    fn verify_webhook_rejects_bad_signature() {
419        let body = b"{}";
420        let ts = 1_700_000_000;
421        let header = format!("t={ts},v1=deadbeefdeadbeef");
422        assert_eq!(
423            verify_webhook("secret", body, &header, ts),
424            Err(WebhookError::BadSignature)
425        );
426    }
427
428    #[test]
429    fn verify_webhook_rejects_stale_timestamp() {
430        let secret = "s";
431        let body = b"{}";
432        let ts = 1_700_000_000;
433        let sig = sign(secret, ts, body);
434        let header = format!("t={ts},v1={sig}");
435        // 6 minutes later — outside Stripe's ±5min tolerance.
436        let now = ts + 6 * 60;
437        assert_eq!(
438            verify_webhook(secret, body, &header, now),
439            Err(WebhookError::StaleTimestamp)
440        );
441    }
442
443    #[test]
444    fn verify_webhook_missing_signature_header() {
445        let body = b"{}";
446        assert_eq!(
447            verify_webhook("s", body, "", 0),
448            Err(WebhookError::MissingSignature)
449        );
450        // Has timestamp but no v1 sig.
451        assert_eq!(
452            verify_webhook("s", body, "t=100", 100),
453            Err(WebhookError::MissingSignature)
454        );
455    }
456
457    #[test]
458    fn parse_subscription_changed() {
459        let body = serde_json::json!({
460            "type": "customer.subscription.updated",
461            "data": { "object": {
462                "id": "sub_xyz",
463                "customer": "cus_abc",
464                "status": "active",
465                "current_period_end": 9_999_999_999u64
466            }}
467        });
468        match parse_event(body) {
469            BillingEvent::SubscriptionChanged {
470                subscription_id,
471                customer_id,
472                status,
473                current_period_end,
474            } => {
475                assert_eq!(subscription_id, "sub_xyz");
476                assert_eq!(customer_id, "cus_abc");
477                assert_eq!(status, "active");
478                assert_eq!(current_period_end, 9_999_999_999);
479            }
480            other => panic!("expected SubscriptionChanged, got {other:?}"),
481        }
482    }
483
484    #[test]
485    fn unknown_event_falls_through_to_other() {
486        let body = serde_json::json!({"type": "some.weird.event", "data": {}});
487        match parse_event(body) {
488            BillingEvent::Other { event_type, .. } => {
489                assert_eq!(event_type, "some.weird.event");
490            }
491            other => panic!("expected Other, got {other:?}"),
492        }
493    }
494
495    #[test]
496    fn webhook_accepts_multiple_v1_sigs() {
497        // Stripe rotates webhook secrets via "endpoint signing secret
498        // rotation"; during the rotation window, a single Stripe-Signature
499        // header carries v1 sigs for both old + new secrets. Verifier
500        // must accept any matching one.
501        let secret = "new_secret";
502        let body = br#"{"type":"x"}"#;
503        let ts = 1_700_000_000;
504        let sig_new = sign(secret, ts, body);
505        let header = format!("t={ts},v1=deadbeef,v1={sig_new}");
506        assert!(verify_webhook(secret, body, &header, ts).is_ok());
507    }
508}