use crate::core::types::{EventType, InfraEvent};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct WebhookConfig {
pub port: u16,
pub secret: Option<String>,
pub max_body_bytes: usize,
pub allowed_paths: Vec<String>,
}
impl Default for WebhookConfig {
fn default() -> Self {
Self {
port: 8484,
secret: None,
max_body_bytes: 1024 * 64, allowed_paths: vec!["/webhook".to_string()],
}
}
}
#[derive(Debug, Clone)]
pub struct WebhookRequest {
pub method: String,
pub path: String,
pub headers: HashMap<String, String>,
pub body: String,
pub source_ip: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationResult {
Valid,
BodyTooLarge { size: usize, max: usize },
PathNotAllowed { path: String },
SignatureMissing,
SignatureInvalid,
MethodNotAllowed { method: String },
}
impl ValidationResult {
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
}
pub fn validate_request(config: &WebhookConfig, request: &WebhookRequest) -> ValidationResult {
if request.method.to_uppercase() != "POST" {
return ValidationResult::MethodNotAllowed {
method: request.method.clone(),
};
}
if request.body.len() > config.max_body_bytes {
return ValidationResult::BodyTooLarge {
size: request.body.len(),
max: config.max_body_bytes,
};
}
if !config.allowed_paths.is_empty() && !config.allowed_paths.iter().any(|p| p == &request.path)
{
return ValidationResult::PathNotAllowed {
path: request.path.clone(),
};
}
if let Some(ref secret) = config.secret {
match request.headers.get("x-forjar-signature") {
None => return ValidationResult::SignatureMissing,
Some(sig) => {
let expected = compute_hmac_hex(secret, &request.body);
if sig != &expected {
return ValidationResult::SignatureInvalid;
}
}
}
}
ValidationResult::Valid
}
pub fn parse_json_payload(body: &str) -> Result<HashMap<String, String>, String> {
let value: serde_json::Value =
serde_json::from_str(body).map_err(|e| format!("invalid JSON: {e}"))?;
let mut payload = HashMap::new();
match value {
serde_json::Value::Object(map) => {
for (key, val) in map {
let str_val = match val {
serde_json::Value::String(s) => s,
other => other.to_string(),
};
payload.insert(key, str_val);
}
}
_ => return Err("webhook body must be a JSON object".to_string()),
}
Ok(payload)
}
pub fn request_to_event(request: &WebhookRequest) -> Result<InfraEvent, String> {
let mut payload = parse_json_payload(&request.body)?;
payload.insert("_path".to_string(), request.path.clone());
if let Some(ref ip) = request.source_ip {
payload.insert("_source_ip".to_string(), ip.clone());
}
Ok(InfraEvent {
event_type: EventType::WebhookReceived,
timestamp: now_iso8601(),
machine: None,
payload,
})
}
pub fn compute_hmac_hex(key: &str, data: &str) -> String {
let key_bytes = blake3::hash(key.as_bytes());
let mut hasher = blake3::Hasher::new_keyed(key_bytes.as_bytes());
hasher.update(data.as_bytes());
hasher.finalize().to_hex().to_string()
}
pub fn ack_response(status: u16, message: &str) -> String {
let body = format!(r#"{{"status":"{message}"}}"#);
let reason = status_reason(status);
let len = body.len();
format!(
"HTTP/1.1 {status} {reason}\r\n\
Content-Type: application/json\r\n\
Content-Length: {len}\r\n\
\r\n\
{body}",
)
}
fn status_reason(code: u16) -> &'static str {
match code {
200 => "OK",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
405 => "Method Not Allowed",
413 => "Payload Too Large",
500 => "Internal Server Error",
_ => "Unknown",
}
}
fn now_iso8601() -> String {
let dur = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
format!("{}Z", dur.as_secs())
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> WebhookConfig {
WebhookConfig::default()
}
fn post_request(path: &str, body: &str) -> WebhookRequest {
WebhookRequest {
method: "POST".into(),
path: path.into(),
headers: HashMap::new(),
body: body.into(),
source_ip: Some("127.0.0.1".into()),
}
}
#[test]
fn validate_valid_post() {
let config = default_config();
let req = post_request("/webhook", r#"{"action":"deploy"}"#);
assert!(validate_request(&config, &req).is_valid());
}
#[test]
fn validate_method_not_allowed() {
let config = default_config();
let req = WebhookRequest {
method: "GET".into(),
path: "/webhook".into(),
headers: HashMap::new(),
body: String::new(),
source_ip: None,
};
assert_eq!(
validate_request(&config, &req),
ValidationResult::MethodNotAllowed {
method: "GET".into()
}
);
}
#[test]
fn validate_body_too_large() {
let config = WebhookConfig {
max_body_bytes: 10,
..default_config()
};
let req = post_request("/webhook", "a]long body that exceeds limit");
match validate_request(&config, &req) {
ValidationResult::BodyTooLarge { size, max } => {
assert!(size > max);
}
other => panic!("expected BodyTooLarge, got {other:?}"),
}
}
#[test]
fn validate_path_not_allowed() {
let config = default_config();
let req = post_request("/admin/hack", r#"{}"#);
assert_eq!(
validate_request(&config, &req),
ValidationResult::PathNotAllowed {
path: "/admin/hack".into()
}
);
}
#[test]
fn validate_signature_missing() {
let config = WebhookConfig {
secret: Some("mysecret".into()),
..default_config()
};
let req = post_request("/webhook", r#"{}"#);
assert_eq!(
validate_request(&config, &req),
ValidationResult::SignatureMissing
);
}
#[test]
fn validate_signature_valid() {
let secret = "test-secret";
let body = r#"{"deploy":true}"#;
let sig = compute_hmac_hex(secret, body);
let config = WebhookConfig {
secret: Some(secret.into()),
..default_config()
};
let mut req = post_request("/webhook", body);
req.headers.insert("x-forjar-signature".into(), sig);
assert!(validate_request(&config, &req).is_valid());
}
#[test]
fn validate_signature_invalid() {
let config = WebhookConfig {
secret: Some("real-secret".into()),
..default_config()
};
let mut req = post_request("/webhook", r#"{}"#);
req.headers
.insert("x-forjar-signature".into(), "bad-sig".into());
assert_eq!(
validate_request(&config, &req),
ValidationResult::SignatureInvalid
);
}
#[test]
fn parse_json_object() {
let payload = parse_json_payload(r#"{"action":"deploy","env":"prod"}"#).unwrap();
assert_eq!(payload.get("action").unwrap(), "deploy");
assert_eq!(payload.get("env").unwrap(), "prod");
}
#[test]
fn parse_json_nested() {
let payload = parse_json_payload(r#"{"count":42,"nested":{"a":1}}"#).unwrap();
assert_eq!(payload.get("count").unwrap(), "42");
assert!(payload.get("nested").unwrap().contains("\"a\":1"));
}
#[test]
fn parse_json_invalid() {
assert!(parse_json_payload("not json").is_err());
}
#[test]
fn parse_json_non_object() {
assert!(parse_json_payload("[1,2,3]").is_err());
}
#[test]
fn request_to_event_valid() {
let req = post_request("/webhook", r#"{"action":"restart"}"#);
let event = request_to_event(&req).unwrap();
assert_eq!(event.event_type, EventType::WebhookReceived);
assert_eq!(event.payload.get("action").unwrap(), "restart");
assert_eq!(event.payload.get("_path").unwrap(), "/webhook");
assert_eq!(event.payload.get("_source_ip").unwrap(), "127.0.0.1");
}
#[test]
fn request_to_event_invalid_body() {
let req = post_request("/webhook", "not json");
assert!(request_to_event(&req).is_err());
}
#[test]
fn hmac_deterministic() {
let h1 = compute_hmac_hex("key", "data");
let h2 = compute_hmac_hex("key", "data");
assert_eq!(h1, h2);
assert_eq!(h1.len(), 64);
}
#[test]
fn hmac_different_keys() {
let h1 = compute_hmac_hex("key1", "data");
let h2 = compute_hmac_hex("key2", "data");
assert_ne!(h1, h2);
}
#[test]
fn ack_response_format() {
let resp = ack_response(200, "accepted");
assert!(resp.starts_with("HTTP/1.1 200 OK"));
assert!(resp.contains("application/json"));
assert!(resp.contains("accepted"));
}
#[test]
fn ack_response_error() {
let resp = ack_response(400, "bad request");
assert!(resp.contains("400 Bad Request"));
}
#[test]
fn validation_result_is_valid() {
assert!(ValidationResult::Valid.is_valid());
assert!(!ValidationResult::SignatureMissing.is_valid());
assert!(!ValidationResult::SignatureInvalid.is_valid());
}
#[test]
fn default_webhook_config() {
let config = WebhookConfig::default();
assert_eq!(config.port, 8484);
assert!(config.secret.is_none());
assert_eq!(config.max_body_bytes, 64 * 1024);
assert_eq!(config.allowed_paths, vec!["/webhook"]);
}
}