use base64::{engine::general_purpose::STANDARD, Engine};
use chrono::{DateTime, Utc};
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use serde_json::Value;
use crate::canonical::{canonical, sha256_hash};
use crate::did_key::Resolver;
use crate::envelope::{envelope_pae_bytes, envelope_payload_bytes};
use crate::types::{validate_envelope, validate_receipt};
#[derive(Default, Debug, Clone)]
pub struct Plaintext {
pub args: Option<Value>,
pub response: Option<Value>,
}
pub struct VerifyOptions<'a> {
pub resolver: &'a dyn Resolver,
pub now: Option<DateTime<Utc>>,
pub max_clock_skew_ms: i64,
pub skip_timestamp_check: bool,
pub plaintext: Option<Plaintext>,
}
impl<'a> VerifyOptions<'a> {
pub fn new(resolver: &'a dyn Resolver) -> Self {
Self {
resolver,
now: None,
max_clock_skew_ms: 24 * 3600 * 1000,
skip_timestamp_check: false,
plaintext: None,
}
}
}
#[derive(Debug)]
pub struct VerifyResult {
pub ok: bool,
pub receipt: Option<Value>,
pub error: Option<String>,
}
fn fail(error: impl Into<String>) -> VerifyResult {
VerifyResult {
ok: false,
receipt: None,
error: Some(error.into()),
}
}
pub async fn verify(envelope: &Value, opts: &VerifyOptions<'_>) -> VerifyResult {
if let Err(e) = validate_envelope(envelope) {
return fail(format!("envelope schema: {e}"));
}
let sigs = envelope["signatures"].as_array().unwrap();
if sigs.len() != 2 {
return fail(format!(
"verify requires exactly 2 signatures, got {}",
sigs.len()
));
}
let kid0 = sigs[0]["keyid"].as_str().unwrap_or("");
let kid1 = sigs[1]["keyid"].as_str().unwrap_or("");
if kid0 == kid1 {
return fail("verify rejects duplicate keyid across signatures");
}
let payload_bytes = match envelope_payload_bytes(envelope) {
Ok(b) => b,
Err(e) => return fail(format!("payload: {e}")),
};
let receipt: Value = match serde_json::from_slice(&payload_bytes) {
Ok(v) => v,
Err(e) => return fail(format!("payload is not valid JSON: {e}")),
};
if let Err(e) = validate_receipt(&receipt) {
return fail(format!("receipt schema: {e}"));
}
let canonical_bytes = match canonical(&receipt) {
Ok(b) => b,
Err(e) => return fail(format!("canonicalize: {e}")),
};
if payload_bytes != canonical_bytes {
return fail("envelope payload is not JCS-canonical");
}
let agent_kid = receipt["agent"]["key_id"].as_str().unwrap_or("");
let tool_kid = receipt["tool"]["key_id"].as_str().unwrap_or("");
if kid0 != agent_kid {
return fail(format!(
"keyid mismatch: signatures[0].keyid ({kid0}) != receipt.agent.key_id ({agent_kid})"
));
}
if kid1 != tool_kid {
return fail(format!(
"keyid mismatch: signatures[1].keyid ({kid1}) != receipt.tool.key_id ({tool_kid})"
));
}
let ts_str = receipt["ts"].as_str().unwrap_or("");
let rts = match DateTime::parse_from_rfc3339(ts_str) {
Ok(t) => t.with_timezone(&Utc),
Err(_) => return fail(format!("invalid receipt ts: {ts_str}")),
};
if !opts.skip_timestamp_check {
let now = opts.now.unwrap_or_else(Utc::now);
let delta_ms = (now - rts).num_milliseconds().abs();
if delta_ms > opts.max_clock_skew_ms {
return fail(format!(
"timestamp window exceeded: |now - ts| = {delta_ms}ms, max = {}ms",
opts.max_clock_skew_ms
));
}
}
let pae = match envelope_pae_bytes(envelope) {
Ok(b) => b,
Err(e) => return fail(format!("pae: {e}")),
};
let agent_did = receipt["agent"]["did"].as_str().unwrap_or("");
let agent_pk = match opts.resolver.resolve(agent_did, agent_kid, rts).await {
Some(pk) => pk,
None => return fail(format!("agent DID did not resolve: {agent_did}")),
};
if !verify_sig(&agent_pk, &pae, sigs[0]["sig"].as_str().unwrap_or("")) {
return fail("agent signature invalid");
}
let tool_did = receipt["tool"]["did"].as_str().unwrap_or("");
let tool_pk = match opts.resolver.resolve(tool_did, tool_kid, rts).await {
Some(pk) => pk,
None => return fail(format!("tool DID did not resolve: {tool_did}")),
};
if !verify_sig(&tool_pk, &pae, sigs[1]["sig"].as_str().unwrap_or("")) {
return fail("tool signature invalid");
}
if let Some(pt) = &opts.plaintext {
if let Some(args) = &pt.args {
let recomputed = match sha256_hash(args) {
Ok(h) => h,
Err(e) => return fail(format!("args hash: {e}")),
};
let expected = receipt["call"]["args_hash"].as_str().unwrap_or("");
if recomputed != expected {
return fail(format!(
"plaintext args_hash mismatch: expected {expected}, got {recomputed}"
));
}
}
if let Some(response) = &pt.response {
let recomputed = match sha256_hash(response) {
Ok(h) => h,
Err(e) => return fail(format!("response hash: {e}")),
};
let expected = receipt["result"]["response_hash"].as_str().unwrap_or("");
if recomputed != expected {
return fail(format!(
"plaintext response_hash mismatch: expected {expected}, got {recomputed}"
));
}
}
}
VerifyResult {
ok: true,
receipt: Some(receipt),
error: None,
}
}
fn verify_sig(pk_bytes: &[u8], message: &[u8], sig_b64: &str) -> bool {
if pk_bytes.len() != 32 {
return false;
}
let mut pk_arr = [0u8; 32];
pk_arr.copy_from_slice(pk_bytes);
let pk = match VerifyingKey::from_bytes(&pk_arr) {
Ok(k) => k,
Err(_) => return false,
};
let sig_bytes = match STANDARD.decode(sig_b64) {
Ok(b) => b,
Err(_) => return false,
};
if sig_bytes.len() != 64 {
return false;
}
let mut sig_arr = [0u8; 64];
sig_arr.copy_from_slice(&sig_bytes);
let sig = Signature::from_bytes(&sig_arr);
pk.verify(message, &sig).is_ok()
}