use crate::server::{AppError, AppState};
use anyhow::anyhow;
use axum::{
body::Bytes, extract::State, http::HeaderMap, http::StatusCode, response::IntoResponse,
};
use hmac::{Hmac, Mac};
use serde::Deserialize;
use sha2::Sha256;
fn tier_for_price(price_id: &str) -> Option<&'static str> {
let beacon = std::env::var("BCTX_STRIPE_BEACON_PRICE").unwrap_or_default();
let studio = std::env::var("BCTX_STRIPE_STUDIO_PRICE").unwrap_or_default();
let enterprise = std::env::var("BCTX_STRIPE_ENTERPRISE_PRICE").unwrap_or_default();
if !beacon.is_empty() && price_id == beacon {
return Some("beacon");
}
if !studio.is_empty() && price_id == studio {
return Some("studio");
}
if !enterprise.is_empty() && price_id == enterprise {
return Some("enterprise");
}
None
}
#[derive(Debug, Deserialize)]
struct StripeEvent {
#[serde(rename = "type")]
event_type: String,
data: StripeEventData,
}
#[derive(Debug, Deserialize)]
struct StripeEventData {
object: serde_json::Value,
}
pub async fn handle(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> Result<impl IntoResponse, AppError> {
let secret = std::env::var("BCTX_STRIPE_WEBHOOK_SECRET").unwrap_or_default();
if !secret.is_empty() {
let sig = headers
.get("Stripe-Signature")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| AppError(anyhow!("missing Stripe-Signature header")))?;
verify_stripe_signature(&body, sig, &secret)
.map_err(|e| AppError(anyhow!("signature verification failed: {e}")))?;
}
let event: StripeEvent = serde_json::from_slice(&body)
.map_err(|e| AppError(anyhow!("invalid event payload: {e}")))?;
match event.event_type.as_str() {
"checkout.session.completed" => {
handle_checkout_completed(&state, &event.data.object)?;
}
"customer.subscription.updated" => {
handle_subscription_updated(&state, &event.data.object)?;
}
"customer.subscription.deleted" => {
handle_subscription_deleted(&state, &event.data.object)?;
}
"invoice.payment_failed" => {
tracing::warn!(
customer = event.data.object["customer"].as_str().unwrap_or("?"),
"Stripe payment failed"
);
}
other => {
tracing::debug!(event_type = other, "unhandled Stripe event");
}
}
Ok((StatusCode::OK, "ok"))
}
fn handle_checkout_completed(state: &AppState, obj: &serde_json::Value) -> anyhow::Result<()> {
let email = obj["customer_details"]["email"].as_str().unwrap_or("");
let price_id = obj["metadata"]["price_id"]
.as_str()
.or_else(|| obj["line_items"]["data"][0]["price"]["id"].as_str())
.unwrap_or("");
let tier = tier_for_price(price_id).unwrap_or("beacon");
upgrade_user_by_email(state, email, tier)
}
fn handle_subscription_updated(state: &AppState, obj: &serde_json::Value) -> anyhow::Result<()> {
let price_id = obj["items"]["data"][0]["price"]["id"]
.as_str()
.unwrap_or("");
let customer_email = obj["metadata"]["email"]
.as_str()
.or_else(|| obj["customer_email"].as_str())
.unwrap_or("");
let status = obj["status"].as_str().unwrap_or("active");
if status == "active" || status == "trialing" {
let tier = tier_for_price(price_id).unwrap_or("beacon");
upgrade_user_by_email(state, customer_email, tier)?;
} else if status == "canceled" || status == "unpaid" {
upgrade_user_by_email(state, customer_email, "free")?;
}
Ok(())
}
fn handle_subscription_deleted(state: &AppState, obj: &serde_json::Value) -> anyhow::Result<()> {
let customer_email = obj["metadata"]["email"]
.as_str()
.or_else(|| obj["customer_email"].as_str())
.unwrap_or("");
upgrade_user_by_email(state, customer_email, "free")
}
fn upgrade_user_by_email(state: &AppState, email: &str, tier: &str) -> anyhow::Result<()> {
if email.is_empty() {
return Ok(());
}
let conn = state.db.conn();
let uid: Option<String> = conn
.query_row(
"SELECT id FROM users WHERE email=?1",
rusqlite::params![email],
|row| row.get(0),
)
.ok();
drop(conn);
if let Some(uid) = uid {
state.db.upsert_user(&uid, email, tier)?;
tracing::info!(email, tier, "user tier updated via Stripe webhook");
}
Ok(())
}
fn verify_stripe_signature(payload: &[u8], sig_header: &str, secret: &str) -> anyhow::Result<()> {
let mut timestamp = "";
let mut provided_sig = "";
for part in sig_header.split(',') {
if let Some(v) = part.strip_prefix("t=") {
timestamp = v;
}
if let Some(v) = part.strip_prefix("v1=") {
if provided_sig.is_empty() {
provided_sig = v;
}
}
}
if timestamp.is_empty() || provided_sig.is_empty() {
anyhow::bail!("malformed Stripe-Signature header");
}
let signed_payload = format!("{timestamp}.{}", std::str::from_utf8(payload).unwrap_or(""));
type HmacSha256 = Hmac<Sha256>;
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
.map_err(|e| anyhow!("invalid HMAC key: {e}"))?;
mac.update(signed_payload.as_bytes());
let result = mac.finalize().into_bytes();
let expected_hex: String = result.iter().map(|b| format!("{b:02x}")).collect();
if !constant_time_eq(expected_hex.as_bytes(), provided_sig.as_bytes()) {
anyhow::bail!("signature mismatch");
}
Ok(())
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y))
== 0
}