use super::types::UnsafeOp;
use syn::{visit::Visit, Block, Expr};
pub struct UnsafeOpCollector {
pub ops: Vec<UnsafeOp>,
}
impl UnsafeOpCollector {
pub fn new() -> Self {
Self { ops: Vec::new() }
}
pub fn collect(block: &Block) -> Vec<UnsafeOp> {
let mut collector = Self::new();
collector.visit_block(block);
collector.ops
}
}
impl Default for UnsafeOpCollector {
fn default() -> Self {
Self::new()
}
}
impl<'ast> Visit<'ast> for UnsafeOpCollector {
fn visit_expr(&mut self, expr: &'ast Expr) {
match expr {
Expr::Call(call) => {
if let Expr::Path(path) = &*call.func {
let path_str = quote::quote!(#path).to_string();
if path_str.contains("transmute") {
self.ops.push(UnsafeOp::Transmute);
} else if !path_str.starts_with("std::")
&& !path_str.starts_with("core::")
&& path.path.segments.len() == 1
{
self.ops.push(UnsafeOp::FFICall);
}
}
}
Expr::MethodCall(method) => {
let method_name = method.method.to_string();
match method_name.as_str() {
"read" | "read_volatile" | "read_unaligned" => {
self.ops.push(UnsafeOp::RawPointerRead);
}
"write" | "write_volatile" | "write_unaligned" => {
self.ops.push(UnsafeOp::RawPointerWrite);
}
"offset" | "add" | "sub" | "wrapping_offset" | "wrapping_add"
| "wrapping_sub" => {
self.ops.push(UnsafeOp::PointerArithmetic);
}
_ => {}
}
}
Expr::Assign(assign) => {
if matches!(&*assign.left, Expr::Unary(unary) if matches!(unary.op, syn::UnOp::Deref(_)))
{
self.ops.push(UnsafeOp::RawPointerWrite);
}
}
Expr::Unary(unary) if matches!(unary.op, syn::UnOp::Deref(_)) => {
self.ops.push(UnsafeOp::RawPointerRead);
}
Expr::Path(path) => {
let path_str = quote::quote!(#path).to_string();
if path_str.chars().all(|c| c.is_uppercase() || c == '_') {
self.ops.push(UnsafeOp::MutableStatic);
}
}
Expr::Field(_field) => {
self.ops.push(UnsafeOp::UnionFieldRead);
}
_ => {}
}
syn::visit::visit_expr(self, expr);
}
}
pub fn classify_unsafe_operations(block: &Block) -> Vec<UnsafeOp> {
UnsafeOpCollector::collect(block)
}
pub struct UnsafeAnalysisResult {
pub has_impure: bool,
pub has_pure_unsafe: bool,
}
pub fn analyze_unsafe_block(block: &Block) -> UnsafeAnalysisResult {
let ops = classify_unsafe_operations(block);
let has_impure = ops.iter().any(|op| {
matches!(
op,
UnsafeOp::FFICall
| UnsafeOp::RawPointerWrite
| UnsafeOp::MutableStatic
| UnsafeOp::UnionFieldWrite
)
});
UnsafeAnalysisResult {
has_impure,
has_pure_unsafe: !has_impure && !ops.is_empty(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_transmute_is_pure() {
let block: Block = parse_quote!({ std::mem::transmute(x) });
let ops = classify_unsafe_operations(&block);
assert!(ops.contains(&UnsafeOp::Transmute));
let result = analyze_unsafe_block(&block);
assert!(!result.has_impure);
assert!(result.has_pure_unsafe);
}
#[test]
fn test_pointer_read_is_pure() {
let block: Block = parse_quote!({ ptr.read() });
let ops = classify_unsafe_operations(&block);
assert!(ops.contains(&UnsafeOp::RawPointerRead));
let result = analyze_unsafe_block(&block);
assert!(!result.has_impure);
assert!(result.has_pure_unsafe);
}
#[test]
fn test_pointer_write_is_impure() {
let block: Block = parse_quote!({ ptr.write(value) });
let ops = classify_unsafe_operations(&block);
assert!(ops.contains(&UnsafeOp::RawPointerWrite));
let result = analyze_unsafe_block(&block);
assert!(result.has_impure);
assert!(!result.has_pure_unsafe);
}
#[test]
fn test_pointer_arithmetic_is_pure() {
let block: Block = parse_quote!({ ptr.offset(1) });
let ops = classify_unsafe_operations(&block);
assert!(ops.contains(&UnsafeOp::PointerArithmetic));
let result = analyze_unsafe_block(&block);
assert!(!result.has_impure);
assert!(result.has_pure_unsafe);
}
}