use std::collections::BTreeMap;
use coz::digest::Digest;
use coz::sha2::{Sha256, Sha384, Sha512};
use coz::{Cad, Czd, Thumbprint};
use crate::multihash::MultihashDigest;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct KeyRoot(pub crate::multihash::MultihashDigest);
impl KeyRoot {
#[must_use]
pub fn as_multihash(&self) -> &crate::multihash::MultihashDigest {
&self.0
}
#[must_use]
pub fn get(&self, alg: HashAlg) -> Option<&[u8]> {
self.0.get(alg)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct CommitID(pub crate::multihash::MultihashDigest);
impl CommitID {
#[must_use]
pub fn as_multihash(&self) -> &crate::multihash::MultihashDigest {
&self.0
}
#[must_use]
pub fn get(&self, alg: HashAlg) -> Option<&[u8]> {
self.0.get(alg)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct AuthRoot(pub MultihashDigest);
impl AuthRoot {
#[must_use]
pub fn as_multihash(&self) -> &MultihashDigest {
&self.0
}
#[must_use]
pub fn get(&self, alg: HashAlg) -> Option<&[u8]> {
self.0.get(alg)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct StateRoot(pub MultihashDigest);
impl StateRoot {
#[must_use]
pub fn as_multihash(&self) -> &MultihashDigest {
&self.0
}
#[must_use]
pub fn get(&self, alg: HashAlg) -> Option<&[u8]> {
self.0.get(alg)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DataRoot(pub Cad);
impl DataRoot {
pub fn as_cad(&self) -> &Cad {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct PrincipalRoot(pub MultihashDigest);
impl PrincipalRoot {
#[must_use]
pub fn as_multihash(&self) -> &MultihashDigest {
&self.0
}
#[must_use]
pub fn get(&self, alg: HashAlg) -> Option<&[u8]> {
self.0.get(alg)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrincipalGenesis(pub MultihashDigest);
impl PrincipalGenesis {
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self(MultihashDigest::from_single(HashAlg::Sha256, bytes))
}
pub fn from_initial(ps: &PrincipalRoot) -> Self {
Self(ps.0.clone())
}
#[must_use]
pub fn as_multihash(&self) -> &MultihashDigest {
&self.0
}
#[must_use]
pub fn get(&self, alg: HashAlg) -> Option<&[u8]> {
self.0.get(alg)
}
}
pub use coz::HashAlg;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct TaggedDigest {
alg: HashAlg,
digest: Vec<u8>,
}
impl TaggedDigest {
#[must_use]
pub fn alg(&self) -> HashAlg {
self.alg
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.digest
}
#[must_use]
pub const fn expected_len(alg: HashAlg) -> usize {
match alg {
HashAlg::Sha256 => 32,
HashAlg::Sha384 => 48,
HashAlg::Sha512 => 64,
}
}
fn parse_alg(s: &str) -> Result<HashAlg, crate::error::Error> {
match s {
"SHA-256" => Ok(HashAlg::Sha256),
"SHA-384" => Ok(HashAlg::Sha384),
"SHA-512" => Ok(HashAlg::Sha512),
_ => Err(crate::error::Error::UnsupportedAlgorithm(s.to_string())),
}
}
}
impl std::str::FromStr for TaggedDigest {
type Err = crate::error::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (alg_str, digest_b64) =
s.split_once(':')
.ok_or(crate::error::Error::MalformedDigest(
"missing ':' separator",
))?;
let alg = Self::parse_alg(alg_str)?;
use coz::base64ct::{Base64UrlUnpadded, Encoding};
let expected = Self::expected_len(alg);
let mut buf = [0u8; 64];
let decoded = Base64UrlUnpadded::decode(digest_b64, &mut buf)
.map_err(|_| crate::error::Error::MalformedDigest("invalid base64"))?;
if decoded.len() != expected {
return Err(crate::error::Error::DigestLengthMismatch {
alg,
expected,
actual: decoded.len(),
});
}
Ok(Self {
alg,
digest: decoded.to_vec(),
})
}
}
impl std::fmt::Display for TaggedDigest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use coz::base64ct::{Base64UrlUnpadded, Encoding};
write!(
f,
"{}:{}",
self.alg,
Base64UrlUnpadded::encode_string(&self.digest)
)
}
}
impl serde::Serialize for TaggedDigest {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> serde::Deserialize<'de> for TaggedDigest {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
pub fn hash_alg_from_str(alg: &str) -> crate::error::Result<HashAlg> {
coz::Alg::from_str(alg)
.map(coz::Alg::hash_alg)
.ok_or_else(|| crate::error::Error::UnsupportedAlgorithm(alg.to_string()))
}
pub fn derive_hash_algs(keys: &[&crate::Key]) -> Vec<HashAlg> {
use std::collections::BTreeSet;
let algs: BTreeSet<HashAlg> = keys
.iter()
.filter(|k| k.is_active())
.filter_map(|k| hash_alg_from_str(&k.alg).ok())
.collect();
algs.into_iter().collect()
}
fn hash_sorted_concat(alg: HashAlg, components: &[&[u8]]) -> Cad {
Cad::from_bytes(hash_sorted_concat_bytes(alg, components))
}
pub(crate) fn hash_sorted_concat_bytes(alg: HashAlg, components: &[&[u8]]) -> Vec<u8> {
let mut sorted: Vec<&[u8]> = components.to_vec();
sorted.sort();
match alg {
HashAlg::Sha256 => {
let mut h = Sha256::new();
for c in sorted {
h.update(c);
}
h.finalize().to_vec()
},
HashAlg::Sha384 => {
let mut h = Sha384::new();
for c in sorted {
h.update(c);
}
h.finalize().to_vec()
},
HashAlg::Sha512 => {
let mut h = Sha512::new();
for c in sorted {
h.update(c);
}
h.finalize().to_vec()
},
}
}
pub(crate) fn hash_concat_bytes(alg: HashAlg, components: &[&[u8]]) -> Vec<u8> {
match alg {
HashAlg::Sha256 => {
let mut h = Sha256::new();
for c in components {
h.update(c);
}
h.finalize().to_vec()
},
HashAlg::Sha384 => {
let mut h = Sha384::new();
for c in components {
h.update(c);
}
h.finalize().to_vec()
},
HashAlg::Sha512 => {
let mut h = Sha512::new();
for c in components {
h.update(c);
}
h.finalize().to_vec()
},
}
}
pub(crate) fn hash_bytes(alg: HashAlg, data: &[u8]) -> Vec<u8> {
match alg {
HashAlg::Sha256 => {
let mut h = Sha256::new();
h.update(data);
h.finalize().to_vec()
},
HashAlg::Sha384 => {
let mut h = Sha384::new();
h.update(data);
h.finalize().to_vec()
},
HashAlg::Sha512 => {
let mut h = Sha512::new();
h.update(data);
h.finalize().to_vec()
},
}
}
pub fn compute_kr(
thumbprints: &[&Thumbprint],
nonce: Option<&[u8]>,
algs: &[HashAlg],
) -> crate::error::Result<KeyRoot> {
use crate::multihash::MultihashDigest;
use std::collections::BTreeMap;
if algs.is_empty() {
return Err(crate::error::Error::NoActiveKeys);
}
if thumbprints.len() == 1 && nonce.is_none() {
let alg = algs[0];
return Ok(KeyRoot(MultihashDigest::from_single(
alg,
thumbprints[0].as_bytes().to_vec(),
)));
}
let mut components: Vec<&[u8]> = thumbprints.iter().map(|t| t.as_bytes()).collect();
if let Some(n) = nonce {
components.push(n);
}
let mut variants = BTreeMap::new();
for &alg in algs {
let digest = hash_sorted_concat_bytes(alg, &components);
variants.insert(alg, digest.into_boxed_slice());
}
Ok(KeyRoot(MultihashDigest::new(variants)?))
}
pub fn compute_commit_id(
czds: &[&Czd],
nonce: Option<&[u8]>,
algs: &[HashAlg],
) -> Option<CommitID> {
if czds.is_empty() && nonce.is_none() {
return None;
}
if czds.len() == 1 && nonce.is_none() {
let czd_bytes = czds[0].as_bytes();
return Some(CommitID(MultihashDigest::from_single(
algs.first().copied().unwrap_or(HashAlg::Sha256),
czd_bytes.to_vec(),
)));
}
let mut components: Vec<&[u8]> = czds.iter().map(|c| c.as_bytes()).collect();
if let Some(n) = nonce {
components.push(n);
}
let mut variants = BTreeMap::new();
for &alg in algs {
let digest = hash_concat_bytes(alg, &components);
variants.insert(alg, digest.into_boxed_slice());
}
Some(CommitID(MultihashDigest::new(variants).ok()?))
}
#[derive(Debug, Clone)]
pub struct TaggedCzd<'a> {
pub czd: &'a Czd,
pub alg: HashAlg,
}
impl<'a> TaggedCzd<'a> {
pub fn new(czd: &'a Czd, alg: HashAlg) -> Self {
Self { czd, alg }
}
pub fn convert_to(&self, target: HashAlg) -> Vec<u8> {
if self.alg == target {
self.czd.as_bytes().to_vec()
} else {
hash_bytes(target, self.czd.as_bytes())
}
}
}
pub fn compute_commit_id_tagged(
czds: &[TaggedCzd<'_>],
nonce: Option<&[u8]>,
algs: &[HashAlg],
) -> Option<CommitID> {
if czds.is_empty() && nonce.is_none() {
return None;
}
if czds.len() == 1 && nonce.is_none() {
let target_alg = algs.first().copied().unwrap_or(HashAlg::Sha256);
let converted = czds[0].convert_to(target_alg);
return Some(CommitID(MultihashDigest::from_single(
target_alg, converted,
)));
}
let mut variants = BTreeMap::new();
for &target_alg in algs {
let converted: Vec<Vec<u8>> = czds.iter().map(|tc| tc.convert_to(target_alg)).collect();
let mut components: Vec<&[u8]> = converted.iter().map(|v| v.as_slice()).collect();
if let Some(n) = nonce {
components.push(n);
}
let digest = hash_concat_bytes(target_alg, &components);
variants.insert(target_alg, digest.into_boxed_slice());
}
Some(CommitID(MultihashDigest::new(variants).ok()?))
}
pub fn compute_ar(
ks: &KeyRoot,
nonce: Option<&[u8]>,
embedding: Option<&[u8]>,
algs: &[HashAlg],
) -> crate::error::Result<AuthRoot> {
if nonce.is_none() && embedding.is_none() {
return Ok(AuthRoot(ks.0.clone()));
}
let mut variants = BTreeMap::new();
for &alg in algs {
let kr_bytes = ks.0.get_or_err(alg)?;
let mut components: Vec<&[u8]> = vec![kr_bytes];
if let Some(n) = nonce {
components.push(n);
}
if let Some(e) = embedding {
components.push(e);
}
let digest = hash_sorted_concat_bytes(alg, &components);
variants.insert(alg, digest.into_boxed_slice());
}
Ok(AuthRoot(MultihashDigest::new(variants)?))
}
pub fn compute_sr(
auth_root: &AuthRoot,
ds: Option<&DataRoot>,
embedding: Option<&[u8]>,
algs: &[HashAlg],
) -> crate::error::Result<StateRoot> {
if ds.is_none() && embedding.is_none() {
return Ok(StateRoot(auth_root.0.clone()));
}
let mut variants = BTreeMap::new();
for &alg in algs {
let ar_bytes = auth_root.0.get_or_err(alg)?;
let mut components: Vec<&[u8]> = vec![ar_bytes];
if let Some(d) = ds {
components.push(d.0.as_bytes());
}
if let Some(e) = embedding {
components.push(e);
}
let digest = hash_sorted_concat_bytes(alg, &components);
variants.insert(alg, digest.into_boxed_slice());
}
Ok(StateRoot(MultihashDigest::new(variants)?))
}
pub fn compute_dr(action_czds: &[&Czd], nonce: Option<&[u8]>, alg: HashAlg) -> Option<DataRoot> {
if action_czds.is_empty() && nonce.is_none() {
return None;
}
if action_czds.len() == 1 && nonce.is_none() {
return Some(DataRoot(Cad::from_bytes(
action_czds[0].as_bytes().to_vec(),
)));
}
let mut components: Vec<&[u8]> = action_czds.iter().map(|c| c.as_bytes()).collect();
if let Some(n) = nonce {
components.push(n);
}
Some(DataRoot(hash_sorted_concat(alg, &components)))
}
pub fn compute_pr(
state_root: &StateRoot,
cr: Option<&crate::commit_root::CommitRoot>,
embedding: Option<&[u8]>,
algs: &[HashAlg],
) -> crate::error::Result<PrincipalRoot> {
if cr.is_none() && embedding.is_none() {
return Ok(PrincipalRoot(state_root.0.clone()));
}
let mut variants = BTreeMap::new();
for &alg in algs {
let sr_bytes = state_root.0.get_or_err(alg)?;
let mut components: Vec<&[u8]> = vec![sr_bytes];
if let Some(cr_digest) = cr {
components.push(cr_digest.0.get_or_err(alg)?);
}
if let Some(e) = embedding {
components.push(e);
}
let digest = hash_sorted_concat_bytes(alg, &components);
variants.insert(alg, digest.into_boxed_slice());
}
Ok(PrincipalRoot(MultihashDigest::new(variants)?))
}
pub(crate) fn derive_auth_state(
thumbprints: &[&Thumbprint],
dr: Option<&DataRoot>,
algs: &[HashAlg],
) -> crate::error::Result<(KeyRoot, AuthRoot, StateRoot)> {
let kr = compute_kr(thumbprints, None, algs)?;
let ar = compute_ar(&kr, None, None, algs)?;
let sr = compute_sr(&ar, dr, None, algs)?;
Ok((kr, ar, sr))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ks_single_key_promotion() {
let tmb = Thumbprint::from_bytes(vec![1, 2, 3, 4]);
let ks = compute_kr(&[&tmb], None, &[HashAlg::Sha256]).unwrap();
assert_eq!(ks.0.len(), 1);
assert_eq!(ks.get(HashAlg::Sha256).unwrap(), tmb.as_bytes());
}
#[test]
fn ks_multi_key_hashes() {
let tmb1 = Thumbprint::from_bytes(vec![1, 2, 3]);
let tmb2 = Thumbprint::from_bytes(vec![4, 5, 6]);
let ks = compute_kr(&[&tmb1, &tmb2], None, &[HashAlg::Sha256]).unwrap();
let digest = ks.get(HashAlg::Sha256).unwrap();
assert_eq!(digest.len(), 32);
assert_ne!(digest, tmb1.as_bytes());
}
#[test]
fn ks_with_nonce_hashes() {
let tmb = Thumbprint::from_bytes(vec![1, 2, 3, 4]);
let nonce = vec![0xAA, 0xBB];
let ks = compute_kr(&[&tmb], Some(&nonce), &[HashAlg::Sha256]).unwrap();
let digest = ks.get(HashAlg::Sha256).unwrap();
assert_eq!(digest.len(), 32);
assert_ne!(digest, tmb.as_bytes());
}
#[test]
fn commit_id_empty_is_none() {
let cid = compute_commit_id(&[], None, &[HashAlg::Sha256]);
assert!(cid.is_none());
}
#[test]
fn commit_id_single_czd_promotion() {
let czd = Czd::from_bytes(vec![10, 20, 30]);
let cid = compute_commit_id(&[&czd], None, &[HashAlg::Sha256]);
let cid_bytes = cid.as_ref().map(|c| c.get(HashAlg::Sha256).unwrap());
assert_eq!(cid_bytes.unwrap(), czd.as_bytes());
}
#[test]
fn as_promotion_from_ks() {
let tmb = Thumbprint::from_bytes(vec![1, 2, 3, 4]);
let ks = compute_kr(&[&tmb], None, &[HashAlg::Sha256]).unwrap();
let auth_root = compute_ar(&ks, None, None, &[HashAlg::Sha256]).unwrap();
assert_eq!(
auth_root.get(HashAlg::Sha256).unwrap(),
ks.get(HashAlg::Sha256).unwrap()
);
}
#[test]
fn cs_with_ds_hashes() {
let tmb = Thumbprint::from_bytes(vec![1, 2, 3, 4]);
let ks = compute_kr(&[&tmb], None, &[HashAlg::Sha256]).unwrap();
let auth_root = compute_ar(&ks, None, None, &[HashAlg::Sha256]).unwrap();
let czd = Czd::from_bytes(vec![10, 20, 30]);
let ds = compute_dr(&[&czd], None, HashAlg::Sha256).unwrap();
let cs = compute_sr(&auth_root, Some(&ds), None, &[HashAlg::Sha256]).unwrap();
let cs_bytes = cs.get(HashAlg::Sha256).unwrap();
assert_eq!(cs_bytes.len(), 32);
assert_ne!(cs_bytes, auth_root.get(HashAlg::Sha256).unwrap());
}
#[test]
fn ps_promotion_from_sr() {
let tmb = Thumbprint::from_bytes(vec![1, 2, 3, 4]);
let ks = compute_kr(&[&tmb], None, &[HashAlg::Sha256]).unwrap();
let auth_root = compute_ar(&ks, None, None, &[HashAlg::Sha256]).unwrap();
let sr = compute_sr(&auth_root, None, None, &[HashAlg::Sha256]).unwrap();
let ps = compute_pr(&sr, None, None, &[HashAlg::Sha256]).unwrap();
assert_eq!(
ps.get(HashAlg::Sha256).unwrap(),
auth_root.get(HashAlg::Sha256).unwrap()
);
}
#[test]
fn full_promotion_chain() {
let tmb = Thumbprint::from_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
let ks = compute_kr(&[&tmb], None, &[HashAlg::Sha256]).unwrap();
let auth_root = compute_ar(&ks, None, None, &[HashAlg::Sha256]).unwrap();
let sr = compute_sr(&auth_root, None, None, &[HashAlg::Sha256]).unwrap();
let ps = compute_pr(&sr, None, None, &[HashAlg::Sha256]).unwrap();
let pr = PrincipalGenesis::from_initial(&ps);
let ks_bytes = ks.get(HashAlg::Sha256).unwrap();
let as_bytes = auth_root.get(HashAlg::Sha256).unwrap();
assert_eq!(ks_bytes, tmb.as_bytes());
assert_eq!(as_bytes, tmb.as_bytes());
assert_eq!(ps.get(HashAlg::Sha256).unwrap(), tmb.as_bytes());
assert_eq!(pr.get(HashAlg::Sha256).unwrap(), tmb.as_bytes());
}
#[test]
fn cross_algorithm_conversion_spec_14_2() {
let tmb_es256 = Thumbprint::from_bytes(vec![0xAA; 32]); let tmb_es384 = Thumbprint::from_bytes(vec![0xBB; 48]); let tmb_ed25519 = Thumbprint::from_bytes(vec![0xCC; 64]);
let all_algs = [HashAlg::Sha256, HashAlg::Sha384, HashAlg::Sha512];
let ks = compute_kr(&[&tmb_es256, &tmb_es384, &tmb_ed25519], None, &all_algs).unwrap();
let sha256_variant = ks.get(HashAlg::Sha256);
let sha384_variant = ks.get(HashAlg::Sha384);
let sha512_variant = ks.get(HashAlg::Sha512);
assert!(sha256_variant.is_some(), "SHA-256 variant should exist");
assert!(sha384_variant.is_some(), "SHA-384 variant should exist");
assert!(sha512_variant.is_some(), "SHA-512 variant should exist");
assert_eq!(
sha256_variant.unwrap().len(),
32,
"SHA-256 digest is 32 bytes"
);
assert_eq!(
sha384_variant.unwrap().len(),
48,
"SHA-384 digest is 48 bytes"
);
assert_eq!(
sha512_variant.unwrap().len(),
64,
"SHA-512 digest is 64 bytes"
);
assert_ne!(sha256_variant, sha384_variant);
assert_ne!(sha384_variant, sha512_variant);
assert_ne!(sha256_variant, sha512_variant);
}
#[test]
fn tagged_digest_parse_valid_sha256() {
let input = "SHA-256:U5XUZots-WmQYcQWmsO751Xk0yeVi9XUKWQ2mGz6Aqg";
let digest: super::TaggedDigest = input.parse().expect("valid digest");
assert_eq!(digest.alg(), super::HashAlg::Sha256);
assert_eq!(digest.as_bytes().len(), 32);
}
#[test]
fn tagged_digest_display_roundtrip() {
let input = "SHA-256:U5XUZots-WmQYcQWmsO751Xk0yeVi9XUKWQ2mGz6Aqg";
let digest: super::TaggedDigest = input.parse().expect("valid digest");
let output = digest.to_string();
assert_eq!(input, output);
}
#[test]
fn tagged_digest_serde_roundtrip() {
let input = "SHA-256:U5XUZots-WmQYcQWmsO751Xk0yeVi9XUKWQ2mGz6Aqg";
let digest: super::TaggedDigest = input.parse().expect("valid digest");
let json = serde_json::to_string(&digest).expect("serialize");
assert_eq!(json, format!("\"{}\"", input));
let parsed: super::TaggedDigest = serde_json::from_str(&json).expect("deserialize");
assert_eq!(digest, parsed);
}
#[test]
fn tagged_digest_missing_separator() {
let input = "SHA-256U5XUZots-WmQYcQWmsO751Xk0yeVi9XUKWQ2mGz6Aqg";
let result: Result<super::TaggedDigest, _> = input.parse();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, crate::error::Error::MalformedDigest(_)),
"expected MalformedDigest, got {:?}",
err
);
}
#[test]
fn tagged_digest_unknown_algorithm() {
let input = "SHA-999:U5XUZots-WmQYcQWmsO751Xk0yeVi9XUKWQ2mGz6Aqg";
let result: Result<super::TaggedDigest, _> = input.parse();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, crate::error::Error::UnsupportedAlgorithm(_)),
"expected UnsupportedAlgorithm, got {:?}",
err
);
}
#[test]
fn tagged_digest_length_mismatch() {
let input = "SHA-256:AAAAAAAAAAAAAAAAAAAAAA";
let result: Result<super::TaggedDigest, _> = input.parse();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(
err,
crate::error::Error::DigestLengthMismatch {
alg: super::HashAlg::Sha256,
expected: 32,
actual: 16,
}
),
"expected DigestLengthMismatch, got {:?}",
err
);
}
}