Skip to main content

swift_demangler/
context.rs

1//! Symbol context representation.
2//!
3//! This module provides types for representing the context (location) of a Swift symbol,
4//! such as the module, type, and extension path.
5
6use crate::raw::{Node, NodeKind};
7
8/// The context (location) of a Swift symbol.
9///
10/// A context represents where a symbol is defined, including module, type, and extension information.
11#[derive(Clone, Copy)]
12pub struct SymbolContext<'ctx> {
13    raw: Node<'ctx>,
14}
15
16impl<'ctx> SymbolContext<'ctx> {
17    /// Create a SymbolContext from a raw node.
18    pub fn new(raw: Node<'ctx>) -> Self {
19        Self { raw }
20    }
21
22    /// Get the underlying raw node.
23    pub fn raw(&self) -> Node<'ctx> {
24        self.raw
25    }
26
27    /// Get the module name if this context is in a module.
28    pub fn module(&self) -> Option<&'ctx str> {
29        Self::find_module_in_context(self.raw)
30    }
31
32    fn find_module_in_context(node: Node<'ctx>) -> Option<&'ctx str> {
33        // First check direct children
34        for child in node.children() {
35            if child.kind() == NodeKind::Module {
36                return child.text();
37            }
38        }
39        // Then check the context chain
40        for child in node.children() {
41            match child.kind() {
42                NodeKind::Class
43                | NodeKind::Structure
44                | NodeKind::Enum
45                | NodeKind::Protocol
46                | NodeKind::Extension
47                | NodeKind::TypeAlias
48                | NodeKind::OtherNominalType => {
49                    if let Some(module) = Self::find_module_in_context(child) {
50                        return Some(module);
51                    }
52                }
53                _ => {}
54            }
55        }
56        None
57    }
58
59    /// Get the type name if this context is within a type.
60    pub fn type_name(&self) -> Option<&'ctx str> {
61        self.find_type_name_in_context(self.raw)
62    }
63
64    fn find_type_name_in_context(&self, node: Node<'ctx>) -> Option<&'ctx str> {
65        for child in node.children() {
66            match child.kind() {
67                NodeKind::Class
68                | NodeKind::Structure
69                | NodeKind::Enum
70                | NodeKind::Protocol
71                | NodeKind::TypeAlias
72                | NodeKind::OtherNominalType => {
73                    return self.extract_identifier(child);
74                }
75                NodeKind::Extension => {
76                    // Extension's type is its first child
77                    if let Some(inner) = child.child(0) {
78                        return self
79                            .find_type_name_in_context(inner)
80                            .or_else(|| self.extract_identifier(inner));
81                    }
82                }
83                _ => {}
84            }
85        }
86        None
87    }
88
89    fn extract_identifier(&self, node: Node<'ctx>) -> Option<&'ctx str> {
90        for child in node.children() {
91            if child.kind() == NodeKind::Identifier {
92                return child.text();
93            }
94        }
95        node.text()
96    }
97
98    /// Get the full path as a string (e.g., "ModuleName.TypeName").
99    pub fn full_path(&self) -> String {
100        let components: Vec<String> = self.components().map(|c| c.name().to_string()).collect();
101        components.join(".")
102    }
103
104    /// Check if this context is an extension.
105    pub fn is_extension(&self) -> bool {
106        self.raw.kind() == NodeKind::Extension
107            || self.raw.children().any(|c| c.kind() == NodeKind::Extension)
108    }
109
110    /// Iterate over the context components from outermost (module) to innermost.
111    pub fn components(&self) -> impl Iterator<Item = ContextComponent<'ctx>> + use<'ctx> {
112        let mut components = Vec::new();
113        self.collect_components(self.raw, &mut components);
114        components.into_iter()
115    }
116
117    fn collect_components(&self, node: Node<'ctx>, components: &mut Vec<ContextComponent<'ctx>>) {
118        for child in node.children() {
119            match child.kind() {
120                NodeKind::Module => {
121                    if let Some(name) = child.text() {
122                        components.push(ContextComponent::Module(name));
123                    }
124                }
125                NodeKind::Class => {
126                    // First, find and add the module from inside the class
127                    self.collect_module_from_type(child, components);
128                    // Then add the class itself
129                    if let Some(name) = self.extract_identifier(child) {
130                        components.push(ContextComponent::Class { name, raw: child });
131                    }
132                }
133                NodeKind::Structure => {
134                    self.collect_module_from_type(child, components);
135                    if let Some(name) = self.extract_identifier(child) {
136                        components.push(ContextComponent::Struct { name, raw: child });
137                    }
138                }
139                NodeKind::Enum => {
140                    self.collect_module_from_type(child, components);
141                    if let Some(name) = self.extract_identifier(child) {
142                        components.push(ContextComponent::Enum { name, raw: child });
143                    }
144                }
145                NodeKind::Protocol => {
146                    self.collect_module_from_type(child, components);
147                    if let Some(name) = self.extract_identifier(child) {
148                        components.push(ContextComponent::Protocol { name, raw: child });
149                    }
150                }
151                NodeKind::Extension => {
152                    // Extension wraps the extended type
153                    if let Some(extended_type) = child.child(0) {
154                        let base = self.context_component_from_type(extended_type);
155                        components.push(ContextComponent::Extension {
156                            base: Box::new(base),
157                            raw: child,
158                        });
159                    }
160                }
161                NodeKind::TypeAlias => {
162                    self.collect_module_from_type(child, components);
163                    if let Some(name) = self.extract_identifier(child) {
164                        components.push(ContextComponent::TypeAlias { name, raw: child });
165                    }
166                }
167                _ => {}
168            }
169        }
170    }
171
172    fn collect_module_from_type(
173        &self,
174        type_node: Node<'ctx>,
175        components: &mut Vec<ContextComponent<'ctx>>,
176    ) {
177        for child in type_node.children() {
178            if child.kind() == NodeKind::Module
179                && let Some(name) = child.text()
180            {
181                components.push(ContextComponent::Module(name));
182                return;
183            }
184        }
185    }
186
187    fn context_component_from_type(&self, node: Node<'ctx>) -> ContextComponent<'ctx> {
188        match node.kind() {
189            NodeKind::Class => ContextComponent::Class {
190                name: self.extract_identifier(node).unwrap_or(""),
191                raw: node,
192            },
193            NodeKind::Structure => ContextComponent::Struct {
194                name: self.extract_identifier(node).unwrap_or(""),
195                raw: node,
196            },
197            NodeKind::Enum => ContextComponent::Enum {
198                name: self.extract_identifier(node).unwrap_or(""),
199                raw: node,
200            },
201            NodeKind::Protocol => ContextComponent::Protocol {
202                name: self.extract_identifier(node).unwrap_or(""),
203                raw: node,
204            },
205            NodeKind::TypeAlias => ContextComponent::TypeAlias {
206                name: self.extract_identifier(node).unwrap_or(""),
207                raw: node,
208            },
209            NodeKind::Module => ContextComponent::Module(node.text().unwrap_or("")),
210            _ => ContextComponent::Other(node),
211        }
212    }
213}
214
215impl std::fmt::Debug for SymbolContext<'_> {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        f.debug_struct("SymbolContext")
218            .field("module", &self.module())
219            .field("type_name", &self.type_name())
220            .field("full_path", &self.full_path())
221            .field("is_extension", &self.is_extension())
222            .finish()
223    }
224}
225
226impl std::fmt::Display for SymbolContext<'_> {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        write!(f, "{}", self.full_path())
229    }
230}
231
232/// A component in a symbol's context path.
233#[derive(Debug)]
234pub enum ContextComponent<'ctx> {
235    /// A module.
236    Module(&'ctx str),
237    /// A class.
238    Class { name: &'ctx str, raw: Node<'ctx> },
239    /// A struct.
240    Struct { name: &'ctx str, raw: Node<'ctx> },
241    /// An enum.
242    Enum { name: &'ctx str, raw: Node<'ctx> },
243    /// A protocol.
244    Protocol { name: &'ctx str, raw: Node<'ctx> },
245    /// An extension of another type.
246    Extension {
247        base: Box<ContextComponent<'ctx>>,
248        raw: Node<'ctx>,
249    },
250    /// A type alias.
251    TypeAlias { name: &'ctx str, raw: Node<'ctx> },
252    /// Other context component.
253    Other(Node<'ctx>),
254}
255
256impl<'ctx> ContextComponent<'ctx> {
257    /// Get the name of this context component.
258    pub fn name(&self) -> &'ctx str {
259        match self {
260            ContextComponent::Module(name) => name,
261            ContextComponent::Class { name, .. } => name,
262            ContextComponent::Struct { name, .. } => name,
263            ContextComponent::Enum { name, .. } => name,
264            ContextComponent::Protocol { name, .. } => name,
265            ContextComponent::Extension { base, .. } => base.name(),
266            ContextComponent::TypeAlias { name, .. } => name,
267            ContextComponent::Other(_) => "",
268        }
269    }
270
271    /// Get the raw node for this component, if available.
272    pub fn raw(&self) -> Option<Node<'ctx>> {
273        match self {
274            ContextComponent::Module(_) => None,
275            ContextComponent::Class { raw, .. } => Some(*raw),
276            ContextComponent::Struct { raw, .. } => Some(*raw),
277            ContextComponent::Enum { raw, .. } => Some(*raw),
278            ContextComponent::Protocol { raw, .. } => Some(*raw),
279            ContextComponent::Extension { raw, .. } => Some(*raw),
280            ContextComponent::TypeAlias { raw, .. } => Some(*raw),
281            ContextComponent::Other(raw) => Some(*raw),
282        }
283    }
284
285    /// Check if this component is a type (class, struct, enum, protocol).
286    pub fn is_type(&self) -> bool {
287        matches!(
288            self,
289            ContextComponent::Class { .. }
290                | ContextComponent::Struct { .. }
291                | ContextComponent::Enum { .. }
292                | ContextComponent::Protocol { .. }
293                | ContextComponent::TypeAlias { .. }
294        )
295    }
296
297    /// Check if this component is an extension.
298    pub fn is_extension(&self) -> bool {
299        matches!(self, ContextComponent::Extension { .. })
300    }
301}
302
303/// Extract the context from a symbol node.
304///
305/// This function navigates from a symbol node (Function, Getter, etc.)
306/// to find its containing context.
307pub fn extract_context<'ctx>(symbol_node: Node<'ctx>) -> SymbolContext<'ctx> {
308    // The symbol node itself contains the context as children
309    // For a function: Function -> [Module, Identifier, Type, ...]
310    // For a method: Function -> [Class/Struct/etc -> [Module, Identifier], Identifier, Type, ...]
311    SymbolContext::new(symbol_node)
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::raw::Context;
318
319    #[test]
320    fn test_simple_function_context() {
321        let ctx = Context::new();
322        // main.hello() async throws -> Swift.String
323        let root = Node::parse(&ctx, "$s4main5helloSSyYaKF").unwrap();
324        let func = root.child(0).unwrap();
325        let context = extract_context(func);
326
327        assert_eq!(context.module(), Some("main"));
328    }
329
330    #[test]
331    fn test_method_context() {
332        let ctx = Context::new();
333        // foo.bar.bas(zim: foo.zim) -> ()
334        let root = Node::parse(&ctx, "_TFC3foo3bar3basfT3zimCS_3zim_T_").unwrap();
335        let func = root.child(0).unwrap();
336        let context = extract_context(func);
337
338        assert_eq!(context.module(), Some("foo"));
339        assert_eq!(context.type_name(), Some("bar"));
340    }
341
342    #[test]
343    fn test_context_full_path() {
344        let ctx = Context::new();
345        // foo.bar.bas(zim: foo.zim) -> () - a method in class bar
346        let root = Node::parse(&ctx, "_TFC3foo3bar3basfT3zimCS_3zim_T_").unwrap();
347        let func = root.child(0).unwrap();
348        let context = extract_context(func);
349
350        // full_path should contain the module and type
351        let path = context.full_path();
352        assert!(
353            path.contains("foo"),
354            "path should contain module 'foo': {path}"
355        );
356        assert!(
357            path.contains("bar"),
358            "path should contain type 'bar': {path}"
359        );
360    }
361
362    #[test]
363    fn test_context_components() {
364        let ctx = Context::new();
365        let root = Node::parse(&ctx, "_TFC3foo3bar3basfT3zimCS_3zim_T_").unwrap();
366        let func = root.child(0).unwrap();
367        let context = extract_context(func);
368
369        let components: Vec<_> = context.components().collect();
370
371        // Should have module and class components
372        assert!(!components.is_empty(), "should have at least one component");
373
374        // Verify we can find the module
375        let has_module = components.iter().any(|c| c.name() == "foo");
376        assert!(
377            has_module,
378            "should have module 'foo' in components: {:?}",
379            components.iter().map(|c| c.name()).collect::<Vec<_>>()
380        );
381    }
382}