use std::path::PathBuf;
use base64::Engine;
use serde::Deserialize;
use vectorpin::{
hash::{hash_text, hash_vector, VecDtype, VectorRef},
signer::PinOptions,
Pin, Signer, Verifier, VerifyError,
};
#[derive(Debug, Deserialize)]
struct FixtureBundle {
public_key_b64: String,
private_seed_b64: String,
key_id: String,
fixtures: Vec<Fixture>,
}
#[derive(Debug, Deserialize)]
struct Fixture {
name: String,
input: FixtureInput,
expected: FixtureExpected,
}
#[derive(Debug, Deserialize)]
struct FixtureInput {
source: String,
model: String,
vector_b64: String,
vec_dtype: String,
vec_dim: usize,
timestamp: String,
}
#[derive(Debug, Deserialize)]
struct FixtureExpected {
pin_json: String,
canonical_header_b64: String,
vec_hash: String,
source_hash: String,
}
fn b64(s: &str) -> Vec<u8> {
base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(s.as_bytes())
.expect("base64 fixture input")
}
fn fixtures_path() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("testvectors")
.join("v1.json")
}
fn negative_path() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("testvectors")
.join("negative_v1.json")
}
fn parse_vec_f32(bytes: &[u8], dim: usize) -> Vec<f32> {
assert_eq!(bytes.len(), dim * 4, "f32 fixture length sanity check");
let mut out = Vec::with_capacity(dim);
for chunk in bytes.chunks_exact(4) {
let arr: [u8; 4] = chunk.try_into().unwrap();
out.push(f32::from_le_bytes(arr));
}
out
}
fn parse_vec_f64(bytes: &[u8], dim: usize) -> Vec<f64> {
assert_eq!(bytes.len(), dim * 8, "f64 fixture length sanity check");
let mut out = Vec::with_capacity(dim);
for chunk in bytes.chunks_exact(8) {
let arr: [u8; 8] = chunk.try_into().unwrap();
out.push(f64::from_le_bytes(arr));
}
out
}
fn run_fixture(bundle: &FixtureBundle, fx: &Fixture) {
eprintln!("running fixture: {}", fx.name);
let dtype = VecDtype::parse(&fx.input.vec_dtype).expect("known dtype");
let raw_bytes = b64(&fx.input.vector_b64);
let computed_source_hash = hash_text(&fx.input.source);
assert_eq!(
computed_source_hash, fx.expected.source_hash,
"source_hash mismatch for {}",
fx.name
);
let computed_vec_hash = match dtype {
VecDtype::F32 => {
let v = parse_vec_f32(&raw_bytes, fx.input.vec_dim);
hash_vector(VectorRef::F32(&v), dtype)
}
VecDtype::F64 => {
let v = parse_vec_f64(&raw_bytes, fx.input.vec_dim);
hash_vector(VectorRef::F64(&v), dtype)
}
};
assert_eq!(
computed_vec_hash, fx.expected.vec_hash,
"vec_hash mismatch for {}",
fx.name
);
let signer = Signer::from_private_bytes(&b64(&bundle.private_seed_b64), bundle.key_id.clone())
.expect("seed loads");
assert_eq!(
signer.public_key_bytes().to_vec(),
b64(&bundle.public_key_b64)
);
let pin = match dtype {
VecDtype::F32 => {
let v = parse_vec_f32(&raw_bytes, fx.input.vec_dim);
signer
.pin_with_options(
&fx.input.source,
&fx.input.model,
v.as_slice(),
PinOptions {
dtype: Some(dtype),
timestamp: Some(fx.input.timestamp.clone()),
..PinOptions::default()
},
)
.unwrap()
}
VecDtype::F64 => {
let v = parse_vec_f64(&raw_bytes, fx.input.vec_dim);
signer
.pin_with_options(
&fx.input.source,
&fx.input.model,
v.as_slice(),
PinOptions {
dtype: Some(dtype),
timestamp: Some(fx.input.timestamp.clone()),
..PinOptions::default()
},
)
.unwrap()
}
};
let canonical = pin.header.canonicalize();
let expected_canonical = b64(&fx.expected.canonical_header_b64);
assert_eq!(
canonical, expected_canonical,
"canonical header bytes mismatch for {}",
fx.name
);
let produced_json = pin.to_json();
assert_eq!(
produced_json, fx.expected.pin_json,
"pin JSON mismatch for {}",
fx.name
);
let parsed = Pin::from_json(&produced_json).expect("rust parses its own JSON");
let mut verifier = Verifier::new();
verifier.add_key(&bundle.key_id, signer.public_key_bytes());
verifier
.verify_full::<&[f32]>(&parsed, Some(&fx.input.source), None, None)
.expect("rust verifies own pin");
let python_pin = Pin::from_json(&fx.expected.pin_json).expect("rust parses python JSON");
verifier
.verify_full::<&[f32]>(&python_pin, Some(&fx.input.source), None, None)
.expect("rust verifies python-produced pin");
}
#[test]
fn cross_language_positive_fixtures() {
let raw = std::fs::read_to_string(fixtures_path()).expect("read v1.json");
let bundle: FixtureBundle = serde_json::from_str(&raw).expect("parse v1.json");
assert!(!bundle.fixtures.is_empty(), "no fixtures to test");
for fx in &bundle.fixtures {
run_fixture(&bundle, fx);
}
}
#[derive(Debug, Deserialize)]
struct NegativeFixture {
pin_json: String,
tampered_vector_b64: String,
expected_error: String,
}
#[test]
fn cross_language_negative_tampered_vector() {
let raw = std::fs::read_to_string(negative_path()).expect("read negative_v1.json");
let neg: NegativeFixture = serde_json::from_str(&raw).expect("parse negative_v1.json");
assert_eq!(neg.expected_error, "vector_tampered");
let pin = Pin::from_json(&neg.pin_json).expect("parse pin");
let tampered = parse_vec_f32(&b64(&neg.tampered_vector_b64), pin.header.vec_dim as usize);
let raw_pos = std::fs::read_to_string(fixtures_path()).expect("read v1.json");
let bundle: FixtureBundle = serde_json::from_str(&raw_pos).expect("parse v1.json");
let mut verifier = Verifier::new();
verifier.add_key(
&bundle.key_id,
b64(&bundle.public_key_b64).try_into().unwrap(),
);
let err = verifier
.verify_full::<&[f32]>(&pin, None, Some(tampered.as_slice()), None)
.expect_err("tampered vector must fail");
assert_eq!(err, VerifyError::VectorTampered);
}