use hkdf::Hkdf;
use sha2::Sha256;
use zeroize::Zeroize;
#[derive(Debug, thiserror::Error)]
pub enum DeriveError {
#[error("key derivation label must not be empty")]
EmptyLabel,
}
pub fn validate_label(label: &str) -> Result<(), DeriveError> {
if label.is_empty() {
return Err(DeriveError::EmptyLabel);
}
Ok(())
}
const HKDF_SALT: &[u8] = b"styrene-identity-v1";
const HKDF_SALT_AGENT: &[u8] = b"styrene-identity-agent-v1";
const HKDF_SALT_SSH_USER: &[u8] = b"styrene-identity-ssh-user-v1";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KeyPurpose {
RnsEncryption,
RnsSigning,
Yggdrasil,
WireGuard,
SshHost,
Age,
GitSigning,
}
impl KeyPurpose {
pub fn info(&self) -> &'static [u8] {
match self {
Self::RnsEncryption => b"styrene-rns-encryption-v1",
Self::RnsSigning => b"styrene-rns-signing-v1",
Self::Yggdrasil => b"styrene-yggdrasil-v1",
Self::WireGuard => b"styrene-wireguard-v1",
Self::SshHost => b"styrene-ssh-host-v1",
Self::Age => b"styrene-age-v1",
Self::GitSigning => b"styrene-git-signing-v1",
}
}
pub fn all() -> &'static [KeyPurpose] {
&[
Self::RnsEncryption,
Self::RnsSigning,
Self::Yggdrasil,
Self::WireGuard,
Self::SshHost,
Self::Age,
Self::GitSigning,
]
}
}
pub struct KeyDeriver {
prk: [u8; 32],
}
impl Drop for KeyDeriver {
fn drop(&mut self) {
self.prk.zeroize();
}
}
impl KeyDeriver {
pub fn new(root_secret: &[u8; 32]) -> Self {
let (prk_hmac, _) = Hkdf::<Sha256>::extract(Some(HKDF_SALT), root_secret);
let mut prk_bytes = [0u8; 32];
prk_bytes.copy_from_slice(prk_hmac.as_slice());
Self { prk: prk_bytes }
}
fn expander(&self) -> Hkdf<Sha256> {
Hkdf::<Sha256>::from_prk(&self.prk).expect("32-byte PRK is always valid for HKDF-SHA256")
}
pub fn derive(&self, purpose: KeyPurpose) -> [u8; 32] {
let mut okm = [0u8; 32];
self.expander()
.expand(purpose.info(), &mut okm)
.expect("HKDF-SHA256 expand to 32 bytes should never fail");
okm
}
pub fn derive_all(&self) -> DerivedKeys {
DerivedKeys {
rns_encryption: self.derive(KeyPurpose::RnsEncryption),
rns_signing: self.derive(KeyPurpose::RnsSigning),
yggdrasil: self.derive(KeyPurpose::Yggdrasil),
wireguard: self.derive(KeyPurpose::WireGuard),
ssh_host: self.derive(KeyPurpose::SshHost),
age: self.derive(KeyPurpose::Age),
git_signing: self.derive(KeyPurpose::GitSigning),
}
}
pub fn ssh_host_seed(&self) -> [u8; 32] {
self.derive(KeyPurpose::SshHost)
}
pub fn age_secret(&self) -> [u8; 32] {
self.derive(KeyPurpose::Age)
}
pub fn git_signing_seed(&self) -> [u8; 32] {
self.derive(KeyPurpose::GitSigning)
}
pub fn derive_agent_key(&self, agent_name: &str) -> Result<[u8; 32], DeriveError> {
if agent_name.is_empty() {
return Err(DeriveError::EmptyLabel);
}
let mut master = [0u8; 32];
self.expander()
.expand(b"styrene-agent-master-v1", &mut master)
.expect("HKDF expand should not fail");
let hk2 = Hkdf::<Sha256>::new(Some(HKDF_SALT_AGENT), &master);
master.zeroize();
let mut okm = [0u8; 32];
hk2.expand(agent_name.as_bytes(), &mut okm).expect("HKDF expand should not fail");
Ok(okm)
}
pub fn derive_ssh_user_key(&self, label: &str) -> Result<[u8; 32], DeriveError> {
if label.is_empty() {
return Err(DeriveError::EmptyLabel);
}
let mut master = [0u8; 32];
self.expander()
.expand(b"styrene-ssh-user-master-v1", &mut master)
.expect("HKDF expand should not fail");
let hk2 = Hkdf::<Sha256>::new(Some(HKDF_SALT_SSH_USER), &master);
master.zeroize();
let mut okm = [0u8; 32];
hk2.expand(label.as_bytes(), &mut okm).expect("HKDF expand should not fail");
Ok(okm)
}
}
pub fn derive_key(root_secret: &[u8; 32], purpose: KeyPurpose) -> [u8; 32] {
KeyDeriver::new(root_secret).derive(purpose)
}
#[derive(Zeroize)]
#[zeroize(drop)]
pub struct DerivedKeys {
pub rns_encryption: [u8; 32],
pub rns_signing: [u8; 32],
pub yggdrasil: [u8; 32],
pub wireguard: [u8; 32],
pub ssh_host: [u8; 32],
pub age: [u8; 32],
pub git_signing: [u8; 32],
}
impl std::fmt::Debug for DerivedKeys {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("DerivedKeys([REDACTED])")
}
}
pub fn derive_keys(root_secret: &[u8; 32]) -> DerivedKeys {
KeyDeriver::new(root_secret).derive_all()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_key_deterministic() {
let root = [42u8; 32];
let k1 = derive_key(&root, KeyPurpose::RnsEncryption);
let k2 = derive_key(&root, KeyPurpose::RnsEncryption);
assert_eq!(k1, k2);
}
#[test]
fn different_purposes_produce_different_keys() {
let root = [42u8; 32];
let keys: Vec<[u8; 32]> = KeyPurpose::all().iter().map(|p| derive_key(&root, *p)).collect();
for i in 0..keys.len() {
for j in (i + 1)..keys.len() {
assert_ne!(keys[i], keys[j], "collision between purposes {i} and {j}");
}
}
}
#[test]
fn different_roots_produce_different_keys() {
let k1 = derive_key(&[1u8; 32], KeyPurpose::RnsEncryption);
let k2 = derive_key(&[2u8; 32], KeyPurpose::RnsEncryption);
assert_ne!(k1, k2);
}
#[test]
fn derive_keys_produces_all_seven() {
let root = [99u8; 32];
let keys = derive_keys(&root);
assert_ne!(keys.rns_encryption, [0u8; 32]);
assert_ne!(keys.rns_signing, [0u8; 32]);
assert_ne!(keys.yggdrasil, [0u8; 32]);
assert_ne!(keys.wireguard, [0u8; 32]);
assert_ne!(keys.ssh_host, [0u8; 32]);
assert_ne!(keys.age, [0u8; 32]);
assert_ne!(keys.git_signing, [0u8; 32]);
assert_ne!(keys.rns_encryption, keys.rns_signing);
}
#[test]
fn all_purposes_covered() {
assert_eq!(KeyPurpose::all().len(), 7);
}
#[test]
fn key_deriver_matches_free_function() {
let root = [42u8; 32];
let deriver = KeyDeriver::new(&root);
for purpose in KeyPurpose::all() {
assert_eq!(deriver.derive(*purpose), derive_key(&root, *purpose));
}
}
#[test]
fn key_deriver_derive_all_matches_individual() {
let root = [77u8; 32];
let deriver = KeyDeriver::new(&root);
let all = deriver.derive_all();
assert_eq!(all.rns_encryption, deriver.derive(KeyPurpose::RnsEncryption));
assert_eq!(all.rns_signing, deriver.derive(KeyPurpose::RnsSigning));
assert_eq!(all.yggdrasil, deriver.derive(KeyPurpose::Yggdrasil));
assert_eq!(all.wireguard, deriver.derive(KeyPurpose::WireGuard));
assert_eq!(all.ssh_host, deriver.derive(KeyPurpose::SshHost));
assert_eq!(all.age, deriver.derive(KeyPurpose::Age));
assert_eq!(all.git_signing, deriver.derive(KeyPurpose::GitSigning));
}
#[test]
fn ssh_host_and_age_non_zero_and_distinct() {
let root = [55u8; 32];
let deriver = KeyDeriver::new(&root);
let ssh = deriver.ssh_host_seed();
let age = deriver.age_secret();
assert_ne!(ssh, [0u8; 32]);
assert_ne!(age, [0u8; 32]);
assert_ne!(ssh, age);
}
#[test]
fn ssh_user_key_deterministic() {
let d = KeyDeriver::new(&[42u8; 32]);
let k1 = d.derive_ssh_user_key("github").unwrap();
let k2 = d.derive_ssh_user_key("github").unwrap();
assert_eq!(k1, k2);
}
#[test]
fn ssh_user_key_different_labels() {
let d = KeyDeriver::new(&[42u8; 32]);
let github = d.derive_ssh_user_key("github").unwrap();
let work = d.derive_ssh_user_key("work").unwrap();
assert_ne!(github, work);
}
#[test]
fn ssh_user_key_no_collision_with_flat_purposes() {
let d = KeyDeriver::new(&[42u8; 32]);
let ssh_user = d.derive_ssh_user_key("github").unwrap();
for purpose in KeyPurpose::all() {
let flat = d.derive(*purpose);
assert_ne!(ssh_user, flat, "SSH user key collides with {:?}", purpose);
}
assert_ne!(ssh_user, d.ssh_host_seed());
assert_ne!(ssh_user, d.age_secret());
}
#[test]
fn ssh_user_key_different_roots() {
let k1 = KeyDeriver::new(&[1u8; 32]).derive_ssh_user_key("github").unwrap();
let k2 = KeyDeriver::new(&[2u8; 32]).derive_ssh_user_key("github").unwrap();
assert_ne!(k1, k2);
}
#[test]
fn ssh_user_key_empty_label_rejected() {
let d = KeyDeriver::new(&[42u8; 32]);
assert!(d.derive_ssh_user_key("").is_err());
}
#[test]
fn agent_key_deterministic() {
let d = KeyDeriver::new(&[42u8; 32]);
let k1 = d.derive_agent_key("omegon-primary").unwrap();
let k2 = d.derive_agent_key("omegon-primary").unwrap();
assert_eq!(k1, k2);
}
#[test]
fn agent_key_different_names() {
let d = KeyDeriver::new(&[42u8; 32]);
let primary = d.derive_agent_key("omegon-primary").unwrap();
let cleave = d.derive_agent_key("omegon-cleave-0").unwrap();
assert_ne!(primary, cleave);
}
#[test]
fn agent_key_no_collision_with_flat_or_ssh() {
let d = KeyDeriver::new(&[42u8; 32]);
let agent = d.derive_agent_key("omegon-primary").unwrap();
for purpose in KeyPurpose::all() {
assert_ne!(agent, d.derive(*purpose), "agent key collides with {:?}", purpose);
}
assert_ne!(agent, d.derive_ssh_user_key("github").unwrap());
assert_ne!(agent, d.git_signing_seed());
}
#[test]
fn agent_key_differs_from_ssh_user_same_label() {
let d = KeyDeriver::new(&[42u8; 32]);
let ssh = d.derive_ssh_user_key("github").unwrap();
let agent = d.derive_agent_key("github").unwrap();
assert_ne!(ssh, agent);
}
#[test]
fn agent_key_empty_name_rejected() {
let d = KeyDeriver::new(&[42u8; 32]);
assert!(d.derive_agent_key("").is_err());
}
#[test]
fn git_signing_distinct_from_all() {
let d = KeyDeriver::new(&[42u8; 32]);
let git = d.git_signing_seed();
assert_ne!(git, [0u8; 32]);
assert_ne!(git, d.ssh_host_seed());
assert_ne!(git, d.age_secret());
assert_ne!(git, d.derive_ssh_user_key("github").unwrap());
assert_ne!(git, d.derive_agent_key("omegon-primary").unwrap());
}
#[test]
fn test_vector_flat_purposes() {
let d = KeyDeriver::new(&[0x42u8; 32]);
assert_eq!(
hex::encode(d.derive(KeyPurpose::RnsEncryption)),
"aefdbd63fb6746c2edb73bba3bcb34f61909077f65fe033c9372b55f6ace0c0c"
);
assert_eq!(
hex::encode(d.derive(KeyPurpose::GitSigning)),
"6eb3d3ef12a2447f6de281d6f896eba20ad0b0add3bc6fce80499f36b7343842"
);
}
#[test]
fn test_vector_ssh_user_key() {
let d = KeyDeriver::new(&[0x42u8; 32]);
assert_eq!(
hex::encode(d.derive_ssh_user_key("github").unwrap()),
"3c261af80e084a637fd20e0f7274a4106702894f0d23c47e855f6c9adce20d75"
);
}
#[test]
fn test_vector_agent_key() {
let d = KeyDeriver::new(&[0x42u8; 32]);
assert_eq!(
hex::encode(d.derive_agent_key("omegon-primary").unwrap()),
"4dd66edcda091a5e3d15aa3fb8ec32d81e212d94760b61915b1d6f204b0672e2"
);
}
#[test]
fn salt_provides_domain_separation() {
let root = [42u8; 32];
let salted = Hkdf::<Sha256>::new(Some(HKDF_SALT), &root);
let unsalted = Hkdf::<Sha256>::new(None, &root);
let mut s_out = [0u8; 32];
let mut u_out = [0u8; 32];
let info = KeyPurpose::RnsEncryption.info();
salted.expand(info, &mut s_out).expect("expand");
unsalted.expand(info, &mut u_out).expect("expand");
assert_ne!(s_out, u_out, "salt must change derived output");
}
}