use crate::error::SklearsError;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
pub type Result<T> = std::result::Result<T, SklearsError>;
#[derive(Debug, Clone)]
pub struct UnsafeAuditConfig {
pub scan_paths: Vec<PathBuf>,
pub exclude_paths: Vec<PathBuf>,
pub max_unsafe_per_file: usize,
pub require_justification: bool,
pub strict_mode: bool,
pub allowed_patterns: Vec<UnsafePattern>,
}
#[derive(Debug, Clone)]
pub struct UnsafePattern {
pub name: String,
pub signatures: Vec<String>,
pub justification: String,
pub preconditions: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct UnsafeAuditReport {
pub passed: bool,
pub files_scanned: usize,
pub files_with_unsafe: usize,
pub total_unsafe_blocks: usize,
pub findings: HashMap<PathBuf, Vec<UnsafeFinding>>,
pub summary: UnsafeSummary,
pub recommendations: Vec<SafetyRecommendation>,
}
#[derive(Debug, Clone)]
pub struct UnsafeFinding {
pub file: PathBuf,
pub line: usize,
pub column: Option<usize>,
pub unsafe_type: UnsafeType,
pub code_snippet: String,
pub justification: Option<String>,
pub is_known_safe: bool,
pub severity: SafetySeverity,
pub suggestions: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum UnsafeType {
RawPointerDeref,
UnsafeFunctionCall,
MutableStatic,
UnionFieldAccess,
Transmute,
InlineAssembly,
UnsafeBlock,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SafetySeverity {
Info,
Low,
Medium,
High,
Critical,
}
impl PartialOrd for SafetySeverity {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SafetySeverity {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use SafetySeverity::*;
match (self, other) {
(Info, Info) => std::cmp::Ordering::Equal,
(Info, _) => std::cmp::Ordering::Less,
(_, Info) => std::cmp::Ordering::Greater,
(Low, Low) => std::cmp::Ordering::Equal,
(Low, _) => std::cmp::Ordering::Less,
(_, Low) => std::cmp::Ordering::Greater,
(Medium, Medium) => std::cmp::Ordering::Equal,
(Medium, _) => std::cmp::Ordering::Less,
(_, Medium) => std::cmp::Ordering::Greater,
(High, High) => std::cmp::Ordering::Equal,
(High, Critical) => std::cmp::Ordering::Less,
(Critical, High) => std::cmp::Ordering::Greater,
(Critical, Critical) => std::cmp::Ordering::Equal,
}
}
}
#[derive(Debug, Clone)]
pub struct UnsafeSummary {
pub types_breakdown: HashMap<UnsafeType, usize>,
pub severity_breakdown: HashMap<SafetySeverity, usize>,
pub top_unsafe_files: Vec<(PathBuf, usize)>,
pub common_patterns: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct SafetyRecommendation {
pub recommendation_type: RecommendationType,
pub description: String,
pub affected_files: Vec<PathBuf>,
pub effort: EffortLevel,
pub safety_impact: SafetyImpact,
}
#[derive(Debug, Clone)]
pub enum RecommendationType {
ReplaceWithSafe,
ImproveDocumentation,
ReduceScope,
AddSafetyChecks,
Refactor,
UseSaferAbstractions,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum EffortLevel {
Minimal,
Low,
Medium,
High,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum SafetyImpact {
Critical,
High,
Medium,
Low,
}
impl Default for UnsafeAuditConfig {
fn default() -> Self {
Self {
scan_paths: vec![PathBuf::from("src")],
exclude_paths: vec![
PathBuf::from("target"),
PathBuf::from("benches"),
PathBuf::from("examples"),
],
max_unsafe_per_file: 5,
require_justification: true,
strict_mode: false,
allowed_patterns: Self::default_safe_patterns(),
}
}
}
impl UnsafeAuditConfig {
fn default_safe_patterns() -> Vec<UnsafePattern> {
vec![
UnsafePattern {
name: "SIMD Operations".to_string(),
signatures: vec!["std::simd::".to_string(), "std::arch::".to_string()],
justification: "SIMD operations are generally safe when used correctly".to_string(),
preconditions: vec![
"Input arrays are properly aligned".to_string(),
"Array bounds are checked".to_string(),
],
},
UnsafePattern {
name: "Slice from Raw Parts".to_string(),
signatures: vec![
"std::slice::from_raw_parts".to_string(),
"std::slice::from_raw_parts_mut".to_string(),
],
justification: "Safe when pointer and length are valid".to_string(),
preconditions: vec![
"Pointer is non-null and properly aligned".to_string(),
"Length is accurate and doesn't overflow".to_string(),
"Memory is valid for the lifetime".to_string(),
],
},
UnsafePattern {
name: "FFI Bindings".to_string(),
signatures: vec!["extern".to_string()],
justification: "FFI calls to well-tested C libraries".to_string(),
preconditions: vec![
"C library is memory-safe".to_string(),
"Parameters are validated".to_string(),
"Return values are checked".to_string(),
],
},
]
}
}
pub struct UnsafeAuditor {
config: UnsafeAuditConfig,
}
impl UnsafeAuditor {
pub fn new() -> Self {
Self {
config: UnsafeAuditConfig::default(),
}
}
pub fn with_config(config: UnsafeAuditConfig) -> Self {
Self { config }
}
pub fn audit<P: AsRef<Path>>(&self, root_path: P) -> Result<UnsafeAuditReport> {
let root_path = root_path.as_ref();
let mut findings = HashMap::new();
let mut files_scanned = 0;
let mut total_unsafe_blocks = 0;
for scan_path in &self.config.scan_paths {
let full_path = root_path.join(scan_path);
if full_path.exists() {
self.scan_directory(
&full_path,
&mut findings,
&mut files_scanned,
&mut total_unsafe_blocks,
)?;
}
}
let files_with_unsafe = findings.len();
let passed = self.evaluate_audit_results(&findings);
let summary = self.generate_summary(&findings);
let recommendations = self.generate_recommendations(&findings);
Ok(UnsafeAuditReport {
passed,
files_scanned,
files_with_unsafe,
total_unsafe_blocks,
findings,
summary,
recommendations,
})
}
fn scan_directory(
&self,
dir: &Path,
findings: &mut HashMap<PathBuf, Vec<UnsafeFinding>>,
files_scanned: &mut usize,
total_unsafe: &mut usize,
) -> Result<()> {
if self.should_exclude(dir) {
return Ok(());
}
let entries = fs::read_dir(dir)
.map_err(|e| SklearsError::InvalidInput(format!("Failed to read directory: {e}")))?;
for entry in entries {
let entry = entry
.map_err(|e| SklearsError::InvalidInput(format!("Failed to read entry: {e}")))?;
let path = entry.path();
if path.is_dir() {
self.scan_directory(&path, findings, files_scanned, total_unsafe)?;
} else if path.extension().map(|ext| ext == "rs").unwrap_or(false)
&& !self.should_exclude(&path)
{
*files_scanned += 1;
let file_findings = self.scan_file(&path)?;
*total_unsafe += file_findings.len();
if !file_findings.is_empty() {
findings.insert(path, file_findings);
}
}
}
Ok(())
}
fn scan_file(&self, file_path: &Path) -> Result<Vec<UnsafeFinding>> {
let content = fs::read_to_string(file_path)
.map_err(|e| SklearsError::InvalidInput(format!("Failed to read file: {e}")))?;
let mut findings = Vec::new();
let lines: Vec<&str> = content.lines().collect();
for (line_num, line) in lines.iter().enumerate() {
if let Some(finding) = self.analyze_line(file_path, line_num + 1, line) {
findings.push(finding);
}
}
findings.extend(self.analyze_unsafe_blocks(file_path, &content)?);
Ok(findings)
}
fn analyze_line(&self, file_path: &Path, line_num: usize, line: &str) -> Option<UnsafeFinding> {
let trimmed = line.trim();
if trimmed.starts_with("unsafe") {
let unsafe_type = self.determine_unsafe_type(line);
let severity = self.assess_severity(&unsafe_type, line);
let is_known_safe = self.is_known_safe_pattern(line);
let justification = self.extract_justification(line);
let suggestions = self.generate_suggestions(&unsafe_type, line);
Some(UnsafeFinding {
file: file_path.to_path_buf(),
line: line_num,
column: line.find("unsafe"),
unsafe_type,
code_snippet: line.to_string(),
justification,
is_known_safe,
severity,
suggestions,
})
} else {
None
}
}
fn analyze_unsafe_blocks(&self, file_path: &Path, content: &str) -> Result<Vec<UnsafeFinding>> {
let mut findings = Vec::new();
let mut in_unsafe_block = false;
let mut block_start = 0;
let mut brace_count = 0;
for (line_num, line) in content.lines().enumerate() {
if line.contains("unsafe {") {
in_unsafe_block = true;
block_start = line_num + 1;
brace_count = 1;
} else if in_unsafe_block {
brace_count += line.matches('{').count();
brace_count -= line.matches('}').count();
if brace_count == 0 {
in_unsafe_block = false;
let block_lines: Vec<&str> = content
.lines()
.skip(block_start - 1)
.take(line_num - block_start + 2)
.collect();
let block_content = block_lines.join("\n");
let unsafe_type = UnsafeType::UnsafeBlock;
let severity = self.assess_block_severity(&block_content);
let is_known_safe = self.is_known_safe_pattern(&block_content);
let justification = self.extract_block_justification(&block_content);
let suggestions = self.generate_block_suggestions(&block_content);
findings.push(UnsafeFinding {
file: file_path.to_path_buf(),
line: block_start,
column: None,
unsafe_type,
code_snippet: block_content,
justification,
is_known_safe,
severity,
suggestions,
});
}
}
}
Ok(findings)
}
fn determine_unsafe_type(&self, line: &str) -> UnsafeType {
if line.contains("transmute") {
UnsafeType::Transmute
} else if line.contains("asm!") {
UnsafeType::InlineAssembly
} else if line.contains("static mut") {
UnsafeType::MutableStatic
} else if line.contains("union") {
UnsafeType::UnionFieldAccess
} else if line.contains("*ptr")
|| (line.contains("*") && (line.contains("as *") || line.contains("->")))
{
UnsafeType::RawPointerDeref
} else if line.contains("func()")
|| (line.contains("(")
&& line.contains(")")
&& !line.contains("asm!")
&& !line.contains("transmute"))
{
UnsafeType::UnsafeFunctionCall
} else {
UnsafeType::UnsafeBlock
}
}
fn assess_severity(&self, unsafe_type: &UnsafeType, code: &str) -> SafetySeverity {
match unsafe_type {
UnsafeType::Transmute => SafetySeverity::Critical,
UnsafeType::InlineAssembly => SafetySeverity::Critical,
UnsafeType::MutableStatic => SafetySeverity::High,
UnsafeType::RawPointerDeref => {
if code.contains("null") || code.contains("dangling") {
SafetySeverity::Critical
} else {
SafetySeverity::High
}
}
UnsafeType::UnsafeFunctionCall => {
if self.is_known_safe_pattern(code) {
SafetySeverity::Low
} else {
SafetySeverity::Medium
}
}
UnsafeType::UnionFieldAccess => SafetySeverity::Medium,
UnsafeType::UnsafeBlock => SafetySeverity::Medium,
}
}
fn assess_block_severity(&self, block_content: &str) -> SafetySeverity {
let critical_patterns = ["transmute", "asm!", "null"];
let high_patterns = ["static mut", "*mut", "*const"];
for pattern in &critical_patterns {
if block_content.contains(pattern) {
return SafetySeverity::Critical;
}
}
for pattern in &high_patterns {
if block_content.contains(pattern) {
return SafetySeverity::High;
}
}
SafetySeverity::Medium
}
fn is_known_safe_pattern(&self, code: &str) -> bool {
for pattern in &self.config.allowed_patterns {
for signature in &pattern.signatures {
if code.contains(signature) {
return true;
}
}
}
false
}
fn extract_justification(&self, line: &str) -> Option<String> {
if let Some(comment_start) = line.find("//") {
let comment = &line[comment_start + 2..].trim();
if !comment.is_empty() {
Some(comment.to_string())
} else {
None
}
} else {
None
}
}
fn extract_block_justification(&self, block: &str) -> Option<String> {
let lines: Vec<&str> = block.lines().collect();
for line in lines {
if let Some(comment_start) = line.find("//") {
let comment = &line[comment_start + 2..].trim();
if comment.to_lowercase().contains("safety")
|| comment.to_lowercase().contains("justification")
|| comment.to_lowercase().contains("safe because")
{
return Some(comment.to_string());
}
}
}
None
}
fn generate_suggestions(&self, unsafe_type: &UnsafeType, _code: &str) -> Vec<String> {
match unsafe_type {
UnsafeType::RawPointerDeref => vec![
"Consider using safe array indexing with bounds checking".to_string(),
"Use slice methods instead of raw pointer arithmetic".to_string(),
"Add explicit null pointer checks".to_string(),
],
UnsafeType::UnsafeFunctionCall => vec![
"Document why this function call is safe".to_string(),
"Consider wrapping in a safe abstraction".to_string(),
"Validate all parameters before calling".to_string(),
],
UnsafeType::Transmute => vec![
"Use safe type conversion methods instead".to_string(),
"Consider using union types for type punning".to_string(),
"Add size and alignment assertions".to_string(),
],
UnsafeType::MutableStatic => vec![
"Use thread-local storage or synchronization".to_string(),
"Consider using lazy_static or once_cell".to_string(),
"Document thread safety guarantees".to_string(),
],
UnsafeType::InlineAssembly => vec![
"Document assembly code thoroughly".to_string(),
"Consider using intrinsics instead".to_string(),
"Add extensive testing for different platforms".to_string(),
],
UnsafeType::UnionFieldAccess => vec![
"Document which field is active".to_string(),
"Use tagged unions for safety".to_string(),
"Consider using enums instead".to_string(),
],
UnsafeType::UnsafeBlock => vec![
"Minimize the scope of the unsafe block".to_string(),
"Document all safety invariants".to_string(),
"Add safety assertions where possible".to_string(),
],
}
}
fn generate_block_suggestions(&self, block: &str) -> Vec<String> {
let mut suggestions = Vec::new();
if !block.contains("//") {
suggestions.push("Add comments explaining why this unsafe code is safe".to_string());
}
if block.lines().count() > 10 {
suggestions
.push("Consider breaking this large unsafe block into smaller pieces".to_string());
}
if block.contains("panic!") {
suggestions.push("Avoid panicking inside unsafe blocks".to_string());
}
suggestions.push("Add debug assertions to validate safety invariants".to_string());
suggestions.push("Consider creating a safe wrapper function".to_string());
suggestions
}
fn should_exclude(&self, path: &Path) -> bool {
for exclude_path in &self.config.exclude_paths {
if path.ends_with(exclude_path)
|| path
.components()
.any(|c| c.as_os_str() == exclude_path.as_os_str())
{
return true;
}
}
false
}
fn evaluate_audit_results(&self, findings: &HashMap<PathBuf, Vec<UnsafeFinding>>) -> bool {
if self.config.strict_mode {
return findings.is_empty();
}
for file_findings in findings.values() {
if file_findings.len() > self.config.max_unsafe_per_file {
return false;
}
if self.config.require_justification {
for finding in file_findings {
if finding.severity >= SafetySeverity::High && finding.justification.is_none() {
return false;
}
}
}
}
true
}
fn generate_summary(&self, findings: &HashMap<PathBuf, Vec<UnsafeFinding>>) -> UnsafeSummary {
let mut types_breakdown = HashMap::new();
let mut severity_breakdown = HashMap::new();
let mut file_counts = Vec::new();
let mut patterns = HashSet::new();
for (file, file_findings) in findings {
file_counts.push((file.clone(), file_findings.len()));
for finding in file_findings {
*types_breakdown
.entry(finding.unsafe_type.clone())
.or_insert(0) += 1;
*severity_breakdown
.entry(finding.severity.clone())
.or_insert(0) += 1;
if finding.code_snippet.contains("transmute") {
patterns.insert("transmute usage".to_string());
}
if finding.code_snippet.contains("*mut") || finding.code_snippet.contains("*const")
{
patterns.insert("raw pointer usage".to_string());
}
if finding.code_snippet.contains("std::slice::from_raw_parts") {
patterns.insert("slice from raw parts".to_string());
}
}
}
file_counts.sort_by_key(|item| std::cmp::Reverse(item.1));
let top_unsafe_files = file_counts.into_iter().take(10).collect();
UnsafeSummary {
types_breakdown,
severity_breakdown,
top_unsafe_files,
common_patterns: patterns.into_iter().collect(),
}
}
fn generate_recommendations(
&self,
findings: &HashMap<PathBuf, Vec<UnsafeFinding>>,
) -> Vec<SafetyRecommendation> {
let mut recommendations = Vec::new();
let mut files_with_high_severity = Vec::new();
let mut files_without_justification = Vec::new();
let mut files_with_many_unsafe = Vec::new();
for (file, file_findings) in findings {
let high_severity_count = file_findings
.iter()
.filter(|f| f.severity >= SafetySeverity::High)
.count();
let missing_justification_count = file_findings
.iter()
.filter(|f| f.severity >= SafetySeverity::Medium && f.justification.is_none())
.count();
if high_severity_count > 0 {
files_with_high_severity.push(file.clone());
}
if missing_justification_count > 0 {
files_without_justification.push(file.clone());
}
if file_findings.len() > self.config.max_unsafe_per_file {
files_with_many_unsafe.push(file.clone());
}
}
if !files_with_high_severity.is_empty() {
recommendations.push(SafetyRecommendation {
recommendation_type: RecommendationType::ReplaceWithSafe,
description: "Replace high-severity unsafe code with safe alternatives".to_string(),
affected_files: files_with_high_severity,
effort: EffortLevel::High,
safety_impact: SafetyImpact::Critical,
});
}
if !files_without_justification.is_empty() {
recommendations.push(SafetyRecommendation {
recommendation_type: RecommendationType::ImproveDocumentation,
description: "Add safety justifications for all unsafe code".to_string(),
affected_files: files_without_justification,
effort: EffortLevel::Low,
safety_impact: SafetyImpact::Medium,
});
}
if !files_with_many_unsafe.is_empty() {
recommendations.push(SafetyRecommendation {
recommendation_type: RecommendationType::Refactor,
description: "Refactor files with excessive unsafe code".to_string(),
affected_files: files_with_many_unsafe,
effort: EffortLevel::High,
safety_impact: SafetyImpact::High,
});
}
recommendations
}
pub fn config(&self) -> &UnsafeAuditConfig {
&self.config
}
pub fn set_config(&mut self, config: UnsafeAuditConfig) {
self.config = config;
}
}
impl Default for UnsafeAuditor {
fn default() -> Self {
Self::new()
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unsafe_audit_config_default() {
let config = UnsafeAuditConfig::default();
assert_eq!(config.max_unsafe_per_file, 5);
assert!(config.require_justification);
assert!(!config.strict_mode);
assert!(!config.allowed_patterns.is_empty());
}
#[test]
fn test_unsafe_auditor_creation() {
let auditor = UnsafeAuditor::new();
assert_eq!(auditor.config().max_unsafe_per_file, 5);
}
#[test]
fn test_determine_unsafe_type() {
let auditor = UnsafeAuditor::new();
assert_eq!(
auditor.determine_unsafe_type("unsafe { *ptr }"),
UnsafeType::RawPointerDeref
);
assert_eq!(
auditor.determine_unsafe_type("unsafe { transmute(x) }"),
UnsafeType::Transmute
);
assert_eq!(
auditor.determine_unsafe_type("unsafe { static mut X }"),
UnsafeType::MutableStatic
);
assert_eq!(
auditor.determine_unsafe_type("unsafe { asm!() }"),
UnsafeType::InlineAssembly
);
assert_eq!(
auditor.determine_unsafe_type("unsafe { func() }"),
UnsafeType::UnsafeFunctionCall
);
}
#[test]
fn test_assess_severity() {
let auditor = UnsafeAuditor::new();
assert_eq!(
auditor.assess_severity(&UnsafeType::Transmute, "transmute"),
SafetySeverity::Critical
);
assert_eq!(
auditor.assess_severity(&UnsafeType::InlineAssembly, "asm!"),
SafetySeverity::Critical
);
assert_eq!(
auditor.assess_severity(&UnsafeType::MutableStatic, "static mut"),
SafetySeverity::High
);
assert_eq!(
auditor.assess_severity(&UnsafeType::RawPointerDeref, "*null"),
SafetySeverity::Critical
);
assert_eq!(
auditor.assess_severity(&UnsafeType::RawPointerDeref, "*ptr"),
SafetySeverity::High
);
}
#[test]
fn test_is_known_safe_pattern() {
let auditor = UnsafeAuditor::new();
assert!(auditor.is_known_safe_pattern("std::simd::f32x4::new()"));
assert!(auditor.is_known_safe_pattern("std::slice::from_raw_parts(ptr, len)"));
assert!(!auditor.is_known_safe_pattern("transmute(x)"));
}
#[test]
fn test_extract_justification() {
let auditor = UnsafeAuditor::new();
let result =
auditor.extract_justification("unsafe { *ptr } // SAFETY: ptr is guaranteed non-null");
assert_eq!(
result,
Some("SAFETY: ptr is guaranteed non-null".to_string())
);
let result = auditor.extract_justification("unsafe { *ptr }");
assert_eq!(result, None);
}
#[test]
fn test_generate_suggestions() {
let auditor = UnsafeAuditor::new();
let suggestions = auditor.generate_suggestions(&UnsafeType::RawPointerDeref, "*ptr");
assert!(!suggestions.is_empty());
assert!(suggestions.iter().any(|s| s.contains("bounds checking")));
let suggestions = auditor.generate_suggestions(&UnsafeType::Transmute, "transmute");
assert!(suggestions
.iter()
.any(|s| s.contains("safe type conversion")));
}
#[test]
fn test_should_exclude() {
let config = UnsafeAuditConfig {
exclude_paths: vec![PathBuf::from("target"), PathBuf::from("benches")],
..Default::default()
};
let auditor = UnsafeAuditor::with_config(config);
assert!(auditor.should_exclude(Path::new("target/debug/foo")));
assert!(auditor.should_exclude(Path::new("benches/benchmark.rs")));
assert!(!auditor.should_exclude(Path::new("src/lib.rs")));
}
#[test]
fn test_unsafe_finding_creation() {
let finding = UnsafeFinding {
file: PathBuf::from("test.rs"),
line: 10,
column: Some(5),
unsafe_type: UnsafeType::RawPointerDeref,
code_snippet: "unsafe { *ptr }".to_string(),
justification: Some("ptr is non-null".to_string()),
is_known_safe: false,
severity: SafetySeverity::High,
suggestions: vec!["Use safe indexing".to_string()],
};
assert_eq!(finding.file, PathBuf::from("test.rs"));
assert_eq!(finding.line, 10);
assert_eq!(finding.unsafe_type, UnsafeType::RawPointerDeref);
assert_eq!(finding.severity, SafetySeverity::High);
}
#[test]
fn test_safety_severity_ordering() {
assert!(SafetySeverity::Critical > SafetySeverity::High);
assert!(SafetySeverity::High > SafetySeverity::Medium);
assert!(SafetySeverity::Medium > SafetySeverity::Low);
assert!(SafetySeverity::Low > SafetySeverity::Info);
}
#[test]
fn test_effort_level_ordering() {
assert!(EffortLevel::High > EffortLevel::Medium);
assert!(EffortLevel::Medium > EffortLevel::Low);
assert!(EffortLevel::Low > EffortLevel::Minimal);
}
}