1use hmac::{Hmac, Mac};
17use serde::{Deserialize, Serialize};
18use sha2::Sha256;
19
20type HmacSha256 = Hmac<Sha256>;
21
22#[derive(Debug, Clone)]
23pub struct StripeConfig {
24 pub api_key: String,
26 pub webhook_secret: Option<String>,
30}
31
32impl StripeConfig {
33 pub fn from_env() -> Option<Self> {
34 let api_key = std::env::var("PYLON_STRIPE_API_KEY").ok()?;
35 let webhook_secret = std::env::var("PYLON_STRIPE_WEBHOOK_SECRET").ok();
36 Some(Self {
37 api_key,
38 webhook_secret,
39 })
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct StripeCustomer {
45 pub id: String,
46 #[serde(default)]
47 pub email: Option<String>,
48 #[serde(default)]
49 pub name: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CheckoutSession {
54 pub id: String,
55 pub url: String,
57 #[serde(default)]
58 pub customer: Option<String>,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum CheckoutMode {
63 Subscription,
65 Payment,
67}
68
69impl CheckoutMode {
70 fn as_str(&self) -> &'static str {
71 match self {
72 Self::Subscription => "subscription",
73 Self::Payment => "payment",
74 }
75 }
76}
77
78impl StripeConfig {
79 pub fn create_customer(&self, email: &str, name: Option<&str>) -> Result<StripeCustomer, String> {
84 let mut body = format!("email={}", url_encode(email));
85 if let Some(n) = name {
86 body.push_str("&name=");
87 body.push_str(&url_encode(n));
88 }
89 self.post("https://api.stripe.com/v1/customers", &body)
90 }
91
92 pub fn create_checkout(
96 &self,
97 customer_id: Option<&str>,
98 price_ids: &[&str],
99 mode: CheckoutMode,
100 success_url: &str,
101 cancel_url: &str,
102 ) -> Result<CheckoutSession, String> {
103 let mut body = format!(
104 "mode={}&success_url={}&cancel_url={}",
105 mode.as_str(),
106 url_encode(success_url),
107 url_encode(cancel_url),
108 );
109 if let Some(cid) = customer_id {
110 body.push_str("&customer=");
111 body.push_str(&url_encode(cid));
112 }
113 for (i, pid) in price_ids.iter().enumerate() {
114 body.push_str(&format!(
115 "&line_items[{i}][price]={}&line_items[{i}][quantity]=1",
116 url_encode(pid)
117 ));
118 }
119 self.post("https://api.stripe.com/v1/checkout/sessions", &body)
120 }
121
122 fn post<T: for<'de> Deserialize<'de>>(&self, url: &str, body: &str) -> Result<T, String> {
123 let agent = ureq::AgentBuilder::new()
124 .timeout_connect(std::time::Duration::from_secs(10))
125 .timeout_read(std::time::Duration::from_secs(10))
126 .user_agent("pylon-auth/0.1")
127 .build();
128 let resp = agent
129 .post(url)
130 .set("Authorization", &format!("Bearer {}", self.api_key))
131 .set("Content-Type", "application/x-www-form-urlencoded")
132 .send_string(body)
133 .map_err(|e| match e {
134 ureq::Error::Status(code, r) => {
135 let body = r.into_string().unwrap_or_default();
136 format!("stripe HTTP {code}: {body}")
137 }
138 e => format!("stripe network: {e}"),
139 })?;
140 let txt = resp
141 .into_string()
142 .map_err(|e| format!("stripe body: {e}"))?;
143 serde_json::from_str(&txt).map_err(|e| format!("stripe JSON: {e}"))
144 }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum BillingEvent {
156 CheckoutCompleted {
159 customer_id: Option<String>,
160 subscription_id: Option<String>,
161 client_reference_id: Option<String>,
162 },
163 SubscriptionChanged {
168 subscription_id: String,
169 customer_id: String,
170 status: String,
171 current_period_end: u64,
172 },
173 SubscriptionDeleted {
176 subscription_id: String,
177 customer_id: String,
178 },
179 PaymentFailed {
182 customer_id: String,
183 invoice_id: String,
184 },
185 Other {
188 event_type: String,
189 body: serde_json::Value,
190 },
191}
192
193#[derive(Debug, Clone, PartialEq, Eq)]
194pub enum WebhookError {
195 MissingSignature,
197 StaleTimestamp,
200 BadSignature,
203 BadJson,
205}
206
207impl std::fmt::Display for WebhookError {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 f.write_str(match self {
210 Self::MissingSignature => "Stripe-Signature header missing",
211 Self::StaleTimestamp => "webhook timestamp outside ±5min tolerance",
212 Self::BadSignature => "webhook signature mismatch",
213 Self::BadJson => "webhook body not valid JSON",
214 })
215 }
216}
217
218pub fn verify_webhook(
225 secret: &str,
226 body: &[u8],
227 signature_header: &str,
228 now_secs: u64,
229) -> Result<BillingEvent, WebhookError> {
230 let mut t: Option<u64> = None;
231 let mut v1_sigs: Vec<&str> = Vec::new();
232 const MAX_V1_SIGS: usize = 8;
237 for kv in signature_header.split(',') {
238 let kv = kv.trim();
239 if let Some(v) = kv.strip_prefix("t=") {
240 t = v.parse().ok();
241 } else if let Some(v) = kv.strip_prefix("v1=") {
242 if v1_sigs.len() < MAX_V1_SIGS {
243 v1_sigs.push(v);
244 }
245 }
246 }
247 let ts = t.ok_or(WebhookError::MissingSignature)?;
248 if v1_sigs.is_empty() {
249 return Err(WebhookError::MissingSignature);
250 }
251 let diff = if now_secs > ts { now_secs - ts } else { ts - now_secs };
253 if diff > 5 * 60 {
254 return Err(WebhookError::StaleTimestamp);
255 }
256
257 let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
259 .expect("HMAC accepts any key length");
260 mac.update(format!("{ts}.").as_bytes());
261 mac.update(body);
262 let expected = mac.finalize().into_bytes();
263 let expected_hex = bytes_to_hex(&expected);
264
265 let any_match = v1_sigs
266 .iter()
267 .any(|s| crate::constant_time_eq(s.as_bytes(), expected_hex.as_bytes()));
268 if !any_match {
269 return Err(WebhookError::BadSignature);
270 }
271
272 let body_json: serde_json::Value =
273 serde_json::from_slice(body).map_err(|_| WebhookError::BadJson)?;
274 Ok(parse_event(body_json))
275}
276
277fn parse_event(body: serde_json::Value) -> BillingEvent {
278 let event_type = body
279 .get("type")
280 .and_then(|v| v.as_str())
281 .unwrap_or("")
282 .to_string();
283 let object = body.pointer("/data/object").cloned().unwrap_or_default();
284 match event_type.as_str() {
285 "checkout.session.completed" => BillingEvent::CheckoutCompleted {
286 customer_id: object
287 .get("customer")
288 .and_then(|v| v.as_str())
289 .map(String::from),
290 subscription_id: object
291 .get("subscription")
292 .and_then(|v| v.as_str())
293 .map(String::from),
294 client_reference_id: object
295 .get("client_reference_id")
296 .and_then(|v| v.as_str())
297 .map(String::from),
298 },
299 "customer.subscription.updated" | "customer.subscription.created" => {
300 BillingEvent::SubscriptionChanged {
301 subscription_id: object
302 .get("id")
303 .and_then(|v| v.as_str())
304 .unwrap_or("")
305 .to_string(),
306 customer_id: object
307 .get("customer")
308 .and_then(|v| v.as_str())
309 .unwrap_or("")
310 .to_string(),
311 status: object
312 .get("status")
313 .and_then(|v| v.as_str())
314 .unwrap_or("")
315 .to_string(),
316 current_period_end: object
317 .get("current_period_end")
318 .and_then(|v| v.as_u64())
319 .unwrap_or(0),
320 }
321 }
322 "customer.subscription.deleted" => BillingEvent::SubscriptionDeleted {
323 subscription_id: object
324 .get("id")
325 .and_then(|v| v.as_str())
326 .unwrap_or("")
327 .to_string(),
328 customer_id: object
329 .get("customer")
330 .and_then(|v| v.as_str())
331 .unwrap_or("")
332 .to_string(),
333 },
334 "invoice.payment_failed" => BillingEvent::PaymentFailed {
335 customer_id: object
336 .get("customer")
337 .and_then(|v| v.as_str())
338 .unwrap_or("")
339 .to_string(),
340 invoice_id: object
341 .get("id")
342 .and_then(|v| v.as_str())
343 .unwrap_or("")
344 .to_string(),
345 },
346 _ => BillingEvent::Other {
347 event_type,
348 body,
349 },
350 }
351}
352
353fn url_encode(s: &str) -> String {
354 let mut out = String::with_capacity(s.len());
355 for b in s.bytes() {
356 match b {
357 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
358 out.push(b as char)
359 }
360 _ => out.push_str(&format!("%{b:02X}")),
361 }
362 }
363 out
364}
365
366fn bytes_to_hex(bytes: &[u8]) -> String {
367 use std::fmt::Write;
368 let mut s = String::with_capacity(bytes.len() * 2);
369 for b in bytes {
370 let _ = write!(s, "{b:02x}");
371 }
372 s
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use sha2::Sha256;
379
380 fn sign(secret: &str, ts: u64, body: &[u8]) -> String {
381 let mut mac =
382 Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
383 mac.update(format!("{ts}.").as_bytes());
384 mac.update(body);
385 bytes_to_hex(&mac.finalize().into_bytes())
386 }
387
388 #[test]
389 fn verify_webhook_round_trip_checkout_completed() {
390 let secret = "whsec_test_secret";
391 let body = br#"{
392 "type": "checkout.session.completed",
393 "data": { "object": {
394 "customer": "cus_xyz",
395 "subscription": "sub_abc",
396 "client_reference_id": "user_123"
397 }}
398 }"#;
399 let ts = 1_700_000_000;
400 let sig = sign(secret, ts, body);
401 let header = format!("t={ts},v1={sig}");
402 let event = verify_webhook(secret, body, &header, ts).unwrap();
403 match event {
404 BillingEvent::CheckoutCompleted {
405 customer_id,
406 subscription_id,
407 client_reference_id,
408 } => {
409 assert_eq!(customer_id.as_deref(), Some("cus_xyz"));
410 assert_eq!(subscription_id.as_deref(), Some("sub_abc"));
411 assert_eq!(client_reference_id.as_deref(), Some("user_123"));
412 }
413 other => panic!("expected CheckoutCompleted, got {other:?}"),
414 }
415 }
416
417 #[test]
418 fn verify_webhook_rejects_bad_signature() {
419 let body = b"{}";
420 let ts = 1_700_000_000;
421 let header = format!("t={ts},v1=deadbeefdeadbeef");
422 assert_eq!(
423 verify_webhook("secret", body, &header, ts),
424 Err(WebhookError::BadSignature)
425 );
426 }
427
428 #[test]
429 fn verify_webhook_rejects_stale_timestamp() {
430 let secret = "s";
431 let body = b"{}";
432 let ts = 1_700_000_000;
433 let sig = sign(secret, ts, body);
434 let header = format!("t={ts},v1={sig}");
435 let now = ts + 6 * 60;
437 assert_eq!(
438 verify_webhook(secret, body, &header, now),
439 Err(WebhookError::StaleTimestamp)
440 );
441 }
442
443 #[test]
444 fn verify_webhook_missing_signature_header() {
445 let body = b"{}";
446 assert_eq!(
447 verify_webhook("s", body, "", 0),
448 Err(WebhookError::MissingSignature)
449 );
450 assert_eq!(
452 verify_webhook("s", body, "t=100", 100),
453 Err(WebhookError::MissingSignature)
454 );
455 }
456
457 #[test]
458 fn parse_subscription_changed() {
459 let body = serde_json::json!({
460 "type": "customer.subscription.updated",
461 "data": { "object": {
462 "id": "sub_xyz",
463 "customer": "cus_abc",
464 "status": "active",
465 "current_period_end": 9_999_999_999u64
466 }}
467 });
468 match parse_event(body) {
469 BillingEvent::SubscriptionChanged {
470 subscription_id,
471 customer_id,
472 status,
473 current_period_end,
474 } => {
475 assert_eq!(subscription_id, "sub_xyz");
476 assert_eq!(customer_id, "cus_abc");
477 assert_eq!(status, "active");
478 assert_eq!(current_period_end, 9_999_999_999);
479 }
480 other => panic!("expected SubscriptionChanged, got {other:?}"),
481 }
482 }
483
484 #[test]
485 fn unknown_event_falls_through_to_other() {
486 let body = serde_json::json!({"type": "some.weird.event", "data": {}});
487 match parse_event(body) {
488 BillingEvent::Other { event_type, .. } => {
489 assert_eq!(event_type, "some.weird.event");
490 }
491 other => panic!("expected Other, got {other:?}"),
492 }
493 }
494
495 #[test]
496 fn webhook_accepts_multiple_v1_sigs() {
497 let secret = "new_secret";
502 let body = br#"{"type":"x"}"#;
503 let ts = 1_700_000_000;
504 let sig_new = sign(secret, ts, body);
505 let header = format!("t={ts},v1=deadbeef,v1={sig_new}");
506 assert!(verify_webhook(secret, body, &header, ts).is_ok());
507 }
508}