use anyhow::{anyhow, bail, Result};
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct SecurityContext {
policy: SecurityPolicy,
cfi_state: Option<CfiState>,
canary_state: Option<CanaryState>,
aslr_state: Option<AslrState>,
attestation_state: Option<AttestationState>,
violations: Vec<SecurityViolation>,
}
#[derive(Debug, Clone)]
pub struct SecurityPolicy {
pub enable_cfi: bool,
pub enable_stack_canaries: bool,
pub enable_aslr: bool,
pub enable_attestation: bool,
pub panic_on_violation: bool,
pub max_violations: usize,
}
impl SecurityPolicy {
pub fn strict() -> Self {
Self {
enable_cfi: true,
enable_stack_canaries: true,
enable_aslr: true,
enable_attestation: true,
panic_on_violation: true,
max_violations: 0,
}
}
pub fn standard() -> Self {
Self {
enable_cfi: true,
enable_stack_canaries: true,
enable_aslr: false,
enable_attestation: false,
panic_on_violation: false,
max_violations: 10,
}
}
pub fn permissive() -> Self {
Self {
enable_cfi: false,
enable_stack_canaries: false,
enable_aslr: false,
enable_attestation: false,
panic_on_violation: false,
max_violations: 100,
}
}
}
impl Default for SecurityPolicy {
fn default() -> Self {
Self::standard()
}
}
#[derive(Debug)]
pub struct CfiState {
valid_call_targets: HashSet<u32>,
call_depth: usize,
max_call_depth: usize,
violations: usize,
shadow_stack: Vec<u32>,
}
impl CfiState {
pub fn new(max_call_depth: usize) -> Self {
Self {
valid_call_targets: HashSet::new(),
call_depth: 0,
max_call_depth,
violations: 0,
shadow_stack: Vec::with_capacity(max_call_depth),
}
}
pub fn add_call_target(&mut self, address: u32) {
self.valid_call_targets.insert(address);
}
pub fn validate_call(&mut self, target: u32) -> Result<()> {
if !self.valid_call_targets.contains(&target) {
self.violations += 1;
bail!(
"CFI violation: Invalid call target 0x{:08x} (not in valid target set)",
target
);
}
if self.call_depth >= self.max_call_depth {
self.violations += 1;
bail!(
"CFI violation: Maximum call depth {} exceeded",
self.max_call_depth
);
}
self.call_depth += 1;
Ok(())
}
pub fn validate_return(&mut self, return_address: u32) -> Result<()> {
if let Some(expected) = self.shadow_stack.pop() {
if expected != return_address {
self.violations += 1;
bail!(
"CFI violation: Return address mismatch. Expected 0x{:08x}, got 0x{:08x}",
expected,
return_address
);
}
} else {
self.violations += 1;
bail!("CFI violation: Return with empty shadow stack");
}
if self.call_depth > 0 {
self.call_depth -= 1;
}
Ok(())
}
pub fn push_return_address(&mut self, address: u32) -> Result<()> {
if self.shadow_stack.len() >= self.max_call_depth {
bail!("Shadow stack overflow");
}
self.shadow_stack.push(address);
Ok(())
}
pub fn stats(&self) -> CfiStats {
CfiStats {
valid_targets: self.valid_call_targets.len(),
current_depth: self.call_depth,
max_depth: self.max_call_depth,
violations: self.violations,
shadow_stack_size: self.shadow_stack.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct CfiStats {
pub valid_targets: usize,
pub current_depth: usize,
pub max_depth: usize,
pub violations: usize,
pub shadow_stack_size: usize,
}
#[derive(Debug)]
pub struct CanaryState {
canary_value: u64,
frame_canaries: HashMap<u32, u32>,
violations: usize,
}
impl CanaryState {
pub fn new() -> Self {
let canary_value = Self::generate_canary();
Self {
canary_value,
frame_canaries: HashMap::new(),
violations: 0,
}
}
fn generate_canary() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_nanos() as u64;
let mut s0 = seed;
let mut s1 = seed.wrapping_mul(6364136223846793005);
s1 ^= s0;
s0 = s0.rotate_left(24) ^ s1 ^ (s1 << 16);
s1 = s1.rotate_left(37);
s0.wrapping_add(s1)
}
pub fn place_canary(&mut self, frame_id: u32, canary_address: u32) {
self.frame_canaries.insert(frame_id, canary_address);
}
pub fn verify_canary(&mut self, frame_id: u32, actual_value: u64) -> Result<()> {
if actual_value != self.canary_value {
self.violations += 1;
bail!(
"Stack canary violation: Frame {} - Expected 0x{:016x}, got 0x{:016x}",
frame_id,
self.canary_value,
actual_value
);
}
Ok(())
}
pub fn remove_canary(&mut self, frame_id: u32) {
self.frame_canaries.remove(&frame_id);
}
pub fn get_canary_value(&self) -> u64 {
self.canary_value
}
pub fn stats(&self) -> CanaryStats {
CanaryStats {
active_canaries: self.frame_canaries.len(),
violations: self.violations,
}
}
}
impl Default for CanaryState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CanaryStats {
pub active_canaries: usize,
pub violations: usize,
}
#[derive(Debug)]
pub struct AslrState {
base_offset: u32,
heap_offset: u32,
stack_offset: u32,
applied: bool,
}
impl AslrState {
pub fn new() -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_nanos() as u64;
let base_offset = ((seed & 0xFFFF) as u32) << 16; let heap_offset = (((seed >> 16) & 0xFFFF) as u32) << 16;
let stack_offset = (((seed >> 32) & 0xFFFF) as u32) << 16;
Self {
base_offset,
heap_offset,
stack_offset,
applied: false,
}
}
pub fn get_base_address(&self, original: u32) -> u32 {
original.wrapping_add(self.base_offset)
}
pub fn get_heap_address(&self, original: u32) -> u32 {
original.wrapping_add(self.heap_offset)
}
pub fn get_stack_address(&self, original: u32) -> u32 {
original.wrapping_add(self.stack_offset)
}
pub fn mark_applied(&mut self) {
self.applied = true;
}
pub fn is_applied(&self) -> bool {
self.applied
}
pub fn stats(&self) -> AslrStats {
AslrStats {
base_offset: self.base_offset,
heap_offset: self.heap_offset,
stack_offset: self.stack_offset,
applied: self.applied,
}
}
}
impl Default for AslrState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AslrStats {
pub base_offset: u32,
pub heap_offset: u32,
pub stack_offset: u32,
pub applied: bool,
}
#[derive(Debug)]
pub struct AttestationState {
tokens: HashMap<String, CapabilityToken>,
delegation_chain: HashMap<String, String>,
revoked: HashSet<String>,
}
impl AttestationState {
pub fn new() -> Self {
Self {
tokens: HashMap::new(),
delegation_chain: HashMap::new(),
revoked: HashSet::new(),
}
}
pub fn create_token(
&mut self,
capability_id: String,
permissions: Vec<String>,
) -> Result<CapabilityToken> {
if self.tokens.contains_key(&capability_id) {
bail!("Capability token already exists: {}", capability_id);
}
let token = CapabilityToken::new(capability_id.clone(), permissions);
self.tokens.insert(capability_id, token.clone());
Ok(token)
}
pub fn delegate(
&mut self,
parent_id: &str,
child_id: String,
permissions: Vec<String>,
) -> Result<CapabilityToken> {
let parent_token = self
.tokens
.get(parent_id)
.ok_or_else(|| anyhow!("Parent capability not found: {}", parent_id))?;
if self.revoked.contains(parent_id) {
bail!("Parent capability is revoked: {}", parent_id);
}
for perm in &permissions {
if !parent_token.permissions.contains(perm) {
bail!(
"Permission '{}' not in parent capability {}",
perm,
parent_id
);
}
}
let child_token = CapabilityToken::new(child_id.clone(), permissions);
self.tokens.insert(child_id.clone(), child_token.clone());
self.delegation_chain
.insert(child_id, parent_id.to_string());
Ok(child_token)
}
pub fn verify(&self, capability_id: &str, required_permission: &str) -> Result<()> {
if self.revoked.contains(capability_id) {
bail!("Capability is revoked: {}", capability_id);
}
let token = self
.tokens
.get(capability_id)
.ok_or_else(|| anyhow!("Capability not found: {}", capability_id))?;
if !token.permissions.contains(&required_permission.to_string()) {
bail!(
"Capability {} does not have permission '{}'",
capability_id,
required_permission
);
}
let mut current = capability_id;
while let Some(parent) = self.delegation_chain.get(current) {
if self.revoked.contains(parent) {
bail!("Parent capability {} in chain is revoked", parent);
}
current = parent;
}
Ok(())
}
pub fn revoke(&mut self, capability_id: &str) -> Result<()> {
if !self.tokens.contains_key(capability_id) {
bail!("Capability not found: {}", capability_id);
}
self.revoked.insert(capability_id.to_string());
let children: Vec<String> = self
.delegation_chain
.iter()
.filter(|(_, parent)| parent.as_str() == capability_id)
.map(|(child, _)| child.clone())
.collect();
for child in children {
self.revoke(&child)?;
}
Ok(())
}
pub fn stats(&self) -> AttestationStats {
AttestationStats {
total_tokens: self.tokens.len(),
revoked_tokens: self.revoked.len(),
delegation_chains: self.delegation_chain.len(),
}
}
}
impl Default for AttestationState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AttestationStats {
pub total_tokens: usize,
pub revoked_tokens: usize,
pub delegation_chains: usize,
}
#[derive(Debug, Clone)]
pub struct CapabilityToken {
pub id: String,
pub permissions: Vec<String>,
pub signature: [u8; 32],
pub created_at: u64,
}
impl CapabilityToken {
pub fn new(id: String, permissions: Vec<String>) -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_secs();
let signature = Self::generate_signature(&id, &permissions, created_at);
Self {
id,
permissions,
signature,
created_at,
}
}
fn generate_signature(id: &str, permissions: &[String], timestamp: u64) -> [u8; 32] {
let mut hash = [0u8; 32];
let mut h: u64 = 14695981039346656037;
for byte in id.as_bytes() {
h ^= *byte as u64;
h = h.wrapping_mul(1099511628211); }
for perm in permissions {
for byte in perm.as_bytes() {
h ^= *byte as u64;
h = h.wrapping_mul(1099511628211);
}
}
for byte in timestamp.to_le_bytes() {
h ^= byte as u64;
h = h.wrapping_mul(1099511628211);
}
for i in 0..4 {
let bytes = h.to_le_bytes();
hash[i * 8..(i + 1) * 8].copy_from_slice(&bytes);
h = h.wrapping_mul(1099511628211);
}
hash
}
pub fn verify_signature(&self) -> bool {
let expected = Self::generate_signature(&self.id, &self.permissions, self.created_at);
self.signature == expected
}
}
#[derive(Debug, Clone)]
pub struct SecurityViolation {
pub violation_type: ViolationType,
pub description: String,
pub timestamp: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ViolationType {
CfiViolation,
CanaryViolation,
AslrViolation,
AttestationViolation,
}
impl SecurityContext {
pub fn new(policy: SecurityPolicy) -> Self {
Self {
policy,
cfi_state: None,
canary_state: None,
aslr_state: None,
attestation_state: None,
violations: Vec::new(),
}
}
pub fn enable_cfi(&mut self, max_call_depth: usize) -> Result<()> {
if !self.policy.enable_cfi {
bail!("CFI is disabled in security policy");
}
self.cfi_state = Some(CfiState::new(max_call_depth));
Ok(())
}
pub fn enable_stack_canaries(&mut self) -> Result<()> {
if !self.policy.enable_stack_canaries {
bail!("Stack canaries are disabled in security policy");
}
self.canary_state = Some(CanaryState::new());
Ok(())
}
pub fn enable_aslr(&mut self) -> Result<()> {
if !self.policy.enable_aslr {
bail!("ASLR is disabled in security policy");
}
self.aslr_state = Some(AslrState::new());
Ok(())
}
pub fn enable_attestation(&mut self) -> Result<()> {
if !self.policy.enable_attestation {
bail!("Attestation is disabled in security policy");
}
self.attestation_state = Some(AttestationState::new());
Ok(())
}
pub fn validate_call(&mut self, target: u32) -> Result<()> {
if let Some(cfi) = &mut self.cfi_state {
cfi.validate_call(target).inspect_err(|e| {
self.record_violation(ViolationType::CfiViolation, e.to_string());
})
} else {
Ok(())
}
}
pub fn cfi_state_mut(&mut self) -> Option<&mut CfiState> {
self.cfi_state.as_mut()
}
pub fn canary_state_mut(&mut self) -> Option<&mut CanaryState> {
self.canary_state.as_mut()
}
pub fn aslr_state_mut(&mut self) -> Option<&mut AslrState> {
self.aslr_state.as_mut()
}
pub fn attestation_state_mut(&mut self) -> Option<&mut AttestationState> {
self.attestation_state.as_mut()
}
fn record_violation(&mut self, violation_type: ViolationType, description: String) {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_secs();
let desc_clone = description.clone();
self.violations.push(SecurityViolation {
violation_type,
description,
timestamp,
});
if self.policy.panic_on_violation {
panic!("Security violation: {}", desc_clone);
}
}
pub fn get_violations(&self) -> &[SecurityViolation] {
&self.violations
}
pub fn stats(&self) -> SecurityStats {
SecurityStats {
cfi_stats: self.cfi_state.as_ref().map(|s| s.stats()),
canary_stats: self.canary_state.as_ref().map(|s| s.stats()),
aslr_stats: self.aslr_state.as_ref().map(|s| s.stats()),
attestation_stats: self.attestation_state.as_ref().map(|s| s.stats()),
total_violations: self.violations.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct SecurityStats {
pub cfi_stats: Option<CfiStats>,
pub canary_stats: Option<CanaryStats>,
pub aslr_stats: Option<AslrStats>,
pub attestation_stats: Option<AttestationStats>,
pub total_violations: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_security_policy_presets() {
let strict = SecurityPolicy::strict();
assert!(strict.enable_cfi);
assert!(strict.enable_stack_canaries);
assert!(strict.enable_aslr);
assert!(strict.enable_attestation);
assert!(strict.panic_on_violation);
let standard = SecurityPolicy::standard();
assert!(standard.enable_cfi);
assert!(standard.enable_stack_canaries);
let permissive = SecurityPolicy::permissive();
assert!(!permissive.enable_cfi);
assert!(!permissive.enable_stack_canaries);
}
#[test]
fn test_cfi_state_creation() {
let cfi = CfiState::new(100);
assert_eq!(cfi.max_call_depth, 100);
assert_eq!(cfi.call_depth, 0);
assert_eq!(cfi.violations, 0);
}
#[test]
fn test_cfi_call_validation() {
let mut cfi = CfiState::new(10);
cfi.add_call_target(0x1000);
cfi.add_call_target(0x2000);
assert!(cfi.validate_call(0x1000).is_ok());
assert_eq!(cfi.call_depth, 1);
assert!(cfi.validate_call(0x3000).is_err());
assert_eq!(cfi.violations, 1);
}
#[test]
fn test_cfi_shadow_stack() {
let mut cfi = CfiState::new(10);
cfi.add_call_target(0x1000);
assert!(cfi.push_return_address(0x5000).is_ok());
assert!(cfi.validate_return(0x5000).is_ok());
assert!(cfi.validate_return(0x6000).is_err());
}
#[test]
fn test_cfi_max_depth() {
let mut cfi = CfiState::new(2);
cfi.add_call_target(0x1000);
assert!(cfi.validate_call(0x1000).is_ok());
assert!(cfi.validate_call(0x1000).is_ok());
assert!(cfi.validate_call(0x1000).is_err());
}
#[test]
fn test_canary_generation() {
let canary1 = CanaryState::new();
let canary2 = CanaryState::new();
assert_ne!(canary1.canary_value, canary2.canary_value);
}
#[test]
fn test_canary_verification() {
let mut canary = CanaryState::new();
let canary_value = canary.get_canary_value();
canary.place_canary(1, 0x1000);
assert!(canary.verify_canary(1, canary_value).is_ok());
assert!(canary.verify_canary(1, canary_value + 1).is_err());
assert_eq!(canary.violations, 1);
}
#[test]
fn test_aslr_randomization() {
use std::collections::HashSet;
use std::thread;
use std::time::Duration;
let mut base_offsets = HashSet::new();
let mut heap_offsets = HashSet::new();
let mut stack_offsets = HashSet::new();
for i in 0..5 {
if i > 0 {
thread::sleep(Duration::from_micros(10));
}
let aslr = AslrState::new();
base_offsets.insert(aslr.base_offset);
heap_offsets.insert(aslr.heap_offset);
stack_offsets.insert(aslr.stack_offset);
}
let has_variation =
base_offsets.len() > 1 || heap_offsets.len() > 1 || stack_offsets.len() > 1;
assert!(
has_variation,
"ASLR should produce varied offsets: base={} heap={} stack={}",
base_offsets.len(),
heap_offsets.len(),
stack_offsets.len()
);
}
#[test]
fn test_aslr_address_translation() {
let aslr = AslrState::new();
let original = 0x1000;
let randomized = aslr.get_base_address(original);
assert_ne!(randomized, original);
assert_eq!(randomized, aslr.get_base_address(original));
}
#[test]
fn test_attestation_token_creation() {
let mut attestation = AttestationState::new();
let token = attestation
.create_token(
"cap1".to_string(),
vec!["read".to_string(), "write".to_string()],
)
.unwrap();
assert_eq!(token.id, "cap1");
assert_eq!(token.permissions.len(), 2);
assert!(token.verify_signature());
}
#[test]
fn test_attestation_delegation() {
let mut attestation = AttestationState::new();
attestation
.create_token(
"parent".to_string(),
vec![
"read".to_string(),
"write".to_string(),
"execute".to_string(),
],
)
.unwrap();
let child = attestation
.delegate("parent", "child".to_string(), vec!["read".to_string()])
.unwrap();
assert_eq!(child.permissions.len(), 1);
assert!(attestation.verify("child", "read").is_ok());
assert!(attestation.verify("child", "write").is_err());
}
#[test]
fn test_attestation_revocation() {
let mut attestation = AttestationState::new();
attestation
.create_token("parent".to_string(), vec!["read".to_string()])
.unwrap();
attestation
.delegate("parent", "child".to_string(), vec!["read".to_string()])
.unwrap();
attestation.revoke("parent").unwrap();
assert!(attestation.verify("parent", "read").is_err());
assert!(attestation.verify("child", "read").is_err());
}
#[test]
fn test_security_context_creation() {
let policy = SecurityPolicy::strict();
let ctx = SecurityContext::new(policy);
assert!(ctx.cfi_state.is_none());
assert!(ctx.canary_state.is_none());
assert!(ctx.aslr_state.is_none());
assert!(ctx.attestation_state.is_none());
}
#[test]
fn test_security_context_enable_cfi() {
let policy = SecurityPolicy::strict();
let mut ctx = SecurityContext::new(policy);
assert!(ctx.enable_cfi(100).is_ok());
assert!(ctx.cfi_state.is_some());
}
#[test]
fn test_security_context_enable_canaries() {
let policy = SecurityPolicy::strict();
let mut ctx = SecurityContext::new(policy);
assert!(ctx.enable_stack_canaries().is_ok());
assert!(ctx.canary_state.is_some());
}
#[test]
fn test_security_context_enable_aslr() {
let policy = SecurityPolicy::strict();
let mut ctx = SecurityContext::new(policy);
assert!(ctx.enable_aslr().is_ok());
assert!(ctx.aslr_state.is_some());
}
#[test]
fn test_security_context_enable_attestation() {
let policy = SecurityPolicy::strict();
let mut ctx = SecurityContext::new(policy);
assert!(ctx.enable_attestation().is_ok());
assert!(ctx.attestation_state.is_some());
}
#[test]
fn test_security_context_disabled_features() {
let policy = SecurityPolicy::permissive();
let mut ctx = SecurityContext::new(policy);
assert!(ctx.enable_cfi(100).is_err());
assert!(ctx.enable_stack_canaries().is_err());
}
#[test]
fn test_capability_token_signature() {
let token = CapabilityToken::new("test".to_string(), vec!["read".to_string()]);
assert!(token.verify_signature());
let mut tampered = token.clone();
tampered.signature[0] ^= 0xFF;
assert!(!tampered.verify_signature());
}
}