use elara_core::{ElaraError, ElaraResult, NodeId, PacketClass, RepresentationProfile, SessionId};
use elara_crypto::{ReplayManager, ReplayWindow, SecureFrameProcessor, KEY_SIZE};
use elara_wire::Extensions;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct SecurityTestConfig {
pub iterations: usize,
pub timing_threshold_ns: u64,
pub max_timing_variance: f64,
}
impl Default for SecurityTestConfig {
fn default() -> Self {
SecurityTestConfig {
iterations: 1000,
timing_threshold_ns: 100,
max_timing_variance: 0.1, }
}
}
#[derive(Debug, Clone)]
pub struct SecurityTestResult {
pub passed: bool,
pub description: String,
pub error: Option<String>,
pub timing_stats: Option<TimingStats>,
}
impl SecurityTestResult {
pub fn pass(description: impl Into<String>) -> Self {
SecurityTestResult {
passed: true,
description: description.into(),
error: None,
timing_stats: None,
}
}
pub fn fail(description: impl Into<String>, error: impl Into<String>) -> Self {
SecurityTestResult {
passed: false,
description: description.into(),
error: Some(error.into()),
timing_stats: None,
}
}
pub fn with_timing(mut self, stats: TimingStats) -> Self {
self.timing_stats = Some(stats);
self
}
}
#[derive(Debug, Clone)]
pub struct TimingStats {
pub min_ns: u64,
pub max_ns: u64,
pub mean_ns: f64,
pub std_dev_ns: f64,
pub coefficient_of_variation: f64,
}
impl TimingStats {
pub fn from_measurements(measurements: &[u64]) -> Self {
let min_ns = *measurements.iter().min().unwrap_or(&0);
let max_ns = *measurements.iter().max().unwrap_or(&0);
let mean_ns = measurements.iter().sum::<u64>() as f64 / measurements.len() as f64;
let variance = measurements
.iter()
.map(|&x| {
let diff = x as f64 - mean_ns;
diff * diff
})
.sum::<f64>()
/ measurements.len() as f64;
let std_dev_ns = variance.sqrt();
let coefficient_of_variation = if mean_ns > 0.0 {
std_dev_ns / mean_ns
} else {
0.0
};
TimingStats {
min_ns,
max_ns,
mean_ns,
std_dev_ns,
coefficient_of_variation,
}
}
pub fn is_constant_time(&self, max_variance: f64) -> bool {
self.coefficient_of_variation <= max_variance
}
}
pub struct ReplayProtectionTestHarness {
manager: ReplayManager,
node_id: NodeId,
class: PacketClass,
}
impl ReplayProtectionTestHarness {
pub fn new(node_id: NodeId, class: PacketClass) -> Self {
ReplayProtectionTestHarness {
manager: ReplayManager::new(),
node_id,
class,
}
}
pub fn accept(&mut self, seq: u16) -> ElaraResult<()> {
self.manager.accept(self.node_id, self.class, seq)
}
pub fn check(&self, seq: u16) -> bool {
self.manager.check(self.node_id, self.class, seq)
}
pub fn get_window(&self) -> Option<&ReplayWindow> {
self.manager.get_window(self.node_id, self.class)
}
pub fn test_replay_attack(&mut self, seq: u16) -> SecurityTestResult {
if let Err(e) = self.accept(seq) {
return SecurityTestResult::fail(
"Replay attack test",
format!("First packet acceptance failed: {:?}", e),
);
}
match self.accept(seq) {
Err(ElaraError::ReplayDetected(_)) => {
SecurityTestResult::pass("Replay attack correctly detected")
}
Ok(_) => SecurityTestResult::fail(
"Replay attack test",
"Replay attack not detected - same packet accepted twice",
),
Err(e) => SecurityTestResult::fail(
"Replay attack test",
format!("Unexpected error: {:?}", e),
),
}
}
pub fn test_window_advancement(&mut self, initial_seq: u16, jump: u16) -> SecurityTestResult {
if let Err(e) = self.accept(initial_seq) {
return SecurityTestResult::fail(
"Window advancement test",
format!("Initial packet acceptance failed: {:?}", e),
);
}
let window_before = self.get_window().map(|w| w.min_seq());
let new_seq = initial_seq.wrapping_add(jump);
if let Err(e) = self.accept(new_seq) {
return SecurityTestResult::fail(
"Window advancement test",
format!("Jump packet acceptance failed: {:?}", e),
);
}
let window_after = self.get_window().map(|w| w.min_seq());
match (window_before, window_after) {
(Some(before), Some(after)) => {
if jump >= self.class.replay_window_size() {
if after > before {
SecurityTestResult::pass("Replay window advanced correctly")
} else {
SecurityTestResult::fail(
"Window advancement test",
format!("Window did not advance: before={}, after={}", before, after),
)
}
} else {
SecurityTestResult::pass("Replay window handled small jump correctly")
}
}
_ => SecurityTestResult::fail(
"Window advancement test",
"Could not retrieve window state",
),
}
}
pub fn test_sequence_wraparound(&mut self) -> SecurityTestResult {
let mut current = 0u16;
let step = 10000u16;
while current < 65000 {
if let Err(e) = self.accept(current) {
return SecurityTestResult::fail(
"Sequence wraparound test",
format!("Failed to advance window at {}: {:?}", current, e),
);
}
current = current.saturating_add(step);
}
for seq in 65530..=65535 {
if let Err(e) = self.accept(seq) {
return SecurityTestResult::fail(
"Sequence wraparound test",
format!("Failed to accept sequence {}: {:?}", seq, e),
);
}
}
for seq in 0..5 {
if let Err(e) = self.accept(seq) {
return SecurityTestResult::fail(
"Sequence wraparound test",
format!("Failed to accept wrapped sequence {}: {:?}", seq, e),
);
}
}
SecurityTestResult::pass("Sequence number wraparound handled correctly")
}
}
pub struct MessageAuthenticationTestHarness {
processor: SecureFrameProcessor,
_session_id: SessionId,
_node_id: NodeId,
}
impl MessageAuthenticationTestHarness {
pub fn new(session_id: SessionId, node_id: NodeId, session_key: [u8; KEY_SIZE]) -> Self {
MessageAuthenticationTestHarness {
processor: SecureFrameProcessor::new(session_id, node_id, session_key),
_session_id: session_id,
_node_id: node_id,
}
}
pub fn encrypt_message(
&mut self,
class: PacketClass,
payload: &[u8],
) -> ElaraResult<Vec<u8>> {
self.processor.encrypt_frame(
class,
RepresentationProfile::Textual,
0,
Extensions::default(),
payload,
)
}
pub fn decrypt_message(&mut self, data: &[u8]) -> ElaraResult<Vec<u8>> {
let decrypted = self.processor.decrypt_frame(data)?;
Ok(decrypted.payload.clone())
}
pub fn test_message_tampering(&mut self, payload: &[u8]) -> SecurityTestResult {
let encrypted = match self.encrypt_message(PacketClass::Core, payload) {
Ok(data) => data,
Err(e) => {
return SecurityTestResult::fail(
"Message tampering test",
format!("Encryption failed: {:?}", e),
)
}
};
let mut tampered = encrypted.clone();
if tampered.len() > 20 {
let idx = tampered.len() / 2;
tampered[idx] ^= 0x01;
}
match self.decrypt_message(&tampered) {
Err(_) => SecurityTestResult::pass("Message tampering correctly detected"),
Ok(_) => SecurityTestResult::fail(
"Message tampering test",
"Tampered message was accepted - authentication failed",
),
}
}
pub fn test_mac_verification(&mut self, payload: &[u8]) -> SecurityTestResult {
let encrypted = match self.encrypt_message(PacketClass::Core, payload) {
Ok(data) => data,
Err(e) => {
return SecurityTestResult::fail(
"MAC verification test",
format!("Encryption failed: {:?}", e),
)
}
};
if encrypted.len() < 16 {
return SecurityTestResult::fail(
"MAC verification test",
"Encrypted message too short",
);
}
let truncated = &encrypted[..encrypted.len() - 16];
match self.decrypt_message(truncated) {
Err(_) => SecurityTestResult::pass("MAC verification correctly enforced"),
Ok(_) => SecurityTestResult::fail(
"MAC verification test",
"Message without MAC was accepted",
),
}
}
}
pub struct KeyIsolationTestHarness {
processors: Vec<SecureFrameProcessor>,
}
impl KeyIsolationTestHarness {
pub fn new(num_sessions: usize) -> Self {
let mut processors = Vec::new();
for i in 0..num_sessions {
let session_id = SessionId::new(i as u64);
let node_id = NodeId::new(i as u64);
let session_key = [i as u8; KEY_SIZE];
processors.push(SecureFrameProcessor::new(session_id, node_id, session_key));
}
KeyIsolationTestHarness { processors }
}
pub fn test_session_key_isolation(&mut self, payload: &[u8]) -> SecurityTestResult {
if self.processors.len() < 2 {
return SecurityTestResult::fail(
"Session key isolation test",
"Need at least 2 sessions for isolation test",
);
}
let encrypted = match self.processors[0].encrypt_frame(
PacketClass::Core,
RepresentationProfile::Textual,
0,
Extensions::default(),
payload,
) {
Ok(data) => data,
Err(e) => {
return SecurityTestResult::fail(
"Session key isolation test",
format!("Encryption failed: {:?}", e),
)
}
};
match self.processors[1].decrypt_frame(&encrypted) {
Err(_) => SecurityTestResult::pass("Session keys are properly isolated"),
Ok(_) => SecurityTestResult::fail(
"Session key isolation test",
"Message encrypted with one session key was decrypted with another - key isolation failed",
),
}
}
pub fn test_key_derivation_independence(&mut self, payload: &[u8]) -> SecurityTestResult {
if self.processors.is_empty() {
return SecurityTestResult::fail(
"Key derivation independence test",
"Need at least 1 session",
);
}
let mut encrypted_messages = Vec::new();
for _ in 0..5 {
match self.processors[0].encrypt_frame(
PacketClass::Core,
RepresentationProfile::Textual,
0,
Extensions::default(),
payload,
) {
Ok(data) => encrypted_messages.push(data),
Err(e) => {
return SecurityTestResult::fail(
"Key derivation independence test",
format!("Encryption failed: {:?}", e),
)
}
}
}
for i in 0..encrypted_messages.len() {
for j in (i + 1)..encrypted_messages.len() {
if encrypted_messages[i] == encrypted_messages[j] {
return SecurityTestResult::fail(
"Key derivation independence test",
"Multiple encryptions produced identical ciphertext - key derivation not independent",
);
}
}
}
SecurityTestResult::pass("Key derivation is independent across messages")
}
}
pub struct TimingAttackTestHarness {
config: SecurityTestConfig,
}
impl TimingAttackTestHarness {
pub fn new(config: SecurityTestConfig) -> Self {
TimingAttackTestHarness { config }
}
fn measure_operation<F>(&self, mut operation: F) -> Vec<u64>
where
F: FnMut(),
{
let mut measurements = Vec::with_capacity(self.config.iterations);
for _ in 0..10 {
operation();
}
for _ in 0..self.config.iterations {
let start = Instant::now();
operation();
let elapsed = start.elapsed().as_nanos() as u64;
measurements.push(elapsed);
}
measurements
}
pub fn test_constant_time_operation<F>(
&self,
description: &str,
operation: F,
) -> SecurityTestResult
where
F: FnMut(),
{
let measurements = self.measure_operation(operation);
let stats = TimingStats::from_measurements(&measurements);
if stats.is_constant_time(self.config.max_timing_variance) {
SecurityTestResult::pass(format!(
"{} is constant-time (CV: {:.4})",
description, stats.coefficient_of_variation
))
.with_timing(stats)
} else {
SecurityTestResult::fail(
description,
format!(
"Operation is not constant-time (CV: {:.4} > {:.4})",
stats.coefficient_of_variation, self.config.max_timing_variance
),
)
.with_timing(stats)
}
}
pub fn test_encryption_timing(
&self,
processor: &mut SecureFrameProcessor,
payloads: &[&[u8]],
) -> SecurityTestResult {
if payloads.len() < 2 {
return SecurityTestResult::fail(
"Encryption timing test",
"Need at least 2 different payloads",
);
}
let mut all_measurements = Vec::new();
for payload in payloads {
let measurements = self.measure_operation(|| {
let _ = processor.encrypt_frame(
PacketClass::Core,
RepresentationProfile::Textual,
0,
Extensions::default(),
payload,
);
});
all_measurements.push(measurements);
}
let stats: Vec<TimingStats> = all_measurements
.iter()
.map(|m| TimingStats::from_measurements(m))
.collect();
let mean_of_means =
stats.iter().map(|s| s.mean_ns).sum::<f64>() / stats.len() as f64;
let max_deviation = stats
.iter()
.map(|s| ((s.mean_ns - mean_of_means) / mean_of_means).abs())
.fold(0.0f64, f64::max);
if max_deviation <= self.config.max_timing_variance {
SecurityTestResult::pass(format!(
"Encryption timing is consistent across payloads (max deviation: {:.4})",
max_deviation
))
} else {
SecurityTestResult::fail(
"Encryption timing test",
format!(
"Encryption timing varies significantly across payloads (max deviation: {:.4} > {:.4})",
max_deviation, self.config.max_timing_variance
),
)
}
}
}
pub struct SecurityTestSuite {
_config: SecurityTestConfig,
results: Vec<SecurityTestResult>,
}
impl SecurityTestSuite {
pub fn new(config: SecurityTestConfig) -> Self {
SecurityTestSuite {
_config: config,
results: Vec::new(),
}
}
pub fn add_result(&mut self, result: SecurityTestResult) {
self.results.push(result);
}
pub fn run_all_tests(&mut self) {
self.run_replay_protection_tests();
self.run_message_authentication_tests();
self.run_key_isolation_tests();
}
fn run_replay_protection_tests(&mut self) {
let node_id = NodeId::new(1);
let class = PacketClass::Core;
let mut harness = ReplayProtectionTestHarness::new(node_id, class);
self.add_result(harness.test_replay_attack(100));
let mut harness2 = ReplayProtectionTestHarness::new(node_id, class);
self.add_result(harness2.test_window_advancement(0, 20));
let mut harness3 = ReplayProtectionTestHarness::new(node_id, class);
self.add_result(harness3.test_sequence_wraparound());
}
fn run_message_authentication_tests(&mut self) {
let session_id = SessionId::new(1);
let node_id = NodeId::new(1);
let session_key = [0x42; KEY_SIZE];
let mut harness = MessageAuthenticationTestHarness::new(session_id, node_id, session_key);
let test_payload = b"Test message for authentication";
self.add_result(harness.test_message_tampering(test_payload));
let mut harness2 =
MessageAuthenticationTestHarness::new(session_id, node_id, session_key);
self.add_result(harness2.test_mac_verification(test_payload));
}
fn run_key_isolation_tests(&mut self) {
let mut harness = KeyIsolationTestHarness::new(3);
let test_payload = b"Test message for key isolation";
self.add_result(harness.test_session_key_isolation(test_payload));
self.add_result(harness.test_key_derivation_independence(test_payload));
}
pub fn results(&self) -> &[SecurityTestResult] {
&self.results
}
pub fn all_passed(&self) -> bool {
self.results.iter().all(|r| r.passed)
}
pub fn summary(&self) -> SecurityTestSummary {
let total = self.results.len();
let passed = self.results.iter().filter(|r| r.passed).count();
let failed = total - passed;
SecurityTestSummary {
total,
passed,
failed,
}
}
}
#[derive(Debug, Clone)]
pub struct SecurityTestSummary {
pub total: usize,
pub passed: usize,
pub failed: usize,
}
impl SecurityTestSummary {
pub fn success_rate(&self) -> f64 {
if self.total == 0 {
0.0
} else {
self.passed as f64 / self.total as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timing_stats_calculation() {
let measurements = vec![100, 102, 98, 101, 99, 103, 97, 100, 102, 98];
let stats = TimingStats::from_measurements(&measurements);
assert_eq!(stats.min_ns, 97);
assert_eq!(stats.max_ns, 103);
assert!((stats.mean_ns - 100.0).abs() < 1.0);
}
#[test]
fn test_replay_protection_harness() {
let node_id = NodeId::new(1);
let class = PacketClass::Core;
let mut harness = ReplayProtectionTestHarness::new(node_id, class);
assert!(harness.accept(0).is_ok());
assert!(harness.accept(0).is_err());
}
#[test]
fn test_security_test_suite() {
let config = SecurityTestConfig::default();
let mut suite = SecurityTestSuite::new(config);
suite.run_all_tests();
let summary = suite.summary();
assert!(summary.total > 0);
assert!(summary.passed > 0);
}
}