1use std::collections::{HashMap, HashSet, VecDeque};
8use syn::visit::Visit;
9use syn::{Item, Visibility};
10
11pub fn is_builtin(name: &str) -> bool {
13 if is_c_integer_alias(name) {
15 return true;
16 }
17
18 matches!(
19 name,
20 "bool"
22 | "u8"
23 | "u16"
24 | "u32"
25 | "u64"
26 | "u128"
27 | "usize"
28 | "i8"
29 | "i16"
30 | "i32"
31 | "i64"
32 | "i128"
33 | "isize"
34 | "f32"
35 | "f64"
36 | "str"
37 | "String"
38 | "char"
39 | "Option"
41 | "Result"
42 | "Vec"
43 | "Box"
44 | "Sized"
45 | "Send"
46 | "Sync"
47 | "Copy"
48 | "Clone"
49 | "Debug"
50 | "Display"
51 | "Default"
52 | "PartialEq"
53 | "Eq"
54 | "PartialOrd"
55 | "Ord"
56 | "Hash"
57 | "Drop"
58 | "From"
59 | "Into"
60 | "TryFrom"
61 | "TryInto"
62 | "AsRef"
63 | "AsMut"
64 | "Iterator"
65 | "IntoIterator"
66 | "Fn"
67 | "FnMut"
68 | "FnOnce"
69 | "Deref"
70 | "DerefMut"
71 | "c_void"
73 | "c_char"
74 | "c_schar"
75 | "c_uchar"
76 | "c_short"
77 | "c_ushort"
78 | "c_int"
79 | "c_uint"
80 | "c_long"
81 | "c_ulong"
82 | "c_longlong"
83 | "c_ulonglong"
84 | "c_float"
85 | "c_double"
86 | "Self"
88 | "self"
89 | "Target"
90 | "Error"
91 | "Output"
92 | "Formatter"
93 | "Arguments"
94 )
95}
96
97fn is_c_integer_alias(name: &str) -> bool {
100 c_integer_primitive(name).is_some()
101}
102
103pub fn c_integer_primitive(name: &str) -> Option<&'static str> {
106 let s = name.strip_prefix("__").unwrap_or(name);
107 let (prefix, rest) = if let Some(r) = s.strip_prefix("u_int") {
108 ("u", r)
109 } else if let Some(r) = s.strip_prefix("uint") {
110 ("u", r)
111 } else if let Some(r) = s.strip_prefix("int") {
112 ("i", r)
113 } else {
114 return None;
115 };
116 match (prefix, rest) {
117 ("u", "8_t") => Some("u8"),
118 ("u", "16_t") => Some("u16"),
119 ("u", "32_t") => Some("u32"),
120 ("u", "64_t") => Some("u64"),
121 ("i", "8_t") => Some("i8"),
122 ("i", "16_t") => Some("i16"),
123 ("i", "32_t") => Some("i32"),
124 ("i", "64_t") => Some("i64"),
125 _ => None,
126 }
127}
128
129fn is_external_path(path: &syn::Path) -> bool {
131 if let Some(first) = path.segments.first() {
132 let name = first.ident.to_string();
133 matches!(
134 name.as_str(),
135 "std" | "core" | "alloc" | "objc" | "libc" | "crate" | "super" | "self"
136 )
137 } else {
138 false
139 }
140}
141
142pub(crate) struct TypeRefCollector {
144 pub(crate) types: HashSet<String>,
145}
146
147impl TypeRefCollector {
148 pub(crate) fn new() -> Self {
149 Self {
150 types: HashSet::new(),
151 }
152 }
153}
154
155impl<'ast> Visit<'ast> for TypeRefCollector {
156 fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
157 if !is_external_path(&node.path) {
159 if let Some(seg) = node.path.segments.last() {
160 let name = seg.ident.to_string();
161 if !is_builtin(&name) {
162 self.types.insert(name);
163 }
164 }
165 }
166 syn::visit::visit_type_path(self, node);
168 }
169}
170
171fn item_name(item: &Item) -> Option<String> {
173 match item {
174 Item::Struct(s) if matches!(s.vis, Visibility::Public(_)) => Some(s.ident.to_string()),
175 Item::Enum(e) if matches!(e.vis, Visibility::Public(_)) => Some(e.ident.to_string()),
176 Item::Type(t) if matches!(t.vis, Visibility::Public(_)) => Some(t.ident.to_string()),
177 Item::Fn(f) if matches!(f.vis, Visibility::Public(_)) => Some(f.sig.ident.to_string()),
178 Item::Const(c) if matches!(c.vis, Visibility::Public(_)) => Some(c.ident.to_string()),
179 Item::Static(s) if matches!(s.vis, Visibility::Public(_)) => Some(s.ident.to_string()),
180 Item::Trait(t) if matches!(t.vis, Visibility::Public(_)) => Some(t.ident.to_string()),
181 Item::Union(u) if matches!(u.vis, Visibility::Public(_)) => Some(u.ident.to_string()),
182 _ => None,
183 }
184}
185
186fn extract_use_renames(
192 tree: &syn::UseTree,
193 def_graph: &mut HashMap<String, HashSet<String>>,
194 all_graph: &mut HashMap<String, HashSet<String>>,
195) {
196 match tree {
197 syn::UseTree::Path(path) if path.ident == "self" => {
198 extract_use_renames(&path.tree, def_graph, all_graph);
199 }
200 syn::UseTree::Rename(rename) => {
201 let source = rename.ident.to_string();
202 let alias = rename.rename.to_string();
203 let deps: HashSet<String> = [source].into_iter().collect();
204 def_graph.insert(alias.clone(), deps.clone());
205 all_graph.insert(alias, deps);
206 }
207 syn::UseTree::Group(group) => {
208 for item in &group.items {
209 extract_use_renames(item, def_graph, all_graph);
210 }
211 }
212 _ => {}
213 }
214}
215
216pub struct DependencyGraphs {
225 pub definition_deps: HashMap<String, HashSet<String>>,
228 pub all_deps: HashMap<String, HashSet<String>>,
231}
232
233pub fn build_dependency_graphs(code: &str) -> DependencyGraphs {
237 let file = match syn::parse_file(code) {
238 Ok(f) => f,
239 Err(e) => {
240 eprintln!(
241 "Warning: Failed to parse generated code for dep graph: {}",
242 e
243 );
244 return DependencyGraphs {
245 definition_deps: HashMap::new(),
246 all_deps: HashMap::new(),
247 };
248 }
249 };
250
251 let mut def_graph: HashMap<String, HashSet<String>> = HashMap::new();
252 let mut all_graph: HashMap<String, HashSet<String>> = HashMap::new();
253
254 for item in &file.items {
255 match item {
256 Item::ForeignMod(fm) => {
257 for foreign_item in &fm.items {
258 let name = match foreign_item {
259 syn::ForeignItem::Fn(f) => Some(f.sig.ident.to_string()),
260 syn::ForeignItem::Static(s) => Some(s.ident.to_string()),
261 syn::ForeignItem::Type(t) => Some(t.ident.to_string()),
262 _ => None,
263 };
264 if let Some(name) = name {
265 let mut collector = TypeRefCollector::new();
266 collector.visit_foreign_item(foreign_item);
267 def_graph.insert(name.clone(), collector.types.clone());
268 all_graph.insert(name, collector.types);
269 }
270 }
271 }
272 Item::Impl(impl_item) => {
273 let type_name = match impl_item.self_ty.as_ref() {
274 syn::Type::Path(tp) => tp.path.segments.last().map(|s| s.ident.to_string()),
275 _ => None,
276 };
277
278 if let Some(type_name) = type_name {
279 let mut collector = TypeRefCollector::new();
280 collector.visit_item_impl(impl_item);
281
282 def_graph.entry(type_name.clone()).or_default();
287 let all_entry = all_graph.entry(type_name).or_default();
288 all_entry.extend(collector.types);
289
290 if let Some((_, path, _)) = &impl_item.trait_ {
291 if let Some(seg) = path.segments.last() {
292 let trait_name = seg.ident.to_string();
293 if !is_builtin(&trait_name) {
294 all_entry.insert(trait_name);
295 }
296 }
297 }
298 }
299 }
300 Item::Use(use_item) => {
301 extract_use_renames(&use_item.tree, &mut def_graph, &mut all_graph);
306 }
307 _ => {
308 if let Some(name) = item_name(item) {
309 let mut collector = TypeRefCollector::new();
310 collector.visit_item(item);
311
312 let mut refs = collector.types;
313 refs.remove(&name);
314 def_graph.insert(name.clone(), refs.clone());
315 all_graph.insert(name, refs);
316 }
317 }
318 }
319 }
320
321 DependencyGraphs {
322 definition_deps: def_graph,
323 all_deps: all_graph,
324 }
325}
326
327pub fn build_dependency_graph(code: &str) -> HashMap<String, HashSet<String>> {
332 build_dependency_graphs(code).all_deps
333}
334
335pub fn impl_block_deps(impl_item: &syn::ItemImpl) -> HashSet<String> {
342 let mut collector = TypeRefCollector::new();
343 collector.visit_item_impl(impl_item);
344
345 if let Some((_, path, _)) = &impl_item.trait_ {
348 if !is_external_path(path) {
349 if let Some(seg) = path.segments.last() {
350 let trait_name = seg.ident.to_string();
351 if !is_builtin(&trait_name) {
352 collector.types.insert(trait_name);
353 }
354 }
355 }
356 }
357
358 collector.types
359}
360
361pub fn compute_reachable(
363 graph: &HashMap<String, HashSet<String>>,
364 roots: &HashSet<String>,
365) -> HashSet<String> {
366 let mut reachable = HashSet::new();
367 let mut queue = VecDeque::new();
368
369 for root in roots {
371 if graph.contains_key(root) && reachable.insert(root.clone()) {
372 queue.push_back(root.clone());
373 }
374 }
375
376 while let Some(current) = queue.pop_front() {
377 if let Some(deps) = graph.get(¤t) {
378 for dep in deps {
379 if reachable.insert(dep.clone()) {
380 if graph.contains_key(dep) {
382 queue.push_back(dep.clone());
383 }
384 }
385 }
386 }
387 }
388
389 reachable
390}
391
392pub fn compute_reachable_symbols(code: &str, owned_symbols: &HashSet<String>) -> HashSet<String> {
398 let graph = build_dependency_graph(code);
399 compute_reachable(&graph, owned_symbols)
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_simple_reachability() {
408 let mut graph = HashMap::new();
409 graph.insert("A".into(), HashSet::from(["B".into(), "C".into()]));
410 graph.insert("B".into(), HashSet::from(["D".into()]));
411 graph.insert("C".into(), HashSet::new());
412 graph.insert("D".into(), HashSet::new());
413 graph.insert("E".into(), HashSet::new()); let roots = HashSet::from(["A".into()]);
416 let reachable = compute_reachable(&graph, &roots);
417
418 assert!(reachable.contains("A"));
419 assert!(reachable.contains("B"));
420 assert!(reachable.contains("C"));
421 assert!(reachable.contains("D"));
422 assert!(!reachable.contains("E"));
423 }
424
425 #[test]
426 fn test_cyclic_reachability() {
427 let mut graph = HashMap::new();
428 graph.insert("A".into(), HashSet::from(["B".into()]));
429 graph.insert("B".into(), HashSet::from(["A".into()]));
430 graph.insert("C".into(), HashSet::new());
431
432 let roots = HashSet::from(["A".into()]);
433 let reachable = compute_reachable(&graph, &roots);
434
435 assert!(reachable.contains("A"));
436 assert!(reachable.contains("B"));
437 assert!(!reachable.contains("C"));
438 }
439
440 #[test]
441 fn test_build_graph_from_code() {
442 let code = r#"
443pub type CFIndex = ::std::os::raw::c_long;
444pub type CFStringRef = *const __CFString;
445pub struct __CFString {
446 _data: [u8; 0],
447}
448pub struct MyStruct {
449 pub field: CFIndex,
450 pub name: CFStringRef,
451}
452"#;
453 let graph = build_dependency_graph(code);
454
455 assert!(
457 graph
458 .get("CFStringRef")
459 .map_or(false, |deps| deps.contains("__CFString"))
460 );
461
462 let my_deps = graph.get("MyStruct").unwrap();
464 assert!(my_deps.contains("CFIndex"));
465 assert!(my_deps.contains("CFStringRef"));
466
467 assert!(graph.get("CFIndex").map_or(true, |deps| deps.is_empty()));
469 }
470
471 #[test]
472 fn test_extern_functions() {
473 let code = r#"
474pub type CFAllocatorRef = *const CFAllocator;
475pub struct CFAllocator { _data: [u8; 0] }
476pub type CFStringRef = *const __CFString;
477pub struct __CFString { _data: [u8; 0] }
478unsafe extern "C" {
479 pub fn CFStringCreateCopy(alloc: CFAllocatorRef, theString: CFStringRef) -> CFStringRef;
480}
481"#;
482 let graph = build_dependency_graph(code);
483
484 let func_deps = graph.get("CFStringCreateCopy").unwrap();
485 assert!(func_deps.contains("CFAllocatorRef"));
486 assert!(func_deps.contains("CFStringRef"));
487 }
488
489 #[test]
490 fn test_reachable_from_code() {
491 let code = r#"
492pub type CFIndex = ::std::os::raw::c_long;
493pub type CFStringRef = *const __CFString;
494pub struct __CFString { _data: [u8; 0] }
495pub struct Unrelated { pub x: CFIndex }
496unsafe extern "C" {
497 pub fn CFStringGetLength(theString: CFStringRef) -> CFIndex;
498}
499"#;
500 let owned = HashSet::from(["CFStringGetLength".to_string()]);
501 let reachable = compute_reachable_symbols(code, &owned);
502
503 assert!(reachable.contains("CFStringGetLength"));
504 assert!(reachable.contains("CFStringRef"));
505 assert!(reachable.contains("CFIndex"));
506 assert!(reachable.contains("__CFString"));
507 assert!(!reachable.contains("Unrelated"));
508 }
509}