1use serde::Deserialize;
36use std::collections::BTreeMap;
37
38#[derive(Debug, Deserialize)]
42pub struct VectorFile {
43 pub schema_version: u32,
45 pub ash_version: String,
47 #[serde(default)]
49 pub categories: BTreeMap<String, Vec<Vector>>,
50 #[serde(default)]
52 pub vectors: Vec<Vector>,
53}
54
55#[derive(Debug, Clone, Deserialize)]
57pub struct Vector {
58 pub id: String,
60 #[serde(default)]
62 pub category: String,
63 #[serde(default)]
65 pub description: String,
66 #[serde(default)]
68 pub input: serde_json::Value,
69 #[serde(default)]
71 pub expected: serde_json::Value,
72}
73
74#[derive(Debug, Clone)]
78pub struct AdapterResult {
79 pub output: Option<String>,
81 pub ok: bool,
83 pub error_code: Option<String>,
85 pub error_status: Option<u16>,
87}
88
89impl AdapterResult {
90 pub fn ok(output: impl Into<String>) -> Self {
92 Self {
93 output: Some(output.into()),
94 ok: true,
95 error_code: None,
96 error_status: None,
97 }
98 }
99
100 pub fn ok_bool(val: bool) -> Self {
102 Self {
103 output: Some(val.to_string()),
104 ok: true,
105 error_code: None,
106 error_status: None,
107 }
108 }
109
110 pub fn error(code: impl Into<String>, status: u16) -> Self {
112 Self {
113 output: None,
114 ok: false,
115 error_code: Some(code.into()),
116 error_status: Some(status),
117 }
118 }
119
120 pub fn skip() -> Self {
122 Self {
123 output: None,
124 ok: true,
125 error_code: None,
126 error_status: None,
127 }
128 }
129}
130
131pub trait AshAdapter {
136 fn canonicalize_json(&self, input: &str) -> AdapterResult { let _ = input; AdapterResult::skip() }
138 fn canonicalize_query(&self, input: &str) -> AdapterResult { let _ = input; AdapterResult::skip() }
140 fn canonicalize_urlencoded(&self, input: &str) -> AdapterResult { let _ = input; AdapterResult::skip() }
142 fn normalize_binding(&self, method: &str, path: &str, query: &str) -> AdapterResult { let _ = (method, path, query); AdapterResult::skip() }
144 fn hash_body(&self, body: &str) -> AdapterResult { let _ = body; AdapterResult::skip() }
146 fn derive_client_secret(&self, nonce: &str, context_id: &str, binding: &str) -> AdapterResult { let _ = (nonce, context_id, binding); AdapterResult::skip() }
148 fn build_proof(&self, secret: &str, ts: &str, binding: &str, body_hash: &str) -> AdapterResult { let _ = (secret, ts, binding, body_hash); AdapterResult::skip() }
150 fn timing_safe_equal(&self, a: &str, b: &str) -> AdapterResult { let _ = (a, b); AdapterResult::skip() }
152 fn validate_timestamp(&self, ts: &str) -> AdapterResult { let _ = ts; AdapterResult::skip() }
154 fn trigger_error(&self, input: &serde_json::Value) -> AdapterResult { let _ = input; AdapterResult::skip() }
156 fn extract_scoped_fields(&self, payload: &str, fields: &[String], strict: bool) -> AdapterResult { let _ = (payload, fields, strict); AdapterResult::skip() }
158 fn build_unified_proof(&self, input: &serde_json::Value) -> AdapterResult { let _ = input; AdapterResult::skip() }
160}
161
162pub fn load_vectors(data: &[u8]) -> Result<Vec<Vector>, String> {
168 let file: serde_json::Value =
169 serde_json::from_slice(data).map_err(|e| format!("Failed to parse vectors JSON: {}", e))?;
170
171 let mut all_vectors = Vec::new();
172
173 if let Some(obj) = file.as_object() {
175 for (key, val) in obj {
176 if matches!(
178 key.as_str(),
179 "schema_version"
180 | "ash_version"
181 | "generated_from"
182 | "generated_at"
183 | "generator_version"
184 | "platform"
185 ) {
186 continue;
187 }
188
189 if let Some(arr) = val.as_array() {
191 for item in arr {
192 if let Ok(mut vec) = serde_json::from_value::<Vector>(item.clone()) {
193 if vec.category.is_empty() {
194 vec.category = key.clone();
195 }
196 all_vectors.push(vec);
197 }
198 }
199 }
200 }
201 }
202
203 Ok(all_vectors)
204}
205
206pub fn load_vectors_from_file(path: &str) -> Result<Vec<Vector>, String> {
208 let data = std::fs::read(path).map_err(|e| format!("Failed to read {}: {}", path, e))?;
209 load_vectors(&data)
210}
211
212#[derive(Debug, Clone)]
216pub struct VectorResult {
217 pub id: String,
219 pub category: String,
221 pub passed: bool,
223 pub skipped: bool,
225 pub expected: String,
227 pub actual: String,
229 pub diff: Option<String>,
231}
232
233#[derive(Debug)]
235pub struct TestReport {
236 pub results: Vec<VectorResult>,
238 pub total: usize,
240 pub passed: usize,
242 pub failed: usize,
244 pub skipped: usize,
246}
247
248impl TestReport {
249 pub fn all_passed(&self) -> bool {
251 self.failed == 0
252 }
253
254 pub fn failures(&self) -> Vec<&VectorResult> {
256 self.results.iter().filter(|r| !r.passed && !r.skipped).collect()
257 }
258
259 pub fn summary(&self) -> String {
261 format!(
262 "{}/{} passed, {} failed, {} skipped",
263 self.passed, self.total, self.failed, self.skipped
264 )
265 }
266}
267
268pub fn run_vectors(vectors: &[Vector], adapter: &dyn AshAdapter) -> TestReport {
273 let mut results = Vec::with_capacity(vectors.len());
274 let mut passed = 0;
275 let mut failed = 0;
276 let mut skipped = 0;
277
278 for vec in vectors {
279 let result = run_single_vector(vec, adapter);
280 if result.skipped {
281 skipped += 1;
282 } else if result.passed {
283 passed += 1;
284 } else {
285 failed += 1;
286 }
287 results.push(result);
288 }
289
290 TestReport {
291 total: vectors.len(),
292 passed,
293 failed,
294 skipped,
295 results,
296 }
297}
298
299fn run_single_vector(vec: &Vector, adapter: &dyn AshAdapter) -> VectorResult {
300 let category = vec.category.as_str();
301
302 let (adapter_result, expected_str) = match category {
303 "json_canonicalization" => {
304 let input = vec.input.get("input_json_text")
305 .or_else(|| vec.input.get("input"))
306 .and_then(|v| v.as_str())
307 .unwrap_or("");
308 let expected = vec.expected.get("canonical_json")
309 .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
310 .and_then(|v| v.as_str())
311 .unwrap_or("");
312 (adapter.canonicalize_json(input), expected.to_string())
313 }
314 "query_canonicalization" => {
315 let input = vec.input.get("raw_query")
316 .or_else(|| vec.input.get("input"))
317 .and_then(|v| v.as_str())
318 .unwrap_or("");
319 let expected = vec.expected.get("canonical_query")
320 .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
321 .and_then(|v| v.as_str())
322 .unwrap_or("");
323 (adapter.canonicalize_query(input), expected.to_string())
324 }
325 "urlencoded_canonicalization" => {
326 let input = vec.input.get("input")
327 .and_then(|v| v.as_str())
328 .unwrap_or("");
329 let expected = vec.expected.get("canonical")
330 .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
331 .and_then(|v| v.as_str())
332 .unwrap_or("");
333 (adapter.canonicalize_urlencoded(input), expected.to_string())
334 }
335 "binding_normalization" => {
336 let method = vec.input.get("method").and_then(|v| v.as_str()).unwrap_or("");
337 let path = vec.input.get("path").and_then(|v| v.as_str()).unwrap_or("");
338 let query = vec.input.get("query").and_then(|v| v.as_str()).unwrap_or("");
339 let expected = vec.expected.get("binding")
340 .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
341 .and_then(|v| v.as_str())
342 .unwrap_or("");
343 (adapter.normalize_binding(method, path, query), expected.to_string())
344 }
345 "body_hashing" => {
346 let input = vec.input.get("body")
347 .or_else(|| vec.input.get("input"))
348 .and_then(|v| v.as_str())
349 .unwrap_or("");
350 let expected = vec.expected.get("hash")
351 .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
352 .and_then(|v| v.as_str())
353 .unwrap_or("");
354 (adapter.hash_body(input), expected.to_string())
355 }
356 "client_secret_derivation" => {
357 let nonce = vec.input.get("nonce").and_then(|v| v.as_str()).unwrap_or("");
358 let ctx = vec.input.get("context_id").and_then(|v| v.as_str()).unwrap_or("");
359 let binding = vec.input.get("binding").and_then(|v| v.as_str()).unwrap_or("");
360 let expected = vec.expected.get("client_secret")
361 .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
362 .and_then(|v| v.as_str())
363 .unwrap_or("");
364 (adapter.derive_client_secret(nonce, ctx, binding), expected.to_string())
365 }
366 "proof_generation" => {
367 let secret = vec.input.get("client_secret").and_then(|v| v.as_str()).unwrap_or("");
368 let ts = vec.input.get("timestamp").and_then(|v| v.as_str()).unwrap_or("");
369 let binding = vec.input.get("binding").and_then(|v| v.as_str()).unwrap_or("");
370 let body_hash = vec.input.get("body_hash").and_then(|v| v.as_str()).unwrap_or("");
371 let expected = vec.expected.get("proof")
372 .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
373 .and_then(|v| v.as_str())
374 .unwrap_or("");
375 (adapter.build_proof(secret, ts, binding, body_hash), expected.to_string())
376 }
377 "timing_safe_comparison" => {
378 let a = vec.input.get("a").and_then(|v| v.as_str()).unwrap_or("");
379 let b = vec.input.get("b").and_then(|v| v.as_str()).unwrap_or("");
380 let expected = vec.expected.get("equal")
381 .and_then(|v| v.as_bool())
382 .map(|b| b.to_string())
383 .unwrap_or_default();
384 (adapter.timing_safe_equal(a, b), expected)
385 }
386 "error_behavior" => {
387 let expected_code = vec.expected.get("error_code")
388 .and_then(|v| v.as_str())
389 .unwrap_or("");
390 let expected_status = vec.expected.get("http_status")
391 .and_then(|v| v.as_u64())
392 .unwrap_or(0) as u16;
393 let result = adapter.trigger_error(&vec.input);
394 let expected_str = format!("{}:{}", expected_code, expected_status);
395 let actual_str = if result.ok {
396 "ok".to_string()
397 } else {
398 format!("{}:{}", result.error_code.as_deref().unwrap_or(""), result.error_status.unwrap_or(0))
399 };
400 return VectorResult {
401 id: vec.id.clone(),
402 category: vec.category.clone(),
403 passed: !result.ok
404 && result.error_code.as_deref() == Some(expected_code)
405 && result.error_status == Some(expected_status),
406 skipped: result.output.is_none() && result.ok && result.error_code.is_none(),
407 expected: expected_str,
408 actual: actual_str,
409 diff: None,
410 };
411 }
412 "timestamp_validation" => {
413 let ts = vec.input.get("timestamp").and_then(|v| v.as_str()).unwrap_or("");
414 let should_pass = vec.expected.get("valid").and_then(|v| v.as_bool()).unwrap_or(false);
415 let result = adapter.validate_timestamp(ts);
416 let actual_ok = result.ok;
417 return VectorResult {
418 id: vec.id.clone(),
419 category: vec.category.clone(),
420 passed: actual_ok == should_pass,
421 skipped: result.output.is_none() && result.ok && result.error_code.is_none(),
422 expected: format!("valid={}", should_pass),
423 actual: format!("valid={}", actual_ok),
424 diff: if actual_ok != should_pass {
425 Some(format!("Expected valid={}, got valid={}", should_pass, actual_ok))
426 } else {
427 None
428 },
429 };
430 }
431 _ => {
432 return VectorResult {
433 id: vec.id.clone(),
434 category: vec.category.clone(),
435 passed: false,
436 skipped: true,
437 expected: String::new(),
438 actual: String::new(),
439 diff: Some(format!("Unknown category: {}", category)),
440 };
441 }
442 };
443
444 if adapter_result.output.is_none() && adapter_result.ok && adapter_result.error_code.is_none() {
446 return VectorResult {
447 id: vec.id.clone(),
448 category: vec.category.clone(),
449 passed: false,
450 skipped: true,
451 expected: expected_str,
452 actual: String::new(),
453 diff: None,
454 };
455 }
456
457 let actual = adapter_result.output.unwrap_or_default();
458 let pass = actual == expected_str;
459
460 VectorResult {
461 id: vec.id.clone(),
462 category: vec.category.clone(),
463 passed: pass,
464 skipped: false,
465 expected: expected_str.clone(),
466 actual: actual.clone(),
467 diff: if pass {
468 None
469 } else {
470 Some(format!("expected: {}\n actual: {}", expected_str, actual))
471 },
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
482 fn test_adapter_result_ok() {
483 let r = AdapterResult::ok("hello");
484 assert!(r.ok);
485 assert_eq!(r.output, Some("hello".to_string()));
486 assert!(r.error_code.is_none());
487 }
488
489 #[test]
490 fn test_adapter_result_error() {
491 let r = AdapterResult::error("ASH_VALIDATION_ERROR", 485);
492 assert!(!r.ok);
493 assert!(r.output.is_none());
494 assert_eq!(r.error_code, Some("ASH_VALIDATION_ERROR".to_string()));
495 assert_eq!(r.error_status, Some(485));
496 }
497
498 #[test]
499 fn test_adapter_result_skip() {
500 let r = AdapterResult::skip();
501 assert!(r.ok);
502 assert!(r.output.is_none());
503 }
504
505 #[test]
506 fn test_adapter_result_ok_bool() {
507 let r = AdapterResult::ok_bool(true);
508 assert_eq!(r.output, Some("true".to_string()));
509 }
510
511 #[test]
514 fn test_report_all_passed() {
515 let report = TestReport {
516 results: vec![],
517 total: 5,
518 passed: 5,
519 failed: 0,
520 skipped: 0,
521 };
522 assert!(report.all_passed());
523 assert_eq!(report.summary(), "5/5 passed, 0 failed, 0 skipped");
524 }
525
526 #[test]
527 fn test_report_with_failures() {
528 let report = TestReport {
529 results: vec![VectorResult {
530 id: "test_001".to_string(),
531 category: "json".to_string(),
532 passed: false,
533 skipped: false,
534 expected: "a".to_string(),
535 actual: "b".to_string(),
536 diff: Some("expected: a\n actual: b".to_string()),
537 }],
538 total: 1,
539 passed: 0,
540 failed: 1,
541 skipped: 0,
542 };
543 assert!(!report.all_passed());
544 assert_eq!(report.failures().len(), 1);
545 }
546
547 struct EmptyAdapter;
550 impl AshAdapter for EmptyAdapter {}
551
552 #[test]
553 fn test_empty_adapter_skips_all() {
554 let vec = Vector {
555 id: "test".to_string(),
556 category: "json_canonicalization".to_string(),
557 description: "test".to_string(),
558 input: serde_json::json!({"input_json_text": "{}"}),
559 expected: serde_json::json!({"canonical_json": "{}"}),
560 };
561 let report = run_vectors(&[vec], &EmptyAdapter);
562 assert_eq!(report.skipped, 1);
563 }
564
565 struct RustCoreAdapter;
568 impl AshAdapter for RustCoreAdapter {
569 fn canonicalize_json(&self, input: &str) -> AdapterResult {
570 match crate::ash_canonicalize_json(input) {
571 Ok(s) => AdapterResult::ok(s),
572 Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
573 }
574 }
575 fn canonicalize_query(&self, input: &str) -> AdapterResult {
576 match crate::ash_canonicalize_query(input) {
577 Ok(s) => AdapterResult::ok(s),
578 Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
579 }
580 }
581 fn hash_body(&self, body: &str) -> AdapterResult {
582 AdapterResult::ok(crate::ash_hash_body(body))
583 }
584 fn derive_client_secret(&self, nonce: &str, ctx: &str, binding: &str) -> AdapterResult {
585 match crate::ash_derive_client_secret(nonce, ctx, binding) {
586 Ok(s) => AdapterResult::ok(s),
587 Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
588 }
589 }
590 fn build_proof(&self, secret: &str, ts: &str, binding: &str, body_hash: &str) -> AdapterResult {
591 match crate::ash_build_proof(secret, ts, binding, body_hash) {
592 Ok(s) => AdapterResult::ok(s),
593 Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
594 }
595 }
596 fn timing_safe_equal(&self, a: &str, b: &str) -> AdapterResult {
597 AdapterResult::ok_bool(crate::ash_timing_safe_equal(a.as_bytes(), b.as_bytes()))
598 }
599 fn normalize_binding(&self, method: &str, path: &str, query: &str) -> AdapterResult {
600 match crate::ash_normalize_binding(method, path, query) {
601 Ok(s) => AdapterResult::ok(s),
602 Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
603 }
604 }
605 }
606
607 #[test]
608 fn test_rust_core_adapter_json() {
609 let vec = Vector {
610 id: "json_inline".to_string(),
611 category: "json_canonicalization".to_string(),
612 description: "sort keys".to_string(),
613 input: serde_json::json!({"input_json_text": r#"{"z":1,"a":2}"#}),
614 expected: serde_json::json!({"canonical_json": r#"{"a":2,"z":1}"#}),
615 };
616 let report = run_vectors(&[vec], &RustCoreAdapter);
617 assert!(report.all_passed(), "Failures: {:?}", report.failures());
618 }
619
620 #[test]
621 fn test_rust_core_adapter_body_hash() {
622 let hash = crate::ash_hash_body("test");
623 let vec = Vector {
624 id: "hash_inline".to_string(),
625 category: "body_hashing".to_string(),
626 description: "hash test".to_string(),
627 input: serde_json::json!({"body": "test"}),
628 expected: serde_json::json!({"hash": hash}),
629 };
630 let report = run_vectors(&[vec], &RustCoreAdapter);
631 assert!(report.all_passed());
632 }
633
634 #[test]
635 fn test_rust_core_adapter_timing_safe() {
636 let vectors = vec![
637 Vector {
638 id: "ts_eq".to_string(),
639 category: "timing_safe_comparison".to_string(),
640 description: "equal".to_string(),
641 input: serde_json::json!({"a": "hello", "b": "hello"}),
642 expected: serde_json::json!({"equal": true}),
643 },
644 Vector {
645 id: "ts_neq".to_string(),
646 category: "timing_safe_comparison".to_string(),
647 description: "not equal".to_string(),
648 input: serde_json::json!({"a": "hello", "b": "world"}),
649 expected: serde_json::json!({"equal": false}),
650 },
651 ];
652 let report = run_vectors(&vectors, &RustCoreAdapter);
653 assert!(report.all_passed());
654 assert_eq!(report.passed, 2);
655 }
656}