use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceSecurityState {
pub bundle_version: u32,
#[serde(default)]
pub last_verification_time: Option<u64>,
pub firmware_build_time: u64,
#[serde(default)]
pub module_versions: BTreeMap<String, ModuleVersionInfo>,
}
impl DeviceSecurityState {
pub fn new(firmware_build_time: u64) -> Self {
Self {
bundle_version: 0,
last_verification_time: None,
firmware_build_time,
module_versions: BTreeMap::new(),
}
}
pub fn with_build_timestamp() -> Self {
Self::new(crate::time::BUILD_TIMESTAMP)
}
pub fn check_bundle_version(&self, version: u32) -> bool {
version >= self.bundle_version
}
pub fn update_bundle_version(&mut self, version: u32) {
if version > self.bundle_version {
self.bundle_version = version;
}
}
pub fn update_verification_time(&mut self, time: u64) {
self.last_verification_time = Some(time);
}
pub fn update_module_version(&mut self, module_id: &str, info: ModuleVersionInfo) {
self.module_versions.insert(module_id.to_string(), info);
}
pub fn check_module_version(&self, module_id: &str, signature_time: u64) -> bool {
if let Some(info) = self.module_versions.get(module_id) {
signature_time >= info.signature_time
} else {
true
}
}
pub fn to_json(&self) -> Result<Vec<u8>, crate::error::WSError> {
serde_json::to_vec_pretty(self).map_err(|e| {
crate::error::WSError::InternalError(format!("Failed to serialize state: {}", e))
})
}
pub fn from_json(data: &[u8]) -> Result<Self, crate::error::WSError> {
serde_json::from_slice(data).map_err(|e| {
crate::error::WSError::InternalError(format!("Failed to parse state: {}", e))
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModuleVersionInfo {
#[serde(default)]
pub version: Option<String>,
pub signature_time: u64,
pub module_hash: String,
}
impl ModuleVersionInfo {
pub fn new(signature_time: u64, module_hash: &[u8]) -> Self {
Self {
version: None,
signature_time,
module_hash: hex::encode(module_hash),
}
}
pub fn with_version(mut self, version: &str) -> Self {
self.version = Some(version.to_string());
self
}
}
pub trait SecureStorage {
fn load_state(&self) -> Result<DeviceSecurityState, crate::error::WSError>;
fn save_state(&self, state: &DeviceSecurityState) -> Result<(), crate::error::WSError>;
}
#[derive(Debug, Default)]
pub struct MemoryStorage {
state: std::sync::RwLock<Option<DeviceSecurityState>>,
}
impl MemoryStorage {
pub fn new() -> Self {
Self::default()
}
pub fn with_state(state: DeviceSecurityState) -> Self {
Self {
state: std::sync::RwLock::new(Some(state)),
}
}
}
impl SecureStorage for MemoryStorage {
fn load_state(&self) -> Result<DeviceSecurityState, crate::error::WSError> {
self.state
.read()
.map_err(|_| crate::error::WSError::InternalError("Lock poisoned".to_string()))?
.clone()
.ok_or_else(|| crate::error::WSError::InternalError("No state stored".to_string()))
}
fn save_state(&self, state: &DeviceSecurityState) -> Result<(), crate::error::WSError> {
let mut guard = self
.state
.write()
.map_err(|_| crate::error::WSError::InternalError("Lock poisoned".to_string()))?;
*guard = Some(state.clone());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_state_creation() {
let state = DeviceSecurityState::new(1704067200);
assert_eq!(state.bundle_version, 0);
assert!(state.last_verification_time.is_none());
assert_eq!(state.firmware_build_time, 1704067200);
}
#[test]
fn test_bundle_version_check() {
let mut state = DeviceSecurityState::new(1704067200);
state.bundle_version = 5;
assert!(state.check_bundle_version(5)); assert!(state.check_bundle_version(6)); assert!(!state.check_bundle_version(4)); }
#[test]
fn test_bundle_version_update() {
let mut state = DeviceSecurityState::new(1704067200);
state.update_bundle_version(5);
assert_eq!(state.bundle_version, 5);
state.update_bundle_version(3); assert_eq!(state.bundle_version, 5);
state.update_bundle_version(7);
assert_eq!(state.bundle_version, 7);
}
#[test]
fn test_module_version_tracking() {
let mut state = DeviceSecurityState::new(1704067200);
assert!(state.check_module_version("my-module", 1000));
state.update_module_version(
"my-module",
ModuleVersionInfo::new(1000, &[0u8; 32]),
);
assert!(state.check_module_version("my-module", 1000));
assert!(state.check_module_version("my-module", 2000));
assert!(!state.check_module_version("my-module", 500));
}
#[test]
fn test_state_json_roundtrip() {
let mut state = DeviceSecurityState::new(1704067200);
state.bundle_version = 42;
state.update_verification_time(1704100000);
let json = state.to_json().unwrap();
let parsed = DeviceSecurityState::from_json(&json).unwrap();
assert_eq!(parsed.bundle_version, 42);
assert_eq!(parsed.last_verification_time, Some(1704100000));
}
#[test]
fn test_memory_storage() {
let storage = MemoryStorage::new();
assert!(storage.load_state().is_err());
let state = DeviceSecurityState::new(1704067200);
storage.save_state(&state).unwrap();
let loaded = storage.load_state().unwrap();
assert_eq!(loaded.firmware_build_time, 1704067200);
}
}