use once_cell::sync::Lazy;
use regex::Regex;
use serde_json::Value;
use crate::error::Error;
pub const PROTOCOL_VERSION: &str = "tp/0.1";
pub const PAYLOAD_TYPE: &str = "application/vnd.agent-toolprint+json";
static HASH_HEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"^sha256:[0-9a-f]{64}$").unwrap());
static BASE64: Lazy<Regex> = Lazy::new(|| Regex::new(r"^[A-Za-z0-9+/]*={0,2}$").unwrap());
static NONCE32: Lazy<Regex> = Lazy::new(|| Regex::new(r"^[A-Za-z0-9+/]{43}=$").unwrap());
static DID_KEY: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^did:key:z[1-9A-HJ-NP-Za-km-z]+$").unwrap());
static UUID_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$").unwrap()
});
static RFC3339: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?(Z|[+-]\d{2}:\d{2})$").unwrap()
});
const RECEIPT_ALLOWED: &[&str] = &[
"v", "id", "ts", "agent", "tool", "call", "result", "nonce", "parent",
];
const RECEIPT_REQUIRED: &[&str] = &["v", "id", "ts", "agent", "tool", "call", "result", "nonce"];
const PARTY_ALLOWED: &[&str] = &["did", "key_id"];
const CALL_ALLOWED: &[&str] = &["name", "args_hash"];
const RESULT_ALLOWED: &[&str] = &["status", "response_hash"];
const ENVELOPE_ALLOWED: &[&str] = &["payloadType", "payload", "signatures"];
const SIG_ALLOWED: &[&str] = &["keyid", "sig"];
fn ensure_object<'a>(
value: &'a Value,
where_: &str,
) -> Result<&'a serde_json::Map<String, Value>, Error> {
value
.as_object()
.ok_or_else(|| Error::Invalid(format!("{where_}: expected object")))
}
fn ensure_string<'a>(value: &'a Value, where_: &str) -> Result<&'a str, Error> {
value
.as_str()
.ok_or_else(|| Error::Invalid(format!("{where_}: expected string")))
}
fn check_pattern(s: &str, re: &Regex, where_: &str) -> Result<(), Error> {
if re.is_match(s) {
Ok(())
} else {
Err(Error::Invalid(format!(
"{where_}: does not match required pattern"
)))
}
}
fn check_keys(
obj: &serde_json::Map<String, Value>,
allowed: &[&str],
where_: &str,
) -> Result<(), Error> {
for k in obj.keys() {
if !allowed.contains(&k.as_str()) {
return Err(Error::Invalid(format!("{where_}: unknown key {k:?}")));
}
}
Ok(())
}
fn validate_party(value: &Value, where_: &str) -> Result<(), Error> {
let obj = ensure_object(value, where_)?;
check_keys(obj, PARTY_ALLOWED, where_)?;
let did = obj
.get("did")
.ok_or_else(|| Error::Invalid(format!("{where_}: missing 'did'")))?;
let kid = obj
.get("key_id")
.ok_or_else(|| Error::Invalid(format!("{where_}: missing 'key_id'")))?;
let did_s = ensure_string(did, &format!("{where_}.did"))?;
check_pattern(did_s, &DID_KEY, &format!("{where_}.did"))?;
let kid_s = ensure_string(kid, &format!("{where_}.key_id"))?;
if kid_s.is_empty() {
return Err(Error::Invalid(format!(
"{where_}.key_id: must be non-empty"
)));
}
Ok(())
}
pub fn validate_receipt(value: &Value) -> Result<(), Error> {
let obj = ensure_object(value, "receipt")?;
check_keys(obj, RECEIPT_ALLOWED, "receipt")?;
for req in RECEIPT_REQUIRED {
if !obj.contains_key(*req) {
return Err(Error::Invalid(format!(
"receipt: missing required key {req:?}"
)));
}
}
if obj["v"].as_str() != Some(PROTOCOL_VERSION) {
return Err(Error::Invalid(format!(
"receipt.v: expected {PROTOCOL_VERSION:?}, got {:?}",
obj["v"]
)));
}
let id = ensure_string(&obj["id"], "receipt.id")?;
check_pattern(id, &UUID_RE, "receipt.id")?;
let ts = ensure_string(&obj["ts"], "receipt.ts")?;
check_pattern(ts, &RFC3339, "receipt.ts")?;
validate_party(&obj["agent"], "receipt.agent")?;
validate_party(&obj["tool"], "receipt.tool")?;
let call = ensure_object(&obj["call"], "receipt.call")?;
check_keys(call, CALL_ALLOWED, "receipt.call")?;
let name = ensure_string(
call.get("name")
.ok_or_else(|| Error::Invalid("receipt.call: missing 'name'".into()))?,
"receipt.call.name",
)?;
if name.is_empty() {
return Err(Error::Invalid(
"receipt.call.name: must be non-empty".into(),
));
}
let args_hash = ensure_string(
call.get("args_hash")
.ok_or_else(|| Error::Invalid("receipt.call: missing 'args_hash'".into()))?,
"receipt.call.args_hash",
)?;
check_pattern(args_hash, &HASH_HEX, "receipt.call.args_hash")?;
let result = ensure_object(&obj["result"], "receipt.result")?;
check_keys(result, RESULT_ALLOWED, "receipt.result")?;
let status = ensure_string(
result
.get("status")
.ok_or_else(|| Error::Invalid("receipt.result: missing 'status'".into()))?,
"receipt.result.status",
)?;
if status != "ok" && status != "error" {
return Err(Error::Invalid(format!(
"receipt.result.status: must be 'ok' or 'error', got {status:?}"
)));
}
let resp_hash = ensure_string(
result
.get("response_hash")
.ok_or_else(|| Error::Invalid("receipt.result: missing 'response_hash'".into()))?,
"receipt.result.response_hash",
)?;
check_pattern(resp_hash, &HASH_HEX, "receipt.result.response_hash")?;
let nonce = ensure_string(&obj["nonce"], "receipt.nonce")?;
check_pattern(nonce, &NONCE32, "receipt.nonce")?;
if let Some(parent) = obj.get("parent") {
let p = ensure_string(parent, "receipt.parent")?;
check_pattern(p, &UUID_RE, "receipt.parent")?;
}
Ok(())
}
pub fn validate_envelope(value: &Value) -> Result<(), Error> {
let obj = ensure_object(value, "envelope")?;
check_keys(obj, ENVELOPE_ALLOWED, "envelope")?;
for req in ENVELOPE_ALLOWED {
if !obj.contains_key(*req) {
return Err(Error::Invalid(format!("envelope: missing key {req:?}")));
}
}
if obj["payloadType"].as_str() != Some(PAYLOAD_TYPE) {
return Err(Error::Invalid(format!(
"envelope.payloadType: expected {PAYLOAD_TYPE:?}, got {:?}",
obj["payloadType"]
)));
}
let payload = ensure_string(&obj["payload"], "envelope.payload")?;
check_pattern(payload, &BASE64, "envelope.payload")?;
let sigs = obj["signatures"]
.as_array()
.ok_or_else(|| Error::Invalid("envelope.signatures: expected array".into()))?;
if sigs.is_empty() || sigs.len() > 2 {
return Err(Error::Invalid(format!(
"envelope.signatures: must have 1 or 2 entries, got {}",
sigs.len()
)));
}
for (i, s) in sigs.iter().enumerate() {
let where_ = format!("envelope.signatures[{i}]");
let so = ensure_object(s, &where_)?;
check_keys(so, SIG_ALLOWED, &where_)?;
let kid = ensure_string(
so.get("keyid")
.ok_or_else(|| Error::Invalid(format!("{where_}: missing 'keyid'")))?,
&format!("{where_}.keyid"),
)?;
if kid.is_empty() {
return Err(Error::Invalid(format!("{where_}.keyid: must be non-empty")));
}
let sig = ensure_string(
so.get("sig")
.ok_or_else(|| Error::Invalid(format!("{where_}: missing 'sig'")))?,
&format!("{where_}.sig"),
)?;
check_pattern(sig, &BASE64, &format!("{where_}.sig"))?;
}
Ok(())
}