Skip to main content

apple_bindgen/deps/
depgraph.rs

1//! Symbol dependency graph construction and reachability analysis.
2//!
3//! Parses generated Rust bindings with `syn` and builds a directed graph
4//! of type references between symbols. Then computes the set of reachable
5//! symbols from a given set of root (owned) symbols via BFS.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8use syn::visit::Visit;
9use syn::{Item, Visibility};
10
11/// Built-in types and paths that should not be tracked as dependencies.
12pub fn is_builtin(name: &str) -> bool {
13    // C/BSD integer type aliases: u_int32_t, int64_t, uint16_t, __int32_t, __uint16_t, etc.
14    if is_c_integer_alias(name) {
15        return true;
16    }
17
18    matches!(
19        name,
20        // Rust primitives
21        "bool"
22            | "u8"
23            | "u16"
24            | "u32"
25            | "u64"
26            | "u128"
27            | "usize"
28            | "i8"
29            | "i16"
30            | "i32"
31            | "i64"
32            | "i128"
33            | "isize"
34            | "f32"
35            | "f64"
36            | "str"
37            | "String"
38            | "char"
39            // std/core types
40            | "Option"
41            | "Result"
42            | "Vec"
43            | "Box"
44            | "Sized"
45            | "Send"
46            | "Sync"
47            | "Copy"
48            | "Clone"
49            | "Debug"
50            | "Display"
51            | "Default"
52            | "PartialEq"
53            | "Eq"
54            | "PartialOrd"
55            | "Ord"
56            | "Hash"
57            | "Drop"
58            | "From"
59            | "Into"
60            | "TryFrom"
61            | "TryInto"
62            | "AsRef"
63            | "AsMut"
64            | "Iterator"
65            | "IntoIterator"
66            | "Fn"
67            | "FnMut"
68            | "FnOnce"
69            | "Deref"
70            | "DerefMut"
71            // C types from std::os::raw
72            | "c_void"
73            | "c_char"
74            | "c_schar"
75            | "c_uchar"
76            | "c_short"
77            | "c_ushort"
78            | "c_int"
79            | "c_uint"
80            | "c_long"
81            | "c_ulong"
82            | "c_longlong"
83            | "c_ulonglong"
84            | "c_float"
85            | "c_double"
86            // Special
87            | "Self"
88            | "self"
89            | "Target"
90            | "Error"
91            | "Output"
92            | "Formatter"
93            | "Arguments"
94    )
95}
96
97/// C/BSD integer type aliases that are just primitive mappings.
98/// Patterns: int32_t, uint16_t, u_int64_t, __int32_t, __uint16_t, etc.
99fn is_c_integer_alias(name: &str) -> bool {
100    c_integer_primitive(name).is_some()
101}
102
103/// Map a C integer type alias to its Rust primitive equivalent.
104/// Returns None if not an integer alias.
105pub fn c_integer_primitive(name: &str) -> Option<&'static str> {
106    let s = name.strip_prefix("__").unwrap_or(name);
107    let (prefix, rest) = if let Some(r) = s.strip_prefix("u_int") {
108        ("u", r)
109    } else if let Some(r) = s.strip_prefix("uint") {
110        ("u", r)
111    } else if let Some(r) = s.strip_prefix("int") {
112        ("i", r)
113    } else {
114        return None;
115    };
116    match (prefix, rest) {
117        ("u", "8_t") => Some("u8"),
118        ("u", "16_t") => Some("u16"),
119        ("u", "32_t") => Some("u32"),
120        ("u", "64_t") => Some("u64"),
121        ("i", "8_t") => Some("i8"),
122        ("i", "16_t") => Some("i16"),
123        ("i", "32_t") => Some("i32"),
124        ("i", "64_t") => Some("i64"),
125        _ => None,
126    }
127}
128
129/// Check if a path starts with a well-known module prefix (std, core, objc, etc.)
130fn is_external_path(path: &syn::Path) -> bool {
131    if let Some(first) = path.segments.first() {
132        let name = first.ident.to_string();
133        matches!(
134            name.as_str(),
135            "std" | "core" | "alloc" | "objc" | "libc" | "crate" | "super" | "self"
136        )
137    } else {
138        false
139    }
140}
141
142/// Collects type references from a syn AST item.
143pub(crate) struct TypeRefCollector {
144    pub(crate) types: HashSet<String>,
145}
146
147impl TypeRefCollector {
148    pub(crate) fn new() -> Self {
149        Self {
150            types: HashSet::new(),
151        }
152    }
153}
154
155impl<'ast> Visit<'ast> for TypeRefCollector {
156    fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
157        // Skip paths starting with std::, core::, objc::, etc.
158        if !is_external_path(&node.path) {
159            if let Some(seg) = node.path.segments.last() {
160                let name = seg.ident.to_string();
161                if !is_builtin(&name) {
162                    self.types.insert(name);
163                }
164            }
165        }
166        // Continue visiting generic arguments within the path
167        syn::visit::visit_type_path(self, node);
168    }
169}
170
171/// Extract the name of a public item (same logic as isolation.rs extract_item_name)
172fn item_name(item: &Item) -> Option<String> {
173    match item {
174        Item::Struct(s) if matches!(s.vis, Visibility::Public(_)) => Some(s.ident.to_string()),
175        Item::Enum(e) if matches!(e.vis, Visibility::Public(_)) => Some(e.ident.to_string()),
176        Item::Type(t) if matches!(t.vis, Visibility::Public(_)) => Some(t.ident.to_string()),
177        Item::Fn(f) if matches!(f.vis, Visibility::Public(_)) => Some(f.sig.ident.to_string()),
178        Item::Const(c) if matches!(c.vis, Visibility::Public(_)) => Some(c.ident.to_string()),
179        Item::Static(s) if matches!(s.vis, Visibility::Public(_)) => Some(s.ident.to_string()),
180        Item::Trait(t) if matches!(t.vis, Visibility::Public(_)) => Some(t.ident.to_string()),
181        Item::Union(u) if matches!(u.vis, Visibility::Public(_)) => Some(u.ident.to_string()),
182        _ => None,
183    }
184}
185
186/// Extract `pub use self::X as Y;` renames and register them in the dep graphs.
187///
188/// Handles patterns like:
189/// - `pub use self::ppd_ui_e as ppd_ui_t;` → Y="ppd_ui_t" depends on X="ppd_ui_e"
190/// - `pub use self::{A as B, C as D};` → group renames
191fn extract_use_renames(
192    tree: &syn::UseTree,
193    def_graph: &mut HashMap<String, HashSet<String>>,
194    all_graph: &mut HashMap<String, HashSet<String>>,
195) {
196    match tree {
197        syn::UseTree::Path(path) if path.ident == "self" => {
198            extract_use_renames(&path.tree, def_graph, all_graph);
199        }
200        syn::UseTree::Rename(rename) => {
201            let source = rename.ident.to_string();
202            let alias = rename.rename.to_string();
203            let deps: HashSet<String> = [source].into_iter().collect();
204            def_graph.insert(alias.clone(), deps.clone());
205            all_graph.insert(alias, deps);
206        }
207        syn::UseTree::Group(group) => {
208            for item in &group.items {
209                extract_use_renames(item, def_graph, all_graph);
210            }
211        }
212        _ => {}
213    }
214}
215
216/// Split dependency graphs: definition-level deps vs all deps (including impl blocks).
217///
218/// `definition_deps` contains only dependencies from struct/type/fn/extern definitions.
219/// `all_deps` additionally includes dependencies from impl blocks merged into the self type.
220///
221/// The split is needed because impl block deps (e.g., ObjC category extensions referencing
222/// types from other frameworks) should not cause the base type to be removed during
223/// dependency closure in the ownership phase.
224pub struct DependencyGraphs {
225    /// Dependencies from definitions only (struct, type, fn, extern).
226    /// Used for dependency closure (removing symbols whose deps are unavailable).
227    pub definition_deps: HashMap<String, HashSet<String>>,
228    /// All dependencies including impl block references.
229    /// Used for BFS reachability computation.
230    pub all_deps: HashMap<String, HashSet<String>>,
231}
232
233/// Build split dependency graphs from generated Rust code.
234///
235/// Returns `DependencyGraphs` with both definition-only and full (incl. impl) dep maps.
236pub fn build_dependency_graphs(code: &str) -> DependencyGraphs {
237    let file = match syn::parse_file(code) {
238        Ok(f) => f,
239        Err(e) => {
240            eprintln!(
241                "Warning: Failed to parse generated code for dep graph: {}",
242                e
243            );
244            return DependencyGraphs {
245                definition_deps: HashMap::new(),
246                all_deps: HashMap::new(),
247            };
248        }
249    };
250
251    let mut def_graph: HashMap<String, HashSet<String>> = HashMap::new();
252    let mut all_graph: HashMap<String, HashSet<String>> = HashMap::new();
253
254    for item in &file.items {
255        match item {
256            Item::ForeignMod(fm) => {
257                for foreign_item in &fm.items {
258                    let name = match foreign_item {
259                        syn::ForeignItem::Fn(f) => Some(f.sig.ident.to_string()),
260                        syn::ForeignItem::Static(s) => Some(s.ident.to_string()),
261                        syn::ForeignItem::Type(t) => Some(t.ident.to_string()),
262                        _ => None,
263                    };
264                    if let Some(name) = name {
265                        let mut collector = TypeRefCollector::new();
266                        collector.visit_foreign_item(foreign_item);
267                        def_graph.insert(name.clone(), collector.types.clone());
268                        all_graph.insert(name, collector.types);
269                    }
270                }
271            }
272            Item::Impl(impl_item) => {
273                let type_name = match impl_item.self_ty.as_ref() {
274                    syn::Type::Path(tp) => tp.path.segments.last().map(|s| s.ident.to_string()),
275                    _ => None,
276                };
277
278                if let Some(type_name) = type_name {
279                    let mut collector = TypeRefCollector::new();
280                    collector.visit_item_impl(impl_item);
281
282                    // Impl block deps go ONLY into all_deps, not definition_deps.
283                    // This prevents ObjC category extensions from pulling in
284                    // cross-framework types that would cause the base type to be
285                    // removed during dependency closure.
286                    def_graph.entry(type_name.clone()).or_default();
287                    let all_entry = all_graph.entry(type_name).or_default();
288                    all_entry.extend(collector.types);
289
290                    if let Some((_, path, _)) = &impl_item.trait_ {
291                        if let Some(seg) = path.segments.last() {
292                            let trait_name = seg.ident.to_string();
293                            if !is_builtin(&trait_name) {
294                                all_entry.insert(trait_name);
295                            }
296                        }
297                    }
298                }
299            }
300            Item::Use(use_item) => {
301                // Track `pub use self::X as Y;` re-exports.
302                // Registers Y as a symbol with dependency on X, so that:
303                // - BFS discovers X transitively through Y
304                // - Dependency closure verifies X is available before keeping Y
305                extract_use_renames(&use_item.tree, &mut def_graph, &mut all_graph);
306            }
307            _ => {
308                if let Some(name) = item_name(item) {
309                    let mut collector = TypeRefCollector::new();
310                    collector.visit_item(item);
311
312                    let mut refs = collector.types;
313                    refs.remove(&name);
314                    def_graph.insert(name.clone(), refs.clone());
315                    all_graph.insert(name, refs);
316                }
317            }
318        }
319    }
320
321    DependencyGraphs {
322        definition_deps: def_graph,
323        all_deps: all_graph,
324    }
325}
326
327/// Build a symbol dependency graph from generated Rust code.
328///
329/// Returns a map from each symbol name to the set of symbol names it references.
330/// This is the legacy API that merges impl block deps into the type entry.
331pub fn build_dependency_graph(code: &str) -> HashMap<String, HashSet<String>> {
332    build_dependency_graphs(code).all_deps
333}
334
335/// Extract type references from a single impl block.
336///
337/// Returns all non-builtin type names referenced within the impl block,
338/// including the trait name (if it's a trait impl). The trait path goes
339/// through `visit_path` rather than `visit_type_path`, so it must be
340/// handled explicitly.
341pub fn impl_block_deps(impl_item: &syn::ItemImpl) -> HashSet<String> {
342    let mut collector = TypeRefCollector::new();
343    collector.visit_item_impl(impl_item);
344
345    // Trait name is visited via visit_path, not visit_type_path,
346    // so TypeRefCollector doesn't capture it. Add it explicitly.
347    if let Some((_, path, _)) = &impl_item.trait_ {
348        if !is_external_path(path) {
349            if let Some(seg) = path.segments.last() {
350                let trait_name = seg.ident.to_string();
351                if !is_builtin(&trait_name) {
352                    collector.types.insert(trait_name);
353                }
354            }
355        }
356    }
357
358    collector.types
359}
360
361/// Compute the set of reachable symbols from roots via BFS.
362pub fn compute_reachable(
363    graph: &HashMap<String, HashSet<String>>,
364    roots: &HashSet<String>,
365) -> HashSet<String> {
366    let mut reachable = HashSet::new();
367    let mut queue = VecDeque::new();
368
369    // Seed BFS with roots that exist in the graph
370    for root in roots {
371        if graph.contains_key(root) && reachable.insert(root.clone()) {
372            queue.push_back(root.clone());
373        }
374    }
375
376    while let Some(current) = queue.pop_front() {
377        if let Some(deps) = graph.get(&current) {
378            for dep in deps {
379                if reachable.insert(dep.clone()) {
380                    // Only continue BFS if this symbol has its own edges
381                    if graph.contains_key(dep) {
382                        queue.push_back(dep.clone());
383                    }
384                }
385            }
386        }
387    }
388
389    reachable
390}
391
392/// High-level API: compute reachable symbols from generated code and owned symbol set.
393///
394/// 1. Builds the dependency graph from the generated Rust code
395/// 2. Intersects owned_symbols with the graph (to find roots that actually exist)
396/// 3. BFS from roots to find all transitively reachable symbols
397pub fn compute_reachable_symbols(code: &str, owned_symbols: &HashSet<String>) -> HashSet<String> {
398    let graph = build_dependency_graph(code);
399    compute_reachable(&graph, owned_symbols)
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_simple_reachability() {
408        let mut graph = HashMap::new();
409        graph.insert("A".into(), HashSet::from(["B".into(), "C".into()]));
410        graph.insert("B".into(), HashSet::from(["D".into()]));
411        graph.insert("C".into(), HashSet::new());
412        graph.insert("D".into(), HashSet::new());
413        graph.insert("E".into(), HashSet::new()); // unreachable
414
415        let roots = HashSet::from(["A".into()]);
416        let reachable = compute_reachable(&graph, &roots);
417
418        assert!(reachable.contains("A"));
419        assert!(reachable.contains("B"));
420        assert!(reachable.contains("C"));
421        assert!(reachable.contains("D"));
422        assert!(!reachable.contains("E"));
423    }
424
425    #[test]
426    fn test_cyclic_reachability() {
427        let mut graph = HashMap::new();
428        graph.insert("A".into(), HashSet::from(["B".into()]));
429        graph.insert("B".into(), HashSet::from(["A".into()]));
430        graph.insert("C".into(), HashSet::new());
431
432        let roots = HashSet::from(["A".into()]);
433        let reachable = compute_reachable(&graph, &roots);
434
435        assert!(reachable.contains("A"));
436        assert!(reachable.contains("B"));
437        assert!(!reachable.contains("C"));
438    }
439
440    #[test]
441    fn test_build_graph_from_code() {
442        let code = r#"
443pub type CFIndex = ::std::os::raw::c_long;
444pub type CFStringRef = *const __CFString;
445pub struct __CFString {
446    _data: [u8; 0],
447}
448pub struct MyStruct {
449    pub field: CFIndex,
450    pub name: CFStringRef,
451}
452"#;
453        let graph = build_dependency_graph(code);
454
455        // CFStringRef references __CFString
456        assert!(
457            graph
458                .get("CFStringRef")
459                .map_or(false, |deps| deps.contains("__CFString"))
460        );
461
462        // MyStruct references CFIndex and CFStringRef
463        let my_deps = graph.get("MyStruct").unwrap();
464        assert!(my_deps.contains("CFIndex"));
465        assert!(my_deps.contains("CFStringRef"));
466
467        // CFIndex should not reference anything (c_long is builtin)
468        assert!(graph.get("CFIndex").map_or(true, |deps| deps.is_empty()));
469    }
470
471    #[test]
472    fn test_extern_functions() {
473        let code = r#"
474pub type CFAllocatorRef = *const CFAllocator;
475pub struct CFAllocator { _data: [u8; 0] }
476pub type CFStringRef = *const __CFString;
477pub struct __CFString { _data: [u8; 0] }
478unsafe extern "C" {
479    pub fn CFStringCreateCopy(alloc: CFAllocatorRef, theString: CFStringRef) -> CFStringRef;
480}
481"#;
482        let graph = build_dependency_graph(code);
483
484        let func_deps = graph.get("CFStringCreateCopy").unwrap();
485        assert!(func_deps.contains("CFAllocatorRef"));
486        assert!(func_deps.contains("CFStringRef"));
487    }
488
489    #[test]
490    fn test_reachable_from_code() {
491        let code = r#"
492pub type CFIndex = ::std::os::raw::c_long;
493pub type CFStringRef = *const __CFString;
494pub struct __CFString { _data: [u8; 0] }
495pub struct Unrelated { pub x: CFIndex }
496unsafe extern "C" {
497    pub fn CFStringGetLength(theString: CFStringRef) -> CFIndex;
498}
499"#;
500        let owned = HashSet::from(["CFStringGetLength".to_string()]);
501        let reachable = compute_reachable_symbols(code, &owned);
502
503        assert!(reachable.contains("CFStringGetLength"));
504        assert!(reachable.contains("CFStringRef"));
505        assert!(reachable.contains("CFIndex"));
506        assert!(reachable.contains("__CFString"));
507        assert!(!reachable.contains("Unrelated"));
508    }
509}