use super::device_id::DeviceId;
use super::error::SecurityError;
use super::keypair::DeviceKeypair;
use super::{CHALLENGE_NONCE_SIZE, DEFAULT_CHALLENGE_TIMEOUT_SECS};
use peat_schema::security::v1::{Challenge, SignedChallengeResponse};
use rand_core::{OsRng, RngCore};
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub const CURRENT_PROTOCOL_VERSION: u32 = 1;
pub const INCOMPATIBLE_PROTOCOL_VERSION_PREFIX: &str = "incompatible protocol version:";
pub struct DeviceAuthenticator {
keypair: DeviceKeypair,
verified_peers: RwLock<HashMap<DeviceId, VerifiedPeer>>,
challenge_timeout: Duration,
}
#[derive(Debug, Clone)]
pub struct VerifiedPeer {
pub device_id: DeviceId,
pub public_key: [u8; 32],
pub verified_at: SystemTime,
}
impl DeviceAuthenticator {
pub fn new(keypair: DeviceKeypair) -> Self {
Self::with_timeout(keypair, Duration::from_secs(DEFAULT_CHALLENGE_TIMEOUT_SECS))
}
pub fn with_timeout(keypair: DeviceKeypair, challenge_timeout: Duration) -> Self {
DeviceAuthenticator {
keypair,
verified_peers: RwLock::new(HashMap::new()),
challenge_timeout,
}
}
pub fn device_id(&self) -> DeviceId {
self.keypair.device_id()
}
pub fn public_key_bytes(&self) -> [u8; 32] {
self.keypair.public_key_bytes()
}
pub fn generate_challenge(&self) -> Challenge {
let mut nonce = [0u8; CHALLENGE_NONCE_SIZE];
OsRng.fill_bytes(&mut nonce);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let expires = now + self.challenge_timeout;
Challenge {
nonce: nonce.to_vec(),
timestamp: Some(peat_schema::common::v1::Timestamp {
seconds: now.as_secs(),
nanos: now.subsec_nanos(),
}),
challenger_id: self.device_id().to_hex(),
expires_at: Some(peat_schema::common::v1::Timestamp {
seconds: expires.as_secs(),
nanos: expires.subsec_nanos(),
}),
protocol_version: CURRENT_PROTOCOL_VERSION,
capabilities: Vec::new(),
}
}
pub fn respond_to_challenge(
&self,
challenge: &Challenge,
) -> Result<SignedChallengeResponse, SecurityError> {
self.check_challenge_expiry(challenge)?;
let negotiated_version = challenge.protocol_version.min(CURRENT_PROTOCOL_VERSION);
let response_ts_seconds = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let message = build_signed_message(
&challenge.nonce,
&challenge.challenger_id,
response_ts_seconds,
negotiated_version,
);
let signature = self.keypair.sign(&message);
Ok(SignedChallengeResponse {
challenge_nonce: challenge.nonce.clone(),
public_key: self.keypair.public_key_bytes().to_vec(),
signature: signature.to_bytes().to_vec(),
timestamp: Some(peat_schema::common::v1::Timestamp {
seconds: response_ts_seconds,
nanos: 0,
}),
device_type: 0, certificates: vec![], protocol_version: negotiated_version,
capabilities: Vec::new(),
})
}
pub fn verify_response(
&self,
response: &SignedChallengeResponse,
) -> Result<DeviceId, SecurityError> {
if response.protocol_version > CURRENT_PROTOCOL_VERSION {
return Err(SecurityError::AuthenticationFailed(format!(
"{INCOMPATIBLE_PROTOCOL_VERSION_PREFIX} peer claims {peer}, our maximum is {ours}",
peer = response.protocol_version,
ours = CURRENT_PROTOCOL_VERSION,
)));
}
let public_key = DeviceKeypair::verifying_key_from_bytes(&response.public_key)?;
let peer_device_id = DeviceId::from_public_key(&public_key);
let response_ts_seconds = response
.timestamp
.as_ref()
.map(|ts| ts.seconds)
.unwrap_or(0);
let message = build_signed_message(
&response.challenge_nonce,
&self.device_id().to_hex(),
response_ts_seconds,
response.protocol_version,
);
let signature = DeviceKeypair::signature_from_bytes(&response.signature)?;
DeviceKeypair::verify_with_key(&public_key, &message, &signature)?;
let verified_peer = VerifiedPeer {
device_id: peer_device_id,
public_key: public_key.to_bytes(),
verified_at: SystemTime::now(),
};
self.verified_peers
.write()
.map_err(|e| SecurityError::Internal(format!("lock poisoned: {}", e)))?
.insert(peer_device_id, verified_peer);
Ok(peer_device_id)
}
pub fn is_verified(&self, device_id: &DeviceId) -> bool {
self.verified_peers
.read()
.map(|cache| cache.contains_key(device_id))
.unwrap_or(false)
}
pub fn get_verified_peer(&self, device_id: &DeviceId) -> Option<VerifiedPeer> {
self.verified_peers
.read()
.ok()
.and_then(|cache| cache.get(device_id).cloned())
}
pub fn remove_peer(&self, device_id: &DeviceId) {
if let Ok(mut cache) = self.verified_peers.write() {
cache.remove(device_id);
}
}
pub fn clear_verified_peers(&self) {
if let Ok(mut cache) = self.verified_peers.write() {
cache.clear();
}
}
pub fn verified_peer_count(&self) -> usize {
self.verified_peers
.read()
.map(|cache| cache.len())
.unwrap_or(0)
}
fn check_challenge_expiry(&self, challenge: &Challenge) -> Result<(), SecurityError> {
if let Some(expires) = &challenge.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
if now.as_secs() > expires.seconds {
return Err(SecurityError::ChallengeExpired(expires.seconds));
}
}
Ok(())
}
}
fn build_signed_message(
nonce: &[u8],
challenger_id: &str,
response_ts_seconds: u64,
protocol_version: u32,
) -> Vec<u8> {
let mut message = Vec::with_capacity(
nonce.len() + challenger_id.len() + 8 + if protocol_version >= 1 { 4 } else { 0 },
);
message.extend_from_slice(nonce);
message.extend_from_slice(challenger_id.as_bytes());
message.extend_from_slice(&response_ts_seconds.to_le_bytes());
if protocol_version >= 1 {
message.extend_from_slice(&protocol_version.to_le_bytes());
}
message
}
impl std::fmt::Debug for DeviceAuthenticator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeviceAuthenticator")
.field("device_id", &self.device_id())
.field("verified_peer_count", &self.verified_peer_count())
.field("challenge_timeout", &self.challenge_timeout)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_authenticator() -> DeviceAuthenticator {
let keypair = DeviceKeypair::generate();
DeviceAuthenticator::new(keypair)
}
#[test]
fn test_generate_challenge() {
let auth = create_test_authenticator();
let challenge = auth.generate_challenge();
assert_eq!(challenge.nonce.len(), CHALLENGE_NONCE_SIZE);
assert!(!challenge.challenger_id.is_empty());
assert!(challenge.timestamp.is_some());
assert!(challenge.expires_at.is_some());
}
#[test]
fn test_challenge_nonce_unique() {
let auth = create_test_authenticator();
let c1 = auth.generate_challenge();
let c2 = auth.generate_challenge();
assert_ne!(c1.nonce, c2.nonce);
}
#[test]
fn test_respond_to_challenge() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let challenge = auth1.generate_challenge();
let response = auth2.respond_to_challenge(&challenge).unwrap();
assert_eq!(response.public_key.len(), 32);
assert_eq!(response.signature.len(), 64);
assert_eq!(response.challenge_nonce, challenge.nonce);
}
#[test]
fn test_full_authentication_flow() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let challenge = auth1.generate_challenge();
let response = auth2.respond_to_challenge(&challenge).unwrap();
let peer_id = auth1.verify_response(&response).unwrap();
assert_eq!(peer_id, auth2.device_id());
assert!(auth1.is_verified(&peer_id));
}
#[test]
fn test_expired_challenge_rejected() {
let auth = create_test_authenticator();
let mut challenge = auth.generate_challenge();
challenge.expires_at = Some(peat_schema::common::v1::Timestamp {
seconds: 0, nanos: 0,
});
let result = auth.respond_to_challenge(&challenge);
assert!(matches!(result, Err(SecurityError::ChallengeExpired(_))));
}
#[test]
fn test_invalid_signature_rejected() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let challenge = auth1.generate_challenge();
let mut response = auth2.respond_to_challenge(&challenge).unwrap();
response.signature[0] ^= 0xFF;
let result = auth1.verify_response(&response);
assert!(matches!(result, Err(SecurityError::InvalidSignature(_))));
}
#[test]
fn test_remove_peer() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let challenge = auth1.generate_challenge();
let response = auth2.respond_to_challenge(&challenge).unwrap();
let peer_id = auth1.verify_response(&response).unwrap();
assert!(auth1.is_verified(&peer_id));
auth1.remove_peer(&peer_id);
assert!(!auth1.is_verified(&peer_id));
}
#[test]
fn test_clear_verified_peers() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let auth3 = create_test_authenticator();
let c1 = auth1.generate_challenge();
let r1 = auth2.respond_to_challenge(&c1).unwrap();
auth1.verify_response(&r1).unwrap();
let c2 = auth1.generate_challenge();
let r2 = auth3.respond_to_challenge(&c2).unwrap();
auth1.verify_response(&r2).unwrap();
assert_eq!(auth1.verified_peer_count(), 2);
auth1.clear_verified_peers();
assert_eq!(auth1.verified_peer_count(), 0);
}
#[test]
fn timestamp_mismatch_between_challenge_and_response_does_not_break_verification() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let challenge = Challenge {
nonce: vec![7u8; CHALLENGE_NONCE_SIZE],
timestamp: Some(peat_schema::common::v1::Timestamp {
seconds: now.as_secs() - 5,
nanos: 0,
}),
challenger_id: auth1.device_id().to_hex(),
expires_at: Some(peat_schema::common::v1::Timestamp {
seconds: now.as_secs() + 60,
nanos: 0,
}),
protocol_version: CURRENT_PROTOCOL_VERSION,
capabilities: Vec::new(),
};
let response = auth2
.respond_to_challenge(&challenge)
.expect("respond_to_challenge succeeds (challenge not expired)");
let peer_id = auth1
.verify_response(&response)
.expect("verify_response must succeed regardless of when the challenge was issued");
assert_eq!(peer_id, auth2.device_id());
}
#[test]
fn v1_v1_roundtrip_negotiates_to_current_version() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let challenge = auth1.generate_challenge();
assert_eq!(
challenge.protocol_version, CURRENT_PROTOCOL_VERSION,
"generate_challenge must advertise our current version"
);
let response = auth2.respond_to_challenge(&challenge).unwrap();
assert_eq!(
response.protocol_version, CURRENT_PROTOCOL_VERSION,
"v1 responder must negotiate to CURRENT_PROTOCOL_VERSION when peer also speaks it"
);
let peer_id = auth1.verify_response(&response).expect("verify v1");
assert_eq!(peer_id, auth2.device_id());
}
#[test]
fn v0_challenger_negotiates_down_to_v0_construction() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let challenge = Challenge {
nonce: vec![3u8; CHALLENGE_NONCE_SIZE],
timestamp: Some(peat_schema::common::v1::Timestamp {
seconds: now_secs,
nanos: 0,
}),
challenger_id: auth1.device_id().to_hex(),
expires_at: Some(peat_schema::common::v1::Timestamp {
seconds: now_secs + 60,
nanos: 0,
}),
protocol_version: 0,
capabilities: Vec::new(),
};
let response = auth2.respond_to_challenge(&challenge).unwrap();
assert_eq!(
response.protocol_version, 0,
"v1 responder must negotiate down to 0 when peer advertises 0"
);
let peer_id = auth1
.verify_response(&response)
.expect("v0 roundtrip verifies cleanly");
assert_eq!(peer_id, auth2.device_id());
}
#[test]
fn v1_verifier_rejects_future_protocol_version_with_distinct_error() {
let auth = create_test_authenticator();
let response = SignedChallengeResponse {
challenge_nonce: vec![0u8; CHALLENGE_NONCE_SIZE],
public_key: vec![0u8; 32],
signature: vec![0u8; 64],
timestamp: Some(peat_schema::common::v1::Timestamp {
seconds: 0,
nanos: 0,
}),
device_type: 0,
certificates: vec![],
protocol_version: u32::MAX,
capabilities: Vec::new(),
};
let err = auth
.verify_response(&response)
.expect_err("version-mismatch must surface as Err");
match err {
SecurityError::AuthenticationFailed(msg) => {
assert!(
msg.starts_with(INCOMPATIBLE_PROTOCOL_VERSION_PREFIX),
"operator-visible message must start with the documented \
prefix for callers to detect; got: {msg}"
);
assert!(
msg.contains(&u32::MAX.to_string())
&& msg.contains(&CURRENT_PROTOCOL_VERSION.to_string()),
"diagnostic must name both the peer's claimed version \
and our maximum; got: {msg}"
);
}
other => panic!(
"version mismatch must surface as AuthenticationFailed \
(distinct from InvalidSignature), not: {other:?}"
),
}
}
#[test]
fn v1_signature_binds_protocol_version_against_mitm_downgrade() {
let auth1 = create_test_authenticator();
let auth2 = create_test_authenticator();
let challenge = auth1.generate_challenge();
let mut response = auth2.respond_to_challenge(&challenge).unwrap();
assert_eq!(response.protocol_version, CURRENT_PROTOCOL_VERSION);
response.protocol_version = 0;
let err = auth1
.verify_response(&response)
.expect_err("MITM-downgraded version must fail signature verification");
assert!(
matches!(err, SecurityError::InvalidSignature(_)),
"downgrade attempt must surface as InvalidSignature (the bytes \
don't match), not as a clean negotiation result; got: {err:?}"
);
}
}