#![forbid(unsafe_code)]
#![warn(missing_docs)]
use std::collections::BTreeSet;
use syn::{visit::Visit, ExprPath, ExprUnsafe, ItemFn, Path};
#[derive(Debug, Clone)]
pub struct Policy {
pub denied_paths: BTreeSet<String>,
pub denied_prefixes: BTreeSet<String>,
pub deny_unsafe: bool,
}
impl Policy {
pub fn empty() -> Self {
Self {
denied_paths: BTreeSet::new(),
denied_prefixes: BTreeSet::new(),
deny_unsafe: false,
}
}
pub fn deny_compute_impurity() -> Self {
let mut denied_paths = BTreeSet::new();
for p in [
"std::time::Instant::now",
"std::time::SystemTime::now",
"std::time::UNIX_EPOCH",
"chrono::Utc::now",
"chrono::Local::now",
"minstant::Instant::now",
"quanta::Clock::now",
"coarsetime::Instant::now",
"instant::Instant::now",
"rand::random",
"rand::thread_rng",
"rand::rngs::OsRng",
"rand::rngs::ThreadRng",
"getrandom::getrandom",
"getrandom::fill", "rdrand::RdRand",
"std::io::stdin",
"std::io::stdout",
"std::io::stderr",
] {
denied_paths.insert(p.to_string());
}
let mut denied_prefixes = BTreeSet::new();
for p in [
"std::fs",
"std::net",
"std::process",
"std::env",
"tokio::fs",
"tokio::net",
"tokio::io",
"tokio::time",
"async_std::fs",
"async_std::net",
"async_std::io",
"async_std::task",
"mio",
"socket2",
"libc",
] {
denied_prefixes.insert(p.to_string());
}
Self {
denied_paths,
denied_prefixes,
deny_unsafe: true,
}
}
}
impl Default for Policy {
fn default() -> Self {
Self::deny_compute_impurity()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PurityViolation {
pub denied_path: String,
pub site: String,
pub reason: &'static str,
}
pub fn check_purity(item: &ItemFn, policy: &Policy) -> Vec<PurityViolation> {
let mut visitor = PurityVisitor {
policy,
violations: Vec::new(),
};
visitor.visit_item_fn(item);
visitor.violations
}
pub fn check_purity_default(item: &ItemFn) -> Vec<PurityViolation> {
check_purity(item, &Policy::deny_compute_impurity())
}
struct PurityVisitor<'p> {
policy: &'p Policy,
violations: Vec<PurityViolation>,
}
impl<'ast, 'p> Visit<'ast> for PurityVisitor<'p> {
fn visit_expr_path(&mut self, node: &'ast ExprPath) {
let path_str = path_to_string(&node.path);
if let Some((denied, kind)) = self.match_against_deny_list(&node.path, &path_str) {
self.violations.push(PurityViolation {
denied_path: denied.to_string(),
site: format!("{} ({kind})", path_str),
reason: classify_reason(denied),
});
}
syn::visit::visit_expr_path(self, node);
}
fn visit_expr_unsafe(&mut self, node: &'ast ExprUnsafe) {
if self.policy.deny_unsafe {
self.violations.push(PurityViolation {
denied_path: "unsafe-block".to_string(),
site: "unsafe { ... }".to_string(),
reason: "unsafe",
});
}
syn::visit::visit_expr_unsafe(self, node);
}
}
#[derive(Copy, Clone)]
enum MatchKind {
Exact,
SingleIdentSuffix,
Prefix,
}
impl core::fmt::Display for MatchKind {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(match self {
MatchKind::Exact => "exact",
MatchKind::SingleIdentSuffix => "imported-ident",
MatchKind::Prefix => "prefix",
})
}
}
impl<'p> PurityVisitor<'p> {
fn match_against_deny_list<'d>(
&'d self,
path: &Path,
path_str: &str,
) -> Option<(&'d str, MatchKind)> {
if let Some(entry) = self.policy.denied_paths.get(path_str) {
return Some((entry.as_str(), MatchKind::Exact));
}
if path.segments.len() == 1 {
let ident = &path.segments[0].ident;
let needle = format!("::{ident}");
for denied in &self.policy.denied_paths {
if denied.ends_with(&needle) {
return Some((denied.as_str(), MatchKind::SingleIdentSuffix));
}
}
}
for prefix in &self.policy.denied_prefixes {
if path_str == prefix.as_str() || path_str.starts_with(&format!("{prefix}::")) {
return Some((prefix.as_str(), MatchKind::Prefix));
}
}
None
}
}
fn path_to_string(path: &Path) -> String {
let mut out = String::new();
if path.leading_colon.is_some() {
out.push_str("::");
}
let segs: Vec<String> = path.segments.iter().map(|s| s.ident.to_string()).collect();
out.push_str(&segs.join("::"));
out
}
fn classify_reason(denied: &str) -> &'static str {
if denied == "unsafe-block" {
return "unsafe";
}
if denied.contains("time::")
|| denied.contains("chrono::")
|| denied.contains("minstant::")
|| denied.contains("quanta::")
|| denied.contains("coarsetime::")
|| denied.contains("instant::Instant")
|| denied == "tokio::time"
{
"clock"
} else if denied.contains("rand")
|| denied.contains("OsRng")
|| denied.contains("getrandom")
|| denied.contains("rdrand")
{
"rng"
} else if denied.contains("fs")
|| denied.contains("net")
|| denied.contains("io::")
|| denied.ends_with("::io")
|| denied.contains("process")
|| denied.contains("env")
|| denied == "mio"
|| denied == "socket2"
|| denied.contains("async_std::task")
{
"io"
} else if denied == "libc" || denied.contains("libc::") {
"ffi"
} else {
"other"
}
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn pure_compute_passes() {
let f: ItemFn = parse_quote! {
fn compute(a: u32, b: u32) -> u32 {
a.wrapping_add(b).wrapping_mul(2)
}
};
let violations = check_purity_default(&f);
assert!(
violations.is_empty(),
"pure compute must not trigger violations: {violations:?}"
);
}
#[test]
fn instant_now_full_path_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u128 {
let _now = std::time::Instant::now();
0
}
};
let violations = check_purity_default(&f);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].denied_path, "std::time::Instant::now");
assert_eq!(violations[0].reason, "clock");
}
#[test]
fn use_imported_thread_rng_single_ident_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _r = thread_rng();
0
}
};
let violations = check_purity_default(&f);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].denied_path, "rand::thread_rng");
assert_eq!(violations[0].reason, "rng");
}
#[test]
fn os_rng_full_path_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = rand::rngs::OsRng;
0
}
};
let violations = check_purity_default(&f);
assert!(violations
.iter()
.any(|v| v.denied_path == "rand::rngs::OsRng"));
}
#[test]
fn unix_epoch_constant_access_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u128 {
let _ = std::time::UNIX_EPOCH;
0
}
};
let violations = check_purity_default(&f);
assert!(violations
.iter()
.any(|v| v.denied_path == "std::time::UNIX_EPOCH"));
}
#[test]
fn type_position_path_does_not_match() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _x: Option<std::time::Instant> = None;
0
}
};
let violations = check_purity_default(&f);
assert!(
violations.is_empty(),
"type-position path must not match: {violations:?}"
);
}
#[test]
fn shell_defined_now_method_does_not_match() {
let f: ItemFn = parse_quote! {
fn compute(s: ShellState) -> u32 {
let _ = s.now();
0
}
};
let violations = check_purity_default(&f);
assert!(
violations.is_empty(),
"shell .now() method must not falsely trigger: {violations:?}"
);
}
#[test]
fn fs_namespace_prefix_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = std::fs::read_to_string("/etc/passwd");
0
}
};
let violations = check_purity_default(&f);
assert!(
violations.iter().any(|v| v.denied_path == "std::fs"),
"std::fs::* prefix must trigger: {violations:?}"
);
assert!(violations.iter().any(|v| v.reason == "io"));
}
#[test]
fn net_namespace_prefix_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = std::net::TcpStream::connect("0.0.0.0:1");
0
}
};
let violations = check_purity_default(&f);
assert!(violations.iter().any(|v| v.denied_path == "std::net"));
}
#[test]
fn process_namespace_prefix_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = std::process::id();
0
}
};
let violations = check_purity_default(&f);
assert!(violations.iter().any(|v| v.denied_path == "std::process"));
}
#[test]
fn env_namespace_prefix_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = std::env::var("HOME");
0
}
};
let violations = check_purity_default(&f);
assert!(violations.iter().any(|v| v.denied_path == "std::env"));
}
#[test]
fn libc_namespace_prefix_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = libc::getpid();
0
}
};
let violations = check_purity_default(&f);
assert!(violations.iter().any(|v| v.denied_path == "libc"));
assert!(violations.iter().any(|v| v.reason == "ffi"));
}
#[test]
fn unsafe_block_rejected() {
let f: ItemFn = parse_quote! {
fn compute(x: u32) -> u32 {
unsafe {
let p = &x as *const u32;
*p
}
}
};
let violations = check_purity_default(&f);
assert!(
violations.iter().any(|v| v.denied_path == "unsafe-block"),
"unsafe block must trigger: {violations:?}"
);
assert!(violations.iter().any(|v| v.reason == "unsafe"));
}
#[test]
fn tokio_time_prefix_rejected() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = tokio::time::Instant::now();
0
}
};
let violations = check_purity_default(&f);
assert!(
violations
.iter()
.any(|v| v.denied_path == "tokio::time" && v.reason == "clock"),
"tokio::time::* must trigger as clock: {violations:?}"
);
}
#[test]
fn tokio_io_prefix_classified_as_io() {
assert_eq!(classify_reason("tokio::io"), "io");
assert_eq!(classify_reason("async_std::io"), "io");
}
#[test]
fn classify_reason_categorises_correctly() {
assert_eq!(classify_reason("std::time::Instant::now"), "clock");
assert_eq!(classify_reason("rand::thread_rng"), "rng");
assert_eq!(classify_reason("std::fs"), "io");
assert_eq!(classify_reason("libc"), "ffi");
assert_eq!(classify_reason("unsafe-block"), "unsafe");
assert_eq!(classify_reason("blake3::hash"), "other");
}
#[test]
fn empty_policy_accepts_anything() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
let _ = std::time::Instant::now();
let _ = rand::thread_rng();
let _ = std::fs::read_to_string("/etc/passwd");
unsafe { let _: u8 = 1; }
0
}
};
let violations = check_purity(&f, &Policy::empty());
assert!(violations.is_empty());
}
#[test]
fn local_fn_named_random_currently_false_positives() {
let f: ItemFn = parse_quote! {
fn compute() -> u32 {
fn random() -> u32 { 42 }
random()
}
};
let violations = check_purity_default(&f);
assert!(
violations.iter().any(|v| v.denied_path == "rand::random"),
"bare-ident `random()` is a known false positive"
);
}
#[test]
fn getrandom_fill_rejected() {
let f: ItemFn = parse_quote! {
fn compute(buf: &mut [u8]) -> () {
let _ = getrandom::fill(buf);
}
};
let violations = check_purity_default(&f);
assert!(violations
.iter()
.any(|v| v.denied_path == "getrandom::fill"));
}
#[test]
fn pure_with_blake3_passes() {
let f: ItemFn = parse_quote! {
fn compute(input: &[u8]) -> [u8; 32] {
let mut h = blake3::Hasher::new();
h.update(input);
*h.finalize().as_bytes()
}
};
let violations = check_purity_default(&f);
assert!(violations.is_empty());
}
}