Skip to main content

mangle_analysis/
name_trie.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Name trie for Mangle's name constant hierarchy.
16//!
17//! Maps name constants (like `/animal/dog`) to their most precise type
18//! via longest-prefix matching. Used by the bounds checker to infer types
19//! for name constants appearing in facts.
20
21use fxhash::FxHashMap;
22use mangle_ir::{Inst, InstId, Ir};
23
24use crate::type_expr;
25
26/// A trie over the `/`-separated segments of Mangle name constants.
27#[derive(Debug, Default)]
28pub struct NameTrie {
29    children: FxHashMap<String, NameTrie>,
30    is_terminal: bool,
31}
32
33impl NameTrie {
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Inserts a name path (e.g. "/foo/bar") into the trie.
39    pub fn add(&mut self, name: &str) {
40        let parts = split_name(name);
41        let mut node = self;
42        for part in parts {
43            node = node
44                .children
45                .entry(part.to_string())
46                .or_default();
47        }
48        node.is_terminal = true;
49    }
50
51    /// Returns true if the exact name is in the trie.
52    pub fn contains(&self, name: &str) -> bool {
53        let parts = split_name(name);
54        let mut node = self;
55        for part in parts {
56            match node.children.get(part) {
57                Some(child) => node = child,
58                None => return false,
59            }
60        }
61        node.is_terminal
62    }
63
64    /// Finds the longest prefix of `name` that exists in the trie.
65    /// Returns the prefix as a name string, or `/name` if no prefix found.
66    ///
67    /// Example: if trie contains `/animal` and `/animal/dog`, and we look up
68    /// `/animal/dog/poodle`, returns `/animal/dog`.
69    pub fn prefix_name(&self, name: &str) -> String {
70        let parts = split_name(name);
71        let mut node = self;
72        let mut last_terminal_idx: Option<usize> = None;
73
74        for (i, part) in parts.iter().enumerate() {
75            match node.children.get(*part) {
76                Some(child) => {
77                    if child.is_terminal {
78                        last_terminal_idx = Some(i);
79                    }
80                    node = child;
81                }
82                None => break,
83            }
84        }
85
86        match last_terminal_idx {
87            Some(idx) => {
88                let prefix_parts = &parts[..=idx];
89                format!("/{}", prefix_parts.join("/"))
90            }
91            None => "/name".to_string(),
92        }
93    }
94
95    /// Collects all name constants from a type expression into this trie.
96    ///
97    /// Special handling for `fn:TaggedUnion`: skips the tag field (arg 0)
98    /// and tag values (odd-indexed args), only recurses into variant
99    /// struct types (even-indexed args from index 2).
100    pub fn collect(&mut self, ir: &Ir, id: InstId) {
101        match ir.get(id) {
102            Inst::Name(n) => {
103                let name = ir.resolve_name(*n);
104                if !type_expr::is_base_type(ir, id) && name.starts_with('/') {
105                    self.add(name);
106                }
107            }
108            Inst::ApplyFn { function, args } => {
109                let fname = ir.resolve_name(*function);
110                if fname == type_expr::FN_TAGGED_UNION && args.len() >= 3 {
111                    // Only collect from variant struct types, not tag
112                    // field name or variant tag values.
113                    for i in (2..args.len()).step_by(2) {
114                        self.collect(ir, args[i]);
115                    }
116                } else {
117                    for arg in args {
118                        self.collect(ir, *arg);
119                    }
120                }
121            }
122            _ => {}
123        }
124    }
125}
126
127/// Splits a name like "/foo/bar" into segments ["foo", "bar"].
128fn split_name(name: &str) -> Vec<&str> {
129    name.split('/')
130        .filter(|s| !s.is_empty())
131        .collect()
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn basic_trie_operations() {
140        let mut trie = NameTrie::new();
141        trie.add("/animal");
142        trie.add("/animal/dog");
143        trie.add("/color");
144
145        assert!(trie.contains("/animal"));
146        assert!(trie.contains("/animal/dog"));
147        assert!(trie.contains("/color"));
148        assert!(!trie.contains("/animal/cat"));
149        assert!(!trie.contains("/plant"));
150    }
151
152    #[test]
153    fn prefix_name_lookup() {
154        let mut trie = NameTrie::new();
155        trie.add("/animal");
156        trie.add("/animal/dog");
157
158        assert_eq!(trie.prefix_name("/animal/dog/poodle"), "/animal/dog");
159        assert_eq!(trie.prefix_name("/animal/cat"), "/animal");
160        assert_eq!(trie.prefix_name("/animal"), "/animal");
161        assert_eq!(trie.prefix_name("/plant/rose"), "/name");
162    }
163
164    #[test]
165    fn collect_from_type_expr() {
166        let mut ir = Ir::new();
167        let mut trie = NameTrie::new();
168
169        // Build: fn:Struct(/x, /animal, /y, /color)
170        let x = {
171            let n = ir.intern_name("/x");
172            ir.add_inst(Inst::Name(n))
173        };
174        let animal = {
175            let n = ir.intern_name("/animal");
176            ir.add_inst(Inst::Name(n))
177        };
178        let y = {
179            let n = ir.intern_name("/y");
180            ir.add_inst(Inst::Name(n))
181        };
182        let color = {
183            let n = ir.intern_name("/color");
184            ir.add_inst(Inst::Name(n))
185        };
186        let struct_name = ir.intern_name("fn:Struct");
187        let struct_type = ir.add_inst(Inst::ApplyFn {
188            function: struct_name,
189            args: vec![x, animal, y, color],
190        });
191
192        trie.collect(&ir, struct_type);
193        assert!(trie.contains("/x"));
194        assert!(trie.contains("/animal"));
195        assert!(trie.contains("/y"));
196        assert!(trie.contains("/color"));
197    }
198
199    #[test]
200    fn collect_from_tagged_union_skips_tags() {
201        let mut ir = Ir::new();
202        let mut trie = NameTrie::new();
203
204        // fn:TaggedUnion(/kind, /move, fn:Struct(/x, /number), /quit, fn:Struct())
205        let kind = {
206            let n = ir.intern_name("/kind");
207            ir.add_inst(Inst::Name(n))
208        };
209        let move_ = {
210            let n = ir.intern_name("/move");
211            ir.add_inst(Inst::Name(n))
212        };
213        let x = {
214            let n = ir.intern_name("/x");
215            ir.add_inst(Inst::Name(n))
216        };
217        let number = {
218            let n = ir.intern_name("/number");
219            ir.add_inst(Inst::Name(n))
220        };
221        let struct_name = ir.intern_name("fn:Struct");
222        let move_struct = ir.add_inst(Inst::ApplyFn {
223            function: struct_name,
224            args: vec![x, number],
225        });
226        let quit = {
227            let n = ir.intern_name("/quit");
228            ir.add_inst(Inst::Name(n))
229        };
230        let quit_struct = ir.add_inst(Inst::ApplyFn {
231            function: struct_name,
232            args: vec![],
233        });
234        let tu_name = ir.intern_name("fn:TaggedUnion");
235        let tu = ir.add_inst(Inst::ApplyFn {
236            function: tu_name,
237            args: vec![kind, move_, move_struct, quit, quit_struct],
238        });
239
240        trie.collect(&ir, tu);
241
242        // Should collect /x and /number from variant structs.
243        assert!(trie.contains("/x"));
244        // /number is a base type, so it should NOT be collected.
245        assert!(!trie.contains("/number"));
246        // Should NOT collect /kind (tag field), /move, /quit (tag values).
247        assert!(!trie.contains("/kind"));
248        assert!(!trie.contains("/move"));
249        assert!(!trie.contains("/quit"));
250    }
251}