Skip to main content

pylon_plugin/builtin/
stripe.rs

1//! Stripe billing primitives.
2//!
3//! This plugin is intentionally small. It does NOT make HTTP calls to Stripe
4//! — apps drive the API from TypeScript functions where they already have a
5//! networking story. What this module *does* provide is the security-critical
6//! and easy-to-mess-up bits:
7//!
8//! - **Webhook signature verification** (`verify_signature`) — Stripe rejects
9//!   a webhook if you don't validate the `Stripe-Signature` header against
10//!   your endpoint secret. Getting the HMAC + timestamp comparison right is
11//!   subtle (constant-time compare, replay window). This implementation
12//!   matches Stripe's published reference algorithm.
13//! - **Event payload typing** (`StripeEvent`) — a tiny shape over what arrives
14//!   from a webhook, so app code can match on `event.type` without re-parsing
15//!   raw JSON.
16//! - **Customer lookup state** (`StripeCustomerStore`) — an optional in-memory
17//!   map from app user id → Stripe customer id, useful in dev. Production
18//!   apps store this in their own user table.
19//!
20//! See https://stripe.com/docs/webhooks/signatures for the algorithm spec.
21
22use std::collections::HashMap;
23use std::sync::Mutex;
24use std::time::{SystemTime, UNIX_EPOCH};
25
26use hmac::{Hmac, Mac};
27use serde::{Deserialize, Serialize};
28use sha2::Sha256;
29
30use crate::Plugin;
31
32/// Default Stripe replay window: 5 minutes.
33const DEFAULT_TOLERANCE_SECS: u64 = 300;
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum SignatureError {
37    MissingTimestamp,
38    MissingSignature,
39    Replayed,
40    InvalidSignature,
41    BadHeaderFormat,
42}
43
44/// Verify a Stripe webhook signature.
45///
46/// `header` is the raw `Stripe-Signature` header value. `payload` is the raw
47/// request body bytes (NOT a re-serialized JSON value — Stripe signs the
48/// exact bytes they sent). `secret` is the endpoint signing secret from your
49/// Stripe dashboard (`whsec_...`).
50///
51/// `now_unix_secs` is injected so tests can pin the clock; production callers
52/// pass `current_unix_secs()`.
53pub fn verify_signature(
54    header: &str,
55    payload: &[u8],
56    secret: &str,
57    now_unix_secs: u64,
58    tolerance_secs: u64,
59) -> Result<(), SignatureError> {
60    let mut timestamp: Option<u64> = None;
61    let mut sigs: Vec<&str> = Vec::new();
62
63    for part in header.split(',') {
64        let mut kv = part.splitn(2, '=');
65        let key = kv.next().unwrap_or("").trim();
66        let val = kv.next().ok_or(SignatureError::BadHeaderFormat)?.trim();
67        match key {
68            "t" => timestamp = val.parse().ok(),
69            "v1" => sigs.push(val),
70            _ => {} // older / future schemes ignored
71        }
72    }
73
74    let ts = timestamp.ok_or(SignatureError::MissingTimestamp)?;
75    if sigs.is_empty() {
76        return Err(SignatureError::MissingSignature);
77    }
78    if now_unix_secs.saturating_sub(ts) > tolerance_secs {
79        return Err(SignatureError::Replayed);
80    }
81
82    let signed_payload = format!("{ts}.");
83    let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
84        .map_err(|_| SignatureError::InvalidSignature)?;
85    mac.update(signed_payload.as_bytes());
86    mac.update(payload);
87    let expected = mac.finalize().into_bytes();
88    let expected_hex = hex_encode(&expected);
89
90    // Constant-time compare against any matching v1 signature.
91    if sigs
92        .iter()
93        .any(|s| ct_eq(s.as_bytes(), expected_hex.as_bytes()))
94    {
95        Ok(())
96    } else {
97        Err(SignatureError::InvalidSignature)
98    }
99}
100
101pub fn current_unix_secs() -> u64 {
102    SystemTime::now()
103        .duration_since(UNIX_EPOCH)
104        .map(|d| d.as_secs())
105        .unwrap_or(0)
106}
107
108/// Minimal projection of a Stripe webhook event.
109///
110/// Apps usually want `event_type` to dispatch on, then `data` for the actual
111/// object payload. Skipping the full Stripe schema keeps this plugin compile-
112/// independent of any one Stripe API version.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct StripeEvent {
115    pub id: String,
116    #[serde(rename = "type")]
117    pub event_type: String,
118    pub created: u64,
119    pub data: serde_json::Value,
120}
121
122impl StripeEvent {
123    pub fn from_payload(bytes: &[u8]) -> Result<Self, serde_json::Error> {
124        serde_json::from_slice(bytes)
125    }
126
127    pub fn object_id(&self) -> Option<&str> {
128        self.data
129            .get("object")
130            .and_then(|o| o.get("id"))
131            .and_then(|v| v.as_str())
132    }
133}
134
135/// In-memory mapping from app user id → Stripe customer id.
136///
137/// For dev / tests / single-process deployments. Production apps store this
138/// on the User entity directly so it survives restarts.
139pub struct StripeCustomerStore {
140    map: Mutex<HashMap<String, String>>,
141}
142
143impl StripeCustomerStore {
144    pub fn new() -> Self {
145        Self {
146            map: Mutex::new(HashMap::new()),
147        }
148    }
149
150    pub fn link(&self, user_id: &str, stripe_customer_id: &str) {
151        self.map
152            .lock()
153            .unwrap()
154            .insert(user_id.into(), stripe_customer_id.into());
155    }
156
157    pub fn lookup(&self, user_id: &str) -> Option<String> {
158        self.map.lock().unwrap().get(user_id).cloned()
159    }
160
161    pub fn unlink(&self, user_id: &str) -> Option<String> {
162        self.map.lock().unwrap().remove(user_id)
163    }
164}
165
166impl Default for StripeCustomerStore {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172/// Aggregate plugin so apps can register the whole Stripe surface at once.
173pub struct StripePlugin {
174    pub customers: StripeCustomerStore,
175    pub webhook_secret: String,
176    pub tolerance_secs: u64,
177}
178
179impl StripePlugin {
180    pub fn new(webhook_secret: impl Into<String>) -> Self {
181        Self {
182            customers: StripeCustomerStore::new(),
183            webhook_secret: webhook_secret.into(),
184            tolerance_secs: DEFAULT_TOLERANCE_SECS,
185        }
186    }
187
188    pub fn verify_webhook(
189        &self,
190        header: &str,
191        payload: &[u8],
192    ) -> Result<StripeEvent, SignatureError> {
193        verify_signature(
194            header,
195            payload,
196            &self.webhook_secret,
197            current_unix_secs(),
198            self.tolerance_secs,
199        )?;
200        StripeEvent::from_payload(payload).map_err(|_| SignatureError::InvalidSignature)
201    }
202}
203
204impl Plugin for StripePlugin {
205    fn name(&self) -> &str {
206        "stripe"
207    }
208}
209
210// ---------------------------------------------------------------------------
211// helpers
212// ---------------------------------------------------------------------------
213
214fn hex_encode(bytes: &[u8]) -> String {
215    const HEX: &[u8] = b"0123456789abcdef";
216    let mut out = String::with_capacity(bytes.len() * 2);
217    for &b in bytes {
218        out.push(HEX[(b >> 4) as usize] as char);
219        out.push(HEX[(b & 0xF) as usize] as char);
220    }
221    out
222}
223
224fn ct_eq(a: &[u8], b: &[u8]) -> bool {
225    if a.len() != b.len() {
226        return false;
227    }
228    let mut diff: u8 = 0;
229    for (x, y) in a.iter().zip(b.iter()) {
230        diff |= x ^ y;
231    }
232    diff == 0
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    fn signed_header(ts: u64, payload: &[u8], secret: &str) -> String {
240        let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
241        mac.update(format!("{ts}.").as_bytes());
242        mac.update(payload);
243        let sig = hex_encode(&mac.finalize().into_bytes());
244        format!("t={ts},v1={sig}")
245    }
246
247    #[test]
248    fn verifies_valid_signature() {
249        let payload = br#"{"id":"evt_1","type":"checkout.session.completed","created":1,"data":{"object":{"id":"cs_1"}}}"#;
250        let secret = "whsec_test";
251        let ts = 1_700_000_000;
252        let header = signed_header(ts, payload, secret);
253        verify_signature(&header, payload, secret, ts + 5, 300).unwrap();
254    }
255
256    #[test]
257    fn rejects_tampered_payload() {
258        let secret = "whsec_test";
259        let ts = 1_700_000_000;
260        let header = signed_header(ts, b"original", secret);
261        let err = verify_signature(&header, b"tampered", secret, ts, 300).unwrap_err();
262        assert_eq!(err, SignatureError::InvalidSignature);
263    }
264
265    #[test]
266    fn rejects_wrong_secret() {
267        let payload = b"hi";
268        let header = signed_header(100, payload, "whsec_a");
269        let err = verify_signature(&header, payload, "whsec_b", 100, 300).unwrap_err();
270        assert_eq!(err, SignatureError::InvalidSignature);
271    }
272
273    #[test]
274    fn rejects_replay_outside_tolerance() {
275        let payload = b"hi";
276        let secret = "whsec";
277        let ts = 1_000;
278        let header = signed_header(ts, payload, secret);
279        let err = verify_signature(&header, payload, secret, ts + 1000, 300).unwrap_err();
280        assert_eq!(err, SignatureError::Replayed);
281    }
282
283    #[test]
284    fn rejects_missing_timestamp() {
285        let err = verify_signature("v1=abc", b"hi", "secret", 0, 300).unwrap_err();
286        assert_eq!(err, SignatureError::MissingTimestamp);
287    }
288
289    #[test]
290    fn rejects_missing_signature() {
291        let err = verify_signature("t=100", b"hi", "secret", 100, 300).unwrap_err();
292        assert_eq!(err, SignatureError::MissingSignature);
293    }
294
295    #[test]
296    fn accepts_one_of_multiple_v1_signatures() {
297        let payload = b"hi";
298        let secret = "whsec";
299        let ts = 100;
300        let valid = signed_header(ts, payload, secret);
301        // Pull the v1 portion off and reuse with an extra bogus v1 in front.
302        let v1 = valid.split(',').find(|p| p.starts_with("v1=")).unwrap();
303        let header = format!("t={ts},v1=deadbeef,{v1}");
304        verify_signature(&header, payload, secret, ts, 300).unwrap();
305    }
306
307    #[test]
308    fn parses_event_payload() {
309        let bytes = br#"{"id":"evt_X","type":"customer.created","created":42,"data":{"object":{"id":"cus_1"}}}"#;
310        let ev = StripeEvent::from_payload(bytes).unwrap();
311        assert_eq!(ev.id, "evt_X");
312        assert_eq!(ev.event_type, "customer.created");
313        assert_eq!(ev.created, 42);
314        assert_eq!(ev.object_id(), Some("cus_1"));
315    }
316
317    #[test]
318    fn customer_store_round_trip() {
319        let s = StripeCustomerStore::new();
320        s.link("user_1", "cus_abc");
321        assert_eq!(s.lookup("user_1").as_deref(), Some("cus_abc"));
322        assert_eq!(s.unlink("user_1").as_deref(), Some("cus_abc"));
323        assert_eq!(s.lookup("user_1"), None);
324    }
325
326    #[test]
327    fn plugin_verify_webhook_end_to_end() {
328        let secret = "whsec_E2E";
329        let payload = br#"{"id":"evt_1","type":"x","created":1,"data":{}}"#;
330        let plugin = StripePlugin::new(secret);
331        let ts = current_unix_secs();
332        let header = signed_header(ts, payload, secret);
333        let ev = plugin.verify_webhook(&header, payload).unwrap();
334        assert_eq!(ev.event_type, "x");
335    }
336}