use crate::errors::{AshError, AshErrorCode, InternalReason};
pub const HDR_TIMESTAMP: &str = "x-ash-ts";
pub const HDR_NONCE: &str = "x-ash-nonce";
pub const HDR_BODY_HASH: &str = "x-ash-body-hash";
pub const HDR_PROOF: &str = "x-ash-proof";
pub const HDR_CONTEXT_ID: &str = "x-ash-context-id";
pub trait HeaderMapView {
fn get_all_ci(&self, name: &str) -> Vec<&str>;
}
#[derive(Debug, Clone)]
pub struct HeaderBundle {
pub ts: String,
pub nonce: String,
pub body_hash: String,
pub proof: String,
pub context_id: Option<String>,
}
pub fn ash_extract_headers(h: &impl HeaderMapView) -> Result<HeaderBundle, AshError> {
let ts = get_one(h, HDR_TIMESTAMP)?;
let nonce = get_one(h, HDR_NONCE)?;
let body_hash = get_one(h, HDR_BODY_HASH)?;
let proof = get_one(h, HDR_PROOF)?;
let context_id = get_optional_one(h, HDR_CONTEXT_ID)?;
Ok(HeaderBundle {
ts,
nonce,
body_hash,
proof,
context_id,
})
}
fn get_one(h: &impl HeaderMapView, name: &'static str) -> Result<String, AshError> {
let vals = h.get_all_ci(name);
if vals.is_empty() {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrMissing,
format!("Required header '{}' is missing", name),
)
.with_detail("header", name),
);
}
if vals.len() > 1 {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrMultiValue,
format!("Header '{}' must have exactly one value, got {}", name, vals.len()),
)
.with_detail("header", name)
.with_detail("count", vals.len().to_string()),
);
}
let v = vals[0].trim();
if contains_ctl_or_newlines(v) {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrInvalidChars,
format!("Header '{}' contains invalid characters", name),
)
.with_detail("header", name),
);
}
Ok(v.to_string())
}
fn get_optional_one(h: &impl HeaderMapView, name: &'static str) -> Result<Option<String>, AshError> {
let vals = h.get_all_ci(name);
if vals.is_empty() {
return Ok(None);
}
if vals.len() > 1 {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrMultiValue,
format!("Header '{}' must have exactly one value, got {}", name, vals.len()),
)
.with_detail("header", name)
.with_detail("count", vals.len().to_string()),
);
}
let v = vals[0].trim();
if contains_ctl_or_newlines(v) {
return Err(
AshError::with_reason(
AshErrorCode::ValidationError,
InternalReason::HdrInvalidChars,
format!("Header '{}' contains invalid characters", name),
)
.with_detail("header", name),
);
}
Ok(Some(v.to_string()))
}
fn contains_ctl_or_newlines(s: &str) -> bool {
s.chars().any(|c| c == '\r' || c == '\n' || c.is_control())
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHeaders(Vec<(String, String)>);
impl HeaderMapView for TestHeaders {
fn get_all_ci(&self, name: &str) -> Vec<&str> {
let name_lower = name.to_ascii_lowercase();
self.0
.iter()
.filter(|(k, _)| k.to_ascii_lowercase() == name_lower)
.map(|(_, v)| v.as_str())
.collect()
}
}
fn valid_headers() -> TestHeaders {
TestHeaders(vec![
("X-ASH-TS".into(), "1700000000".into()),
("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
("X-Ash-Body-Hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
])
}
#[test]
fn test_extract_all_required() {
let bundle = ash_extract_headers(&valid_headers()).unwrap();
assert_eq!(bundle.ts, "1700000000");
assert_eq!(bundle.nonce, "0123456789abcdef0123456789abcdef");
assert_eq!(bundle.body_hash, "a".repeat(64));
assert_eq!(bundle.proof, "b".repeat(64));
assert!(bundle.context_id.is_none());
}
#[test]
fn test_extract_with_context_id() {
let mut h = valid_headers();
h.0.push(("X-ASH-Context-ID".into(), "ctx_abc123".into()));
let bundle = ash_extract_headers(&h).unwrap();
assert_eq!(bundle.context_id, Some("ctx_abc123".into()));
}
#[test]
fn test_case_insensitive() {
let h = TestHeaders(vec![
("x-ash-ts".into(), "1700000000".into()),
("X-ASH-NONCE".into(), "0123456789abcdef0123456789abcdef".into()),
("X-Ash-Body-Hash".into(), "a".repeat(64)),
("x-AsH-pRoOf".into(), "b".repeat(64)),
]);
assert!(ash_extract_headers(&h).is_ok());
}
#[test]
fn test_missing_timestamp() {
let h = TestHeaders(vec![
("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let err = ash_extract_headers(&h).unwrap_err();
assert_eq!(err.code(), AshErrorCode::ValidationError);
assert_eq!(err.http_status(), 485);
assert_eq!(err.reason(), InternalReason::HdrMissing);
assert!(err.details().unwrap().get("header").unwrap().contains("ts"));
}
#[test]
fn test_missing_nonce() {
let h = TestHeaders(vec![
("x-ash-ts".into(), "1700000000".into()),
("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let err = ash_extract_headers(&h).unwrap_err();
assert_eq!(err.reason(), InternalReason::HdrMissing);
}
#[test]
fn test_multi_value_nonce() {
let h = TestHeaders(vec![
("x-ash-ts".into(), "1700000000".into()),
("x-ash-nonce".into(), "aaa".into()),
("x-ash-nonce".into(), "bbb".into()),
("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "b".repeat(64)),
]);
let err = ash_extract_headers(&h).unwrap_err();
assert_eq!(err.code(), AshErrorCode::ValidationError);
assert_eq!(err.http_status(), 485);
assert_eq!(err.reason(), InternalReason::HdrMultiValue);
}
#[test]
fn test_control_chars_in_proof() {
let h = TestHeaders(vec![
("x-ash-ts".into(), "1700000000".into()),
("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
("x-ash-body-hash".into(), "a".repeat(64)),
("x-ash-proof".into(), "proof\ninjection".into()),
]);
let err = ash_extract_headers(&h).unwrap_err();
assert_eq!(err.reason(), InternalReason::HdrInvalidChars);
}
#[test]
fn test_trimming() {
let h = TestHeaders(vec![
("x-ash-ts".into(), " 1700000000 ".into()),
("x-ash-nonce".into(), " 0123456789abcdef0123456789abcdef ".into()),
("x-ash-body-hash".into(), format!(" {} ", "a".repeat(64))),
("x-ash-proof".into(), format!(" {} ", "b".repeat(64))),
]);
let bundle = ash_extract_headers(&h).unwrap();
assert_eq!(bundle.ts, "1700000000");
assert_eq!(bundle.nonce, "0123456789abcdef0123456789abcdef");
}
#[test]
fn test_multi_value_optional_context_id() {
let mut h = valid_headers();
h.0.push(("x-ash-context-id".into(), "ctx_1".into()));
h.0.push(("X-ASH-Context-ID".into(), "ctx_2".into()));
let err = ash_extract_headers(&h).unwrap_err();
assert_eq!(err.reason(), InternalReason::HdrMultiValue);
}
}