use std::collections::HashMap;
use ed25519_dalek::{Signature, Verifier as _, VerifyingKey};
use crate::attestation::{Pin, PROTOCOL_VERSION};
use crate::hash::{hash_text, hash_vector, VecDtype, VectorRef};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VerifyError {
UnsupportedVersion(u32),
UnknownKey(String),
SignatureInvalid,
VectorTampered,
SourceMismatch,
ModelMismatch {
pin_model: String,
expected: String,
},
ShapeMismatch {
supplied: usize,
expected: u32,
},
UnsupportedDtype(String),
}
impl std::fmt::Display for VerifyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VerifyError::UnsupportedVersion(v) => write!(f, "unsupported pin version: {v}"),
VerifyError::UnknownKey(k) => write!(f, "unknown signing key id: {k}"),
VerifyError::SignatureInvalid => write!(f, "ed25519 signature did not verify"),
VerifyError::VectorTampered => write!(
f,
"vector hash mismatch — embedding has been modified after pinning"
),
VerifyError::SourceMismatch => write!(f, "source hash mismatch"),
VerifyError::ModelMismatch {
pin_model,
expected,
} => {
write!(f, "pin model {pin_model:?} != expected {expected:?}")
}
VerifyError::ShapeMismatch { supplied, expected } => {
write!(
f,
"vector shape mismatch: supplied len {supplied}, pin dim {expected}"
)
}
VerifyError::UnsupportedDtype(s) => write!(f, "unsupported canonical dtype: {s}"),
}
}
}
impl std::error::Error for VerifyError {}
#[derive(Default)]
pub struct Verifier {
keys: HashMap<String, VerifyingKey>,
}
impl Verifier {
pub fn new() -> Self {
Self::default()
}
pub fn add_key(&mut self, kid: &str, public_key_bytes: [u8; 32]) {
if let Ok(vk) = VerifyingKey::from_bytes(&public_key_bytes) {
self.keys.insert(kid.to_owned(), vk);
}
}
pub fn key_count(&self) -> usize {
self.keys.len()
}
pub fn verify_signature(&self, pin: &Pin) -> Result<(), VerifyError> {
if pin.header.v != PROTOCOL_VERSION {
return Err(VerifyError::UnsupportedVersion(pin.header.v));
}
let key = self
.keys
.get(&pin.kid)
.ok_or_else(|| VerifyError::UnknownKey(pin.kid.clone()))?;
let sig_bytes: [u8; 64] = pin
.sig
.as_slice()
.try_into()
.map_err(|_| VerifyError::SignatureInvalid)?;
let signature = Signature::from_bytes(&sig_bytes);
key.verify(&pin.header.canonicalize(), &signature)
.map_err(|_| VerifyError::SignatureInvalid)
}
pub fn verify_full<'a, V>(
&self,
pin: &Pin,
source: Option<&str>,
vector: Option<V>,
expected_model: Option<&str>,
) -> Result<(), VerifyError>
where
V: Into<VectorRef<'a>>,
{
self.verify_signature(pin)?;
if let Some(vec) = vector {
let vec = vec.into();
if vec.len() as u32 != pin.header.vec_dim {
return Err(VerifyError::ShapeMismatch {
supplied: vec.len(),
expected: pin.header.vec_dim,
});
}
let dtype = VecDtype::parse(&pin.header.vec_dtype)
.map_err(|_| VerifyError::UnsupportedDtype(pin.header.vec_dtype.clone()))?;
if hash_vector(vec, dtype) != pin.header.vec_hash {
return Err(VerifyError::VectorTampered);
}
}
if let Some(s) = source {
if hash_text(s) != pin.header.source_hash {
return Err(VerifyError::SourceMismatch);
}
}
if let Some(em) = expected_model {
if pin.header.model != em {
return Err(VerifyError::ModelMismatch {
pin_model: pin.header.model.clone(),
expected: em.to_owned(),
});
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::signer::Signer;
fn fixture(kid: &str) -> (Signer, Verifier, Vec<f32>) {
let signer = Signer::generate(kid.into());
let mut verifier = Verifier::new();
verifier.add_key(signer.key_id(), signer.public_key_bytes());
let v: Vec<f32> = (0..16).map(|i| (i as f32) * 0.1).collect();
(signer, verifier, v)
}
#[test]
fn verify_full_passes_on_honest_inputs() {
let (signer, verifier, v) = fixture("k1");
let pin = signer.pin("hello", "m", v.as_slice()).unwrap();
verifier
.verify_full(&pin, Some("hello"), Some(v.as_slice()), None)
.expect("honest verify must succeed");
}
#[test]
fn verify_signature_only_passes() {
let (signer, verifier, v) = fixture("k1");
let pin = signer.pin("hello", "m", v.as_slice()).unwrap();
verifier.verify_signature(&pin).unwrap();
}
#[test]
fn vector_tamper_is_caught() {
let (signer, verifier, v) = fixture("k1");
let pin = signer.pin("hello", "m", v.as_slice()).unwrap();
let mut tampered = v.clone();
tampered[0] += 1e-5;
let err = verifier
.verify_full(&pin, None::<&str>, Some(tampered.as_slice()), None)
.unwrap_err();
assert_eq!(err, VerifyError::VectorTampered);
}
#[test]
fn source_mismatch_is_caught() {
let (signer, verifier, v) = fixture("k1");
let pin = signer.pin("hello", "m", v.as_slice()).unwrap();
let err = verifier
.verify_full(&pin, Some("HELLO"), None::<&[f32]>, None)
.unwrap_err();
assert_eq!(err, VerifyError::SourceMismatch);
}
#[test]
fn unknown_key_is_caught() {
let signer = Signer::generate("rogue".into());
let v: Vec<f32> = vec![1.0, 2.0, 3.0];
let pin = signer.pin("x", "m", v.as_slice()).unwrap();
let other = Signer::generate("prod".into());
let mut verifier = Verifier::new();
verifier.add_key(other.key_id(), other.public_key_bytes());
let err = verifier.verify_signature(&pin).unwrap_err();
assert!(matches!(err, VerifyError::UnknownKey(_)));
}
#[test]
fn shape_mismatch_is_caught() {
let (signer, verifier, v) = fixture("k1");
let pin = signer.pin("x", "m", v.as_slice()).unwrap();
let truncated: Vec<f32> = v.iter().take(8).copied().collect();
let err = verifier
.verify_full(&pin, None::<&str>, Some(truncated.as_slice()), None)
.unwrap_err();
assert!(matches!(err, VerifyError::ShapeMismatch { .. }));
}
#[test]
fn model_mismatch_is_caught() {
let (signer, verifier, v) = fixture("k1");
let pin = signer.pin("x", "model-A", v.as_slice()).unwrap();
let err = verifier
.verify_full(&pin, None::<&str>, None::<&[f32]>, Some("model-B"))
.unwrap_err();
assert!(matches!(err, VerifyError::ModelMismatch { .. }));
}
#[test]
fn key_rotation_works() {
let old = Signer::generate("2026-04".into());
let new = Signer::generate("2026-05".into());
let mut verifier = Verifier::new();
verifier.add_key(old.key_id(), old.public_key_bytes());
verifier.add_key(new.key_id(), new.public_key_bytes());
let v: Vec<f32> = vec![1.0, 2.0];
verifier
.verify_signature(&old.pin("x", "m", v.as_slice()).unwrap())
.unwrap();
verifier
.verify_signature(&new.pin("x", "m", v.as_slice()).unwrap())
.unwrap();
}
}