1use std::collections::HashMap;
23use std::sync::Mutex;
24use std::time::{SystemTime, UNIX_EPOCH};
25
26use hmac::{Hmac, Mac};
27use serde::{Deserialize, Serialize};
28use sha2::Sha256;
29
30use crate::Plugin;
31
32const DEFAULT_TOLERANCE_SECS: u64 = 300;
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum SignatureError {
37 MissingTimestamp,
38 MissingSignature,
39 Replayed,
40 InvalidSignature,
41 BadHeaderFormat,
42}
43
44pub fn verify_signature(
54 header: &str,
55 payload: &[u8],
56 secret: &str,
57 now_unix_secs: u64,
58 tolerance_secs: u64,
59) -> Result<(), SignatureError> {
60 let mut timestamp: Option<u64> = None;
61 let mut sigs: Vec<&str> = Vec::new();
62
63 for part in header.split(',') {
64 let mut kv = part.splitn(2, '=');
65 let key = kv.next().unwrap_or("").trim();
66 let val = kv.next().ok_or(SignatureError::BadHeaderFormat)?.trim();
67 match key {
68 "t" => timestamp = val.parse().ok(),
69 "v1" => sigs.push(val),
70 _ => {} }
72 }
73
74 let ts = timestamp.ok_or(SignatureError::MissingTimestamp)?;
75 if sigs.is_empty() {
76 return Err(SignatureError::MissingSignature);
77 }
78 if now_unix_secs.saturating_sub(ts) > tolerance_secs {
79 return Err(SignatureError::Replayed);
80 }
81
82 let signed_payload = format!("{ts}.");
83 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
84 .map_err(|_| SignatureError::InvalidSignature)?;
85 mac.update(signed_payload.as_bytes());
86 mac.update(payload);
87 let expected = mac.finalize().into_bytes();
88 let expected_hex = hex_encode(&expected);
89
90 if sigs
92 .iter()
93 .any(|s| ct_eq(s.as_bytes(), expected_hex.as_bytes()))
94 {
95 Ok(())
96 } else {
97 Err(SignatureError::InvalidSignature)
98 }
99}
100
101pub fn current_unix_secs() -> u64 {
102 SystemTime::now()
103 .duration_since(UNIX_EPOCH)
104 .map(|d| d.as_secs())
105 .unwrap_or(0)
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct StripeEvent {
115 pub id: String,
116 #[serde(rename = "type")]
117 pub event_type: String,
118 pub created: u64,
119 pub data: serde_json::Value,
120}
121
122impl StripeEvent {
123 pub fn from_payload(bytes: &[u8]) -> Result<Self, serde_json::Error> {
124 serde_json::from_slice(bytes)
125 }
126
127 pub fn object_id(&self) -> Option<&str> {
128 self.data
129 .get("object")
130 .and_then(|o| o.get("id"))
131 .and_then(|v| v.as_str())
132 }
133}
134
135pub struct StripeCustomerStore {
140 map: Mutex<HashMap<String, String>>,
141}
142
143impl StripeCustomerStore {
144 pub fn new() -> Self {
145 Self {
146 map: Mutex::new(HashMap::new()),
147 }
148 }
149
150 pub fn link(&self, user_id: &str, stripe_customer_id: &str) {
151 self.map
152 .lock()
153 .unwrap()
154 .insert(user_id.into(), stripe_customer_id.into());
155 }
156
157 pub fn lookup(&self, user_id: &str) -> Option<String> {
158 self.map.lock().unwrap().get(user_id).cloned()
159 }
160
161 pub fn unlink(&self, user_id: &str) -> Option<String> {
162 self.map.lock().unwrap().remove(user_id)
163 }
164}
165
166impl Default for StripeCustomerStore {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172pub struct StripePlugin {
174 pub customers: StripeCustomerStore,
175 pub webhook_secret: String,
176 pub tolerance_secs: u64,
177}
178
179impl StripePlugin {
180 pub fn new(webhook_secret: impl Into<String>) -> Self {
181 Self {
182 customers: StripeCustomerStore::new(),
183 webhook_secret: webhook_secret.into(),
184 tolerance_secs: DEFAULT_TOLERANCE_SECS,
185 }
186 }
187
188 pub fn verify_webhook(
189 &self,
190 header: &str,
191 payload: &[u8],
192 ) -> Result<StripeEvent, SignatureError> {
193 verify_signature(
194 header,
195 payload,
196 &self.webhook_secret,
197 current_unix_secs(),
198 self.tolerance_secs,
199 )?;
200 StripeEvent::from_payload(payload).map_err(|_| SignatureError::InvalidSignature)
201 }
202}
203
204impl Plugin for StripePlugin {
205 fn name(&self) -> &str {
206 "stripe"
207 }
208}
209
210fn hex_encode(bytes: &[u8]) -> String {
215 const HEX: &[u8] = b"0123456789abcdef";
216 let mut out = String::with_capacity(bytes.len() * 2);
217 for &b in bytes {
218 out.push(HEX[(b >> 4) as usize] as char);
219 out.push(HEX[(b & 0xF) as usize] as char);
220 }
221 out
222}
223
224fn ct_eq(a: &[u8], b: &[u8]) -> bool {
225 if a.len() != b.len() {
226 return false;
227 }
228 let mut diff: u8 = 0;
229 for (x, y) in a.iter().zip(b.iter()) {
230 diff |= x ^ y;
231 }
232 diff == 0
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 fn signed_header(ts: u64, payload: &[u8], secret: &str) -> String {
240 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
241 mac.update(format!("{ts}.").as_bytes());
242 mac.update(payload);
243 let sig = hex_encode(&mac.finalize().into_bytes());
244 format!("t={ts},v1={sig}")
245 }
246
247 #[test]
248 fn verifies_valid_signature() {
249 let payload = br#"{"id":"evt_1","type":"checkout.session.completed","created":1,"data":{"object":{"id":"cs_1"}}}"#;
250 let secret = "whsec_test";
251 let ts = 1_700_000_000;
252 let header = signed_header(ts, payload, secret);
253 verify_signature(&header, payload, secret, ts + 5, 300).unwrap();
254 }
255
256 #[test]
257 fn rejects_tampered_payload() {
258 let secret = "whsec_test";
259 let ts = 1_700_000_000;
260 let header = signed_header(ts, b"original", secret);
261 let err = verify_signature(&header, b"tampered", secret, ts, 300).unwrap_err();
262 assert_eq!(err, SignatureError::InvalidSignature);
263 }
264
265 #[test]
266 fn rejects_wrong_secret() {
267 let payload = b"hi";
268 let header = signed_header(100, payload, "whsec_a");
269 let err = verify_signature(&header, payload, "whsec_b", 100, 300).unwrap_err();
270 assert_eq!(err, SignatureError::InvalidSignature);
271 }
272
273 #[test]
274 fn rejects_replay_outside_tolerance() {
275 let payload = b"hi";
276 let secret = "whsec";
277 let ts = 1_000;
278 let header = signed_header(ts, payload, secret);
279 let err = verify_signature(&header, payload, secret, ts + 1000, 300).unwrap_err();
280 assert_eq!(err, SignatureError::Replayed);
281 }
282
283 #[test]
284 fn rejects_missing_timestamp() {
285 let err = verify_signature("v1=abc", b"hi", "secret", 0, 300).unwrap_err();
286 assert_eq!(err, SignatureError::MissingTimestamp);
287 }
288
289 #[test]
290 fn rejects_missing_signature() {
291 let err = verify_signature("t=100", b"hi", "secret", 100, 300).unwrap_err();
292 assert_eq!(err, SignatureError::MissingSignature);
293 }
294
295 #[test]
296 fn accepts_one_of_multiple_v1_signatures() {
297 let payload = b"hi";
298 let secret = "whsec";
299 let ts = 100;
300 let valid = signed_header(ts, payload, secret);
301 let v1 = valid.split(',').find(|p| p.starts_with("v1=")).unwrap();
303 let header = format!("t={ts},v1=deadbeef,{v1}");
304 verify_signature(&header, payload, secret, ts, 300).unwrap();
305 }
306
307 #[test]
308 fn parses_event_payload() {
309 let bytes = br#"{"id":"evt_X","type":"customer.created","created":42,"data":{"object":{"id":"cus_1"}}}"#;
310 let ev = StripeEvent::from_payload(bytes).unwrap();
311 assert_eq!(ev.id, "evt_X");
312 assert_eq!(ev.event_type, "customer.created");
313 assert_eq!(ev.created, 42);
314 assert_eq!(ev.object_id(), Some("cus_1"));
315 }
316
317 #[test]
318 fn customer_store_round_trip() {
319 let s = StripeCustomerStore::new();
320 s.link("user_1", "cus_abc");
321 assert_eq!(s.lookup("user_1").as_deref(), Some("cus_abc"));
322 assert_eq!(s.unlink("user_1").as_deref(), Some("cus_abc"));
323 assert_eq!(s.lookup("user_1"), None);
324 }
325
326 #[test]
327 fn plugin_verify_webhook_end_to_end() {
328 let secret = "whsec_E2E";
329 let payload = br#"{"id":"evt_1","type":"x","created":1,"data":{}}"#;
330 let plugin = StripePlugin::new(secret);
331 let ts = current_unix_secs();
332 let header = signed_header(ts, payload, secret);
333 let ev = plugin.verify_webhook(&header, payload).unwrap();
334 assert_eq!(ev.event_type, "x");
335 }
336}