use core::fmt::{Debug, Display, Formatter};
use crate::{
DiscoKeyPair, MachineKeyPair, MachinePrivateKey, NetworkLockKeyPair, NetworkLockPrivateKey,
NodeKeyPair, NodePrivateKey, NodePublicKey,
};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PersistState {
pub machine_key: MachinePrivateKey,
pub network_lock_key: NetworkLockPrivateKey,
pub node_key: NodePrivateKey,
#[cfg_attr(feature = "serde", serde(default))]
pub old_node_key: Option<NodePublicKey>,
#[cfg_attr(feature = "serde", serde(default))]
pub acme_account_key: Option<zeroize::Zeroizing<alloc::vec::Vec<u8>>>,
}
impl PersistState {
pub fn rotate_node_key(&mut self) {
self.old_node_key = Some(self.node_key.public_key());
self.node_key = NodePrivateKey::random();
}
}
impl From<&NodeState> for PersistState {
fn from(value: &NodeState) -> Self {
Self {
node_key: value.node_keys.private,
machine_key: value.machine_keys.private,
network_lock_key: value.network_lock_keys.private,
old_node_key: value.old_node_key,
acme_account_key: value.acme_account_key.clone(),
}
}
}
impl From<NodeState> for PersistState {
fn from(value: NodeState) -> Self {
Self::from(&value)
}
}
impl Default for PersistState {
fn default() -> Self {
Self {
machine_key: MachinePrivateKey::random(),
network_lock_key: NetworkLockPrivateKey::random(),
node_key: NodePrivateKey::random(),
old_node_key: None,
acme_account_key: None,
}
}
}
#[derive(Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
pub struct NodeState {
pub disco_keys: DiscoKeyPair,
pub machine_keys: MachineKeyPair,
pub network_lock_keys: NetworkLockKeyPair,
pub node_keys: NodeKeyPair,
#[cfg_attr(feature = "serde", serde(default))]
pub old_node_key: Option<NodePublicKey>,
#[cfg_attr(feature = "serde", serde(default))]
pub acme_account_key: Option<zeroize::Zeroizing<alloc::vec::Vec<u8>>>,
}
impl Debug for NodeState {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("NodeState")
.field(&self.machine_keys.public)
.field(&self.node_keys.public)
.field(&self.disco_keys.public)
.field(&self.network_lock_keys.public)
.finish()
}
}
impl Display for NodeState {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
Debug::fmt(self, f)
}
}
impl NodeState {
pub fn generate() -> Self {
Default::default()
}
}
impl From<&PersistState> for NodeState {
fn from(value: &PersistState) -> Self {
Self {
disco_keys: Default::default(),
node_keys: value.node_key.into(),
machine_keys: value.machine_key.into(),
network_lock_keys: value.network_lock_key.into(),
old_node_key: value.old_node_key,
acme_account_key: value.acme_account_key.clone(),
}
}
}
impl From<PersistState> for NodeState {
fn from(value: PersistState) -> Self {
Self::from(&value)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rotate_node_key_sets_old_and_fresh() {
let mut state = PersistState::default();
let before_pub = state.node_key.public_key();
state.rotate_node_key();
assert_eq!(state.old_node_key, Some(before_pub));
assert_ne!(state.node_key.public_key(), before_pub);
}
#[test]
fn node_state_threads_old_node_key() {
let mut persist = PersistState::default();
let some_pub = NodePrivateKey::random().public_key();
persist.old_node_key = Some(some_pub);
let node_state = NodeState::from(&persist);
assert_eq!(node_state.old_node_key, Some(some_pub));
let round_trip = PersistState::from(&node_state);
assert_eq!(round_trip.old_node_key, Some(some_pub));
}
#[test]
fn default_persist_state_has_no_old_key() {
assert!(PersistState::default().old_node_key.is_none());
}
#[cfg(feature = "serde")]
#[test]
fn persist_state_old_node_key_serde_default() {
let json = serde_json::to_string(&PersistState::default()).unwrap();
let parsed: PersistState = serde_json::from_str(&json).unwrap();
assert!(parsed.old_node_key.is_none());
let mut value: serde_json::Value = serde_json::from_str(&json).unwrap();
value
.as_object_mut()
.unwrap()
.remove("old_node_key")
.expect("default serializes the field");
let parsed: PersistState =
serde_json::from_value(value).expect("missing old_node_key deserializes via default");
assert!(parsed.old_node_key.is_none());
}
#[cfg(feature = "serde")]
#[test]
fn persist_state_acme_account_key_serde_default_and_round_trip() {
use alloc::vec;
let json = serde_json::to_string(&PersistState::default()).unwrap();
let mut value: serde_json::Value = serde_json::from_str(&json).unwrap();
value
.as_object_mut()
.unwrap()
.remove("acme_account_key")
.expect("default serializes the field");
let parsed: PersistState = serde_json::from_value(value)
.expect("missing acme_account_key deserializes via default");
assert!(parsed.acme_account_key.is_none());
let state = PersistState {
acme_account_key: Some(zeroize::Zeroizing::new(vec![1u8, 2, 3, 4])),
..Default::default()
};
let json = serde_json::to_string(&state).unwrap();
assert!(
json.contains("\"acme_account_key\":[1,2,3,4]"),
"Zeroizing must serialize as the bare byte array (unchanged JSON shape): {json}"
);
let parsed: PersistState = serde_json::from_str(&json).unwrap();
assert_eq!(
parsed.acme_account_key.as_deref().map(|v| v.as_slice()),
Some(&[1u8, 2, 3, 4][..])
);
let node_state = NodeState::from(&state);
assert_eq!(
node_state.acme_account_key.as_deref().map(|v| v.as_slice()),
Some(&[1u8, 2, 3, 4][..])
);
let round_trip = PersistState::from(&node_state);
assert_eq!(
round_trip.acme_account_key.as_deref().map(|v| v.as_slice()),
Some(&[1u8, 2, 3, 4][..])
);
}
}