bctx-cloud-core 0.1.4

bctx-cloud-core — cloud client and server for Vault sync, dashboard API, billing
Documentation
use crate::server::{AppError, AppState};
use anyhow::anyhow;
use axum::{
    body::Bytes, extract::State, http::HeaderMap, http::StatusCode, response::IntoResponse,
};
use hmac::{Hmac, Mac};
use serde::Deserialize;
use sha2::Sha256;

// ── Stripe price ID → tier mapping ───────────────────────────────────────────
// Set BCTX_STRIPE_BEACON_PRICE, BCTX_STRIPE_STUDIO_PRICE, BCTX_STRIPE_ENTERPRISE_PRICE
// env vars to match your Stripe dashboard price IDs.

fn tier_for_price(price_id: &str) -> Option<&'static str> {
    let beacon = std::env::var("BCTX_STRIPE_BEACON_PRICE").unwrap_or_default();
    let studio = std::env::var("BCTX_STRIPE_STUDIO_PRICE").unwrap_or_default();
    let enterprise = std::env::var("BCTX_STRIPE_ENTERPRISE_PRICE").unwrap_or_default();
    if !beacon.is_empty() && price_id == beacon {
        return Some("beacon");
    }
    if !studio.is_empty() && price_id == studio {
        return Some("studio");
    }
    if !enterprise.is_empty() && price_id == enterprise {
        return Some("enterprise");
    }
    None
}

// ── Stripe event shapes (minimal subset we care about) ────────────────────────

#[derive(Debug, Deserialize)]
struct StripeEvent {
    #[serde(rename = "type")]
    event_type: String,
    data: StripeEventData,
}

#[derive(Debug, Deserialize)]
struct StripeEventData {
    object: serde_json::Value,
}

// ── Webhook handler ───────────────────────────────────────────────────────────

/// POST /billing/webhook — receives Stripe events and updates user tiers.
pub async fn handle(
    State(state): State<AppState>,
    headers: HeaderMap,
    body: Bytes,
) -> Result<impl IntoResponse, AppError> {
    let secret = std::env::var("BCTX_STRIPE_WEBHOOK_SECRET").unwrap_or_default();
    if !secret.is_empty() {
        let sig = headers
            .get("Stripe-Signature")
            .and_then(|v| v.to_str().ok())
            .ok_or_else(|| AppError(anyhow!("missing Stripe-Signature header")))?;
        verify_stripe_signature(&body, sig, &secret)
            .map_err(|e| AppError(anyhow!("signature verification failed: {e}")))?;
    }

    let event: StripeEvent = serde_json::from_slice(&body)
        .map_err(|e| AppError(anyhow!("invalid event payload: {e}")))?;

    match event.event_type.as_str() {
        "checkout.session.completed" => {
            handle_checkout_completed(&state, &event.data.object)?;
        }
        "customer.subscription.updated" => {
            handle_subscription_updated(&state, &event.data.object)?;
        }
        "customer.subscription.deleted" => {
            handle_subscription_deleted(&state, &event.data.object)?;
        }
        "invoice.payment_failed" => {
            // Log but don't downgrade immediately — give a grace period
            tracing::warn!(
                customer = event.data.object["customer"].as_str().unwrap_or("?"),
                "Stripe payment failed"
            );
        }
        other => {
            tracing::debug!(event_type = other, "unhandled Stripe event");
        }
    }

    Ok((StatusCode::OK, "ok"))
}

fn handle_checkout_completed(state: &AppState, obj: &serde_json::Value) -> anyhow::Result<()> {
    let email = obj["customer_details"]["email"].as_str().unwrap_or("");
    let price_id = obj["metadata"]["price_id"]
        .as_str()
        .or_else(|| obj["line_items"]["data"][0]["price"]["id"].as_str())
        .unwrap_or("");

    let tier = tier_for_price(price_id).unwrap_or("beacon");
    upgrade_user_by_email(state, email, tier)
}

fn handle_subscription_updated(state: &AppState, obj: &serde_json::Value) -> anyhow::Result<()> {
    let price_id = obj["items"]["data"][0]["price"]["id"]
        .as_str()
        .unwrap_or("");
    let customer_email = obj["metadata"]["email"]
        .as_str()
        .or_else(|| obj["customer_email"].as_str())
        .unwrap_or("");
    let status = obj["status"].as_str().unwrap_or("active");

    if status == "active" || status == "trialing" {
        let tier = tier_for_price(price_id).unwrap_or("beacon");
        upgrade_user_by_email(state, customer_email, tier)?;
    } else if status == "canceled" || status == "unpaid" {
        upgrade_user_by_email(state, customer_email, "free")?;
    }
    Ok(())
}

fn handle_subscription_deleted(state: &AppState, obj: &serde_json::Value) -> anyhow::Result<()> {
    let customer_email = obj["metadata"]["email"]
        .as_str()
        .or_else(|| obj["customer_email"].as_str())
        .unwrap_or("");
    upgrade_user_by_email(state, customer_email, "free")
}

fn upgrade_user_by_email(state: &AppState, email: &str, tier: &str) -> anyhow::Result<()> {
    if email.is_empty() {
        return Ok(());
    }
    let conn = state.db.conn();
    let uid: Option<String> = conn
        .query_row(
            "SELECT id FROM users WHERE email=?1",
            rusqlite::params![email],
            |row| row.get(0),
        )
        .ok();
    drop(conn);
    if let Some(uid) = uid {
        state.db.upsert_user(&uid, email, tier)?;
        tracing::info!(email, tier, "user tier updated via Stripe webhook");
    }
    Ok(())
}

// ── Stripe signature verification (HMAC-SHA256, Stripe spec) ─────────────────
// https://docs.stripe.com/webhooks#verify-official-libraries
// Signed payload = "<timestamp>.<raw_body>"
// Expected = HMAC-SHA256(key=webhook_secret, data=signed_payload) as lowercase hex

fn verify_stripe_signature(payload: &[u8], sig_header: &str, secret: &str) -> anyhow::Result<()> {
    let mut timestamp = "";
    let mut provided_sig = "";
    for part in sig_header.split(',') {
        if let Some(v) = part.strip_prefix("t=") {
            timestamp = v;
        }
        if let Some(v) = part.strip_prefix("v1=") {
            if provided_sig.is_empty() {
                provided_sig = v;
            }
        }
    }
    if timestamp.is_empty() || provided_sig.is_empty() {
        anyhow::bail!("malformed Stripe-Signature header");
    }

    let signed_payload = format!("{timestamp}.{}", std::str::from_utf8(payload).unwrap_or(""));

    type HmacSha256 = Hmac<Sha256>;
    let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
        .map_err(|e| anyhow!("invalid HMAC key: {e}"))?;
    mac.update(signed_payload.as_bytes());
    let result = mac.finalize().into_bytes();

    let expected_hex: String = result.iter().map(|b| format!("{b:02x}")).collect();

    if !constant_time_eq(expected_hex.as_bytes(), provided_sig.as_bytes()) {
        anyhow::bail!("signature mismatch");
    }
    Ok(())
}

fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }
    a.iter()
        .zip(b.iter())
        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
        == 0
}