use crate::compare::ash_timing_safe_equal;
use crate::errors::{AshError, AshErrorCode, InternalReason};
use crate::headers::{HeaderMapView, HDR_BODY_HASH, HDR_PROOF, HDR_TIMESTAMP};
use crate::proof::{ash_build_proof, ash_derive_client_secret, ash_validate_timestamp_format};
use crate::validate::ash_validate_nonce;
pub struct VerifyRequestInput<'a, H: HeaderMapView> {
pub headers: &'a H,
pub method: &'a str,
pub path: &'a str,
pub raw_query: &'a str,
pub canonical_body: &'a str,
pub nonce: &'a str,
pub context_id: &'a str,
pub max_age_seconds: u64,
pub clock_skew_seconds: u64,
}
pub struct VerifyResult {
pub ok: bool,
pub error: Option<AshError>,
pub meta: Option<VerifyMeta>,
}
pub struct VerifyMeta {
pub canonical_query: String,
pub computed_body_hash: String,
pub binding: String,
}
impl VerifyResult {
fn fail(error: AshError) -> Self {
Self {
ok: false,
error: Some(error),
meta: None,
}
}
fn success(meta: Option<VerifyMeta>) -> Self {
Self {
ok: true,
error: None,
meta,
}
}
}
pub fn verify_incoming_request<H: HeaderMapView>(input: &VerifyRequestInput<'_, H>) -> VerifyResult {
let ts = match extract_single_header(input.headers, HDR_TIMESTAMP) {
Ok(v) => v,
Err(e) => return VerifyResult::fail(e),
};
let header_body_hash = match extract_single_header(input.headers, HDR_BODY_HASH) {
Ok(v) => v,
Err(e) => return VerifyResult::fail(e),
};
let proof = match extract_single_header(input.headers, HDR_PROOF) {
Ok(v) => v,
Err(e) => return VerifyResult::fail(e),
};
if let Err(e) = ash_validate_timestamp_format(&ts) {
return VerifyResult::fail(e);
}
if let Err(e) = validate_timestamp_with_reference(
&ts,
input.max_age_seconds,
input.clock_skew_seconds,
) {
return VerifyResult::fail(e);
}
if let Err(e) = ash_validate_nonce(input.nonce) {
return VerifyResult::fail(e);
}
let binding = match crate::ash_normalize_binding(input.method, input.path, input.raw_query) {
Ok(b) => b,
Err(e) => return VerifyResult::fail(e),
};
let computed_body_hash = crate::proof::ash_hash_body(input.canonical_body);
if !ash_timing_safe_equal(computed_body_hash.as_bytes(), header_body_hash.as_bytes()) {
return VerifyResult::fail(AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::General,
"Body hash mismatch",
));
}
let client_secret = match ash_derive_client_secret(input.nonce, input.context_id, &binding) {
Ok(s) => s,
Err(e) => return VerifyResult::fail(e),
};
let expected_proof = match ash_build_proof(&client_secret, &ts, &binding, &computed_body_hash) {
Ok(p) => p,
Err(e) => return VerifyResult::fail(e),
};
if !ash_timing_safe_equal(expected_proof.as_bytes(), proof.as_bytes()) {
return VerifyResult::fail(AshError::new(
AshErrorCode::ProofInvalid,
"Proof verification failed",
));
}
let meta = if cfg!(debug_assertions) {
Some(VerifyMeta {
canonical_query: input.raw_query.to_string(),
computed_body_hash,
binding,
})
} else {
None
};
VerifyResult::success(meta)
}
fn extract_single_header(h: &impl HeaderMapView, name: &'static str) -> Result<String, AshError> {
let vals = h.get_all_ci(name);
if vals.is_empty() {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrMissing,
format!("Required header '{}' is missing", name),
)
.with_detail("header", name),
);
}
if vals.len() > 1 {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrMultiValue,
format!("Header '{}' must have exactly one value, got {}", name, vals.len()),
)
.with_detail("header", name)
.with_detail("count", vals.len().to_string()),
);
}
let v = vals[0].trim();
if v.chars().any(|c| c == '\r' || c == '\n' || c.is_control()) {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrInvalidChars,
format!("Header '{}' contains invalid characters", name),
)
.with_detail("header", name),
);
}
Ok(v.to_string())
}
fn validate_timestamp_with_reference(
timestamp: &str,
max_age_seconds: u64,
clock_skew_seconds: u64,
) -> Result<(), AshError> {
crate::proof::ash_validate_timestamp(timestamp, max_age_seconds, clock_skew_seconds)
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHeaders(Vec<(String, String)>);
impl HeaderMapView for TestHeaders {
fn get_all_ci(&self, name: &str) -> Vec<&str> {
let n = name.to_ascii_lowercase();
self.0
.iter()
.filter(|(k, _)| k.to_ascii_lowercase() == n)
.map(|(_, v)| v.as_str())
.collect()
}
}
fn now_ts() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
.to_string()
}
fn make_valid_request() -> (TestHeaders, String, String) {
let nonce = "0123456789abcdef0123456789abcdef";
let context_id = "ctx_test123";
let binding = "POST|/api/transfer|";
let timestamp = now_ts();
let canonical_body = r#"{"amount":100}"#;
let body_hash = crate::proof::ash_hash_body(canonical_body);
let client_secret =
ash_derive_client_secret(nonce, context_id, binding).unwrap();
let proof =
ash_build_proof(&client_secret, ×tamp, binding, &body_hash).unwrap();
let headers = TestHeaders(vec![
("x-ash-ts".into(), timestamp),
("x-ash-body-hash".into(), body_hash),
("x-ash-proof".into(), proof),
]);
(headers, canonical_body.to_string(), nonce.to_string())
}
#[test]
fn test_valid_request_passes() {
let (headers, canonical_body, nonce) = make_valid_request();
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/transfer",
raw_query: "",
canonical_body: &canonical_body,
nonce: &nonce,
context_id: "ctx_test123",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(result.ok, "Expected ok, got error: {:?}", result.error);
}
#[test]
fn test_missing_timestamp_fails() {
let headers = TestHeaders(vec![
("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/test",
raw_query: "",
canonical_body: "{}",
nonce: "0123456789abcdef0123456789abcdef",
context_id: "ctx_test",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
let err = result.error.unwrap();
assert_eq!(err.code(), AshErrorCode::ValidationError);
assert_eq!(err.reason(), InternalReason::HdrMissing);
}
#[test]
fn test_invalid_timestamp_format_fails() {
let headers = TestHeaders(vec![
("x-ash-ts".into(), "not_a_number".into()),
("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/test",
raw_query: "",
canonical_body: "{}",
nonce: "0123456789abcdef0123456789abcdef",
context_id: "ctx_test",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
assert_eq!(result.error.unwrap().code(), AshErrorCode::TimestampInvalid);
}
#[test]
fn test_expired_timestamp_fails() {
let headers = TestHeaders(vec![
("x-ash-ts".into(), "1000000000".into()), ("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/test",
raw_query: "",
canonical_body: "{}",
nonce: "0123456789abcdef0123456789abcdef",
context_id: "ctx_test",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
assert_eq!(result.error.unwrap().code(), AshErrorCode::TimestampInvalid);
}
#[test]
fn test_body_hash_mismatch_fails() {
let (mut headers, _canonical_body, nonce) = make_valid_request();
for (k, v) in &mut headers.0 {
if k.to_ascii_lowercase() == "x-ash-body-hash" {
*v = "f".repeat(64); }
}
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/transfer",
raw_query: "",
canonical_body: r#"{"amount":100}"#,
nonce: &nonce,
context_id: "ctx_test123",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
let err = result.error.unwrap();
assert_eq!(err.code(), AshErrorCode::ValidationError);
assert!(err.message().contains("Body hash"));
}
#[test]
fn test_wrong_proof_fails() {
let (mut headers, canonical_body, nonce) = make_valid_request();
for (k, v) in &mut headers.0 {
if k.to_ascii_lowercase() == "x-ash-proof" {
*v = "f".repeat(64); }
}
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/transfer",
raw_query: "",
canonical_body: &canonical_body,
nonce: &nonce,
context_id: "ctx_test123",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
assert_eq!(result.error.unwrap().code(), AshErrorCode::ProofInvalid);
}
#[test]
fn test_tampered_body_fails() {
let (headers, _canonical_body, nonce) = make_valid_request();
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/transfer",
raw_query: "",
canonical_body: r#"{"amount":999}"#, nonce: &nonce,
context_id: "ctx_test123",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
let err = result.error.unwrap();
assert_eq!(err.code(), AshErrorCode::ValidationError);
}
#[test]
fn precedence_missing_ts_before_body_hash_mismatch() {
let headers = TestHeaders(vec![
("x-ash-body-hash".into(), "wrong".repeat(10)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/test",
raw_query: "",
canonical_body: "{}",
nonce: "0123456789abcdef0123456789abcdef",
context_id: "ctx_test",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
assert_eq!(result.error.unwrap().reason(), InternalReason::HdrMissing);
}
#[test]
fn precedence_bad_ts_format_before_bad_nonce() {
let headers = TestHeaders(vec![
("x-ash-ts".into(), "not_number".into()),
("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let input = VerifyRequestInput {
headers: &headers,
method: "POST",
path: "/api/test",
raw_query: "",
canonical_body: "{}",
nonce: "short", context_id: "ctx_test",
max_age_seconds: 300,
clock_skew_seconds: 60,
};
let result = verify_incoming_request(&input);
assert!(!result.ok);
assert_eq!(result.error.unwrap().code(), AshErrorCode::TimestampInvalid);
}
}