1use crate::raw::{Node, NodeKind};
7
8#[derive(Clone, Copy)]
12pub struct SymbolContext<'ctx> {
13 raw: Node<'ctx>,
14}
15
16impl<'ctx> SymbolContext<'ctx> {
17 pub fn new(raw: Node<'ctx>) -> Self {
19 Self { raw }
20 }
21
22 pub fn raw(&self) -> Node<'ctx> {
24 self.raw
25 }
26
27 pub fn module(&self) -> Option<&'ctx str> {
29 Self::find_module_in_context(self.raw)
30 }
31
32 fn find_module_in_context(node: Node<'ctx>) -> Option<&'ctx str> {
33 for child in node.children() {
35 if child.kind() == NodeKind::Module {
36 return child.text();
37 }
38 }
39 for child in node.children() {
41 match child.kind() {
42 NodeKind::Class
43 | NodeKind::Structure
44 | NodeKind::Enum
45 | NodeKind::Protocol
46 | NodeKind::Extension
47 | NodeKind::TypeAlias
48 | NodeKind::OtherNominalType => {
49 if let Some(module) = Self::find_module_in_context(child) {
50 return Some(module);
51 }
52 }
53 _ => {}
54 }
55 }
56 None
57 }
58
59 pub fn type_name(&self) -> Option<&'ctx str> {
61 self.find_type_name_in_context(self.raw)
62 }
63
64 fn find_type_name_in_context(&self, node: Node<'ctx>) -> Option<&'ctx str> {
65 for child in node.children() {
66 match child.kind() {
67 NodeKind::Class
68 | NodeKind::Structure
69 | NodeKind::Enum
70 | NodeKind::Protocol
71 | NodeKind::TypeAlias
72 | NodeKind::OtherNominalType => {
73 return self.extract_identifier(child);
74 }
75 NodeKind::Extension => {
76 if let Some(inner) = child.child(0) {
78 return self
79 .find_type_name_in_context(inner)
80 .or_else(|| self.extract_identifier(inner));
81 }
82 }
83 _ => {}
84 }
85 }
86 None
87 }
88
89 fn extract_identifier(&self, node: Node<'ctx>) -> Option<&'ctx str> {
90 for child in node.children() {
91 if child.kind() == NodeKind::Identifier {
92 return child.text();
93 }
94 }
95 node.text()
96 }
97
98 pub fn full_path(&self) -> String {
100 let components: Vec<String> = self.components().map(|c| c.name().to_string()).collect();
101 components.join(".")
102 }
103
104 pub fn is_extension(&self) -> bool {
106 self.raw.kind() == NodeKind::Extension
107 || self.raw.children().any(|c| c.kind() == NodeKind::Extension)
108 }
109
110 pub fn components(&self) -> impl Iterator<Item = ContextComponent<'ctx>> + use<'ctx> {
112 let mut components = Vec::new();
113 self.collect_components(self.raw, &mut components);
114 components.into_iter()
115 }
116
117 fn collect_components(&self, node: Node<'ctx>, components: &mut Vec<ContextComponent<'ctx>>) {
118 for child in node.children() {
119 match child.kind() {
120 NodeKind::Module => {
121 if let Some(name) = child.text() {
122 components.push(ContextComponent::Module(name));
123 }
124 }
125 NodeKind::Class => {
126 self.collect_module_from_type(child, components);
128 if let Some(name) = self.extract_identifier(child) {
130 components.push(ContextComponent::Class { name, raw: child });
131 }
132 }
133 NodeKind::Structure => {
134 self.collect_module_from_type(child, components);
135 if let Some(name) = self.extract_identifier(child) {
136 components.push(ContextComponent::Struct { name, raw: child });
137 }
138 }
139 NodeKind::Enum => {
140 self.collect_module_from_type(child, components);
141 if let Some(name) = self.extract_identifier(child) {
142 components.push(ContextComponent::Enum { name, raw: child });
143 }
144 }
145 NodeKind::Protocol => {
146 self.collect_module_from_type(child, components);
147 if let Some(name) = self.extract_identifier(child) {
148 components.push(ContextComponent::Protocol { name, raw: child });
149 }
150 }
151 NodeKind::Extension => {
152 if let Some(extended_type) = child.child(0) {
154 let base = self.context_component_from_type(extended_type);
155 components.push(ContextComponent::Extension {
156 base: Box::new(base),
157 raw: child,
158 });
159 }
160 }
161 NodeKind::TypeAlias => {
162 self.collect_module_from_type(child, components);
163 if let Some(name) = self.extract_identifier(child) {
164 components.push(ContextComponent::TypeAlias { name, raw: child });
165 }
166 }
167 _ => {}
168 }
169 }
170 }
171
172 fn collect_module_from_type(
173 &self,
174 type_node: Node<'ctx>,
175 components: &mut Vec<ContextComponent<'ctx>>,
176 ) {
177 for child in type_node.children() {
178 if child.kind() == NodeKind::Module
179 && let Some(name) = child.text()
180 {
181 components.push(ContextComponent::Module(name));
182 return;
183 }
184 }
185 }
186
187 fn context_component_from_type(&self, node: Node<'ctx>) -> ContextComponent<'ctx> {
188 match node.kind() {
189 NodeKind::Class => ContextComponent::Class {
190 name: self.extract_identifier(node).unwrap_or(""),
191 raw: node,
192 },
193 NodeKind::Structure => ContextComponent::Struct {
194 name: self.extract_identifier(node).unwrap_or(""),
195 raw: node,
196 },
197 NodeKind::Enum => ContextComponent::Enum {
198 name: self.extract_identifier(node).unwrap_or(""),
199 raw: node,
200 },
201 NodeKind::Protocol => ContextComponent::Protocol {
202 name: self.extract_identifier(node).unwrap_or(""),
203 raw: node,
204 },
205 NodeKind::TypeAlias => ContextComponent::TypeAlias {
206 name: self.extract_identifier(node).unwrap_or(""),
207 raw: node,
208 },
209 NodeKind::Module => ContextComponent::Module(node.text().unwrap_or("")),
210 _ => ContextComponent::Other(node),
211 }
212 }
213}
214
215impl std::fmt::Debug for SymbolContext<'_> {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 f.debug_struct("SymbolContext")
218 .field("module", &self.module())
219 .field("type_name", &self.type_name())
220 .field("full_path", &self.full_path())
221 .field("is_extension", &self.is_extension())
222 .finish()
223 }
224}
225
226impl std::fmt::Display for SymbolContext<'_> {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 write!(f, "{}", self.full_path())
229 }
230}
231
232#[derive(Debug)]
234pub enum ContextComponent<'ctx> {
235 Module(&'ctx str),
237 Class { name: &'ctx str, raw: Node<'ctx> },
239 Struct { name: &'ctx str, raw: Node<'ctx> },
241 Enum { name: &'ctx str, raw: Node<'ctx> },
243 Protocol { name: &'ctx str, raw: Node<'ctx> },
245 Extension {
247 base: Box<ContextComponent<'ctx>>,
248 raw: Node<'ctx>,
249 },
250 TypeAlias { name: &'ctx str, raw: Node<'ctx> },
252 Other(Node<'ctx>),
254}
255
256impl<'ctx> ContextComponent<'ctx> {
257 pub fn name(&self) -> &'ctx str {
259 match self {
260 ContextComponent::Module(name) => name,
261 ContextComponent::Class { name, .. } => name,
262 ContextComponent::Struct { name, .. } => name,
263 ContextComponent::Enum { name, .. } => name,
264 ContextComponent::Protocol { name, .. } => name,
265 ContextComponent::Extension { base, .. } => base.name(),
266 ContextComponent::TypeAlias { name, .. } => name,
267 ContextComponent::Other(_) => "",
268 }
269 }
270
271 pub fn raw(&self) -> Option<Node<'ctx>> {
273 match self {
274 ContextComponent::Module(_) => None,
275 ContextComponent::Class { raw, .. } => Some(*raw),
276 ContextComponent::Struct { raw, .. } => Some(*raw),
277 ContextComponent::Enum { raw, .. } => Some(*raw),
278 ContextComponent::Protocol { raw, .. } => Some(*raw),
279 ContextComponent::Extension { raw, .. } => Some(*raw),
280 ContextComponent::TypeAlias { raw, .. } => Some(*raw),
281 ContextComponent::Other(raw) => Some(*raw),
282 }
283 }
284
285 pub fn is_type(&self) -> bool {
287 matches!(
288 self,
289 ContextComponent::Class { .. }
290 | ContextComponent::Struct { .. }
291 | ContextComponent::Enum { .. }
292 | ContextComponent::Protocol { .. }
293 | ContextComponent::TypeAlias { .. }
294 )
295 }
296
297 pub fn is_extension(&self) -> bool {
299 matches!(self, ContextComponent::Extension { .. })
300 }
301}
302
303pub fn extract_context<'ctx>(symbol_node: Node<'ctx>) -> SymbolContext<'ctx> {
308 SymbolContext::new(symbol_node)
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::raw::Context;
318
319 #[test]
320 fn test_simple_function_context() {
321 let ctx = Context::new();
322 let root = Node::parse(&ctx, "$s4main5helloSSyYaKF").unwrap();
324 let func = root.child(0).unwrap();
325 let context = extract_context(func);
326
327 assert_eq!(context.module(), Some("main"));
328 }
329
330 #[test]
331 fn test_method_context() {
332 let ctx = Context::new();
333 let root = Node::parse(&ctx, "_TFC3foo3bar3basfT3zimCS_3zim_T_").unwrap();
335 let func = root.child(0).unwrap();
336 let context = extract_context(func);
337
338 assert_eq!(context.module(), Some("foo"));
339 assert_eq!(context.type_name(), Some("bar"));
340 }
341
342 #[test]
343 fn test_context_full_path() {
344 let ctx = Context::new();
345 let root = Node::parse(&ctx, "_TFC3foo3bar3basfT3zimCS_3zim_T_").unwrap();
347 let func = root.child(0).unwrap();
348 let context = extract_context(func);
349
350 let path = context.full_path();
352 assert!(
353 path.contains("foo"),
354 "path should contain module 'foo': {path}"
355 );
356 assert!(
357 path.contains("bar"),
358 "path should contain type 'bar': {path}"
359 );
360 }
361
362 #[test]
363 fn test_context_components() {
364 let ctx = Context::new();
365 let root = Node::parse(&ctx, "_TFC3foo3bar3basfT3zimCS_3zim_T_").unwrap();
366 let func = root.child(0).unwrap();
367 let context = extract_context(func);
368
369 let components: Vec<_> = context.components().collect();
370
371 assert!(!components.is_empty(), "should have at least one component");
373
374 let has_module = components.iter().any(|c| c.name() == "foo");
376 assert!(
377 has_module,
378 "should have module 'foo' in components: {:?}",
379 components.iter().map(|c| c.name()).collect::<Vec<_>>()
380 );
381 }
382}