Skip to main content

ash_core/
headers.rs

1//! Canonical header extraction for ASH protocol.
2//!
3//! This module provides a single, authoritative function for extracting
4//! ASH-required headers from any HTTP framework. Middlewares should use
5//! `ash_extract_headers()` instead of reimplementing header parsing.
6//!
7//! ## Why This Exists
8//!
9//! Previously, every middleware (Express, FastAPI, Laravel, Gin, etc.)
10//! reimplemented header extraction with subtle differences in:
11//! - Case-insensitive lookup
12//! - Multi-value handling
13//! - Whitespace trimming
14//! - Control character rejection
15//!
16//! This caused systemic bugs (null nonce bypasses, enum mismatches).
17//! Moving extraction into Core eliminates this entire bug class.
18
19use crate::errors::{AshError, AshErrorCode, InternalReason};
20
21// ── Header Names (constants) ─────────────────────────────────────────
22
23/// ASH timestamp header name.
24pub const HDR_TIMESTAMP: &str = "x-ash-ts";
25
26/// ASH nonce header name.
27pub const HDR_NONCE: &str = "x-ash-nonce";
28
29/// ASH body hash header name.
30pub const HDR_BODY_HASH: &str = "x-ash-body-hash";
31
32/// ASH proof header name.
33pub const HDR_PROOF: &str = "x-ash-proof";
34
35/// ASH context ID header name.
36pub const HDR_CONTEXT_ID: &str = "x-ash-context-id";
37
38// ── Trait ─────────────────────────────────────────────────────────────
39
40/// Framework-agnostic header map interface.
41///
42/// Implement this trait for your HTTP framework's header type to use
43/// `ash_extract_headers()`. The implementation must support case-insensitive
44/// lookup and returning all values for a given header name.
45///
46/// # Example (test helper)
47///
48/// ```rust
49/// use ash_core::headers::HeaderMapView;
50///
51/// struct SimpleHeaders(Vec<(String, String)>);
52///
53/// impl HeaderMapView for SimpleHeaders {
54///     fn get_all_ci(&self, name: &str) -> Vec<&str> {
55///         let name_lower = name.to_ascii_lowercase();
56///         self.0.iter()
57///             .filter(|(k, _)| k.to_ascii_lowercase() == name_lower)
58///             .map(|(_, v)| v.as_str())
59///             .collect()
60///     }
61/// }
62/// ```
63pub trait HeaderMapView {
64    /// Return all values for the given header name (case-insensitive).
65    ///
66    /// Must return an empty Vec if the header is not present.
67    /// Must return multiple entries if the header appears multiple times.
68    fn get_all_ci(&self, name: &str) -> Vec<&str>;
69}
70
71// ── Bundle ────────────────────────────────────────────────────────────
72
73/// Extracted ASH headers, validated and trimmed.
74///
75/// All required headers are present and contain exactly one value
76/// with no control characters.
77#[derive(Debug, Clone)]
78pub struct HeaderBundle {
79    /// Unix timestamp string (validated present, not yet parsed)
80    pub ts: String,
81    /// Nonce string (validated present, not yet format-checked)
82    pub nonce: String,
83    /// Body hash hex string (validated present)
84    pub body_hash: String,
85    /// Proof hex string (validated present)
86    pub proof: String,
87    /// Context ID (optional header)
88    pub context_id: Option<String>,
89}
90
91// ── Extraction ────────────────────────────────────────────────────────
92
93/// Extract and validate all required ASH headers from a request.
94///
95/// # Validation Rules
96///
97/// - Case-insensitive header lookup
98/// - Missing required header → `ASH_VALIDATION_ERROR` (485)
99/// - Multiple values for a single-value header → `ASH_VALIDATION_ERROR` (485)
100/// - Control characters or newlines in value → `ASH_VALIDATION_ERROR` (485)
101/// - Leading/trailing whitespace is trimmed
102///
103/// # Required Headers
104///
105/// - `x-ash-ts` — timestamp
106/// - `x-ash-nonce` — nonce
107/// - `x-ash-body-hash` — body hash
108/// - `x-ash-proof` — proof
109///
110/// # Optional Headers
111///
112/// - `x-ash-context-id` — context ID (present if server-managed contexts are used)
113///
114/// # Example
115///
116/// ```rust
117/// use ash_core::headers::{HeaderMapView, ash_extract_headers};
118///
119/// struct TestHeaders(Vec<(String, String)>);
120/// impl HeaderMapView for TestHeaders {
121///     fn get_all_ci(&self, name: &str) -> Vec<&str> {
122///         let n = name.to_ascii_lowercase();
123///         self.0.iter()
124///             .filter(|(k, _)| k.to_ascii_lowercase() == n)
125///             .map(|(_, v)| v.as_str())
126///             .collect()
127///     }
128/// }
129///
130/// let headers = TestHeaders(vec![
131///     ("X-ASH-TS".into(), "1700000000".into()),
132///     ("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
133///     ("X-Ash-Body-Hash".into(), "a".repeat(64)),
134///     ("x-ash-proof".into(), "b".repeat(64)),
135/// ]);
136///
137/// let bundle = ash_extract_headers(&headers).unwrap();
138/// assert_eq!(bundle.ts, "1700000000");
139/// assert!(bundle.context_id.is_none());
140/// ```
141pub fn ash_extract_headers(h: &impl HeaderMapView) -> Result<HeaderBundle, AshError> {
142    let ts = get_one(h, HDR_TIMESTAMP)?;
143    let nonce = get_one(h, HDR_NONCE)?;
144    let body_hash = get_one(h, HDR_BODY_HASH)?;
145    let proof = get_one(h, HDR_PROOF)?;
146    let context_id = get_optional_one(h, HDR_CONTEXT_ID)?;
147
148    Ok(HeaderBundle {
149        ts,
150        nonce,
151        body_hash,
152        proof,
153        context_id,
154    })
155}
156
157/// Extract exactly one value for a required header.
158fn get_one(h: &impl HeaderMapView, name: &'static str) -> Result<String, AshError> {
159    let vals = h.get_all_ci(name);
160
161    if vals.is_empty() {
162        return Err(
163            AshError::with_reason(
164                AshErrorCode::ValidationError,
165                InternalReason::HdrMissing,
166                format!("Required header '{}' is missing", name),
167            )
168            .with_detail("header", name),
169        );
170    }
171    if vals.len() > 1 {
172        return Err(
173            AshError::with_reason(
174                AshErrorCode::ValidationError,
175                InternalReason::HdrMultiValue,
176                format!("Header '{}' must have exactly one value, got {}", name, vals.len()),
177            )
178            .with_detail("header", name)
179            .with_detail("count", vals.len().to_string()),
180        );
181    }
182
183    let v = vals[0].trim();
184    if contains_ctl_or_newlines(v) {
185        return Err(
186            AshError::with_reason(
187                AshErrorCode::ValidationError,
188                InternalReason::HdrInvalidChars,
189                format!("Header '{}' contains invalid characters", name),
190            )
191            .with_detail("header", name),
192        );
193    }
194
195    Ok(v.to_string())
196}
197
198/// Extract at most one value for an optional header.
199fn get_optional_one(h: &impl HeaderMapView, name: &'static str) -> Result<Option<String>, AshError> {
200    let vals = h.get_all_ci(name);
201
202    if vals.is_empty() {
203        return Ok(None);
204    }
205    if vals.len() > 1 {
206        return Err(
207            AshError::with_reason(
208                AshErrorCode::ValidationError,
209                InternalReason::HdrMultiValue,
210                format!("Header '{}' must have exactly one value, got {}", name, vals.len()),
211            )
212            .with_detail("header", name)
213            .with_detail("count", vals.len().to_string()),
214        );
215    }
216
217    let v = vals[0].trim();
218    if contains_ctl_or_newlines(v) {
219        return Err(
220            AshError::with_reason(
221                AshErrorCode::ValidationError,
222                InternalReason::HdrInvalidChars,
223                format!("Header '{}' contains invalid characters", name),
224            )
225            .with_detail("header", name),
226        );
227    }
228
229    Ok(Some(v.to_string()))
230}
231
232/// Check if a string contains control characters or newlines.
233fn contains_ctl_or_newlines(s: &str) -> bool {
234    s.chars().any(|c| c == '\r' || c == '\n' || c.is_control())
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    /// Simple test implementation of HeaderMapView.
242    struct TestHeaders(Vec<(String, String)>);
243
244    impl HeaderMapView for TestHeaders {
245        fn get_all_ci(&self, name: &str) -> Vec<&str> {
246            let name_lower = name.to_ascii_lowercase();
247            self.0
248                .iter()
249                .filter(|(k, _)| k.to_ascii_lowercase() == name_lower)
250                .map(|(_, v)| v.as_str())
251                .collect()
252        }
253    }
254
255    fn valid_headers() -> TestHeaders {
256        TestHeaders(vec![
257            ("X-ASH-TS".into(), "1700000000".into()),
258            ("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
259            ("X-Ash-Body-Hash".into(), "a".repeat(64)),
260            ("x-ash-proof".into(), "b".repeat(64)),
261        ])
262    }
263
264    #[test]
265    fn test_extract_all_required() {
266        let bundle = ash_extract_headers(&valid_headers()).unwrap();
267        assert_eq!(bundle.ts, "1700000000");
268        assert_eq!(bundle.nonce, "0123456789abcdef0123456789abcdef");
269        assert_eq!(bundle.body_hash, "a".repeat(64));
270        assert_eq!(bundle.proof, "b".repeat(64));
271        assert!(bundle.context_id.is_none());
272    }
273
274    #[test]
275    fn test_extract_with_context_id() {
276        let mut h = valid_headers();
277        h.0.push(("X-ASH-Context-ID".into(), "ctx_abc123".into()));
278        let bundle = ash_extract_headers(&h).unwrap();
279        assert_eq!(bundle.context_id, Some("ctx_abc123".into()));
280    }
281
282    #[test]
283    fn test_case_insensitive() {
284        let h = TestHeaders(vec![
285            ("x-ash-ts".into(), "1700000000".into()),
286            ("X-ASH-NONCE".into(), "0123456789abcdef0123456789abcdef".into()),
287            ("X-Ash-Body-Hash".into(), "a".repeat(64)),
288            ("x-AsH-pRoOf".into(), "b".repeat(64)),
289        ]);
290        assert!(ash_extract_headers(&h).is_ok());
291    }
292
293    #[test]
294    fn test_missing_timestamp() {
295        let h = TestHeaders(vec![
296            ("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
297            ("x-ash-body-hash".into(), "a".repeat(64)),
298            ("x-ash-proof".into(), "b".repeat(64)),
299        ]);
300        let err = ash_extract_headers(&h).unwrap_err();
301        assert_eq!(err.code(), AshErrorCode::ValidationError);
302        assert_eq!(err.http_status(), 485);
303        assert_eq!(err.reason(), InternalReason::HdrMissing);
304        assert!(err.details().unwrap().get("header").unwrap().contains("ts"));
305    }
306
307    #[test]
308    fn test_missing_nonce() {
309        let h = TestHeaders(vec![
310            ("x-ash-ts".into(), "1700000000".into()),
311            ("x-ash-body-hash".into(), "a".repeat(64)),
312            ("x-ash-proof".into(), "b".repeat(64)),
313        ]);
314        let err = ash_extract_headers(&h).unwrap_err();
315        assert_eq!(err.reason(), InternalReason::HdrMissing);
316    }
317
318    #[test]
319    fn test_multi_value_nonce() {
320        let h = TestHeaders(vec![
321            ("x-ash-ts".into(), "1700000000".into()),
322            ("x-ash-nonce".into(), "aaa".into()),
323            ("x-ash-nonce".into(), "bbb".into()),
324            ("x-ash-body-hash".into(), "a".repeat(64)),
325            ("x-ash-proof".into(), "b".repeat(64)),
326        ]);
327        let err = ash_extract_headers(&h).unwrap_err();
328        assert_eq!(err.code(), AshErrorCode::ValidationError);
329        assert_eq!(err.http_status(), 485);
330        assert_eq!(err.reason(), InternalReason::HdrMultiValue);
331    }
332
333    #[test]
334    fn test_control_chars_in_proof() {
335        let h = TestHeaders(vec![
336            ("x-ash-ts".into(), "1700000000".into()),
337            ("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
338            ("x-ash-body-hash".into(), "a".repeat(64)),
339            ("x-ash-proof".into(), "proof\ninjection".into()),
340        ]);
341        let err = ash_extract_headers(&h).unwrap_err();
342        assert_eq!(err.reason(), InternalReason::HdrInvalidChars);
343    }
344
345    #[test]
346    fn test_trimming() {
347        let h = TestHeaders(vec![
348            ("x-ash-ts".into(), "  1700000000  ".into()),
349            ("x-ash-nonce".into(), " 0123456789abcdef0123456789abcdef ".into()),
350            ("x-ash-body-hash".into(), format!(" {} ", "a".repeat(64))),
351            ("x-ash-proof".into(), format!(" {} ", "b".repeat(64))),
352        ]);
353        let bundle = ash_extract_headers(&h).unwrap();
354        assert_eq!(bundle.ts, "1700000000");
355        assert_eq!(bundle.nonce, "0123456789abcdef0123456789abcdef");
356    }
357
358    #[test]
359    fn test_multi_value_optional_context_id() {
360        let mut h = valid_headers();
361        h.0.push(("x-ash-context-id".into(), "ctx_1".into()));
362        h.0.push(("X-ASH-Context-ID".into(), "ctx_2".into()));
363        let err = ash_extract_headers(&h).unwrap_err();
364        assert_eq!(err.reason(), InternalReason::HdrMultiValue);
365    }
366}