Skip to main content

kya_validator/
validator.rs

1use crate::inspector::inspect_manifest;
2use crate::policy::{apply_policy, AllowedRegionRule, MaxTransactionValueRule, PolicyRule};
3use crate::types::{
4    AttestationCheckConfig, ContentCheckType, DigestConfig, HashAlgorithm, LinkCheckConfig,
5    Manifest, PolicyContext, TEEType, TlsPinConfig, ValidationConfig, ValidationOptions,
6    ValidationReport,
7};
8use crate::verifier::verify_manifest_proofs;
9use chrono::{DateTime, Utc};
10use serde_json::Value;
11use sha2::{Digest, Sha256, Sha384, Sha512};
12use std::time::Duration;
13
14# [cfg(not(target_arch = "wasm32"))]
15use jsonschema::{Draft, JSONSchema};
16
17#[cfg(not(target_arch = "wasm32"))]
18use once_cell::sync::Lazy;
19
20#[cfg(not(target_arch = "wasm32"))]
21use regex::Regex;
22
23#[cfg(not(target_arch = "wasm32"))]
24use reqwest::blocking::Client;
25
26#[cfg(not(target_arch = "wasm32"))]
27use std::collections::HashMap;
28
29#[cfg(not(target_arch = "wasm32"))]
30use std::sync::Mutex;
31
32#[cfg(not(target_arch = "wasm32"))]
33use std::time::Instant;
34
35static KYA_SCHEMA: &str = include_str!("../schema/kya-manifest.schema.json");
36
37#[cfg(not(target_arch = "wasm32"))]
38static COMPILED_SCHEMA: Lazy<JSONSchema> = Lazy::new(|| {
39    let schema_json: Value =
40        serde_json::from_str(KYA_SCHEMA).expect("KYA schema JSON should parse at compile time");
41    JSONSchema::options()
42        .with_draft(Draft::Draft7)
43        .compile(&schema_json)
44        .expect("KYA schema should compile")
45});
46
47#[cfg(not(target_arch = "wasm32"))]
48static CACHE: Lazy<Mutex<HashMap<String, (Value, Instant)>>> =
49    Lazy::new(|| Mutex::new(HashMap::new()));
50
51fn parse_datetime(value: &Value, field: &str) -> Result<Option<DateTime<Utc>>, String> {
52    match value.get(field) {
53        Some(Value::String(raw)) => DateTime::parse_from_rfc3339(raw)
54            .map(|dt| Some(dt.with_timezone(&Utc)))
55            .map_err(|err| format!("Invalid {}: {}", field, err)),
56        Some(_) => Err(format!("{} must be an RFC3339 string", field)),
57        None => Ok(None),
58    }
59}
60
61fn validate_ttl(manifest: &Value, now: DateTime<Utc>) -> (bool, Vec<String>) {
62    let mut errors = Vec::new();
63
64    let issuance = parse_datetime(manifest, "issuanceDate");
65    let expiration = parse_datetime(manifest, "expirationDate");
66
67    if let Err(err) = issuance.as_ref() {
68        errors.push(err.to_string());
69    }
70    if let Err(err) = expiration.as_ref() {
71        errors.push(err.to_string());
72    }
73
74    let issuance = issuance.ok().flatten();
75    let expiration = expiration.ok().flatten();
76
77    if let Some(issuance) = issuance {
78        if issuance > now {
79            errors.push("issuanceDate is in the future".to_string());
80        }
81    }
82
83    if let Some(expiration) = expiration {
84        if expiration < now {
85            errors.push("expirationDate is in the past".to_string());
86        }
87    }
88
89    (errors.is_empty(), errors)
90}
91
92pub fn validate_manifest_value(manifest: &Value) -> ValidationReport {
93    validate_manifest_with_config(manifest, &ValidationConfig::default())
94}
95
96fn check_required_fields(manifest: &Value, required_fields: &[String]) -> Vec<String> {
97    let mut errors = Vec::new();
98    for pointer in required_fields {
99        if manifest.pointer(pointer).is_none() {
100            errors.push(format!("Missing required field {}", pointer));
101        }
102    }
103    errors
104}
105
106fn check_required_field_pairs(manifest: &Value, pairs: &[(String, String)]) -> Vec<String> {
107    let mut errors = Vec::new();
108    for (left, right) in pairs {
109        let left_value = manifest.pointer(left);
110        let right_value = manifest.pointer(right);
111        if left_value.is_some() && right_value.is_none() {
112            errors.push(format!("Field {} requires {}", left, right));
113        }
114        if right_value.is_some() && left_value.is_none() {
115            errors.push(format!("Field {} requires {}", right, left));
116        }
117    }
118    errors
119}
120
121#[cfg(not(target_arch = "wasm32"))]
122#[allow(dead_code)]
123fn verify_tls_pin(_cert_der: &[u8], _pin_config: &TlsPinConfig) -> Result<(), String> {
124    // TODO: Implement TLS pinning
125    // Note: reqwest doesn't provide direct access to peer certificates
126    // Would need to use native-tls and custom connector for full TLS pinning support
127    // For now, this is a placeholder
128    Ok(())
129}
130
131fn validate_domain_allowlist(url: &str, allowed_domains: &[String]) -> Result<(), String> {
132    let parsed = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
133    let host = parsed.host_str().ok_or("URL has no host")?;
134
135    if !allowed_domains.is_empty() && !allowed_domains.iter().any(|domain| host.ends_with(domain)) {
136        return Err(format!("Domain {} is not in allowlist", host));
137    }
138
139    Ok(())
140}
141
142#[cfg(not(target_arch = "wasm32"))]
143fn perform_content_check(
144    text: &str,
145    content_value: &Value,
146    check_config: &crate::types::ContentCheck,
147) -> Result<(), String> {
148    match check_config.check_type {
149        ContentCheckType::StringContains => {
150            if !text.contains(&check_config.expected_value) {
151                return Err(format!(
152                    "Content does not contain expected string: {}",
153                    check_config.expected_value
154                ));
155            }
156        }
157        ContentCheckType::StringEquals => {
158            if text != check_config.expected_value {
159                return Err(format!(
160                    "Content does not equal expected value: {}",
161                    check_config.expected_value
162                ));
163            }
164        }
165        ContentCheckType::StringMatchesRegex => {
166            let regex = Regex::new(&check_config.expected_value)
167                .map_err(|e| format!("Invalid regex: {}", e))?;
168            if !regex.is_match(text) {
169                return Err(format!(
170                    "Content does not match regex: {}",
171                    check_config.expected_value
172                ));
173            }
174        }
175        ContentCheckType::JsonPointerEquals | ContentCheckType::JsonPointerMatchesRegex => {
176            let pointer = check_config
177                .json_pointer
178                .as_ref()
179                .ok_or("json_pointer required for JSON pointer checks")?;
180            let target_value = content_value
181                .pointer(pointer)
182                .ok_or(format!("JSON pointer {} not found in content", pointer))?;
183
184            let target_str = target_value
185                .as_str()
186                .ok_or("Target value is not a string")?;
187
188            if check_config.check_type == ContentCheckType::JsonPointerEquals {
189                if target_str != check_config.expected_value {
190                    return Err(format!(
191                        "JSON pointer value does not match: {} != {}",
192                        target_str, check_config.expected_value
193                    ));
194                }
195            } else {
196                let regex = Regex::new(&check_config.expected_value)
197                    .map_err(|e| format!("Invalid regex: {}", e))?;
198                if !regex.is_match(target_str) {
199                    return Err(format!(
200                        "JSON pointer value does not match regex: {}",
201                        check_config.expected_value
202                    ));
203                }
204            }
205        }
206    }
207    Ok(())
208}
209
210#[cfg(not(target_arch = "wasm32"))]
211fn fetch_with_retry(
212    url: &str,
213    timeout_secs: Option<u64>,
214    max_retries: Option<u32>,
215) -> Result<reqwest::blocking::Response, String> {
216    let timeout = Duration::from_secs(timeout_secs.unwrap_or(30));
217    let retries = max_retries.unwrap_or(3);
218
219    let mut last_error: String = String::new();
220
221    for attempt in 0..retries {
222        let client = Client::builder()
223            .timeout(timeout)
224            .build()
225            .map_err(|e| format!("Failed to build HTTP client: {}", e))?;
226
227        match client.get(url).send() {
228            Ok(response) => return Ok(response),
229            Err(err) => {
230                last_error = err.to_string();
231                if attempt < retries - 1 {
232                    std::thread::sleep(Duration::from_millis(1000 * (attempt + 1) as u64));
233                }
234            }
235        }
236    }
237
238    Err(format!("Failed after {} attempts: {}", retries, last_error))
239}
240
241fn compute_hash(data: &[u8], algorithm: HashAlgorithm) -> String {
242    match algorithm {
243        HashAlgorithm::Sha256 => {
244            let mut hasher = Sha256::new();
245            hasher.update(data);
246            let result = hasher.finalize();
247            hex::encode(result)
248        }
249        HashAlgorithm::Sha384 => {
250            let mut hasher = Sha384::new();
251            hasher.update(data);
252            let result = hasher.finalize();
253            hex::encode(result)
254        }
255        HashAlgorithm::Sha512 => {
256            let mut hasher = Sha512::new();
257            hasher.update(data);
258            let result = hasher.finalize();
259            hex::encode(result)
260        }
261    }
262}
263
264fn verify_content_hash(data: &[u8], config: &DigestConfig) -> Result<(), String> {
265    let computed = compute_hash(data, config.algorithm);
266    let expected = config.expected_hash.to_lowercase();
267    let computed = computed.to_lowercase();
268
269    if computed != expected {
270        return Err(format!(
271            "Hash mismatch. Expected: {}, Got: {}",
272            expected, computed
273        ));
274    }
275
276    Ok(())
277}
278
279#[cfg(not(target_arch = "wasm32"))]
280fn check_external_links(manifest: &Value, link_checks: &[LinkCheckConfig]) -> Vec<String> {
281    let mut errors = Vec::new();
282
283    for check in link_checks {
284        let url_value = manifest.pointer(&check.json_pointer);
285        let url = match url_value.and_then(|value| value.as_str()) {
286            Some(url) => url,
287            None => {
288                errors.push(format!("Missing URL for {}", check.json_pointer));
289                continue;
290            }
291        };
292
293        if let Some(ref allowed_domains) = check.allowed_domains {
294            if let Err(e) = validate_domain_allowlist(url, allowed_domains) {
295                errors.push(e);
296                continue;
297            }
298        }
299
300        let cache_key = url.to_string();
301        let cached_result: Option<Value> = check.cache_ttl_secs.and_then(|ttl| {
302            let cache = CACHE.lock().ok()?;
303            if let Some((value, timestamp)) = cache.get(&cache_key) {
304                if timestamp.elapsed() < Duration::from_secs(ttl) {
305                    return Some(value.clone());
306                }
307            }
308            None
309        });
310
311        let response_bytes = if let Some(cached) = cached_result {
312            cached
313        } else {
314            let response = match fetch_with_retry(url, check.timeout_secs, check.max_retries) {
315                Ok(resp) => resp,
316                Err(err) => {
317                    errors.push(format!("Failed to fetch {}: {}", url, err));
318                    continue;
319                }
320            };
321
322            let bytes: Vec<u8> = match response.bytes() {
323                Ok(bytes) => bytes.to_vec(),
324                Err(err) => {
325                    errors.push(format!("Failed to read {}: {}", url, err));
326                    continue;
327                }
328            };
329
330            if let Some(ref digest_config) = check.verify_digest {
331                if let Err(e) = verify_content_hash(&bytes, digest_config) {
332                    errors.push(format!("Digest verification failed for {}: {}", url, e));
333                    continue;
334                }
335            }
336
337            if let Some(_ttl) = check.cache_ttl_secs {
338                if let Ok(mut cache) = CACHE.lock() {
339                    // Cache the response text for content checks
340                    let text = String::from_utf8_lossy(&bytes).to_string();
341                    let cache_entry: (Value, Instant) = (Value::String(text), Instant::now());
342                    cache.insert(cache_key, cache_entry);
343                }
344            }
345
346            Value::String(String::from_utf8_lossy(&bytes).to_string())
347        };
348
349        if let Some(expected) = check.required_contains.as_ref() {
350            let text = response_bytes.as_str().unwrap_or("");
351            if !text.contains(expected) {
352                errors.push(format!(
353                    "{} did not contain expected string: {}",
354                    url, expected
355                ));
356            }
357        }
358
359        if let Some(ref content_check) = check.content_check {
360            if let Err(e) = perform_content_check(
361                response_bytes.as_str().unwrap_or(""),
362                &response_bytes,
363                content_check,
364            ) {
365                errors.push(format!("Content check failed for {}: {}", url, e));
366            }
367        }
368    }
369
370    errors
371}
372
373fn verify_sgx_attestation(
374    _attestation_data: &[u8],
375    config: &AttestationCheckConfig,
376) -> Result<(), String> {
377    // Placeholder for SGX attestation verification
378    // In production, this would use mc-sgx-dcap-quoteverify or similar
379    if config.require_root_certificate {
380        // Verify root certificate chain
381        // This is a stub - actual implementation would verify against Intel root CAs
382    }
383
384    if let Some(ref _tcb_info) = config.expected_tcb_info {
385        // Verify TCB version and SVN
386        // This is a stub - actual implementation would parse quote and verify TCB
387    }
388
389    Ok(())
390}
391
392fn verify_nitro_attestation(
393    _attestation_data: &[u8],
394    _config: &AttestationCheckConfig,
395) -> Result<(), String> {
396    // Placeholder for AWS Nitro attestation verification
397    // Parse Nitro attestation document and verify against AWS root certificate
398    Ok(())
399}
400
401fn verify_sev_snp_attestation(
402    _attestation_data: &[u8],
403    _config: &AttestationCheckConfig,
404) -> Result<(), String> {
405    // Placeholder for AMD SEV-SNP attestation verification
406    // Verify VCEK report and certificate chain
407    Ok(())
408}
409
410fn check_attestations(
411    manifest: &Value,
412    attestation_checks: &[AttestationCheckConfig],
413) -> Vec<String> {
414    let mut errors = Vec::new();
415
416    for check in attestation_checks {
417        let attestation_value = manifest.pointer(&check.json_pointer);
418        let attestation_data: Result<Vec<u8>, String> =
419            match attestation_value.and_then(|value| value.as_str()) {
420                Some(data) => hex::decode(data)
421                    .map_err(|e| format!("Failed to decode attestation hex: {}", e)),
422                None => Err("Missing attestation data".to_string()),
423            };
424
425        let attestation_data: Vec<u8> = match attestation_data {
426            Ok(data) => data,
427            Err(e) => {
428                errors.push(format!("Attestation check {}: {}", check.json_pointer, e));
429                continue;
430            }
431        };
432
433        let result = match check.tee_type {
434            TEEType::SGX => verify_sgx_attestation(&attestation_data, check),
435            TEEType::Nitro => verify_nitro_attestation(&attestation_data, check),
436            TEEType::SevSnp => verify_sev_snp_attestation(&attestation_data, check),
437        };
438
439        if let Err(e) = result {
440            errors.push(format!(
441                "Attestation verification failed for {}: {:?}",
442                check.json_pointer, check.tee_type
443            ));
444            errors.push(e);
445        }
446    }
447
448    errors
449}
450
451fn check_allowed_controllers(manifest: &Value, allowed: &[String]) -> Vec<String> {
452    if allowed.is_empty() {
453        return Vec::new();
454    }
455    let controller = match manifest
456        .pointer("/agentId")
457        .and_then(|value| value.as_str())
458    {
459        Some(controller) => controller,
460        None => return vec!["Missing agentId for controller allowlist".to_string()],
461    };
462    if !allowed.iter().any(|item| item == controller) {
463        return vec![format!("Controller {} is not in allowlist", controller)];
464    }
465    Vec::new()
466}
467
468fn check_required_vc_types(manifest: &Value, required: &[String]) -> Vec<String> {
469    if required.is_empty() {
470        return Vec::new();
471    }
472    let vcs = match manifest.pointer("/verifiableCredential") {
473        Some(Value::Array(entries)) => entries,
474        _ => {
475            return vec![format!(
476                "Missing required VC types: {}",
477                required.join(", ")
478            )];
479        }
480    };
481
482    let mut missing = Vec::new();
483    for required_type in required {
484        let mut found = false;
485        for entry in vcs {
486            if let Some(Value::Array(types)) = entry.get("type") {
487                if types
488                    .iter()
489                    .any(|value| value.as_str() == Some(required_type))
490                {
491                    found = true;
492                    break;
493                }
494            }
495        }
496        if !found {
497            missing.push(required_type.clone());
498        }
499    }
500
501    if missing.is_empty() {
502        Vec::new()
503    } else {
504        vec![format!("Missing required VC types: {}", missing.join(", "))]
505    }
506}
507
508pub fn validate_manifest_with_config(
509    manifest: &Value,
510    config: &ValidationConfig,
511) -> ValidationReport {
512    let mut report = ValidationReport::ok();
513
514    // Schema validation only available on native targets (requires jsonschema with reqwest blocking)
515    #[cfg(not(target_arch = "wasm32"))]
516    {
517        let schema_result = COMPILED_SCHEMA.validate(manifest);
518        if let Err(errors) = schema_result {
519            report.schema_valid = false;
520            report.schema_errors = errors
521                .map(|err: jsonschema::ValidationError| err.to_string())
522                .collect();
523        }
524    }
525
526    // On WASM, skip schema validation (would require async fetch for external refs)
527    #[cfg(target_arch = "wasm32")]
528    {
529        report.schema_valid = true;
530        report.schema_errors = vec!["Schema validation skipped on WASM (use browser fetch for remote schemas)".to_string()];
531    }
532
533    let (ttl_valid, ttl_errors) = validate_ttl(manifest, Utc::now());
534    report.ttl_valid = ttl_valid;
535    report.ttl_errors = ttl_errors;
536
537    let required_field_errors = check_required_fields(manifest, &config.required_fields);
538    if !required_field_errors.is_empty() {
539        report.inspector_valid = false;
540        report.inspector_errors.extend(required_field_errors);
541    }
542
543    let required_pair_errors = check_required_field_pairs(manifest, &config.required_field_pairs);
544    if !required_pair_errors.is_empty() {
545        report.inspector_valid = false;
546        report.inspector_errors.extend(required_pair_errors);
547    }
548
549    let controller_errors = check_allowed_controllers(manifest, &config.allowed_controllers);
550    if !controller_errors.is_empty() {
551        report.inspector_valid = false;
552        report.inspector_errors.extend(controller_errors);
553    }
554
555    let vc_errors = check_required_vc_types(manifest, &config.required_vc_types);
556    if !vc_errors.is_empty() {
557        report.inspector_valid = false;
558        report.inspector_errors.extend(vc_errors);
559    }
560
561    // External link checking only available on native (requires blocking HTTP client)
562    #[cfg(not(target_arch = "wasm32"))]
563    if config.check_external_links {
564        let link_errors = check_external_links(manifest, &config.link_checks);
565        if !link_errors.is_empty() {
566            report.inspector_valid = false;
567            report.inspector_errors.extend(link_errors);
568        }
569    }
570
571    if !config.attestation_checks.is_empty() {
572        let attestation_errors = check_attestations(manifest, &config.attestation_checks);
573        if !attestation_errors.is_empty() {
574            report.inspector_valid = false;
575            report.inspector_errors.extend(attestation_errors);
576        }
577    }
578
579    if let Ok(parsed_manifest) = Manifest::from_value(manifest) {
580        if config.require_all_proofs && parsed_manifest.proof.is_empty() {
581            report.crypto_valid = false;
582            report.crypto_errors.push("No proofs provided".to_string());
583        }
584        let options = ValidationOptions {
585            allowed_kya_versions: config.allowed_kya_versions.clone(),
586            enforce_schema_url: false,
587        };
588        let (inspector_valid, inspector_errors) = inspect_manifest(&parsed_manifest, &options);
589        if !inspector_valid {
590            report.inspector_valid = false;
591            report.inspector_errors.extend(inspector_errors);
592        }
593
594        let (crypto_valid, crypto_errors, crypto_report) =
595            verify_manifest_proofs(&parsed_manifest, manifest, config);
596        report.crypto_valid = crypto_valid;
597        report.crypto_errors = crypto_errors;
598        report.crypto_report = Some(crypto_report);
599
600        let rules: Vec<Box<dyn PolicyRule>> = vec![
601            Box::new(AllowedRegionRule),
602            Box::new(MaxTransactionValueRule),
603        ];
604        let context = PolicyContext::default();
605        let (policy_valid, policy_errors) = apply_policy(&parsed_manifest, &context, &rules);
606        report.policy_valid = policy_valid;
607        report.policy_errors = policy_errors;
608    } else {
609        report.inspector_valid = false;
610        report
611            .inspector_errors
612            .push("Failed to parse manifest".to_string());
613        report.crypto_valid = false;
614        report
615            .crypto_errors
616            .push("Failed to parse manifest".to_string());
617        report.policy_valid = false;
618        report
619            .policy_errors
620            .push("Failed to parse manifest".to_string());
621    }
622    report
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use serde_json::json;
629
630    #[test]
631    fn schema_validation_fails_on_empty_object() {
632        let value = json!({});
633        let report = validate_manifest_value(&value);
634        assert!(!report.schema_valid);
635        assert!(!report.schema_errors.is_empty());
636    }
637
638    #[test]
639    fn ttl_validation_detects_future_and_expired() {
640        let value = json!({
641            "issuanceDate": "2999-01-01T00:00:00Z",
642            "expirationDate": "2000-01-01T00:00:00Z"
643        });
644        let report = validate_manifest_value(&value);
645        assert!(!report.ttl_valid);
646        assert_eq!(report.ttl_errors.len(), 2);
647    }
648}