use crate::analysis::unsafe_ffi_tracker::{RiskLevel, StackFrame};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum RiskFactorType {
RawPointerDereference,
UnsafeDataRace,
InvalidTransmute,
FfiCall,
ManualMemoryManagement,
CrossBoundaryTransfer,
UseAfterFree,
BufferOverflow,
LifetimeViolation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskFactor {
pub factor_type: RiskFactorType,
pub severity: f64,
pub confidence: f64,
pub description: String,
pub source_location: Option<String>,
pub call_stack: Vec<StackFrame>,
pub mitigation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskAssessment {
pub risk_level: RiskLevel,
pub risk_score: f64,
pub risk_factors: Vec<RiskFactor>,
pub confidence_score: f64,
pub mitigation_suggestions: Vec<String>,
pub assessment_timestamp: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UnsafeReport {
pub report_id: String,
pub source: UnsafeSource,
pub risk_assessment: RiskAssessment,
pub dynamic_violations: Vec<DynamicViolation>,
pub related_passports: Vec<String>,
pub memory_context: MemoryContext,
pub generated_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UnsafeSource {
UnsafeBlock {
location: String,
function: String,
file_path: Option<String>,
line_number: Option<u32>,
},
FfiFunction {
library: String,
function: String,
call_site: String,
},
RawPointer {
operation: String,
location: String,
},
Transmute {
from_type: String,
to_type: String,
location: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamicViolation {
pub violation_type: ViolationType,
pub memory_address: usize,
pub memory_size: usize,
pub detected_at: u64,
pub call_stack: Vec<StackFrame>,
pub severity: RiskLevel,
pub context: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ViolationType {
DoubleFree,
UseAfterFree,
BufferOverflow,
InvalidAccess,
DataRace,
FfiBoundaryViolation,
MemoryLeak,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryContext {
pub total_allocated: usize,
pub active_allocations: usize,
pub memory_pressure: MemoryPressureLevel,
pub allocation_patterns: Vec<AllocationPattern>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MemoryPressureLevel {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AllocationPattern {
pub pattern_type: String,
pub frequency: u32,
pub average_size: usize,
pub risk_level: RiskLevel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPassport {
pub passport_id: String,
pub allocation_ptr: usize,
pub size_bytes: usize,
pub status_at_shutdown: PassportStatus,
pub lifecycle_events: Vec<PassportEvent>,
pub risk_assessment: RiskAssessment,
pub created_at: u64,
pub updated_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PassportStatus {
FreedByRust,
HandoverToFfi,
FreedByForeign,
ReclaimedByRust,
InForeignCustody,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PassportEvent {
pub event_type: PassportEventType,
pub timestamp: u64,
pub context: String,
pub call_stack: Vec<StackFrame>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PassportEventType {
AllocatedInRust,
HandoverToFfi,
FreedByForeign,
ReclaimedByRust,
BoundaryAccess,
OwnershipTransfer,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_risk_factor_type_variants() {
let variants = vec![
RiskFactorType::RawPointerDereference,
RiskFactorType::UnsafeDataRace,
RiskFactorType::InvalidTransmute,
RiskFactorType::FfiCall,
RiskFactorType::ManualMemoryManagement,
RiskFactorType::CrossBoundaryTransfer,
RiskFactorType::UseAfterFree,
RiskFactorType::BufferOverflow,
RiskFactorType::LifetimeViolation,
];
for variant in variants {
let debug_str = format!("{variant:?}");
assert!(
!debug_str.is_empty(),
"Variant should have debug representation"
);
}
}
#[test]
fn test_risk_factor_creation() {
let factor = RiskFactor {
factor_type: RiskFactorType::BufferOverflow,
severity: 0.9,
confidence: 0.85,
description: "Test buffer overflow".to_string(),
source_location: Some("test.rs:10".to_string()),
call_stack: vec![],
mitigation: "Use bounds checking".to_string(),
};
assert_eq!(factor.severity, 0.9, "Severity should match");
assert_eq!(factor.confidence, 0.85, "Confidence should match");
assert!(
factor.source_location.is_some(),
"Source location should be present"
);
}
#[test]
fn test_risk_assessment_creation() {
let assessment = RiskAssessment {
risk_level: RiskLevel::High,
risk_score: 75.0,
risk_factors: vec![],
confidence_score: 0.9,
mitigation_suggestions: vec!["Review code".to_string()],
assessment_timestamp: 1000,
};
assert_eq!(
assessment.risk_level,
RiskLevel::High,
"Risk level should be High"
);
assert_eq!(assessment.risk_score, 75.0, "Risk score should match");
assert_eq!(
assessment.mitigation_suggestions.len(),
1,
"Should have one suggestion"
);
}
#[test]
fn test_unsafe_source_variants() {
let block = UnsafeSource::UnsafeBlock {
location: "test.rs:10".to_string(),
function: "test_fn".to_string(),
file_path: Some("test.rs".to_string()),
line_number: Some(10),
};
let ffi = UnsafeSource::FfiFunction {
library: "libc".to_string(),
function: "malloc".to_string(),
call_site: "test.rs:20".to_string(),
};
let raw = UnsafeSource::RawPointer {
operation: "deref".to_string(),
location: "0x1000".to_string(),
};
let transmute = UnsafeSource::Transmute {
from_type: "u8".to_string(),
to_type: "i8".to_string(),
location: "test.rs:30".to_string(),
};
assert!(matches!(block, UnsafeSource::UnsafeBlock { .. }));
assert!(matches!(ffi, UnsafeSource::FfiFunction { .. }));
assert!(matches!(raw, UnsafeSource::RawPointer { .. }));
assert!(matches!(transmute, UnsafeSource::Transmute { .. }));
}
#[test]
fn test_dynamic_violation_creation() {
let violation = DynamicViolation {
violation_type: ViolationType::UseAfterFree,
memory_address: 0x1000,
memory_size: 1024,
detected_at: 1000,
call_stack: vec![],
severity: RiskLevel::Critical,
context: "Use after free detected".to_string(),
};
assert_eq!(
violation.memory_address, 0x1000,
"Memory address should match"
);
assert_eq!(violation.memory_size, 1024, "Memory size should match");
assert_eq!(
violation.severity,
RiskLevel::Critical,
"Severity should be Critical"
);
}
#[test]
fn test_violation_type_equality() {
assert_eq!(ViolationType::DoubleFree, ViolationType::DoubleFree);
assert_eq!(ViolationType::UseAfterFree, ViolationType::UseAfterFree);
assert_eq!(ViolationType::BufferOverflow, ViolationType::BufferOverflow);
assert_eq!(ViolationType::InvalidAccess, ViolationType::InvalidAccess);
assert_eq!(ViolationType::DataRace, ViolationType::DataRace);
assert_eq!(
ViolationType::FfiBoundaryViolation,
ViolationType::FfiBoundaryViolation
);
assert_ne!(ViolationType::DoubleFree, ViolationType::UseAfterFree);
}
#[test]
fn test_memory_context_creation() {
let context = MemoryContext {
total_allocated: 1024 * 1024,
active_allocations: 10,
memory_pressure: MemoryPressureLevel::Medium,
allocation_patterns: vec![],
};
assert_eq!(
context.total_allocated,
1024 * 1024,
"Total allocated should match"
);
assert_eq!(
context.active_allocations, 10,
"Active allocations should match"
);
}
#[test]
fn test_memory_pressure_level() {
let levels = [
MemoryPressureLevel::Low,
MemoryPressureLevel::Medium,
MemoryPressureLevel::High,
MemoryPressureLevel::Critical,
];
for (i, level) in levels.iter().enumerate() {
for (j, other) in levels.iter().enumerate() {
if i == j {
assert_eq!(level, other, "Same levels should be equal");
} else {
assert_ne!(level, other, "Different levels should not be equal");
}
}
}
}
#[test]
fn test_allocation_pattern_creation() {
let pattern = AllocationPattern {
pattern_type: "repeated".to_string(),
frequency: 100,
average_size: 256,
risk_level: RiskLevel::Medium,
};
assert_eq!(pattern.frequency, 100, "Frequency should match");
assert_eq!(pattern.average_size, 256, "Average size should match");
}
#[test]
fn test_memory_passport_creation() {
let passport = MemoryPassport {
passport_id: "passport_123".to_string(),
allocation_ptr: 0x1000,
size_bytes: 1024,
status_at_shutdown: PassportStatus::Unknown,
lifecycle_events: vec![],
risk_assessment: RiskAssessment {
risk_level: RiskLevel::Low,
risk_score: 10.0,
risk_factors: vec![],
confidence_score: 0.5,
mitigation_suggestions: vec![],
assessment_timestamp: 0,
},
created_at: 1000,
updated_at: 1000,
};
assert_eq!(
passport.passport_id, "passport_123",
"Passport ID should match"
);
assert_eq!(
passport.allocation_ptr, 0x1000,
"Allocation pointer should match"
);
assert_eq!(passport.size_bytes, 1024, "Size should match");
}
#[test]
fn test_passport_status_variants() {
let statuses = vec![
PassportStatus::FreedByRust,
PassportStatus::HandoverToFfi,
PassportStatus::FreedByForeign,
PassportStatus::ReclaimedByRust,
PassportStatus::InForeignCustody,
PassportStatus::Unknown,
];
for status in &statuses {
let debug_str = format!("{status:?}");
assert!(
!debug_str.is_empty(),
"Status should have debug representation"
);
}
}
#[test]
fn test_passport_event_creation() {
let event = PassportEvent {
event_type: PassportEventType::HandoverToFfi,
timestamp: 1000,
context: "ffi_transfer".to_string(),
call_stack: vec![],
metadata: HashMap::new(),
};
assert_eq!(event.timestamp, 1000, "Timestamp should match");
assert_eq!(event.context, "ffi_transfer", "Context should match");
}
#[test]
fn test_passport_event_type_variants() {
let event_types = vec![
PassportEventType::AllocatedInRust,
PassportEventType::HandoverToFfi,
PassportEventType::FreedByForeign,
PassportEventType::ReclaimedByRust,
PassportEventType::BoundaryAccess,
PassportEventType::OwnershipTransfer,
];
for event_type in &event_types {
let debug_str = format!("{event_type:?}");
assert!(
!debug_str.is_empty(),
"Event type should have debug representation"
);
}
}
#[test]
fn test_unsafe_report_creation() {
let report = UnsafeReport {
report_id: "UNSAFE-UB-123".to_string(),
source: UnsafeSource::UnsafeBlock {
location: "test.rs".to_string(),
function: "test".to_string(),
file_path: None,
line_number: None,
},
risk_assessment: RiskAssessment {
risk_level: RiskLevel::Medium,
risk_score: 50.0,
risk_factors: vec![],
confidence_score: 0.8,
mitigation_suggestions: vec![],
assessment_timestamp: 0,
},
dynamic_violations: vec![],
related_passports: vec![],
memory_context: MemoryContext {
total_allocated: 0,
active_allocations: 0,
memory_pressure: MemoryPressureLevel::Low,
allocation_patterns: vec![],
},
generated_at: 1000,
};
assert_eq!(report.report_id, "UNSAFE-UB-123", "Report ID should match");
assert_eq!(
report.generated_at, 1000,
"Generated timestamp should match"
);
}
#[test]
fn test_risk_factor_edge_values() {
let zero_factor = RiskFactor {
factor_type: RiskFactorType::UseAfterFree,
severity: 0.0,
confidence: 0.0,
description: String::new(),
source_location: None,
call_stack: vec![],
mitigation: String::new(),
};
let max_factor = RiskFactor {
factor_type: RiskFactorType::BufferOverflow,
severity: 1.0,
confidence: 1.0,
description: "x".repeat(1000),
source_location: Some("x".repeat(1000)),
call_stack: vec![],
mitigation: "x".repeat(1000),
};
assert_eq!(zero_factor.severity, 0.0, "Zero severity should be valid");
assert_eq!(max_factor.severity, 1.0, "Max severity should be valid");
assert_eq!(
max_factor.description.len(),
1000,
"Long description should be preserved"
);
}
#[test]
fn test_serialization() {
let assessment = RiskAssessment {
risk_level: RiskLevel::High,
risk_score: 75.0,
risk_factors: vec![],
confidence_score: 0.9,
mitigation_suggestions: vec!["test".to_string()],
assessment_timestamp: 1000,
};
let json = serde_json::to_string(&assessment);
assert!(json.is_ok(), "Should serialize to JSON");
let deserialized: Result<RiskAssessment, _> = serde_json::from_str(&json.unwrap());
assert!(deserialized.is_ok(), "Should deserialize from JSON");
}
}