use std::time::{Duration, SystemTime, UNIX_EPOCH};
use base64::Engine;
use base64::engine::general_purpose::STANDARD as B64_STANDARD;
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use crate::error::Error;
use crate::helpers::webhook_verification::{
compute_webhook_signature, parse_webhook_signature_header,
};
use crate::models::EventSchema;
const DEFAULT_TOLERANCE: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ActionType {
#[serde(rename = "authentication")]
Authentication,
#[serde(rename = "user_registration")]
UserRegistration,
}
impl ActionType {
fn as_str(self) -> &'static str {
match self {
ActionType::Authentication => "authentication",
ActionType::UserRegistration => "user_registration",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ActionVerdict {
Allow,
Deny,
}
impl ActionVerdict {
fn as_str(self) -> &'static str {
match self {
ActionVerdict::Allow => "Allow",
ActionVerdict::Deny => "Deny",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionSignedResponse {
pub payload: String,
pub sig: String,
}
pub struct ActionsHelper {
tolerance: Duration,
now: Box<dyn Fn() -> SystemTime + Send + Sync>,
}
impl Default for ActionsHelper {
fn default() -> Self {
Self::new()
}
}
impl ActionsHelper {
pub fn new() -> Self {
Self {
tolerance: DEFAULT_TOLERANCE,
now: Box::new(SystemTime::now),
}
}
pub fn with_tolerance(mut self, tolerance: Duration) -> Self {
self.tolerance = tolerance;
self
}
pub fn with_clock(mut self, now: impl Fn() -> SystemTime + Send + Sync + 'static) -> Self {
self.now = Box::new(now);
self
}
pub fn verify_header(
&self,
payload: &str,
sig_header: &str,
secret: &str,
) -> Result<(), Error> {
if sig_header.is_empty() {
return Err(Error::Webhook("webhook not signed".to_string()));
}
let (timestamp, signature) = parse_webhook_signature_header(sig_header)?;
let ts: u64 = timestamp
.parse()
.map_err(|_| Error::Webhook("invalid timestamp in signature header".to_string()))?;
let signed_at = UNIX_EPOCH
.checked_add(Duration::from_millis(ts))
.ok_or_else(|| Error::Webhook("invalid timestamp in signature header".to_string()))?;
let now = (self.now)();
let diff = match now.duration_since(signed_at) {
Ok(d) => d,
Err(e) => e.duration(),
};
if diff > self.tolerance {
return Err(Error::Webhook("timestamp outside tolerance".to_string()));
}
let expected = compute_webhook_signature(secret, ×tamp, payload);
if expected.as_bytes().ct_eq(signature.as_bytes()).unwrap_u8() != 1 {
return Err(Error::Webhook("no valid signature found".to_string()));
}
Ok(())
}
pub fn construct_action(
&self,
payload: &str,
sig_header: &str,
secret: &str,
) -> Result<EventSchema, Error> {
self.verify_header(payload, sig_header, secret)?;
serde_json::from_str(payload).map_err(Error::from)
}
pub fn sign_response(
&self,
action_type: ActionType,
verdict: ActionVerdict,
error_message: &str,
secret: &str,
) -> Result<ActionSignedResponse, Error> {
let body = serde_json::json!({
"type": action_type.as_str(),
"verdict": verdict.as_str(),
"error_message": error_message,
});
let json_bytes = serde_json::to_vec(&body).map_err(Error::from)?;
let b64_payload = B64_STANDARD.encode(&json_bytes);
let now = (self.now)();
let ts_ms = now
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::Crypto(format!("clock before epoch: {e}")))?
.as_millis() as i64;
let timestamp = ts_ms.to_string();
let sig = compute_webhook_signature(secret, ×tamp, &b64_payload);
Ok(ActionSignedResponse {
payload: b64_payload,
sig: format!("t={timestamp},v1={sig}"),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixed_now(ms: u64) -> SystemTime {
UNIX_EPOCH + Duration::from_millis(ms)
}
#[test]
fn verify_header_round_trip() {
let secret = "shh";
let payload = "hello";
let ts = "1700000000000";
let sig = compute_webhook_signature(secret, ts, payload);
let header = format!("t={ts},v1={sig}");
let helper = ActionsHelper::new().with_clock(|| fixed_now(1_700_000_000_000));
helper.verify_header(payload, &header, secret).unwrap();
}
#[test]
fn verify_header_rejects_outside_tolerance() {
let secret = "shh";
let payload = "hello";
let signed = 1_700_000_000_000u64;
let sig = compute_webhook_signature(secret, &signed.to_string(), payload);
let header = format!("t={signed},v1={sig}");
let helper = ActionsHelper::new().with_clock(move || fixed_now(signed + 60_000));
assert!(helper.verify_header(payload, &header, secret).is_err());
}
#[test]
fn sign_response_round_trips() {
let secret = "shh";
let helper = ActionsHelper::new().with_clock(|| fixed_now(1_700_000_000_000));
let signed = helper
.sign_response(ActionType::Authentication, ActionVerdict::Allow, "", secret)
.unwrap();
let bytes = B64_STANDARD.decode(&signed.payload).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(parsed["type"], "authentication");
assert_eq!(parsed["verdict"], "Allow");
let (ts, sig) = parse_webhook_signature_header(&signed.sig).unwrap();
let expected = compute_webhook_signature(secret, &ts, &signed.payload);
assert_eq!(sig, expected);
}
}