Skip to main content

kya_validator/
validator.rs

1use crate::inspector::inspect_manifest;
2use crate::policy::{apply_policy, AllowedRegionRule, MaxTransactionValueRule, PolicyRule};
3use crate::types::{
4    AttestationCheckConfig, Manifest, PolicyContext, TEEType, ValidationConfig, ValidationOptions,
5    ValidationReport,
6};
7use crate::verifier::verify_manifest_proofs;
8use chrono::{DateTime, Utc};
9use serde_json::Value;
10
11#[cfg(not(target_arch = "wasm32"))]
12use crate::types::{ContentCheckType, DigestConfig, HashAlgorithm, LinkCheckConfig, TlsPinConfig};
13
14#[cfg(not(target_arch = "wasm32"))]
15use sha2::{Digest, Sha256, Sha384, Sha512};
16
17#[cfg(not(target_arch = "wasm32"))]
18use jsonschema::{Draft, JSONSchema};
19
20#[cfg(not(target_arch = "wasm32"))]
21use once_cell::sync::Lazy;
22
23#[cfg(not(target_arch = "wasm32"))]
24use regex::Regex;
25
26#[cfg(not(target_arch = "wasm32"))]
27use reqwest::blocking::Client;
28
29#[cfg(not(target_arch = "wasm32"))]
30use std::collections::HashMap;
31
32#[cfg(not(target_arch = "wasm32"))]
33use std::sync::Mutex;
34
35#[cfg(not(target_arch = "wasm32"))]
36use std::time::{Duration, Instant};
37
38#[cfg(not(target_arch = "wasm32"))]
39static KYA_SCHEMA: &str = include_str!("../schema/kya-manifest.schema.json");
40
41#[cfg(not(target_arch = "wasm32"))]
42static COMPILED_SCHEMA: Lazy<JSONSchema> = Lazy::new(|| {
43    let schema_json: Value =
44        serde_json::from_str(KYA_SCHEMA).expect("KYA schema JSON should parse at compile time");
45    JSONSchema::options()
46        .with_draft(Draft::Draft7)
47        .compile(&schema_json)
48        .expect("KYA schema should compile")
49});
50
51#[cfg(not(target_arch = "wasm32"))]
52static CACHE: Lazy<Mutex<HashMap<String, (Value, Instant)>>> =
53    Lazy::new(|| Mutex::new(HashMap::new()));
54
55fn parse_datetime(value: &Value, field: &str) -> Result<Option<DateTime<Utc>>, String> {
56    match value.get(field) {
57        Some(Value::String(raw)) => DateTime::parse_from_rfc3339(raw)
58            .map(|dt| Some(dt.with_timezone(&Utc)))
59            .map_err(|err| format!("Invalid {}: {}", field, err)),
60        Some(_) => Err(format!("{} must be an RFC3339 string", field)),
61        None => Ok(None),
62    }
63}
64
65fn validate_ttl(manifest: &Value, now: DateTime<Utc>) -> (bool, Vec<String>) {
66    let mut errors = Vec::new();
67
68    let issuance = parse_datetime(manifest, "issuanceDate");
69    let expiration = parse_datetime(manifest, "expirationDate");
70
71    if let Err(err) = issuance.as_ref() {
72        errors.push(err.to_string());
73    }
74    if let Err(err) = expiration.as_ref() {
75        errors.push(err.to_string());
76    }
77
78    let issuance = issuance.ok().flatten();
79    let expiration = expiration.ok().flatten();
80
81    if let Some(issuance) = issuance {
82        if issuance > now {
83            errors.push("issuanceDate is in the future".to_string());
84        }
85    }
86
87    if let Some(expiration) = expiration {
88        if expiration < now {
89            errors.push("expirationDate is in the past".to_string());
90        }
91    }
92
93    (errors.is_empty(), errors)
94}
95
96pub fn validate_manifest_value(manifest: &Value) -> ValidationReport {
97    validate_manifest_with_config(manifest, &ValidationConfig::default())
98}
99
100fn check_required_fields(manifest: &Value, required_fields: &[String]) -> Vec<String> {
101    let mut errors = Vec::new();
102    for pointer in required_fields {
103        if manifest.pointer(pointer).is_none() {
104            errors.push(format!("Missing required field {}", pointer));
105        }
106    }
107    errors
108}
109
110fn check_required_field_pairs(manifest: &Value, pairs: &[(String, String)]) -> Vec<String> {
111    let mut errors = Vec::new();
112    for (left, right) in pairs {
113        let left_value = manifest.pointer(left);
114        let right_value = manifest.pointer(right);
115        if left_value.is_some() && right_value.is_none() {
116            errors.push(format!("Field {} requires {}", left, right));
117        }
118        if right_value.is_some() && left_value.is_none() {
119            errors.push(format!("Field {} requires {}", right, left));
120        }
121    }
122    errors
123}
124
125#[cfg(not(target_arch = "wasm32"))]
126#[allow(dead_code)]
127fn verify_tls_pin(_cert_der: &[u8], _pin_config: &TlsPinConfig) -> Result<(), String> {
128    // TODO: Implement TLS pinning
129    // Note: reqwest doesn't provide direct access to peer certificates
130    // Would need to use native-tls and custom connector for full TLS pinning support
131    // For now, this is a placeholder
132    Ok(())
133}
134
135#[cfg(not(target_arch = "wasm32"))]
136fn validate_domain_allowlist(url: &str, allowed_domains: &[String]) -> Result<(), String> {
137    let parsed = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
138    let host = parsed.host_str().ok_or("URL has no host")?;
139
140    if !allowed_domains.is_empty() && !allowed_domains.iter().any(|domain| host.ends_with(domain)) {
141        return Err(format!("Domain {} is not in allowlist", host));
142    }
143
144    Ok(())
145}
146
147#[cfg(not(target_arch = "wasm32"))]
148fn perform_content_check(
149    text: &str,
150    content_value: &Value,
151    check_config: &crate::types::ContentCheck,
152) -> Result<(), String> {
153    match check_config.check_type {
154        ContentCheckType::StringContains => {
155            if !text.contains(&check_config.expected_value) {
156                return Err(format!(
157                    "Content does not contain expected string: {}",
158                    check_config.expected_value
159                ));
160            }
161        }
162        ContentCheckType::StringEquals => {
163            if text != check_config.expected_value {
164                return Err(format!(
165                    "Content does not equal expected value: {}",
166                    check_config.expected_value
167                ));
168            }
169        }
170        ContentCheckType::StringMatchesRegex => {
171            let regex = Regex::new(&check_config.expected_value)
172                .map_err(|e| format!("Invalid regex: {}", e))?;
173            if !regex.is_match(text) {
174                return Err(format!(
175                    "Content does not match regex: {}",
176                    check_config.expected_value
177                ));
178            }
179        }
180        ContentCheckType::JsonPointerEquals | ContentCheckType::JsonPointerMatchesRegex => {
181            let pointer = check_config
182                .json_pointer
183                .as_ref()
184                .ok_or("json_pointer required for JSON pointer checks")?;
185            let target_value = content_value
186                .pointer(pointer)
187                .ok_or(format!("JSON pointer {} not found in content", pointer))?;
188
189            let target_str = target_value
190                .as_str()
191                .ok_or("Target value is not a string")?;
192
193            if check_config.check_type == ContentCheckType::JsonPointerEquals {
194                if target_str != check_config.expected_value {
195                    return Err(format!(
196                        "JSON pointer value does not match: {} != {}",
197                        target_str, check_config.expected_value
198                    ));
199                }
200            } else {
201                let regex = Regex::new(&check_config.expected_value)
202                    .map_err(|e| format!("Invalid regex: {}", e))?;
203                if !regex.is_match(target_str) {
204                    return Err(format!(
205                        "JSON pointer value does not match regex: {}",
206                        check_config.expected_value
207                    ));
208                }
209            }
210        }
211    }
212    Ok(())
213}
214
215#[cfg(not(target_arch = "wasm32"))]
216fn fetch_with_retry(
217    url: &str,
218    timeout_secs: Option<u64>,
219    max_retries: Option<u32>,
220) -> Result<reqwest::blocking::Response, String> {
221    let timeout = Duration::from_secs(timeout_secs.unwrap_or(30));
222    let retries = max_retries.unwrap_or(3);
223
224    let mut last_error: String = String::new();
225
226    for attempt in 0..retries {
227        let client = Client::builder()
228            .timeout(timeout)
229            .build()
230            .map_err(|e| format!("Failed to build HTTP client: {}", e))?;
231
232        match client.get(url).send() {
233            Ok(response) => return Ok(response),
234            Err(err) => {
235                last_error = err.to_string();
236                if attempt < retries - 1 {
237                    std::thread::sleep(Duration::from_millis(1000 * (attempt + 1) as u64));
238                }
239            }
240        }
241    }
242
243    Err(format!("Failed after {} attempts: {}", retries, last_error))
244}
245
246#[cfg(not(target_arch = "wasm32"))]
247fn compute_hash(data: &[u8], algorithm: HashAlgorithm) -> String {
248    match algorithm {
249        HashAlgorithm::Sha256 => {
250            let mut hasher = Sha256::new();
251            hasher.update(data);
252            let result = hasher.finalize();
253            hex::encode(result)
254        }
255        HashAlgorithm::Sha384 => {
256            let mut hasher = Sha384::new();
257            hasher.update(data);
258            let result = hasher.finalize();
259            hex::encode(result)
260        }
261        HashAlgorithm::Sha512 => {
262            let mut hasher = Sha512::new();
263            hasher.update(data);
264            let result = hasher.finalize();
265            hex::encode(result)
266        }
267    }
268}
269
270#[cfg(not(target_arch = "wasm32"))]
271fn verify_content_hash(data: &[u8], config: &DigestConfig) -> Result<(), String> {
272    let computed = compute_hash(data, config.algorithm);
273    let expected = config.expected_hash.to_lowercase();
274    let computed = computed.to_lowercase();
275
276    if computed != expected {
277        return Err(format!(
278            "Hash mismatch. Expected: {}, Got: {}",
279            expected, computed
280        ));
281    }
282
283    Ok(())
284}
285
286#[cfg(not(target_arch = "wasm32"))]
287fn check_external_links(manifest: &Value, link_checks: &[LinkCheckConfig]) -> Vec<String> {
288    let mut errors = Vec::new();
289
290    for check in link_checks {
291        let url_value = manifest.pointer(&check.json_pointer);
292        let url = match url_value.and_then(|value| value.as_str()) {
293            Some(url) => url,
294            None => {
295                errors.push(format!("Missing URL for {}", check.json_pointer));
296                continue;
297            }
298        };
299
300        if let Some(ref allowed_domains) = check.allowed_domains {
301            if let Err(e) = validate_domain_allowlist(url, allowed_domains) {
302                errors.push(e);
303                continue;
304            }
305        }
306
307        let cache_key = url.to_string();
308        let cached_result: Option<Value> = check.cache_ttl_secs.and_then(|ttl| {
309            let cache = CACHE.lock().ok()?;
310            if let Some((value, timestamp)) = cache.get(&cache_key) {
311                if timestamp.elapsed() < Duration::from_secs(ttl) {
312                    return Some(value.clone());
313                }
314            }
315            None
316        });
317
318        let response_bytes = if let Some(cached) = cached_result {
319            cached
320        } else {
321            let response = match fetch_with_retry(url, check.timeout_secs, check.max_retries) {
322                Ok(resp) => resp,
323                Err(err) => {
324                    errors.push(format!("Failed to fetch {}: {}", url, err));
325                    continue;
326                }
327            };
328
329            let bytes: Vec<u8> = match response.bytes() {
330                Ok(bytes) => bytes.to_vec(),
331                Err(err) => {
332                    errors.push(format!("Failed to read {}: {}", url, err));
333                    continue;
334                }
335            };
336
337            if let Some(ref digest_config) = check.verify_digest {
338                if let Err(e) = verify_content_hash(&bytes, digest_config) {
339                    errors.push(format!("Digest verification failed for {}: {}", url, e));
340                    continue;
341                }
342            }
343
344            if let Some(_ttl) = check.cache_ttl_secs {
345                if let Ok(mut cache) = CACHE.lock() {
346                    // Cache the response text for content checks
347                    let text = String::from_utf8_lossy(&bytes).to_string();
348                    let cache_entry: (Value, Instant) = (Value::String(text), Instant::now());
349                    cache.insert(cache_key, cache_entry);
350                }
351            }
352
353            Value::String(String::from_utf8_lossy(&bytes).to_string())
354        };
355
356        if let Some(expected) = check.required_contains.as_ref() {
357            let text = response_bytes.as_str().unwrap_or("");
358            if !text.contains(expected) {
359                errors.push(format!(
360                    "{} did not contain expected string: {}",
361                    url, expected
362                ));
363            }
364        }
365
366        if let Some(ref content_check) = check.content_check {
367            if let Err(e) = perform_content_check(
368                response_bytes.as_str().unwrap_or(""),
369                &response_bytes,
370                content_check,
371            ) {
372                errors.push(format!("Content check failed for {}: {}", url, e));
373            }
374        }
375    }
376
377    errors
378}
379
380fn verify_sgx_attestation(
381    _attestation_data: &[u8],
382    config: &AttestationCheckConfig,
383) -> Result<(), String> {
384    // Placeholder for SGX attestation verification
385    // In production, this would use mc-sgx-dcap-quoteverify or similar
386    if config.require_root_certificate {
387        // Verify root certificate chain
388        // This is a stub - actual implementation would verify against Intel root CAs
389    }
390
391    if let Some(ref _tcb_info) = config.expected_tcb_info {
392        // Verify TCB version and SVN
393        // This is a stub - actual implementation would parse quote and verify TCB
394    }
395
396    Ok(())
397}
398
399fn verify_nitro_attestation(
400    _attestation_data: &[u8],
401    _config: &AttestationCheckConfig,
402) -> Result<(), String> {
403    // Placeholder for AWS Nitro attestation verification
404    // Parse Nitro attestation document and verify against AWS root certificate
405    Ok(())
406}
407
408fn verify_sev_snp_attestation(
409    _attestation_data: &[u8],
410    _config: &AttestationCheckConfig,
411) -> Result<(), String> {
412    // Placeholder for AMD SEV-SNP attestation verification
413    // Verify VCEK report and certificate chain
414    Ok(())
415}
416
417fn check_attestations(
418    manifest: &Value,
419    attestation_checks: &[AttestationCheckConfig],
420) -> Vec<String> {
421    let mut errors = Vec::new();
422
423    for check in attestation_checks {
424        let attestation_value = manifest.pointer(&check.json_pointer);
425        let attestation_data: Result<Vec<u8>, String> =
426            match attestation_value.and_then(|value| value.as_str()) {
427                Some(data) => hex::decode(data)
428                    .map_err(|e| format!("Failed to decode attestation hex: {}", e)),
429                None => Err("Missing attestation data".to_string()),
430            };
431
432        let attestation_data: Vec<u8> = match attestation_data {
433            Ok(data) => data,
434            Err(e) => {
435                errors.push(format!("Attestation check {}: {}", check.json_pointer, e));
436                continue;
437            }
438        };
439
440        let result = match check.tee_type {
441            TEEType::SGX => verify_sgx_attestation(&attestation_data, check),
442            TEEType::Nitro => verify_nitro_attestation(&attestation_data, check),
443            TEEType::SevSnp => verify_sev_snp_attestation(&attestation_data, check),
444        };
445
446        if let Err(e) = result {
447            errors.push(format!(
448                "Attestation verification failed for {}: {:?}",
449                check.json_pointer, check.tee_type
450            ));
451            errors.push(e);
452        }
453    }
454
455    errors
456}
457
458fn check_allowed_controllers(manifest: &Value, allowed: &[String]) -> Vec<String> {
459    if allowed.is_empty() {
460        return Vec::new();
461    }
462    let controller = match manifest
463        .pointer("/agentId")
464        .and_then(|value| value.as_str())
465    {
466        Some(controller) => controller,
467        None => return vec!["Missing agentId for controller allowlist".to_string()],
468    };
469    if !allowed.iter().any(|item| item == controller) {
470        return vec![format!("Controller {} is not in allowlist", controller)];
471    }
472    Vec::new()
473}
474
475fn check_required_vc_types(manifest: &Value, required: &[String]) -> Vec<String> {
476    if required.is_empty() {
477        return Vec::new();
478    }
479    let vcs = match manifest.pointer("/verifiableCredential") {
480        Some(Value::Array(entries)) => entries,
481        _ => {
482            return vec![format!(
483                "Missing required VC types: {}",
484                required.join(", ")
485            )];
486        }
487    };
488
489    let mut missing = Vec::new();
490    for required_type in required {
491        let mut found = false;
492        for entry in vcs {
493            if let Some(Value::Array(types)) = entry.get("type") {
494                if types
495                    .iter()
496                    .any(|value| value.as_str() == Some(required_type))
497                {
498                    found = true;
499                    break;
500                }
501            }
502        }
503        if !found {
504            missing.push(required_type.clone());
505        }
506    }
507
508    if missing.is_empty() {
509        Vec::new()
510    } else {
511        vec![format!("Missing required VC types: {}", missing.join(", "))]
512    }
513}
514
515pub fn validate_manifest_with_config(
516    manifest: &Value,
517    config: &ValidationConfig,
518) -> ValidationReport {
519    let mut report = ValidationReport::ok();
520
521    // Schema validation only available on native targets (requires jsonschema with reqwest blocking)
522    #[cfg(not(target_arch = "wasm32"))]
523    {
524        let schema_result = COMPILED_SCHEMA.validate(manifest);
525        if let Err(errors) = schema_result {
526            report.schema_valid = false;
527            report.schema_errors = errors
528                .map(|err: jsonschema::ValidationError| err.to_string())
529                .collect();
530        }
531    }
532
533    // On WASM, skip schema validation (would require async fetch for external refs)
534    #[cfg(target_arch = "wasm32")]
535    {
536        report.schema_valid = true;
537        report.schema_errors = vec!["Schema validation skipped on WASM (use browser fetch for remote schemas)".to_string()];
538    }
539
540    let (ttl_valid, ttl_errors) = validate_ttl(manifest, Utc::now());
541    report.ttl_valid = ttl_valid;
542    report.ttl_errors = ttl_errors;
543
544    let required_field_errors = check_required_fields(manifest, &config.required_fields);
545    if !required_field_errors.is_empty() {
546        report.inspector_valid = false;
547        report.inspector_errors.extend(required_field_errors);
548    }
549
550    let required_pair_errors = check_required_field_pairs(manifest, &config.required_field_pairs);
551    if !required_pair_errors.is_empty() {
552        report.inspector_valid = false;
553        report.inspector_errors.extend(required_pair_errors);
554    }
555
556    let controller_errors = check_allowed_controllers(manifest, &config.allowed_controllers);
557    if !controller_errors.is_empty() {
558        report.inspector_valid = false;
559        report.inspector_errors.extend(controller_errors);
560    }
561
562    let vc_errors = check_required_vc_types(manifest, &config.required_vc_types);
563    if !vc_errors.is_empty() {
564        report.inspector_valid = false;
565        report.inspector_errors.extend(vc_errors);
566    }
567
568    // External link checking only available on native (requires blocking HTTP client)
569    #[cfg(not(target_arch = "wasm32"))]
570    if config.check_external_links {
571        let link_errors = check_external_links(manifest, &config.link_checks);
572        if !link_errors.is_empty() {
573            report.inspector_valid = false;
574            report.inspector_errors.extend(link_errors);
575        }
576    }
577
578    if !config.attestation_checks.is_empty() {
579        let attestation_errors = check_attestations(manifest, &config.attestation_checks);
580        if !attestation_errors.is_empty() {
581            report.inspector_valid = false;
582            report.inspector_errors.extend(attestation_errors);
583        }
584    }
585
586    if let Ok(parsed_manifest) = Manifest::from_value(manifest) {
587        if config.require_all_proofs && parsed_manifest.proof.is_empty() {
588            report.crypto_valid = false;
589            report.crypto_errors.push("No proofs provided".to_string());
590        }
591        let options = ValidationOptions {
592            allowed_kya_versions: config.allowed_kya_versions.clone(),
593            enforce_schema_url: false,
594        };
595        let (inspector_valid, inspector_errors) = inspect_manifest(&parsed_manifest, &options);
596        if !inspector_valid {
597            report.inspector_valid = false;
598            report.inspector_errors.extend(inspector_errors);
599        }
600
601        let (crypto_valid, crypto_errors, crypto_report) =
602            verify_manifest_proofs(&parsed_manifest, manifest, config);
603        report.crypto_valid = crypto_valid;
604        report.crypto_errors = crypto_errors;
605        report.crypto_report = Some(crypto_report);
606
607        let rules: Vec<Box<dyn PolicyRule>> = vec![
608            Box::new(AllowedRegionRule),
609            Box::new(MaxTransactionValueRule),
610        ];
611        let context = PolicyContext::default();
612        let (policy_valid, policy_errors) = apply_policy(&parsed_manifest, &context, &rules);
613        report.policy_valid = policy_valid;
614        report.policy_errors = policy_errors;
615    } else {
616        report.inspector_valid = false;
617        report
618            .inspector_errors
619            .push("Failed to parse manifest".to_string());
620        report.crypto_valid = false;
621        report
622            .crypto_errors
623            .push("Failed to parse manifest".to_string());
624        report.policy_valid = false;
625        report
626            .policy_errors
627            .push("Failed to parse manifest".to_string());
628    }
629    report
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use serde_json::json;
636
637    #[test]
638    fn schema_validation_fails_on_empty_object() {
639        let value = json!({});
640        let report = validate_manifest_value(&value);
641        assert!(!report.schema_valid);
642        assert!(!report.schema_errors.is_empty());
643    }
644
645    #[test]
646    fn ttl_validation_detects_future_and_expired() {
647        let value = json!({
648            "issuanceDate": "2999-01-01T00:00:00Z",
649            "expirationDate": "2000-01-01T00:00:00Z"
650        });
651        let report = validate_manifest_value(&value);
652        assert!(!report.ttl_valid);
653        assert_eq!(report.ttl_errors.len(), 2);
654    }
655}