use crate::error::VerifierError;
use alloc::{string::String, vec::Vec};
use serde::{Deserialize, Serialize};
pub type DecodedKeyTriple = (Vec<u8>, Vec<u8>, Vec<u8>);
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PublicKeysResponse {
pub epoch: String,
pub is_current: bool,
pub rotation_history: Vec<String>,
pub keys: PublicKeyBundle,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PublicKeyBundle {
pub dilithium: PublicKeyEntry,
pub falcon: PublicKeyEntry,
pub sphincs: PublicKeyEntry,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PublicKeyEntry {
pub algorithm: String,
pub format: String,
pub key_b64: String,
}
impl PublicKeyEntry {
pub fn decode_bytes(
&self,
field_name: &'static str,
) -> Result<Vec<u8>, VerifierError> {
use base64::Engine as _;
base64::engine::general_purpose::STANDARD
.decode(self.key_b64.as_bytes())
.map_err(|e| VerifierError::PublicKeysBase64 {
field: field_name,
detail: alloc::format!("{e}"),
})
}
}
impl PublicKeysResponse {
pub fn from_json(json: &str) -> Result<Self, VerifierError> {
serde_json::from_str(json).map_err(|e| {
VerifierError::PublicKeysParse(alloc::format!("{e}"))
})
}
pub fn decode_all(&self) -> Result<DecodedKeyTriple, VerifierError> {
Ok((
self.keys.dilithium.decode_bytes("dilithium")?,
self.keys.falcon.decode_bytes("falcon")?,
self.keys.sphincs.decode_bytes("sphincs")?,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
const FIXTURE_JSON: &str = r#"{
"epoch": "h33-substrate-abcdef1234567890",
"is_current": true,
"rotation_history": ["h33-substrate-abcdef1234567890"],
"keys": {
"dilithium": {
"algorithm": "ML-DSA-65",
"format": "raw",
"key_b64": "aGVsbG8gd29ybGQ="
},
"falcon": {
"algorithm": "FALCON-512",
"format": "raw",
"key_b64": "Zm9vYmFy"
},
"sphincs": {
"algorithm": "SPHINCS+-SHA2-128f",
"format": "raw",
"key_b64": "YmF6"
}
}
}"#;
#[test]
fn parses_fixture_json() {
let parsed = PublicKeysResponse::from_json(FIXTURE_JSON).unwrap();
assert_eq!(parsed.epoch, "h33-substrate-abcdef1234567890");
assert!(parsed.is_current);
assert_eq!(parsed.rotation_history.len(), 1);
assert_eq!(parsed.keys.dilithium.algorithm, "ML-DSA-65");
assert_eq!(parsed.keys.falcon.algorithm, "FALCON-512");
assert_eq!(parsed.keys.sphincs.algorithm, "SPHINCS+-SHA2-128f");
}
#[test]
fn decodes_base64_bytes() {
let parsed = PublicKeysResponse::from_json(FIXTURE_JSON).unwrap();
let (dil, fal, sph) = parsed.decode_all().unwrap();
assert_eq!(dil, b"hello world");
assert_eq!(fal, b"foobar");
assert_eq!(sph, b"baz");
}
#[test]
fn rejects_malformed_json() {
let bad = "{not valid json";
assert!(matches!(
PublicKeysResponse::from_json(bad),
Err(VerifierError::PublicKeysParse(_))
));
}
#[test]
fn rejects_bad_base64_in_key() {
let bad_json = r#"{
"epoch": "e",
"is_current": true,
"rotation_history": ["e"],
"keys": {
"dilithium": { "algorithm": "ML-DSA-65", "format": "raw", "key_b64": "@@@@" },
"falcon": { "algorithm": "FALCON-512", "format": "raw", "key_b64": "Zm9v" },
"sphincs": { "algorithm": "SPHINCS+-SHA2-128f", "format": "raw", "key_b64": "YmF6" }
}
}"#;
let parsed = PublicKeysResponse::from_json(bad_json).unwrap();
let err = parsed.decode_all().unwrap_err();
assert!(matches!(
err,
VerifierError::PublicKeysBase64 {
field: "dilithium",
..
}
));
}
}