use std::collections::HashMap;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use p256::ecdsa::VerifyingKey;
use p256::pkcs8::DecodePublicKey as _;
use sha2::{Digest, Sha256};
use tracing::warn;
use super::error::ScittError;
const MAX_SPKI_BASE64_LEN: usize = 1024;
#[derive(Debug, Clone)]
pub struct TrustedKey {
pub name: String,
pub kid: [u8; 4],
pub key: VerifyingKey,
}
#[derive(Debug, Clone)]
pub struct ScittKeyStore {
keys: HashMap<[u8; 4], TrustedKey>,
}
impl ScittKeyStore {
pub fn from_c2sp_keys(key_strings: &[String]) -> Result<Self, ScittError> {
let mut keys: HashMap<[u8; 4], TrustedKey> = HashMap::new();
for key_string in key_strings {
match parse_c2sp_key(key_string) {
Ok(trusted_key) => {
if let Some(existing) = keys.get(&trusted_key.kid)
&& existing.name != trusted_key.name
{
warn!(
kid = %hex::encode(trusted_key.kid),
existing = %existing.name,
new = %trusted_key.name,
"Key ID collision detected — overwriting existing key"
);
}
keys.insert(trusted_key.kid, trusted_key);
}
Err(err) => {
warn!(key = %key_string, error = %err, "Skipping invalid C2SP key");
}
}
}
if keys.is_empty() {
return Err(ScittError::InvalidKeyFormat(
"no valid keys could be parsed from input".to_string(),
));
}
Ok(Self { keys })
}
pub fn get(&self, kid: [u8; 4]) -> Result<&TrustedKey, ScittError> {
self.keys.get(&kid).ok_or(ScittError::UnknownKeyId(kid))
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
pub fn merge_from(&self, additional_key_strings: &[String]) -> Self {
let mut keys = self.keys.clone();
let before = keys.len();
for key_string in additional_key_strings {
match parse_c2sp_key(key_string) {
Ok(trusted_key) => {
use std::collections::hash_map::Entry;
match keys.entry(trusted_key.kid) {
Entry::Occupied(e) => {
if e.get().name != trusted_key.name {
warn!(
kid = %hex::encode(trusted_key.kid),
existing = %e.get().name,
new = %trusted_key.name,
"Key ID collision detected during merge — keeping existing key"
);
}
}
Entry::Vacant(e) => {
e.insert(trusted_key);
}
}
}
Err(err) => {
warn!(key = %key_string, error = %err, "Skipping invalid C2SP key during merge");
}
}
}
let added = keys.len() - before;
if added > 0 {
tracing::debug!(added, total = keys.len(), "Merged new root keys into store");
}
Self { keys }
}
}
pub fn parse_c2sp_key(key_string: &str) -> Result<TrustedKey, ScittError> {
let parts: Vec<&str> = key_string.splitn(3, '+').collect();
if parts.len() != 3 {
return Err(ScittError::InvalidKeyFormat(format!(
"expected 3 '+'-delimited parts, got {}",
parts.len()
)));
}
let name = parts[0];
let key_hash_hex = parts[1];
let spki_b64 = parts[2];
if name.is_empty() {
return Err(ScittError::InvalidKeyFormat(
"name (part 0) is empty".to_string(),
));
}
let key_hash_bytes = hex::decode(key_hash_hex)
.map_err(|e| ScittError::InvalidKeyFormat(format!("key_hash is not valid hex: {e}")))?;
if key_hash_bytes.len() != 4 {
return Err(ScittError::InvalidKeyFormat(format!(
"key_hash must be 4 bytes (8 hex chars), got {} bytes",
key_hash_bytes.len()
)));
}
let kid: [u8; 4] = [
key_hash_bytes[0],
key_hash_bytes[1],
key_hash_bytes[2],
key_hash_bytes[3],
];
if spki_b64.len() > MAX_SPKI_BASE64_LEN {
return Err(ScittError::InvalidKeyFormat(format!(
"SPKI base64 is {} chars, maximum is {MAX_SPKI_BASE64_LEN}",
spki_b64.len()
)));
}
let decoded = BASE64_STANDARD
.decode(spki_b64)
.map_err(|e| ScittError::InvalidKeyFormat(format!("SPKI-DER is not valid Base64: {e}")))?;
let spki_der = if decoded.first() == Some(&0x02) && decoded.len() > 1 {
&decoded[1..]
} else {
&decoded
};
let digest = Sha256::digest(spki_der);
let expected_kid: [u8; 4] = [digest[0], digest[1], digest[2], digest[3]];
if expected_kid != kid {
return Err(ScittError::KeyHashMismatch);
}
let key = VerifyingKey::from_public_key_der(spki_der)
.map_err(|e| ScittError::InvalidPublicKey(e.to_string()))?;
Ok(TrustedKey {
name: name.to_string(),
kid,
key,
})
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[cfg(test)]
mod tests {
use p256::ecdsa::SigningKey;
use p256::pkcs8::EncodePublicKey as _;
use super::*;
fn make_c2sp_key(seed: u8, name: &str) -> (String, TrustedKey) {
let seed_bytes = [seed; 32];
let signing_key = SigningKey::from_slice(&seed_bytes).unwrap();
let verifying_key = signing_key.verifying_key();
let spki_doc = verifying_key.to_public_key_der().unwrap();
let spki_der = spki_doc.as_bytes();
let digest = Sha256::digest(spki_der);
let kid: [u8; 4] = [digest[0], digest[1], digest[2], digest[3]];
let key_hash_hex = hex::encode(kid);
let spki_b64 = BASE64_STANDARD.encode(spki_der);
let key_string = format!("{name}+{key_hash_hex}+{spki_b64}");
let trusted_key = TrustedKey {
name: name.to_string(),
kid,
key: *verifying_key,
};
(key_string, trusted_key)
}
#[test]
fn parse_valid_c2sp_key() {
let (key_string, expected) = make_c2sp_key(1, "tl.example.com");
let parsed = parse_c2sp_key(&key_string).unwrap();
assert_eq!(parsed.name, expected.name);
assert_eq!(parsed.kid, expected.kid);
let parsed_der = parsed.key.to_public_key_der().unwrap();
let expected_der = expected.key.to_public_key_der().unwrap();
assert_eq!(parsed_der.as_bytes(), expected_der.as_bytes());
}
#[test]
fn error_zero_plus_delimiters() {
let err = parse_c2sp_key("noplusdelimiters").unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn error_one_plus_delimiter() {
let err = parse_c2sp_key("tl.example.com+a1b2c3d4").unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn error_no_extra_parts_via_splitn() {
let err = parse_c2sp_key("a+b").unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn error_key_hash_not_valid_hex() {
let (_, _) = make_c2sp_key(1, "tl.example.com");
let err = parse_c2sp_key("tl.example.com+ZZZZZZZZ+YWJj").unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn error_key_hash_too_short_3_bytes() {
let err = parse_c2sp_key("tl.example.com+a1b2c3+YWJj").unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn error_key_hash_too_long_5_bytes() {
let err = parse_c2sp_key("tl.example.com+a1b2c3d4e5+YWJj").unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn error_spki_not_valid_base64() {
let err = parse_c2sp_key("tl.example.com+a1b2c3d4+!!!not_base64!!!").unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn error_key_hash_mismatch() {
let (key_string, _) = make_c2sp_key(1, "tl.example.com");
let parts: Vec<&str> = key_string.splitn(3, '+').collect();
let bad_hash = format!("ff{}", &parts[1][2..]);
let tampered = format!("{}+{}+{}", parts[0], bad_hash, parts[2]);
let err = parse_c2sp_key(&tampered).unwrap_err();
assert!(matches!(err, ScittError::KeyHashMismatch));
}
#[test]
fn error_spki_valid_base64_but_not_p256() {
let fake_der = vec![0u8; 32];
let digest = Sha256::digest(&fake_der);
let kid: [u8; 4] = [digest[0], digest[1], digest[2], digest[3]];
let key_hash_hex = hex::encode(kid);
let spki_b64 = BASE64_STANDARD.encode(&fake_der);
let key_string = format!("tl.example.com+{key_hash_hex}+{spki_b64}");
let err = parse_c2sp_key(&key_string).unwrap_err();
assert!(matches!(err, ScittError::InvalidPublicKey(_)));
}
#[test]
fn keystore_lookup_found_and_not_found() {
let (key_string, expected) = make_c2sp_key(1, "tl.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[key_string]).unwrap();
let found = store.get(expected.kid).unwrap();
assert_eq!(found.name, "tl.example.com");
assert_eq!(found.kid, expected.kid);
let other_kid = [0xde, 0xad, 0xbe, 0xef];
let err = store.get(other_kid).unwrap_err();
assert!(matches!(err, ScittError::UnknownKeyId(k) if k == other_kid));
}
#[test]
fn keystore_multiple_keys() {
let (k1, trusted1) = make_c2sp_key(1, "tl.example.com");
let (k2, trusted2) = make_c2sp_key(2, "tl2.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[k1, k2]).unwrap();
assert_eq!(store.len(), 2);
assert!(!store.is_empty());
let found1 = store.get(trusted1.kid).unwrap();
assert_eq!(found1.name, "tl.example.com");
let found2 = store.get(trusted2.kid).unwrap();
assert_eq!(found2.name, "tl2.example.com");
}
#[test]
fn keystore_all_invalid_returns_error() {
let bad_keys = vec!["no+plus".to_string(), "also+bad".to_string()];
let err = ScittKeyStore::from_c2sp_keys(&bad_keys).unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn keystore_mixed_valid_and_invalid() {
let (valid_key, trusted) = make_c2sp_key(1, "tl.example.com");
let keys = vec!["not+valid".to_string(), valid_key, "also+bad".to_string()];
let store = ScittKeyStore::from_c2sp_keys(&keys).unwrap();
assert_eq!(store.len(), 1);
let found = store.get(trusted.kid).unwrap();
assert_eq!(found.name, "tl.example.com");
}
#[test]
fn keystore_empty_input_returns_error() {
let err = ScittKeyStore::from_c2sp_keys(&[]).unwrap_err();
assert!(matches!(err, ScittError::InvalidKeyFormat(_)));
}
#[test]
fn keystore_len_and_is_empty() {
let (k1, _) = make_c2sp_key(3, "tl.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[k1]).unwrap();
assert_eq!(store.len(), 1);
assert!(!store.is_empty());
}
#[test]
fn merge_from_adds_new_key() {
let (k1, trusted1) = make_c2sp_key(1, "tl.example.com");
let (k2, trusted2) = make_c2sp_key(2, "tl2.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[k1]).unwrap();
assert_eq!(store.len(), 1);
let merged = store.merge_from(&[k2]);
assert_eq!(merged.len(), 2);
assert!(merged.get(trusted1.kid).is_ok());
assert!(merged.get(trusted2.kid).is_ok());
}
#[test]
fn merge_from_does_not_overwrite_existing_key() {
let (k1, trusted1) = make_c2sp_key(1, "tl.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[k1.clone()]).unwrap();
let original_der = store
.get(trusted1.kid)
.unwrap()
.key
.to_public_key_der()
.unwrap();
let merged = store.merge_from(&[k1]);
assert_eq!(merged.len(), 1);
let after_der = merged
.get(trusted1.kid)
.unwrap()
.key
.to_public_key_der()
.unwrap();
assert_eq!(original_der.as_bytes(), after_der.as_bytes());
}
#[test]
fn merge_from_with_empty_input_returns_same_keys() {
let (k1, trusted1) = make_c2sp_key(1, "tl.example.com");
let (k2, trusted2) = make_c2sp_key(2, "tl2.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[k1, k2]).unwrap();
assert_eq!(store.len(), 2);
let merged = store.merge_from(&[]);
assert_eq!(merged.len(), 2);
assert!(merged.get(trusted1.kid).is_ok());
assert!(merged.get(trusted2.kid).is_ok());
}
#[test]
fn merge_from_skips_invalid_keys() {
let (k1, trusted1) = make_c2sp_key(1, "tl.example.com");
let (k2, trusted2) = make_c2sp_key(2, "tl2.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[k1]).unwrap();
let merged = store.merge_from(&["not+valid".to_string(), k2, "also+bad".to_string()]);
assert_eq!(merged.len(), 2);
assert!(merged.get(trusted1.kid).is_ok());
assert!(merged.get(trusted2.kid).is_ok());
}
}