1use hmac::{Hmac, Mac};
17use serde::{Deserialize, Serialize};
18use sha2::Sha256;
19
20type HmacSha256 = Hmac<Sha256>;
21
22#[derive(Debug, Clone)]
23pub struct StripeConfig {
24 pub api_key: String,
26 pub webhook_secret: Option<String>,
30}
31
32impl StripeConfig {
33 pub fn from_env() -> Option<Self> {
34 let api_key = std::env::var("PYLON_STRIPE_API_KEY").ok()?;
35 let webhook_secret = std::env::var("PYLON_STRIPE_WEBHOOK_SECRET").ok();
36 Some(Self {
37 api_key,
38 webhook_secret,
39 })
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct StripeCustomer {
45 pub id: String,
46 #[serde(default)]
47 pub email: Option<String>,
48 #[serde(default)]
49 pub name: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CheckoutSession {
54 pub id: String,
55 pub url: String,
57 #[serde(default)]
58 pub customer: Option<String>,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum CheckoutMode {
63 Subscription,
65 Payment,
67}
68
69impl CheckoutMode {
70 fn as_str(&self) -> &'static str {
71 match self {
72 Self::Subscription => "subscription",
73 Self::Payment => "payment",
74 }
75 }
76}
77
78impl StripeConfig {
79 pub fn create_customer(
84 &self,
85 email: &str,
86 name: Option<&str>,
87 ) -> Result<StripeCustomer, String> {
88 let mut body = format!("email={}", url_encode(email));
89 if let Some(n) = name {
90 body.push_str("&name=");
91 body.push_str(&url_encode(n));
92 }
93 self.post("https://api.stripe.com/v1/customers", &body)
94 }
95
96 pub fn create_checkout(
100 &self,
101 customer_id: Option<&str>,
102 price_ids: &[&str],
103 mode: CheckoutMode,
104 success_url: &str,
105 cancel_url: &str,
106 ) -> Result<CheckoutSession, String> {
107 let mut body = format!(
108 "mode={}&success_url={}&cancel_url={}",
109 mode.as_str(),
110 url_encode(success_url),
111 url_encode(cancel_url),
112 );
113 if let Some(cid) = customer_id {
114 body.push_str("&customer=");
115 body.push_str(&url_encode(cid));
116 }
117 for (i, pid) in price_ids.iter().enumerate() {
118 body.push_str(&format!(
119 "&line_items[{i}][price]={}&line_items[{i}][quantity]=1",
120 url_encode(pid)
121 ));
122 }
123 self.post("https://api.stripe.com/v1/checkout/sessions", &body)
124 }
125
126 fn post<T: for<'de> Deserialize<'de>>(&self, url: &str, body: &str) -> Result<T, String> {
127 let agent = ureq::AgentBuilder::new()
128 .timeout_connect(std::time::Duration::from_secs(10))
129 .timeout_read(std::time::Duration::from_secs(10))
130 .user_agent("pylon-auth/0.1")
131 .build();
132 let resp = agent
133 .post(url)
134 .set("Authorization", &format!("Bearer {}", self.api_key))
135 .set("Content-Type", "application/x-www-form-urlencoded")
136 .send_string(body)
137 .map_err(|e| match e {
138 ureq::Error::Status(code, r) => {
139 let body = r.into_string().unwrap_or_default();
140 format!("stripe HTTP {code}: {body}")
141 }
142 e => format!("stripe network: {e}"),
143 })?;
144 let txt = resp
145 .into_string()
146 .map_err(|e| format!("stripe body: {e}"))?;
147 serde_json::from_str(&txt).map_err(|e| format!("stripe JSON: {e}"))
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum BillingEvent {
160 CheckoutCompleted {
163 customer_id: Option<String>,
164 subscription_id: Option<String>,
165 client_reference_id: Option<String>,
166 },
167 SubscriptionChanged {
172 subscription_id: String,
173 customer_id: String,
174 status: String,
175 current_period_end: u64,
176 },
177 SubscriptionDeleted {
180 subscription_id: String,
181 customer_id: String,
182 },
183 PaymentFailed {
186 customer_id: String,
187 invoice_id: String,
188 },
189 Other {
192 event_type: String,
193 body: serde_json::Value,
194 },
195}
196
197#[derive(Debug, Clone, PartialEq, Eq)]
198pub enum WebhookError {
199 MissingSignature,
201 StaleTimestamp,
204 BadSignature,
207 BadJson,
209}
210
211impl std::fmt::Display for WebhookError {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 f.write_str(match self {
214 Self::MissingSignature => "Stripe-Signature header missing",
215 Self::StaleTimestamp => "webhook timestamp outside ±5min tolerance",
216 Self::BadSignature => "webhook signature mismatch",
217 Self::BadJson => "webhook body not valid JSON",
218 })
219 }
220}
221
222pub fn verify_webhook(
229 secret: &str,
230 body: &[u8],
231 signature_header: &str,
232 now_secs: u64,
233) -> Result<BillingEvent, WebhookError> {
234 let mut t: Option<u64> = None;
235 let mut v1_sigs: Vec<&str> = Vec::new();
236 const MAX_V1_SIGS: usize = 8;
241 for kv in signature_header.split(',') {
242 let kv = kv.trim();
243 if let Some(v) = kv.strip_prefix("t=") {
244 t = v.parse().ok();
245 } else if let Some(v) = kv.strip_prefix("v1=") {
246 if v1_sigs.len() < MAX_V1_SIGS {
247 v1_sigs.push(v);
248 }
249 }
250 }
251 let ts = t.ok_or(WebhookError::MissingSignature)?;
252 if v1_sigs.is_empty() {
253 return Err(WebhookError::MissingSignature);
254 }
255 let diff = if now_secs > ts {
257 now_secs - ts
258 } else {
259 ts - now_secs
260 };
261 if diff > 5 * 60 {
262 return Err(WebhookError::StaleTimestamp);
263 }
264
265 let mut mac =
267 HmacSha256::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
268 mac.update(format!("{ts}.").as_bytes());
269 mac.update(body);
270 let expected = mac.finalize().into_bytes();
271 let expected_hex = bytes_to_hex(&expected);
272
273 let any_match = v1_sigs
274 .iter()
275 .any(|s| crate::constant_time_eq(s.as_bytes(), expected_hex.as_bytes()));
276 if !any_match {
277 return Err(WebhookError::BadSignature);
278 }
279
280 let body_json: serde_json::Value =
281 serde_json::from_slice(body).map_err(|_| WebhookError::BadJson)?;
282 Ok(parse_event(body_json))
283}
284
285fn parse_event(body: serde_json::Value) -> BillingEvent {
286 let event_type = body
287 .get("type")
288 .and_then(|v| v.as_str())
289 .unwrap_or("")
290 .to_string();
291 let object = body.pointer("/data/object").cloned().unwrap_or_default();
292 match event_type.as_str() {
293 "checkout.session.completed" => BillingEvent::CheckoutCompleted {
294 customer_id: object
295 .get("customer")
296 .and_then(|v| v.as_str())
297 .map(String::from),
298 subscription_id: object
299 .get("subscription")
300 .and_then(|v| v.as_str())
301 .map(String::from),
302 client_reference_id: object
303 .get("client_reference_id")
304 .and_then(|v| v.as_str())
305 .map(String::from),
306 },
307 "customer.subscription.updated" | "customer.subscription.created" => {
308 BillingEvent::SubscriptionChanged {
309 subscription_id: object
310 .get("id")
311 .and_then(|v| v.as_str())
312 .unwrap_or("")
313 .to_string(),
314 customer_id: object
315 .get("customer")
316 .and_then(|v| v.as_str())
317 .unwrap_or("")
318 .to_string(),
319 status: object
320 .get("status")
321 .and_then(|v| v.as_str())
322 .unwrap_or("")
323 .to_string(),
324 current_period_end: object
325 .get("current_period_end")
326 .and_then(|v| v.as_u64())
327 .unwrap_or(0),
328 }
329 }
330 "customer.subscription.deleted" => BillingEvent::SubscriptionDeleted {
331 subscription_id: object
332 .get("id")
333 .and_then(|v| v.as_str())
334 .unwrap_or("")
335 .to_string(),
336 customer_id: object
337 .get("customer")
338 .and_then(|v| v.as_str())
339 .unwrap_or("")
340 .to_string(),
341 },
342 "invoice.payment_failed" => BillingEvent::PaymentFailed {
343 customer_id: object
344 .get("customer")
345 .and_then(|v| v.as_str())
346 .unwrap_or("")
347 .to_string(),
348 invoice_id: object
349 .get("id")
350 .and_then(|v| v.as_str())
351 .unwrap_or("")
352 .to_string(),
353 },
354 _ => BillingEvent::Other { event_type, body },
355 }
356}
357
358fn url_encode(s: &str) -> String {
359 let mut out = String::with_capacity(s.len());
360 for b in s.bytes() {
361 match b {
362 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
363 out.push(b as char)
364 }
365 _ => out.push_str(&format!("%{b:02X}")),
366 }
367 }
368 out
369}
370
371fn bytes_to_hex(bytes: &[u8]) -> String {
372 use std::fmt::Write;
373 let mut s = String::with_capacity(bytes.len() * 2);
374 for b in bytes {
375 let _ = write!(s, "{b:02x}");
376 }
377 s
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use sha2::Sha256;
384
385 fn sign(secret: &str, ts: u64, body: &[u8]) -> String {
386 let mut mac =
387 Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
388 mac.update(format!("{ts}.").as_bytes());
389 mac.update(body);
390 bytes_to_hex(&mac.finalize().into_bytes())
391 }
392
393 #[test]
394 fn verify_webhook_round_trip_checkout_completed() {
395 let secret = "whsec_test_secret";
396 let body = br#"{
397 "type": "checkout.session.completed",
398 "data": { "object": {
399 "customer": "cus_xyz",
400 "subscription": "sub_abc",
401 "client_reference_id": "user_123"
402 }}
403 }"#;
404 let ts = 1_700_000_000;
405 let sig = sign(secret, ts, body);
406 let header = format!("t={ts},v1={sig}");
407 let event = verify_webhook(secret, body, &header, ts).unwrap();
408 match event {
409 BillingEvent::CheckoutCompleted {
410 customer_id,
411 subscription_id,
412 client_reference_id,
413 } => {
414 assert_eq!(customer_id.as_deref(), Some("cus_xyz"));
415 assert_eq!(subscription_id.as_deref(), Some("sub_abc"));
416 assert_eq!(client_reference_id.as_deref(), Some("user_123"));
417 }
418 other => panic!("expected CheckoutCompleted, got {other:?}"),
419 }
420 }
421
422 #[test]
423 fn verify_webhook_rejects_bad_signature() {
424 let body = b"{}";
425 let ts = 1_700_000_000;
426 let header = format!("t={ts},v1=deadbeefdeadbeef");
427 assert_eq!(
428 verify_webhook("secret", body, &header, ts),
429 Err(WebhookError::BadSignature)
430 );
431 }
432
433 #[test]
434 fn verify_webhook_rejects_stale_timestamp() {
435 let secret = "s";
436 let body = b"{}";
437 let ts = 1_700_000_000;
438 let sig = sign(secret, ts, body);
439 let header = format!("t={ts},v1={sig}");
440 let now = ts + 6 * 60;
442 assert_eq!(
443 verify_webhook(secret, body, &header, now),
444 Err(WebhookError::StaleTimestamp)
445 );
446 }
447
448 #[test]
449 fn verify_webhook_missing_signature_header() {
450 let body = b"{}";
451 assert_eq!(
452 verify_webhook("s", body, "", 0),
453 Err(WebhookError::MissingSignature)
454 );
455 assert_eq!(
457 verify_webhook("s", body, "t=100", 100),
458 Err(WebhookError::MissingSignature)
459 );
460 }
461
462 #[test]
463 fn parse_subscription_changed() {
464 let body = serde_json::json!({
465 "type": "customer.subscription.updated",
466 "data": { "object": {
467 "id": "sub_xyz",
468 "customer": "cus_abc",
469 "status": "active",
470 "current_period_end": 9_999_999_999u64
471 }}
472 });
473 match parse_event(body) {
474 BillingEvent::SubscriptionChanged {
475 subscription_id,
476 customer_id,
477 status,
478 current_period_end,
479 } => {
480 assert_eq!(subscription_id, "sub_xyz");
481 assert_eq!(customer_id, "cus_abc");
482 assert_eq!(status, "active");
483 assert_eq!(current_period_end, 9_999_999_999);
484 }
485 other => panic!("expected SubscriptionChanged, got {other:?}"),
486 }
487 }
488
489 #[test]
490 fn unknown_event_falls_through_to_other() {
491 let body = serde_json::json!({"type": "some.weird.event", "data": {}});
492 match parse_event(body) {
493 BillingEvent::Other { event_type, .. } => {
494 assert_eq!(event_type, "some.weird.event");
495 }
496 other => panic!("expected Other, got {other:?}"),
497 }
498 }
499
500 #[test]
501 fn webhook_accepts_multiple_v1_sigs() {
502 let secret = "new_secret";
507 let body = br#"{"type":"x"}"#;
508 let ts = 1_700_000_000;
509 let sig_new = sign(secret, ts, body);
510 let header = format!("t={ts},v1=deadbeef,v1={sig_new}");
511 assert!(verify_webhook(secret, body, &header, ts).is_ok());
512 }
513}