use std::sync::OnceLock;
use log::{debug, warn};
use xtee_psk::{PskError, PskResult};
pub(crate) struct ServerIdentityCache {
key: OnceLock<Vec<u8>>,
}
impl ServerIdentityCache {
pub(crate) const fn new() -> Self {
Self {
key: OnceLock::new(),
}
}
pub(crate) fn verify(&self, long_term_point: &[u8]) -> PskResult<()> {
match self.key.set(long_term_point.to_vec()) {
Ok(()) => {
debug!("首次连接,已记录服务端长期公钥(TOFU)");
}
Err(_) => {
let stored = self
.key
.get()
.expect("OnceLock must be initialized after set fails");
if stored != long_term_point {
warn!("服务端长期公钥与已记录的不一致,可能存在 MITM 攻击");
return Err(PskError::VerificationFailed);
}
}
}
Ok(())
}
}
static SERVER_IDENTITY: ServerIdentityCache = ServerIdentityCache::new();
pub(crate) fn verify_server_identity(long_term_point: &[u8]) -> PskResult<()> {
SERVER_IDENTITY.verify(long_term_point)
}
#[cfg(test)]
mod psk_tests {
use super::*;
use xtee_psk::SM2_POINT_LEN;
#[test]
fn test_server_identity_cache_new_instance() {
let cache = ServerIdentityCache::new();
let test_point = vec![0xDEu8; SM2_POINT_LEN];
cache.verify(&test_point).expect("首次 TOFU 设置应该成功");
cache.verify(&test_point).expect("相同身份应该通过");
let mut different = test_point.clone();
different[0] ^= 0xFF;
assert!(
cache.verify(&different).is_err(),
"篡改的公钥应该被 TOFU 校验拒绝"
);
}
#[test]
fn test_verify_server_identity_global() {
let identity = SERVER_IDENTITY.key.get();
if let Some(existing) = identity {
verify_server_identity(existing).expect("已有身份应该通过");
let mut different = existing.clone();
different[0] ^= 0xFF;
assert!(
verify_server_identity(&different).is_err(),
"不一致的身份应该失败"
);
} else {
let test_point = vec![0xDEu8; SM2_POINT_LEN];
verify_server_identity(&test_point).expect("首次 TOFU 设置应该成功");
verify_server_identity(&test_point).expect("相同身份应该通过");
}
if let Some(stored) = SERVER_IDENTITY.key.get() {
let mut fake_point = stored.clone();
fake_point[0] ^= 0xFF;
assert!(
verify_server_identity(&fake_point).is_err(),
"篡改的公钥应该被 TOFU 校验拒绝"
);
}
}
}