1use crate::errors::{AshError, AshErrorCode, InternalReason};
20
21pub const HDR_TIMESTAMP: &str = "x-ash-ts";
25
26pub const HDR_NONCE: &str = "x-ash-nonce";
28
29pub const HDR_BODY_HASH: &str = "x-ash-body-hash";
31
32pub const HDR_PROOF: &str = "x-ash-proof";
34
35pub const HDR_CONTEXT_ID: &str = "x-ash-context-id";
37
38pub trait HeaderMapView {
64 fn get_all_ci(&self, name: &str) -> Vec<&str>;
69}
70
71#[derive(Debug, Clone)]
78pub struct HeaderBundle {
79 pub ts: String,
81 pub nonce: String,
83 pub body_hash: String,
85 pub proof: String,
87 pub context_id: Option<String>,
89}
90
91pub fn ash_extract_headers(h: &impl HeaderMapView) -> Result<HeaderBundle, AshError> {
142 let ts = get_one(h, HDR_TIMESTAMP)?;
143 let nonce = get_one(h, HDR_NONCE)?;
144 let body_hash = get_one(h, HDR_BODY_HASH)?;
145 let proof = get_one(h, HDR_PROOF)?;
146 let context_id = get_optional_one(h, HDR_CONTEXT_ID)?;
147
148 Ok(HeaderBundle {
149 ts,
150 nonce,
151 body_hash,
152 proof,
153 context_id,
154 })
155}
156
157fn get_one(h: &impl HeaderMapView, name: &'static str) -> Result<String, AshError> {
159 let vals = h.get_all_ci(name);
160
161 if vals.is_empty() {
162 return Err(
163 AshError::with_reason(
164 AshErrorCode::ValidationError,
165 InternalReason::HdrMissing,
166 format!("Required header '{}' is missing", name),
167 )
168 .with_detail("header", name),
169 );
170 }
171 if vals.len() > 1 {
172 return Err(
173 AshError::with_reason(
174 AshErrorCode::ValidationError,
175 InternalReason::HdrMultiValue,
176 format!("Header '{}' must have exactly one value, got {}", name, vals.len()),
177 )
178 .with_detail("header", name)
179 .with_detail("count", vals.len().to_string()),
180 );
181 }
182
183 let v = vals[0].trim();
184 if contains_ctl_or_newlines(v) {
185 return Err(
186 AshError::with_reason(
187 AshErrorCode::ValidationError,
188 InternalReason::HdrInvalidChars,
189 format!("Header '{}' contains invalid characters", name),
190 )
191 .with_detail("header", name),
192 );
193 }
194
195 Ok(v.to_string())
196}
197
198fn get_optional_one(h: &impl HeaderMapView, name: &'static str) -> Result<Option<String>, AshError> {
200 let vals = h.get_all_ci(name);
201
202 if vals.is_empty() {
203 return Ok(None);
204 }
205 if vals.len() > 1 {
206 return Err(
207 AshError::with_reason(
208 AshErrorCode::ValidationError,
209 InternalReason::HdrMultiValue,
210 format!("Header '{}' must have exactly one value, got {}", name, vals.len()),
211 )
212 .with_detail("header", name)
213 .with_detail("count", vals.len().to_string()),
214 );
215 }
216
217 let v = vals[0].trim();
218 if contains_ctl_or_newlines(v) {
219 return Err(
220 AshError::with_reason(
221 AshErrorCode::ValidationError,
222 InternalReason::HdrInvalidChars,
223 format!("Header '{}' contains invalid characters", name),
224 )
225 .with_detail("header", name),
226 );
227 }
228
229 Ok(Some(v.to_string()))
230}
231
232fn contains_ctl_or_newlines(s: &str) -> bool {
234 s.chars().any(|c| c == '\r' || c == '\n' || c.is_control())
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 struct TestHeaders(Vec<(String, String)>);
243
244 impl HeaderMapView for TestHeaders {
245 fn get_all_ci(&self, name: &str) -> Vec<&str> {
246 let name_lower = name.to_ascii_lowercase();
247 self.0
248 .iter()
249 .filter(|(k, _)| k.to_ascii_lowercase() == name_lower)
250 .map(|(_, v)| v.as_str())
251 .collect()
252 }
253 }
254
255 fn valid_headers() -> TestHeaders {
256 TestHeaders(vec![
257 ("X-ASH-TS".into(), "1700000000".into()),
258 ("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
259 ("X-Ash-Body-Hash".into(), "a".repeat(64)),
260 ("x-ash-proof".into(), "b".repeat(64)),
261 ])
262 }
263
264 #[test]
265 fn test_extract_all_required() {
266 let bundle = ash_extract_headers(&valid_headers()).unwrap();
267 assert_eq!(bundle.ts, "1700000000");
268 assert_eq!(bundle.nonce, "0123456789abcdef0123456789abcdef");
269 assert_eq!(bundle.body_hash, "a".repeat(64));
270 assert_eq!(bundle.proof, "b".repeat(64));
271 assert!(bundle.context_id.is_none());
272 }
273
274 #[test]
275 fn test_extract_with_context_id() {
276 let mut h = valid_headers();
277 h.0.push(("X-ASH-Context-ID".into(), "ctx_abc123".into()));
278 let bundle = ash_extract_headers(&h).unwrap();
279 assert_eq!(bundle.context_id, Some("ctx_abc123".into()));
280 }
281
282 #[test]
283 fn test_case_insensitive() {
284 let h = TestHeaders(vec![
285 ("x-ash-ts".into(), "1700000000".into()),
286 ("X-ASH-NONCE".into(), "0123456789abcdef0123456789abcdef".into()),
287 ("X-Ash-Body-Hash".into(), "a".repeat(64)),
288 ("x-AsH-pRoOf".into(), "b".repeat(64)),
289 ]);
290 assert!(ash_extract_headers(&h).is_ok());
291 }
292
293 #[test]
294 fn test_missing_timestamp() {
295 let h = TestHeaders(vec![
296 ("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
297 ("x-ash-body-hash".into(), "a".repeat(64)),
298 ("x-ash-proof".into(), "b".repeat(64)),
299 ]);
300 let err = ash_extract_headers(&h).unwrap_err();
301 assert_eq!(err.code(), AshErrorCode::ValidationError);
302 assert_eq!(err.http_status(), 485);
303 assert_eq!(err.reason(), InternalReason::HdrMissing);
304 assert!(err.details().unwrap().get("header").unwrap().contains("ts"));
305 }
306
307 #[test]
308 fn test_missing_nonce() {
309 let h = TestHeaders(vec![
310 ("x-ash-ts".into(), "1700000000".into()),
311 ("x-ash-body-hash".into(), "a".repeat(64)),
312 ("x-ash-proof".into(), "b".repeat(64)),
313 ]);
314 let err = ash_extract_headers(&h).unwrap_err();
315 assert_eq!(err.reason(), InternalReason::HdrMissing);
316 }
317
318 #[test]
319 fn test_multi_value_nonce() {
320 let h = TestHeaders(vec![
321 ("x-ash-ts".into(), "1700000000".into()),
322 ("x-ash-nonce".into(), "aaa".into()),
323 ("x-ash-nonce".into(), "bbb".into()),
324 ("x-ash-body-hash".into(), "a".repeat(64)),
325 ("x-ash-proof".into(), "b".repeat(64)),
326 ]);
327 let err = ash_extract_headers(&h).unwrap_err();
328 assert_eq!(err.code(), AshErrorCode::ValidationError);
329 assert_eq!(err.http_status(), 485);
330 assert_eq!(err.reason(), InternalReason::HdrMultiValue);
331 }
332
333 #[test]
334 fn test_control_chars_in_proof() {
335 let h = TestHeaders(vec![
336 ("x-ash-ts".into(), "1700000000".into()),
337 ("x-ash-nonce".into(), "0123456789abcdef0123456789abcdef".into()),
338 ("x-ash-body-hash".into(), "a".repeat(64)),
339 ("x-ash-proof".into(), "proof\ninjection".into()),
340 ]);
341 let err = ash_extract_headers(&h).unwrap_err();
342 assert_eq!(err.reason(), InternalReason::HdrInvalidChars);
343 }
344
345 #[test]
346 fn test_trimming() {
347 let h = TestHeaders(vec![
348 ("x-ash-ts".into(), " 1700000000 ".into()),
349 ("x-ash-nonce".into(), " 0123456789abcdef0123456789abcdef ".into()),
350 ("x-ash-body-hash".into(), format!(" {} ", "a".repeat(64))),
351 ("x-ash-proof".into(), format!(" {} ", "b".repeat(64))),
352 ]);
353 let bundle = ash_extract_headers(&h).unwrap();
354 assert_eq!(bundle.ts, "1700000000");
355 assert_eq!(bundle.nonce, "0123456789abcdef0123456789abcdef");
356 }
357
358 #[test]
359 fn test_multi_value_optional_context_id() {
360 let mut h = valid_headers();
361 h.0.push(("x-ash-context-id".into(), "ctx_1".into()));
362 h.0.push(("X-ASH-Context-ID".into(), "ctx_2".into()));
363 let err = ash_extract_headers(&h).unwrap_err();
364 assert_eq!(err.reason(), InternalReason::HdrMultiValue);
365 }
366}