use hmac::{Hmac, Mac};
use rand_core::{OsRng, RngCore};
use sha2::Sha256;
use super::error::SecurityError;
pub const FORMATION_CHALLENGE_SIZE: usize = 32;
pub const FORMATION_RESPONSE_SIZE: usize = 32;
const HKDF_INFO_FORMATION: &[u8] = b"peat-protocol-v1-formation";
type HmacSha256 = Hmac<Sha256>;
#[derive(Clone)]
pub struct FormationKey {
formation_id: String,
hmac_key: [u8; 32],
}
impl FormationKey {
pub fn new(formation_id: &str, shared_secret: &[u8; 32]) -> Self {
use hkdf::Hkdf;
let hk = Hkdf::<Sha256>::new(Some(formation_id.as_bytes()), shared_secret);
let mut hmac_key = [0u8; 32];
hk.expand(HKDF_INFO_FORMATION, &mut hmac_key)
.expect("HKDF expand should never fail with 32-byte output");
Self {
formation_id: formation_id.to_string(),
hmac_key,
}
}
pub fn from_base64(formation_id: &str, base64_secret: &str) -> Result<Self, SecurityError> {
use base64::{engine::general_purpose::STANDARD, Engine};
use sha2::{Digest, Sha256};
let secret_bytes = STANDARD.decode(base64_secret.trim()).map_err(|e| {
SecurityError::AuthenticationFailed(format!("Invalid base64 shared secret: {}", e))
})?;
let secret: [u8; 32] = if secret_bytes.len() == 32 {
let mut arr = [0u8; 32];
arr.copy_from_slice(&secret_bytes);
arr
} else {
let mut hasher = Sha256::new();
hasher.update(&secret_bytes);
hasher.finalize().into()
};
Ok(Self::new(formation_id, &secret))
}
pub fn generate_secret() -> String {
use base64::{engine::general_purpose::STANDARD, Engine};
let mut secret = [0u8; 32];
OsRng.fill_bytes(&mut secret);
STANDARD.encode(secret)
}
pub fn formation_id(&self) -> &str {
&self.formation_id
}
pub fn create_challenge(
&self,
) -> (
[u8; FORMATION_CHALLENGE_SIZE],
[u8; FORMATION_RESPONSE_SIZE],
) {
let mut nonce = [0u8; FORMATION_CHALLENGE_SIZE];
OsRng.fill_bytes(&mut nonce);
let expected = self.compute_response(&nonce);
(nonce, expected)
}
pub fn respond_to_challenge(&self, nonce: &[u8]) -> [u8; FORMATION_RESPONSE_SIZE] {
self.compute_response(nonce)
}
pub fn verify_response(&self, nonce: &[u8], response: &[u8; FORMATION_RESPONSE_SIZE]) -> bool {
let expected = self.compute_response(nonce);
use subtle::ConstantTimeEq;
expected.ct_eq(response).into()
}
fn compute_response(&self, nonce: &[u8]) -> [u8; FORMATION_RESPONSE_SIZE] {
let mut mac =
HmacSha256::new_from_slice(&self.hmac_key).expect("HMAC key should be valid length");
mac.update(nonce);
mac.update(self.formation_id.as_bytes());
let result = mac.finalize();
let mut response = [0u8; FORMATION_RESPONSE_SIZE];
response.copy_from_slice(&result.into_bytes());
response
}
}
impl std::fmt::Debug for FormationKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FormationKey")
.field("formation_id", &self.formation_id)
.field("hmac_key", &"[REDACTED]")
.finish()
}
}
#[derive(Debug, Clone)]
pub struct FormationChallenge {
pub formation_id: String,
pub nonce: [u8; FORMATION_CHALLENGE_SIZE],
}
impl FormationChallenge {
pub fn to_bytes(&self) -> Vec<u8> {
let id_bytes = self.formation_id.as_bytes();
let mut bytes = Vec::with_capacity(2 + id_bytes.len() + FORMATION_CHALLENGE_SIZE);
bytes.extend_from_slice(&(id_bytes.len() as u16).to_le_bytes());
bytes.extend_from_slice(id_bytes);
bytes.extend_from_slice(&self.nonce);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SecurityError> {
if bytes.len() < 2 {
return Err(SecurityError::AuthenticationFailed(
"Challenge too short".to_string(),
));
}
let id_len = u16::from_le_bytes([bytes[0], bytes[1]]) as usize;
if bytes.len() < 2 + id_len + FORMATION_CHALLENGE_SIZE {
return Err(SecurityError::AuthenticationFailed(
"Challenge truncated".to_string(),
));
}
let formation_id = String::from_utf8(bytes[2..2 + id_len].to_vec()).map_err(|e| {
SecurityError::AuthenticationFailed(format!("Invalid formation ID: {}", e))
})?;
let mut nonce = [0u8; FORMATION_CHALLENGE_SIZE];
nonce.copy_from_slice(&bytes[2 + id_len..2 + id_len + FORMATION_CHALLENGE_SIZE]);
Ok(Self {
formation_id,
nonce,
})
}
}
#[derive(Debug, Clone)]
pub struct FormationChallengeResponse {
pub response: [u8; FORMATION_RESPONSE_SIZE],
}
impl FormationChallengeResponse {
pub fn to_bytes(&self) -> Vec<u8> {
self.response.to_vec()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SecurityError> {
if bytes.len() < FORMATION_RESPONSE_SIZE {
return Err(SecurityError::AuthenticationFailed(
"Response too short".to_string(),
));
}
let mut response = [0u8; FORMATION_RESPONSE_SIZE];
response.copy_from_slice(&bytes[..FORMATION_RESPONSE_SIZE]);
Ok(Self { response })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FormationAuthResult {
Accepted,
Rejected,
}
impl FormationAuthResult {
pub fn to_byte(self) -> u8 {
match self {
Self::Accepted => 0x01,
Self::Rejected => 0x00,
}
}
pub fn from_byte(byte: u8) -> Self {
if byte == 0x01 {
Self::Accepted
} else {
Self::Rejected
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_formation_key_creation() {
let secret = [0x42u8; 32];
let key = FormationKey::new("alpha-company", &secret);
assert_eq!(key.formation_id(), "alpha-company");
}
#[test]
fn test_formation_key_from_base64() {
let secret = FormationKey::generate_secret();
let key = FormationKey::from_base64("test-formation", &secret).unwrap();
assert_eq!(key.formation_id(), "test-formation");
}
#[test]
fn test_formation_key_from_base64_invalid() {
let result = FormationKey::from_base64("test", "not-valid-base64!!!");
assert!(result.is_err());
}
#[test]
fn test_formation_key_from_base64_derives_key_for_non_32_bytes() {
use base64::{engine::general_purpose::STANDARD, Engine};
let short_secret = STANDARD.encode([0u8; 16]);
let result = FormationKey::from_base64("test", &short_secret);
assert!(result.is_ok(), "Short key should be derived via SHA-256");
let long_secret = STANDARD.encode([0xABu8; 138]);
let result = FormationKey::from_base64("test", &long_secret);
assert!(
result.is_ok(),
"Long key (EC format) should be derived via SHA-256"
);
let key1 = FormationKey::from_base64("test", &short_secret).unwrap();
let key2 = FormationKey::from_base64("test", &long_secret).unwrap();
let (nonce, _) = key1.create_challenge();
let response1 = key1.respond_to_challenge(&nonce);
assert!(
!key2.verify_response(&nonce, &response1),
"Different input keys should produce different derived keys"
);
}
#[test]
fn test_challenge_response_success() {
let secret = [0x42u8; 32];
let key = FormationKey::new("alpha-company", &secret);
let (nonce, _expected) = key.create_challenge();
let response = key.respond_to_challenge(&nonce);
assert!(key.verify_response(&nonce, &response));
}
#[test]
fn test_challenge_response_wrong_key() {
let secret1 = [0x42u8; 32];
let secret2 = [0x43u8; 32];
let key1 = FormationKey::new("alpha-company", &secret1);
let key2 = FormationKey::new("alpha-company", &secret2);
let (nonce, _expected) = key1.create_challenge();
let response = key2.respond_to_challenge(&nonce);
assert!(!key1.verify_response(&nonce, &response));
}
#[test]
fn test_challenge_response_wrong_formation() {
let secret = [0x42u8; 32];
let key1 = FormationKey::new("alpha-company", &secret);
let key2 = FormationKey::new("bravo-company", &secret);
let (nonce, _expected) = key1.create_challenge();
let response = key2.respond_to_challenge(&nonce);
assert!(!key1.verify_response(&nonce, &response));
}
#[test]
fn test_different_nonces_produce_different_responses() {
let secret = [0x42u8; 32];
let key = FormationKey::new("alpha-company", &secret);
let (nonce1, _) = key.create_challenge();
let (nonce2, _) = key.create_challenge();
let response1 = key.respond_to_challenge(&nonce1);
let response2 = key.respond_to_challenge(&nonce2);
assert_ne!(response1, response2);
}
#[test]
fn test_challenge_serialization() {
let mut nonce = [0u8; FORMATION_CHALLENGE_SIZE];
nonce[0] = 0x42;
let challenge = FormationChallenge {
formation_id: "test-formation".to_string(),
nonce,
};
let bytes = challenge.to_bytes();
let restored = FormationChallenge::from_bytes(&bytes).unwrap();
assert_eq!(challenge.formation_id, restored.formation_id);
assert_eq!(challenge.nonce, restored.nonce);
}
#[test]
fn test_response_serialization() {
let mut response_bytes = [0u8; FORMATION_RESPONSE_SIZE];
response_bytes[0] = 0x42;
let response = FormationChallengeResponse {
response: response_bytes,
};
let bytes = response.to_bytes();
let restored = FormationChallengeResponse::from_bytes(&bytes).unwrap();
assert_eq!(response.response, restored.response);
}
#[test]
fn test_auth_result_serialization() {
assert_eq!(
FormationAuthResult::from_byte(FormationAuthResult::Accepted.to_byte()),
FormationAuthResult::Accepted
);
assert_eq!(
FormationAuthResult::from_byte(FormationAuthResult::Rejected.to_byte()),
FormationAuthResult::Rejected
);
}
#[test]
fn test_generate_secret() {
let secret1 = FormationKey::generate_secret();
let secret2 = FormationKey::generate_secret();
assert_ne!(secret1, secret2);
use base64::{engine::general_purpose::STANDARD, Engine};
let decoded = STANDARD.decode(&secret1).unwrap();
assert_eq!(decoded.len(), 32);
}
#[test]
fn test_wire_protocol_accept_with_matching_key() {
let secret = [0x42u8; 32];
let acceptor_key = FormationKey::new("test-formation", &secret);
let connector_key = FormationKey::new("test-formation", &secret);
let (nonce, _) = acceptor_key.create_challenge();
let challenge = FormationChallenge {
formation_id: acceptor_key.formation_id().to_string(),
nonce,
};
let challenge_bytes = challenge.to_bytes();
let decoded_challenge = FormationChallenge::from_bytes(&challenge_bytes).unwrap();
assert_eq!(decoded_challenge.formation_id, "test-formation");
let response = connector_key.respond_to_challenge(&decoded_challenge.nonce);
let resp = FormationChallengeResponse { response };
let resp_bytes = resp.to_bytes();
let decoded_resp = FormationChallengeResponse::from_bytes(&resp_bytes).unwrap();
assert!(
acceptor_key.verify_response(&nonce, &decoded_resp.response),
"Matching keys should produce accepted auth"
);
}
#[test]
fn test_wire_protocol_reject_with_wrong_key() {
let acceptor_key = FormationKey::new("test-formation", &[0x42u8; 32]);
let connector_key = FormationKey::new("test-formation", &[0xFF; 32]);
let (nonce, _) = acceptor_key.create_challenge();
let challenge = FormationChallenge {
formation_id: acceptor_key.formation_id().to_string(),
nonce,
};
let challenge_bytes = challenge.to_bytes();
let decoded_challenge = FormationChallenge::from_bytes(&challenge_bytes).unwrap();
let response = connector_key.respond_to_challenge(&decoded_challenge.nonce);
let resp = FormationChallengeResponse { response };
let resp_bytes = resp.to_bytes();
let decoded_resp = FormationChallengeResponse::from_bytes(&resp_bytes).unwrap();
assert!(
!acceptor_key.verify_response(&nonce, &decoded_resp.response),
"Wrong key should produce rejected auth"
);
}
#[test]
fn test_wire_protocol_formation_id_mismatch() {
let acceptor_key = FormationKey::new("alpha", &[0x42u8; 32]);
let connector_key = FormationKey::new("bravo", &[0x42u8; 32]);
let (nonce, _) = acceptor_key.create_challenge();
let challenge = FormationChallenge {
formation_id: acceptor_key.formation_id().to_string(),
nonce,
};
let challenge_bytes = challenge.to_bytes();
let decoded_challenge = FormationChallenge::from_bytes(&challenge_bytes).unwrap();
assert_ne!(
decoded_challenge.formation_id,
connector_key.formation_id(),
"Connector should detect formation ID mismatch before responding"
);
}
}