use std::collections::{HashMap, HashSet, VecDeque};
use syn::visit::Visit;
use syn::{Item, Visibility};
pub fn is_builtin(name: &str) -> bool {
if is_c_integer_alias(name) {
return true;
}
matches!(
name,
"bool"
| "u8"
| "u16"
| "u32"
| "u64"
| "u128"
| "usize"
| "i8"
| "i16"
| "i32"
| "i64"
| "i128"
| "isize"
| "f32"
| "f64"
| "str"
| "String"
| "char"
| "Option"
| "Result"
| "Vec"
| "Box"
| "Sized"
| "Send"
| "Sync"
| "Copy"
| "Clone"
| "Debug"
| "Display"
| "Default"
| "PartialEq"
| "Eq"
| "PartialOrd"
| "Ord"
| "Hash"
| "Drop"
| "From"
| "Into"
| "TryFrom"
| "TryInto"
| "AsRef"
| "AsMut"
| "Iterator"
| "IntoIterator"
| "Fn"
| "FnMut"
| "FnOnce"
| "Deref"
| "DerefMut"
| "c_void"
| "c_char"
| "c_schar"
| "c_uchar"
| "c_short"
| "c_ushort"
| "c_int"
| "c_uint"
| "c_long"
| "c_ulong"
| "c_longlong"
| "c_ulonglong"
| "c_float"
| "c_double"
| "Self"
| "self"
| "Target"
| "Error"
| "Output"
| "Formatter"
| "Arguments"
)
}
fn is_c_integer_alias(name: &str) -> bool {
c_integer_primitive(name).is_some()
}
pub fn c_integer_primitive(name: &str) -> Option<&'static str> {
let s = name.strip_prefix("__").unwrap_or(name);
let (prefix, rest) = if let Some(r) = s.strip_prefix("u_int") {
("u", r)
} else if let Some(r) = s.strip_prefix("uint") {
("u", r)
} else if let Some(r) = s.strip_prefix("int") {
("i", r)
} else {
return None;
};
match (prefix, rest) {
("u", "8_t") => Some("u8"),
("u", "16_t") => Some("u16"),
("u", "32_t") => Some("u32"),
("u", "64_t") => Some("u64"),
("i", "8_t") => Some("i8"),
("i", "16_t") => Some("i16"),
("i", "32_t") => Some("i32"),
("i", "64_t") => Some("i64"),
_ => None,
}
}
fn is_external_path(path: &syn::Path) -> bool {
if let Some(first) = path.segments.first() {
let name = first.ident.to_string();
matches!(
name.as_str(),
"std" | "core" | "alloc" | "objc" | "libc" | "crate" | "super" | "self"
)
} else {
false
}
}
pub(crate) struct TypeRefCollector {
pub(crate) types: HashSet<String>,
}
impl TypeRefCollector {
pub(crate) fn new() -> Self {
Self {
types: HashSet::new(),
}
}
}
impl<'ast> Visit<'ast> for TypeRefCollector {
fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
if !is_external_path(&node.path) {
if let Some(seg) = node.path.segments.last() {
let name = seg.ident.to_string();
if !is_builtin(&name) {
self.types.insert(name);
}
}
}
syn::visit::visit_type_path(self, node);
}
}
fn item_name(item: &Item) -> Option<String> {
match item {
Item::Struct(s) if matches!(s.vis, Visibility::Public(_)) => Some(s.ident.to_string()),
Item::Enum(e) if matches!(e.vis, Visibility::Public(_)) => Some(e.ident.to_string()),
Item::Type(t) if matches!(t.vis, Visibility::Public(_)) => Some(t.ident.to_string()),
Item::Fn(f) if matches!(f.vis, Visibility::Public(_)) => Some(f.sig.ident.to_string()),
Item::Const(c) if matches!(c.vis, Visibility::Public(_)) => Some(c.ident.to_string()),
Item::Static(s) if matches!(s.vis, Visibility::Public(_)) => Some(s.ident.to_string()),
Item::Trait(t) if matches!(t.vis, Visibility::Public(_)) => Some(t.ident.to_string()),
Item::Union(u) if matches!(u.vis, Visibility::Public(_)) => Some(u.ident.to_string()),
_ => None,
}
}
fn extract_use_renames(
tree: &syn::UseTree,
def_graph: &mut HashMap<String, HashSet<String>>,
all_graph: &mut HashMap<String, HashSet<String>>,
) {
match tree {
syn::UseTree::Path(path) if path.ident == "self" => {
extract_use_renames(&path.tree, def_graph, all_graph);
}
syn::UseTree::Rename(rename) => {
let source = rename.ident.to_string();
let alias = rename.rename.to_string();
let deps: HashSet<String> = [source].into_iter().collect();
def_graph.insert(alias.clone(), deps.clone());
all_graph.insert(alias, deps);
}
syn::UseTree::Group(group) => {
for item in &group.items {
extract_use_renames(item, def_graph, all_graph);
}
}
_ => {}
}
}
pub struct DependencyGraphs {
pub definition_deps: HashMap<String, HashSet<String>>,
pub all_deps: HashMap<String, HashSet<String>>,
}
pub fn build_dependency_graphs(code: &str) -> DependencyGraphs {
let file = match syn::parse_file(code) {
Ok(f) => f,
Err(e) => {
eprintln!(
"Warning: Failed to parse generated code for dep graph: {}",
e
);
return DependencyGraphs {
definition_deps: HashMap::new(),
all_deps: HashMap::new(),
};
}
};
let mut def_graph: HashMap<String, HashSet<String>> = HashMap::new();
let mut all_graph: HashMap<String, HashSet<String>> = HashMap::new();
for item in &file.items {
match item {
Item::ForeignMod(fm) => {
for foreign_item in &fm.items {
let name = match foreign_item {
syn::ForeignItem::Fn(f) => Some(f.sig.ident.to_string()),
syn::ForeignItem::Static(s) => Some(s.ident.to_string()),
syn::ForeignItem::Type(t) => Some(t.ident.to_string()),
_ => None,
};
if let Some(name) = name {
let mut collector = TypeRefCollector::new();
collector.visit_foreign_item(foreign_item);
def_graph.insert(name.clone(), collector.types.clone());
all_graph.insert(name, collector.types);
}
}
}
Item::Impl(impl_item) => {
let type_name = match impl_item.self_ty.as_ref() {
syn::Type::Path(tp) => tp.path.segments.last().map(|s| s.ident.to_string()),
_ => None,
};
if let Some(type_name) = type_name {
let mut collector = TypeRefCollector::new();
collector.visit_item_impl(impl_item);
def_graph.entry(type_name.clone()).or_default();
let all_entry = all_graph.entry(type_name).or_default();
all_entry.extend(collector.types);
if let Some((_, path, _)) = &impl_item.trait_ {
if let Some(seg) = path.segments.last() {
let trait_name = seg.ident.to_string();
if !is_builtin(&trait_name) {
all_entry.insert(trait_name);
}
}
}
}
}
Item::Use(use_item) => {
extract_use_renames(&use_item.tree, &mut def_graph, &mut all_graph);
}
_ => {
if let Some(name) = item_name(item) {
let mut collector = TypeRefCollector::new();
collector.visit_item(item);
let mut refs = collector.types;
refs.remove(&name);
def_graph.insert(name.clone(), refs.clone());
all_graph.insert(name, refs);
}
}
}
}
DependencyGraphs {
definition_deps: def_graph,
all_deps: all_graph,
}
}
pub fn build_dependency_graph(code: &str) -> HashMap<String, HashSet<String>> {
build_dependency_graphs(code).all_deps
}
pub fn impl_block_deps(impl_item: &syn::ItemImpl) -> HashSet<String> {
let mut collector = TypeRefCollector::new();
collector.visit_item_impl(impl_item);
if let Some((_, path, _)) = &impl_item.trait_ {
if !is_external_path(path) {
if let Some(seg) = path.segments.last() {
let trait_name = seg.ident.to_string();
if !is_builtin(&trait_name) {
collector.types.insert(trait_name);
}
}
}
}
collector.types
}
pub fn compute_reachable(
graph: &HashMap<String, HashSet<String>>,
roots: &HashSet<String>,
) -> HashSet<String> {
let mut reachable = HashSet::new();
let mut queue = VecDeque::new();
for root in roots {
if graph.contains_key(root) && reachable.insert(root.clone()) {
queue.push_back(root.clone());
}
}
while let Some(current) = queue.pop_front() {
if let Some(deps) = graph.get(¤t) {
for dep in deps {
if reachable.insert(dep.clone()) {
if graph.contains_key(dep) {
queue.push_back(dep.clone());
}
}
}
}
}
reachable
}
pub fn compute_reachable_symbols(code: &str, owned_symbols: &HashSet<String>) -> HashSet<String> {
let graph = build_dependency_graph(code);
compute_reachable(&graph, owned_symbols)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_reachability() {
let mut graph = HashMap::new();
graph.insert("A".into(), HashSet::from(["B".into(), "C".into()]));
graph.insert("B".into(), HashSet::from(["D".into()]));
graph.insert("C".into(), HashSet::new());
graph.insert("D".into(), HashSet::new());
graph.insert("E".into(), HashSet::new());
let roots = HashSet::from(["A".into()]);
let reachable = compute_reachable(&graph, &roots);
assert!(reachable.contains("A"));
assert!(reachable.contains("B"));
assert!(reachable.contains("C"));
assert!(reachable.contains("D"));
assert!(!reachable.contains("E"));
}
#[test]
fn test_cyclic_reachability() {
let mut graph = HashMap::new();
graph.insert("A".into(), HashSet::from(["B".into()]));
graph.insert("B".into(), HashSet::from(["A".into()]));
graph.insert("C".into(), HashSet::new());
let roots = HashSet::from(["A".into()]);
let reachable = compute_reachable(&graph, &roots);
assert!(reachable.contains("A"));
assert!(reachable.contains("B"));
assert!(!reachable.contains("C"));
}
#[test]
fn test_build_graph_from_code() {
let code = r#"
pub type CFIndex = ::std::os::raw::c_long;
pub type CFStringRef = *const __CFString;
pub struct __CFString {
_data: [u8; 0],
}
pub struct MyStruct {
pub field: CFIndex,
pub name: CFStringRef,
}
"#;
let graph = build_dependency_graph(code);
assert!(
graph
.get("CFStringRef")
.map_or(false, |deps| deps.contains("__CFString"))
);
let my_deps = graph.get("MyStruct").unwrap();
assert!(my_deps.contains("CFIndex"));
assert!(my_deps.contains("CFStringRef"));
assert!(graph.get("CFIndex").map_or(true, |deps| deps.is_empty()));
}
#[test]
fn test_extern_functions() {
let code = r#"
pub type CFAllocatorRef = *const CFAllocator;
pub struct CFAllocator { _data: [u8; 0] }
pub type CFStringRef = *const __CFString;
pub struct __CFString { _data: [u8; 0] }
unsafe extern "C" {
pub fn CFStringCreateCopy(alloc: CFAllocatorRef, theString: CFStringRef) -> CFStringRef;
}
"#;
let graph = build_dependency_graph(code);
let func_deps = graph.get("CFStringCreateCopy").unwrap();
assert!(func_deps.contains("CFAllocatorRef"));
assert!(func_deps.contains("CFStringRef"));
}
#[test]
fn test_reachable_from_code() {
let code = r#"
pub type CFIndex = ::std::os::raw::c_long;
pub type CFStringRef = *const __CFString;
pub struct __CFString { _data: [u8; 0] }
pub struct Unrelated { pub x: CFIndex }
unsafe extern "C" {
pub fn CFStringGetLength(theString: CFStringRef) -> CFIndex;
}
"#;
let owned = HashSet::from(["CFStringGetLength".to_string()]);
let reachable = compute_reachable_symbols(code, &owned);
assert!(reachable.contains("CFStringGetLength"));
assert!(reachable.contains("CFStringRef"));
assert!(reachable.contains("CFIndex"));
assert!(reachable.contains("__CFString"));
assert!(!reachable.contains("Unrelated"));
}
}