Skip to main content

ash_core/
testkit.rs

1//! Testkit — Conformance adapter runner.
2//!
3//! Provides a reusable framework for running conformance vectors against
4//! any SDK implementation. New SDKs implement the `AshAdapter` trait
5//! and get full conformance testing for free.
6//!
7//! ## Usage
8//!
9//! 1. Implement `AshAdapter` for your SDK
10//! 2. Call `load_vectors()` to parse vectors.json
11//! 3. Call `run_vectors()` to execute all vectors
12//! 4. Inspect the `TestReport` for pass/fail + diffs
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use ash_core::testkit::{load_vectors, run_vectors, AshAdapter, AdapterResult};
18//!
19//! struct MyAdapter;
20//! impl AshAdapter for MyAdapter {
21//!     fn canonicalize_json(&self, input: &str) -> AdapterResult {
22//!         match my_sdk::canonicalize(input) {
23//!             Ok(s) => AdapterResult::ok(s),
24//!             Err(e) => AdapterResult::error(e.code, e.status),
25//!         }
26//!     }
27//!     // ... implement other methods
28//! }
29//!
30//! let vectors = load_vectors(include_bytes!("../../tests/conformance/vectors.json")).unwrap();
31//! let report = run_vectors(&vectors, &MyAdapter);
32//! assert!(report.all_passed(), "Failures: {:?}", report.failures());
33//! ```
34
35use serde::Deserialize;
36use std::collections::BTreeMap;
37
38// ── Vector Types ─────────────────────────────────────────────────────
39
40/// Top-level vectors file.
41#[derive(Debug, Deserialize)]
42pub struct VectorFile {
43    /// Schema version
44    pub schema_version: u32,
45    /// ASH version these vectors are locked to
46    pub ash_version: String,
47    /// All vector categories
48    #[serde(default)]
49    pub categories: BTreeMap<String, Vec<Vector>>,
50    /// Flat list (alternative format)
51    #[serde(default)]
52    pub vectors: Vec<Vector>,
53}
54
55/// A single conformance vector.
56#[derive(Debug, Clone, Deserialize)]
57pub struct Vector {
58    /// Unique vector ID (e.g., "json_001")
59    pub id: String,
60    /// Category (e.g., "json_canonicalization")
61    #[serde(default)]
62    pub category: String,
63    /// Human-readable description
64    #[serde(default)]
65    pub description: String,
66    /// Input data (varies by category)
67    #[serde(default)]
68    pub input: serde_json::Value,
69    /// Expected output (varies by category)
70    #[serde(default)]
71    pub expected: serde_json::Value,
72}
73
74// ── Adapter Interface ────────────────────────────────────────────────
75
76/// Result from an adapter operation.
77#[derive(Debug, Clone)]
78pub struct AdapterResult {
79    /// Successful output (canonical string, hash, proof, etc.)
80    pub output: Option<String>,
81    /// Whether the operation succeeded
82    pub ok: bool,
83    /// Error code if failed (e.g., "ASH_VALIDATION_ERROR")
84    pub error_code: Option<String>,
85    /// HTTP status if failed
86    pub error_status: Option<u16>,
87}
88
89impl AdapterResult {
90    /// Successful result with output string.
91    pub fn ok(output: impl Into<String>) -> Self {
92        Self {
93            output: Some(output.into()),
94            ok: true,
95            error_code: None,
96            error_status: None,
97        }
98    }
99
100    /// Successful result with boolean (for timing-safe comparison).
101    pub fn ok_bool(val: bool) -> Self {
102        Self {
103            output: Some(val.to_string()),
104            ok: true,
105            error_code: None,
106            error_status: None,
107        }
108    }
109
110    /// Error result.
111    pub fn error(code: impl Into<String>, status: u16) -> Self {
112        Self {
113            output: None,
114            ok: false,
115            error_code: Some(code.into()),
116            error_status: Some(status),
117        }
118    }
119
120    /// Skipped (adapter doesn't support this operation).
121    pub fn skip() -> Self {
122        Self {
123            output: None,
124            ok: true,
125            error_code: None,
126            error_status: None,
127        }
128    }
129}
130
131/// Trait that SDK implementations must implement for conformance testing.
132///
133/// Each method corresponds to a vector category. Return `AdapterResult::skip()`
134/// for categories your SDK doesn't implement yet.
135pub trait AshAdapter {
136    /// JSON canonicalization: input JSON text → canonical JSON text
137    fn canonicalize_json(&self, input: &str) -> AdapterResult { let _ = input; AdapterResult::skip() }
138    /// Query canonicalization: raw query → canonical query
139    fn canonicalize_query(&self, input: &str) -> AdapterResult { let _ = input; AdapterResult::skip() }
140    /// URL-encoded canonicalization: raw → canonical
141    fn canonicalize_urlencoded(&self, input: &str) -> AdapterResult { let _ = input; AdapterResult::skip() }
142    /// Binding normalization: (method, path, query) → binding string
143    fn normalize_binding(&self, method: &str, path: &str, query: &str) -> AdapterResult { let _ = (method, path, query); AdapterResult::skip() }
144    /// Body hashing: canonical body → hex hash
145    fn hash_body(&self, body: &str) -> AdapterResult { let _ = body; AdapterResult::skip() }
146    /// Client secret derivation: (nonce, context_id, binding) → hex secret
147    fn derive_client_secret(&self, nonce: &str, context_id: &str, binding: &str) -> AdapterResult { let _ = (nonce, context_id, binding); AdapterResult::skip() }
148    /// Proof generation: full inputs → hex proof
149    fn build_proof(&self, secret: &str, ts: &str, binding: &str, body_hash: &str) -> AdapterResult { let _ = (secret, ts, binding, body_hash); AdapterResult::skip() }
150    /// Timing-safe comparison: (a, b) → bool
151    fn timing_safe_equal(&self, a: &str, b: &str) -> AdapterResult { let _ = (a, b); AdapterResult::skip() }
152    /// Timestamp validation: ts → ok or error
153    fn validate_timestamp(&self, ts: &str) -> AdapterResult { let _ = ts; AdapterResult::skip() }
154    /// Error behavior: trigger → error code + status
155    fn trigger_error(&self, input: &serde_json::Value) -> AdapterResult { let _ = input; AdapterResult::skip() }
156    /// Scoped field extraction: (payload, fields, mode) → extracted/hash
157    fn extract_scoped_fields(&self, payload: &str, fields: &[String], strict: bool) -> AdapterResult { let _ = (payload, fields, strict); AdapterResult::skip() }
158    /// Unified proof: full inputs → proof + scope_hash + chain_hash
159    fn build_unified_proof(&self, input: &serde_json::Value) -> AdapterResult { let _ = input; AdapterResult::skip() }
160}
161
162// ── Loading ──────────────────────────────────────────────────────────
163
164/// Load vectors from raw JSON bytes (e.g., `include_bytes!`).
165///
166/// Accepts the standard vectors.json format.
167pub fn load_vectors(data: &[u8]) -> Result<Vec<Vector>, String> {
168    let file: serde_json::Value =
169        serde_json::from_slice(data).map_err(|e| format!("Failed to parse vectors JSON: {}", e))?;
170
171    let mut all_vectors = Vec::new();
172
173    // Extract vectors from categorized format
174    if let Some(obj) = file.as_object() {
175        for (key, val) in obj {
176            // Skip metadata fields
177            if matches!(
178                key.as_str(),
179                "schema_version"
180                    | "ash_version"
181                    | "generated_from"
182                    | "generated_at"
183                    | "generator_version"
184                    | "platform"
185            ) {
186                continue;
187            }
188
189            // Each category is an array of vectors
190            if let Some(arr) = val.as_array() {
191                for item in arr {
192                    if let Ok(mut vec) = serde_json::from_value::<Vector>(item.clone()) {
193                        if vec.category.is_empty() {
194                            vec.category = key.clone();
195                        }
196                        all_vectors.push(vec);
197                    }
198                }
199            }
200        }
201    }
202
203    Ok(all_vectors)
204}
205
206/// Load vectors from a file path.
207pub fn load_vectors_from_file(path: &str) -> Result<Vec<Vector>, String> {
208    let data = std::fs::read(path).map_err(|e| format!("Failed to read {}: {}", path, e))?;
209    load_vectors(&data)
210}
211
212// ── Running ──────────────────────────────────────────────────────────
213
214/// Result of running a single vector.
215#[derive(Debug, Clone)]
216pub struct VectorResult {
217    /// Vector ID
218    pub id: String,
219    /// Category
220    pub category: String,
221    /// Whether the vector passed
222    pub passed: bool,
223    /// Whether the vector was skipped
224    pub skipped: bool,
225    /// Expected output
226    pub expected: String,
227    /// Actual output
228    pub actual: String,
229    /// Diff message if failed
230    pub diff: Option<String>,
231}
232
233/// Report from running all vectors.
234#[derive(Debug)]
235pub struct TestReport {
236    /// Results for each vector
237    pub results: Vec<VectorResult>,
238    /// Total vectors processed
239    pub total: usize,
240    /// Vectors that passed
241    pub passed: usize,
242    /// Vectors that failed
243    pub failed: usize,
244    /// Vectors that were skipped
245    pub skipped: usize,
246}
247
248impl TestReport {
249    /// Whether all non-skipped vectors passed.
250    pub fn all_passed(&self) -> bool {
251        self.failed == 0
252    }
253
254    /// Get only the failed vectors.
255    pub fn failures(&self) -> Vec<&VectorResult> {
256        self.results.iter().filter(|r| !r.passed && !r.skipped).collect()
257    }
258
259    /// Get a summary string.
260    pub fn summary(&self) -> String {
261        format!(
262            "{}/{} passed, {} failed, {} skipped",
263            self.passed, self.total, self.failed, self.skipped
264        )
265    }
266}
267
268/// Run all vectors against an adapter.
269///
270/// Dispatches each vector to the appropriate adapter method based on category,
271/// compares outputs, and collects results into a `TestReport`.
272pub fn run_vectors(vectors: &[Vector], adapter: &dyn AshAdapter) -> TestReport {
273    let mut results = Vec::with_capacity(vectors.len());
274    let mut passed = 0;
275    let mut failed = 0;
276    let mut skipped = 0;
277
278    for vec in vectors {
279        let result = run_single_vector(vec, adapter);
280        if result.skipped {
281            skipped += 1;
282        } else if result.passed {
283            passed += 1;
284        } else {
285            failed += 1;
286        }
287        results.push(result);
288    }
289
290    TestReport {
291        total: vectors.len(),
292        passed,
293        failed,
294        skipped,
295        results,
296    }
297}
298
299fn run_single_vector(vec: &Vector, adapter: &dyn AshAdapter) -> VectorResult {
300    let category = vec.category.as_str();
301
302    let (adapter_result, expected_str) = match category {
303        "json_canonicalization" => {
304            let input = vec.input.get("input_json_text")
305                .or_else(|| vec.input.get("input"))
306                .and_then(|v| v.as_str())
307                .unwrap_or("");
308            let expected = vec.expected.get("canonical_json")
309                .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
310                .and_then(|v| v.as_str())
311                .unwrap_or("");
312            (adapter.canonicalize_json(input), expected.to_string())
313        }
314        "query_canonicalization" => {
315            let input = vec.input.get("raw_query")
316                .or_else(|| vec.input.get("input"))
317                .and_then(|v| v.as_str())
318                .unwrap_or("");
319            let expected = vec.expected.get("canonical_query")
320                .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
321                .and_then(|v| v.as_str())
322                .unwrap_or("");
323            (adapter.canonicalize_query(input), expected.to_string())
324        }
325        "urlencoded_canonicalization" => {
326            let input = vec.input.get("input")
327                .and_then(|v| v.as_str())
328                .unwrap_or("");
329            let expected = vec.expected.get("canonical")
330                .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
331                .and_then(|v| v.as_str())
332                .unwrap_or("");
333            (adapter.canonicalize_urlencoded(input), expected.to_string())
334        }
335        "binding_normalization" => {
336            let method = vec.input.get("method").and_then(|v| v.as_str()).unwrap_or("");
337            let path = vec.input.get("path").and_then(|v| v.as_str()).unwrap_or("");
338            let query = vec.input.get("query").and_then(|v| v.as_str()).unwrap_or("");
339            let expected = vec.expected.get("binding")
340                .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
341                .and_then(|v| v.as_str())
342                .unwrap_or("");
343            (adapter.normalize_binding(method, path, query), expected.to_string())
344        }
345        "body_hashing" => {
346            let input = vec.input.get("body")
347                .or_else(|| vec.input.get("input"))
348                .and_then(|v| v.as_str())
349                .unwrap_or("");
350            let expected = vec.expected.get("hash")
351                .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
352                .and_then(|v| v.as_str())
353                .unwrap_or("");
354            (adapter.hash_body(input), expected.to_string())
355        }
356        "client_secret_derivation" => {
357            let nonce = vec.input.get("nonce").and_then(|v| v.as_str()).unwrap_or("");
358            let ctx = vec.input.get("context_id").and_then(|v| v.as_str()).unwrap_or("");
359            let binding = vec.input.get("binding").and_then(|v| v.as_str()).unwrap_or("");
360            let expected = vec.expected.get("client_secret")
361                .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
362                .and_then(|v| v.as_str())
363                .unwrap_or("");
364            (adapter.derive_client_secret(nonce, ctx, binding), expected.to_string())
365        }
366        "proof_generation" => {
367            let secret = vec.input.get("client_secret").and_then(|v| v.as_str()).unwrap_or("");
368            let ts = vec.input.get("timestamp").and_then(|v| v.as_str()).unwrap_or("");
369            let binding = vec.input.get("binding").and_then(|v| v.as_str()).unwrap_or("");
370            let body_hash = vec.input.get("body_hash").and_then(|v| v.as_str()).unwrap_or("");
371            let expected = vec.expected.get("proof")
372                .or_else(|| vec.expected.as_str().map(|_| &vec.expected))
373                .and_then(|v| v.as_str())
374                .unwrap_or("");
375            (adapter.build_proof(secret, ts, binding, body_hash), expected.to_string())
376        }
377        "timing_safe_comparison" => {
378            let a = vec.input.get("a").and_then(|v| v.as_str()).unwrap_or("");
379            let b = vec.input.get("b").and_then(|v| v.as_str()).unwrap_or("");
380            let expected = vec.expected.get("equal")
381                .and_then(|v| v.as_bool())
382                .map(|b| b.to_string())
383                .unwrap_or_default();
384            (adapter.timing_safe_equal(a, b), expected)
385        }
386        "error_behavior" => {
387            let expected_code = vec.expected.get("error_code")
388                .and_then(|v| v.as_str())
389                .unwrap_or("");
390            let expected_status = vec.expected.get("http_status")
391                .and_then(|v| v.as_u64())
392                .unwrap_or(0) as u16;
393            let result = adapter.trigger_error(&vec.input);
394            let expected_str = format!("{}:{}", expected_code, expected_status);
395            let actual_str = if result.ok {
396                "ok".to_string()
397            } else {
398                format!("{}:{}", result.error_code.as_deref().unwrap_or(""), result.error_status.unwrap_or(0))
399            };
400            return VectorResult {
401                id: vec.id.clone(),
402                category: vec.category.clone(),
403                passed: !result.ok
404                    && result.error_code.as_deref() == Some(expected_code)
405                    && result.error_status == Some(expected_status),
406                skipped: result.output.is_none() && result.ok && result.error_code.is_none(),
407                expected: expected_str,
408                actual: actual_str,
409                diff: None,
410            };
411        }
412        "timestamp_validation" => {
413            let ts = vec.input.get("timestamp").and_then(|v| v.as_str()).unwrap_or("");
414            let should_pass = vec.expected.get("valid").and_then(|v| v.as_bool()).unwrap_or(false);
415            let result = adapter.validate_timestamp(ts);
416            let actual_ok = result.ok;
417            return VectorResult {
418                id: vec.id.clone(),
419                category: vec.category.clone(),
420                passed: actual_ok == should_pass,
421                skipped: result.output.is_none() && result.ok && result.error_code.is_none(),
422                expected: format!("valid={}", should_pass),
423                actual: format!("valid={}", actual_ok),
424                diff: if actual_ok != should_pass {
425                    Some(format!("Expected valid={}, got valid={}", should_pass, actual_ok))
426                } else {
427                    None
428                },
429            };
430        }
431        _ => {
432            return VectorResult {
433                id: vec.id.clone(),
434                category: vec.category.clone(),
435                passed: false,
436                skipped: true,
437                expected: String::new(),
438                actual: String::new(),
439                diff: Some(format!("Unknown category: {}", category)),
440            };
441        }
442    };
443
444    // Standard comparison for most categories
445    if adapter_result.output.is_none() && adapter_result.ok && adapter_result.error_code.is_none() {
446        return VectorResult {
447            id: vec.id.clone(),
448            category: vec.category.clone(),
449            passed: false,
450            skipped: true,
451            expected: expected_str,
452            actual: String::new(),
453            diff: None,
454        };
455    }
456
457    let actual = adapter_result.output.unwrap_or_default();
458    let pass = actual == expected_str;
459
460    VectorResult {
461        id: vec.id.clone(),
462        category: vec.category.clone(),
463        passed: pass,
464        skipped: false,
465        expected: expected_str.clone(),
466        actual: actual.clone(),
467        diff: if pass {
468            None
469        } else {
470            Some(format!("expected: {}\n  actual: {}", expected_str, actual))
471        },
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    // ── AdapterResult construction ────────────────────────────────────
480
481    #[test]
482    fn test_adapter_result_ok() {
483        let r = AdapterResult::ok("hello");
484        assert!(r.ok);
485        assert_eq!(r.output, Some("hello".to_string()));
486        assert!(r.error_code.is_none());
487    }
488
489    #[test]
490    fn test_adapter_result_error() {
491        let r = AdapterResult::error("ASH_VALIDATION_ERROR", 485);
492        assert!(!r.ok);
493        assert!(r.output.is_none());
494        assert_eq!(r.error_code, Some("ASH_VALIDATION_ERROR".to_string()));
495        assert_eq!(r.error_status, Some(485));
496    }
497
498    #[test]
499    fn test_adapter_result_skip() {
500        let r = AdapterResult::skip();
501        assert!(r.ok);
502        assert!(r.output.is_none());
503    }
504
505    #[test]
506    fn test_adapter_result_ok_bool() {
507        let r = AdapterResult::ok_bool(true);
508        assert_eq!(r.output, Some("true".to_string()));
509    }
510
511    // ── TestReport ────────────────────────────────────────────────────
512
513    #[test]
514    fn test_report_all_passed() {
515        let report = TestReport {
516            results: vec![],
517            total: 5,
518            passed: 5,
519            failed: 0,
520            skipped: 0,
521        };
522        assert!(report.all_passed());
523        assert_eq!(report.summary(), "5/5 passed, 0 failed, 0 skipped");
524    }
525
526    #[test]
527    fn test_report_with_failures() {
528        let report = TestReport {
529            results: vec![VectorResult {
530                id: "test_001".to_string(),
531                category: "json".to_string(),
532                passed: false,
533                skipped: false,
534                expected: "a".to_string(),
535                actual: "b".to_string(),
536                diff: Some("expected: a\n  actual: b".to_string()),
537            }],
538            total: 1,
539            passed: 0,
540            failed: 1,
541            skipped: 0,
542        };
543        assert!(!report.all_passed());
544        assert_eq!(report.failures().len(), 1);
545    }
546
547    // ── Default adapter returns skips ─────────────────────────────────
548
549    struct EmptyAdapter;
550    impl AshAdapter for EmptyAdapter {}
551
552    #[test]
553    fn test_empty_adapter_skips_all() {
554        let vec = Vector {
555            id: "test".to_string(),
556            category: "json_canonicalization".to_string(),
557            description: "test".to_string(),
558            input: serde_json::json!({"input_json_text": "{}"}),
559            expected: serde_json::json!({"canonical_json": "{}"}),
560        };
561        let report = run_vectors(&[vec], &EmptyAdapter);
562        assert_eq!(report.skipped, 1);
563    }
564
565    // ── Rust core adapter (proves testkit works) ──────────────────────
566
567    struct RustCoreAdapter;
568    impl AshAdapter for RustCoreAdapter {
569        fn canonicalize_json(&self, input: &str) -> AdapterResult {
570            match crate::ash_canonicalize_json(input) {
571                Ok(s) => AdapterResult::ok(s),
572                Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
573            }
574        }
575        fn canonicalize_query(&self, input: &str) -> AdapterResult {
576            match crate::ash_canonicalize_query(input) {
577                Ok(s) => AdapterResult::ok(s),
578                Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
579            }
580        }
581        fn hash_body(&self, body: &str) -> AdapterResult {
582            AdapterResult::ok(crate::ash_hash_body(body))
583        }
584        fn derive_client_secret(&self, nonce: &str, ctx: &str, binding: &str) -> AdapterResult {
585            match crate::ash_derive_client_secret(nonce, ctx, binding) {
586                Ok(s) => AdapterResult::ok(s),
587                Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
588            }
589        }
590        fn build_proof(&self, secret: &str, ts: &str, binding: &str, body_hash: &str) -> AdapterResult {
591            match crate::ash_build_proof(secret, ts, binding, body_hash) {
592                Ok(s) => AdapterResult::ok(s),
593                Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
594            }
595        }
596        fn timing_safe_equal(&self, a: &str, b: &str) -> AdapterResult {
597            AdapterResult::ok_bool(crate::ash_timing_safe_equal(a.as_bytes(), b.as_bytes()))
598        }
599        fn normalize_binding(&self, method: &str, path: &str, query: &str) -> AdapterResult {
600            match crate::ash_normalize_binding(method, path, query) {
601                Ok(s) => AdapterResult::ok(s),
602                Err(e) => AdapterResult::error(e.code().as_str(), e.http_status()),
603            }
604        }
605    }
606
607    #[test]
608    fn test_rust_core_adapter_json() {
609        let vec = Vector {
610            id: "json_inline".to_string(),
611            category: "json_canonicalization".to_string(),
612            description: "sort keys".to_string(),
613            input: serde_json::json!({"input_json_text": r#"{"z":1,"a":2}"#}),
614            expected: serde_json::json!({"canonical_json": r#"{"a":2,"z":1}"#}),
615        };
616        let report = run_vectors(&[vec], &RustCoreAdapter);
617        assert!(report.all_passed(), "Failures: {:?}", report.failures());
618    }
619
620    #[test]
621    fn test_rust_core_adapter_body_hash() {
622        let hash = crate::ash_hash_body("test");
623        let vec = Vector {
624            id: "hash_inline".to_string(),
625            category: "body_hashing".to_string(),
626            description: "hash test".to_string(),
627            input: serde_json::json!({"body": "test"}),
628            expected: serde_json::json!({"hash": hash}),
629        };
630        let report = run_vectors(&[vec], &RustCoreAdapter);
631        assert!(report.all_passed());
632    }
633
634    #[test]
635    fn test_rust_core_adapter_timing_safe() {
636        let vectors = vec![
637            Vector {
638                id: "ts_eq".to_string(),
639                category: "timing_safe_comparison".to_string(),
640                description: "equal".to_string(),
641                input: serde_json::json!({"a": "hello", "b": "hello"}),
642                expected: serde_json::json!({"equal": true}),
643            },
644            Vector {
645                id: "ts_neq".to_string(),
646                category: "timing_safe_comparison".to_string(),
647                description: "not equal".to_string(),
648                input: serde_json::json!({"a": "hello", "b": "world"}),
649                expected: serde_json::json!({"equal": false}),
650            },
651        ];
652        let report = run_vectors(&vectors, &RustCoreAdapter);
653        assert!(report.all_passed());
654        assert_eq!(report.passed, 2);
655    }
656}