Skip to main content

php_lsp/
call_hierarchy.rs

1use std::collections::HashMap;
2use std::ops::ControlFlow;
3use std::sync::Arc;
4
5use php_ast::visitor::{Visitor, walk_expr, walk_stmt};
6use php_ast::{ClassMemberKind, EnumMemberKind, ExprKind, NamespaceBody, Span, Stmt, StmtKind};
7use tower_lsp::lsp_types::{
8    CallHierarchyIncomingCall, CallHierarchyItem, CallHierarchyOutgoingCall, Position, Range,
9    SymbolKind, Url,
10};
11
12use crate::ast::{ParsedDoc, SourceView, span_to_range};
13use crate::references::find_references;
14
15/// Find the declaration matching `name` and return a `CallHierarchyItem`.
16pub fn prepare_call_hierarchy(
17    name: &str,
18    all_docs: &[(Url, Arc<ParsedDoc>)],
19) -> Option<CallHierarchyItem> {
20    for (uri, doc) in all_docs {
21        let sv = doc.view();
22        if let Some(item) = find_declaration_item(name, &doc.program().stmts, sv, uri) {
23            return Some(item);
24        }
25    }
26    None
27}
28
29/// Find all callers of `item.name` and return them grouped by enclosing function.
30pub fn incoming_calls(
31    item: &CallHierarchyItem,
32    all_docs: &[(Url, Arc<ParsedDoc>)],
33) -> Vec<CallHierarchyIncomingCall> {
34    let call_sites = find_references(&item.name, all_docs, false, None);
35    // Build O(1) URI → doc map to avoid scanning all_docs for each call site.
36    let doc_map: HashMap<&Url, &Arc<ParsedDoc>> = all_docs.iter().map(|(u, d)| (u, d)).collect();
37    let mut result: Vec<CallHierarchyIncomingCall> = Vec::new();
38    // Track (caller_name, caller_uri) → index in `result` for O(1) dedup.
39    let mut index: HashMap<(String, Url), usize> = HashMap::new();
40
41    for loc in call_sites {
42        let caller = doc_map.get(&loc.uri).and_then(|doc| {
43            enclosing_function(doc.view(), &doc.program().stmts, loc.range.start, &loc.uri)
44        });
45
46        let key = if let Some(ref ci) = caller {
47            (ci.name.clone(), ci.uri.clone())
48        } else {
49            ("<file scope>".to_string(), loc.uri.clone())
50        };
51
52        if let Some(&idx) = index.get(&key) {
53            result[idx].from_ranges.push(loc.range);
54        } else {
55            let from = caller.unwrap_or_else(|| CallHierarchyItem {
56                name: "<file scope>".to_string(),
57                kind: SymbolKind::FILE,
58                tags: None,
59                detail: None,
60                uri: loc.uri.clone(),
61                range: loc.range,
62                selection_range: loc.range,
63                data: None,
64            });
65            let idx = result.len();
66            index.insert(key, idx);
67            result.push(CallHierarchyIncomingCall {
68                from,
69                from_ranges: vec![loc.range],
70            });
71        }
72    }
73
74    result
75}
76
77/// Find all calls made by the body of `item.name`.
78pub fn outgoing_calls(
79    item: &CallHierarchyItem,
80    all_docs: &[(Url, Arc<ParsedDoc>)],
81) -> Vec<CallHierarchyOutgoingCall> {
82    let Some((_, doc)) = all_docs.iter().find(|(uri, _)| *uri == item.uri) else {
83        return Vec::new();
84    };
85    // Borrow sv.source() directly from the Arc to avoid cloning the whole file.
86    let item_source = doc.source();
87    let mut calls: Vec<(String, Span)> = Vec::new();
88    collect_calls_for(&item.name, &doc.program().stmts, &mut calls);
89
90    let mut result: Vec<CallHierarchyOutgoingCall> = Vec::new();
91    // Track callee_name → index in `result` for O(1) dedup.
92    let mut index: HashMap<String, usize> = HashMap::new();
93    let item_line_starts = doc.line_starts();
94    for (callee_name, span) in calls {
95        let call_range = span_to_range(item_source, item_line_starts, span);
96        if let Some(&idx) = index.get(&callee_name) {
97            result[idx].from_ranges.push(call_range);
98        } else if let Some(callee_item) = prepare_call_hierarchy(&callee_name, all_docs) {
99            let idx = result.len();
100            index.insert(callee_name, idx);
101            result.push(CallHierarchyOutgoingCall {
102                to: callee_item,
103                from_ranges: vec![call_range],
104            });
105        }
106    }
107
108    result
109}
110
111// === Internal helpers ===
112
113fn find_declaration_item(
114    name: &str,
115    stmts: &[Stmt<'_, '_>],
116    sv: SourceView<'_>,
117    uri: &Url,
118) -> Option<CallHierarchyItem> {
119    for stmt in stmts {
120        match &stmt.kind {
121            StmtKind::Function(f) if f.name == name => {
122                let range = sv.range_of(stmt.span);
123                let sel = sv.name_range(f.name);
124                return Some(CallHierarchyItem {
125                    name: name.to_string(),
126                    kind: SymbolKind::FUNCTION,
127                    tags: None,
128                    detail: None,
129                    uri: uri.clone(),
130                    range,
131                    selection_range: sel,
132                    data: None,
133                });
134            }
135            StmtKind::Class(c) => {
136                for member in c.members.iter() {
137                    if let ClassMemberKind::Method(m) = &member.kind
138                        && m.name == name
139                    {
140                        let range = sv.range_of(member.span);
141                        let sel = sv.name_range(m.name);
142                        return Some(CallHierarchyItem {
143                            name: name.to_string(),
144                            kind: SymbolKind::METHOD,
145                            tags: None,
146                            detail: c.name.map(|n| n.to_string()),
147                            uri: uri.clone(),
148                            range,
149                            selection_range: sel,
150                            data: None,
151                        });
152                    }
153                }
154            }
155            StmtKind::Trait(t) => {
156                for member in t.members.iter() {
157                    if let ClassMemberKind::Method(m) = &member.kind
158                        && m.name == name
159                    {
160                        let range = sv.range_of(member.span);
161                        let sel = sv.name_range(m.name);
162                        return Some(CallHierarchyItem {
163                            name: name.to_string(),
164                            kind: SymbolKind::METHOD,
165                            tags: None,
166                            detail: Some(t.name.to_string()),
167                            uri: uri.clone(),
168                            range,
169                            selection_range: sel,
170                            data: None,
171                        });
172                    }
173                }
174            }
175            StmtKind::Enum(e) => {
176                for member in e.members.iter() {
177                    if let EnumMemberKind::Method(m) = &member.kind
178                        && m.name == name
179                    {
180                        let range = sv.range_of(member.span);
181                        let sel = sv.name_range(m.name);
182                        return Some(CallHierarchyItem {
183                            name: name.to_string(),
184                            kind: SymbolKind::METHOD,
185                            tags: None,
186                            detail: Some(e.name.to_string()),
187                            uri: uri.clone(),
188                            range,
189                            selection_range: sel,
190                            data: None,
191                        });
192                    }
193                }
194            }
195            StmtKind::Namespace(ns) => {
196                if let NamespaceBody::Braced(inner) = &ns.body
197                    && let Some(item) = find_declaration_item(name, inner, sv, uri)
198                {
199                    return Some(item);
200                }
201            }
202            _ => {}
203        }
204    }
205    None
206}
207
208fn enclosing_function(
209    sv: SourceView<'_>,
210    stmts: &[Stmt<'_, '_>],
211    pos: Position,
212    uri: &Url,
213) -> Option<CallHierarchyItem> {
214    for stmt in stmts {
215        if let Some(item) = enclosing_in_stmt(sv, stmt, pos, uri) {
216            return Some(item);
217        }
218    }
219    None
220}
221
222fn enclosing_in_stmt(
223    sv: SourceView<'_>,
224    stmt: &Stmt<'_, '_>,
225    pos: Position,
226    uri: &Url,
227) -> Option<CallHierarchyItem> {
228    let range = sv.range_of(stmt.span);
229    if !range_contains(range, pos) {
230        return None;
231    }
232    match &stmt.kind {
233        StmtKind::Function(f) => {
234            let sel = sv.name_range(f.name);
235            Some(CallHierarchyItem {
236                name: f.name.to_string(),
237                kind: SymbolKind::FUNCTION,
238                tags: None,
239                detail: None,
240                uri: uri.clone(),
241                range,
242                selection_range: sel,
243                data: None,
244            })
245        }
246        StmtKind::Class(c) => {
247            for member in c.members.iter() {
248                let m_range = sv.range_of(member.span);
249                if range_contains(m_range, pos)
250                    && let ClassMemberKind::Method(m) = &member.kind
251                {
252                    let sel = sv.name_range(m.name);
253                    return Some(CallHierarchyItem {
254                        name: m.name.to_string(),
255                        kind: SymbolKind::METHOD,
256                        tags: None,
257                        detail: c.name.map(|n| n.to_string()),
258                        uri: uri.clone(),
259                        range: m_range,
260                        selection_range: sel,
261                        data: None,
262                    });
263                }
264            }
265            None
266        }
267        StmtKind::Trait(t) => {
268            for member in t.members.iter() {
269                let m_range = sv.range_of(member.span);
270                if range_contains(m_range, pos)
271                    && let ClassMemberKind::Method(m) = &member.kind
272                {
273                    let sel = sv.name_range(m.name);
274                    return Some(CallHierarchyItem {
275                        name: m.name.to_string(),
276                        kind: SymbolKind::METHOD,
277                        tags: None,
278                        detail: Some(t.name.to_string()),
279                        uri: uri.clone(),
280                        range: m_range,
281                        selection_range: sel,
282                        data: None,
283                    });
284                }
285            }
286            None
287        }
288        StmtKind::Enum(e) => {
289            for member in e.members.iter() {
290                let m_range = sv.range_of(member.span);
291                if range_contains(m_range, pos)
292                    && let EnumMemberKind::Method(m) = &member.kind
293                {
294                    let sel = sv.name_range(m.name);
295                    return Some(CallHierarchyItem {
296                        name: m.name.to_string(),
297                        kind: SymbolKind::METHOD,
298                        tags: None,
299                        detail: Some(e.name.to_string()),
300                        uri: uri.clone(),
301                        range: m_range,
302                        selection_range: sel,
303                        data: None,
304                    });
305                }
306            }
307            None
308        }
309        StmtKind::Namespace(ns) => {
310            if let NamespaceBody::Braced(inner) = &ns.body {
311                return enclosing_function(sv, inner, pos, uri);
312            }
313            None
314        }
315        _ => None,
316    }
317}
318
319fn range_contains(range: Range, pos: Position) -> bool {
320    if pos.line < range.start.line || pos.line > range.end.line {
321        return false;
322    }
323    if pos.line == range.start.line && pos.character < range.start.character {
324        return false;
325    }
326    if pos.line == range.end.line && pos.character >= range.end.character {
327        return false;
328    }
329    true
330}
331
332/// Collect all (callee_name, span) for calls made inside the body of `fn_name`.
333fn collect_calls_for(fn_name: &str, stmts: &[Stmt<'_, '_>], out: &mut Vec<(String, Span)>) {
334    for stmt in stmts {
335        match &stmt.kind {
336            StmtKind::Function(f) if f.name == fn_name => {
337                calls_in_stmts(&f.body, out);
338                return;
339            }
340            StmtKind::Class(c) => {
341                for member in c.members.iter() {
342                    if let ClassMemberKind::Method(m) = &member.kind
343                        && m.name == fn_name
344                        && let Some(body) = &m.body
345                    {
346                        calls_in_stmts(body, out);
347                        return;
348                    }
349                }
350            }
351            StmtKind::Trait(t) => {
352                for member in t.members.iter() {
353                    if let ClassMemberKind::Method(m) = &member.kind
354                        && m.name == fn_name
355                        && let Some(body) = &m.body
356                    {
357                        calls_in_stmts(body, out);
358                        return;
359                    }
360                }
361            }
362            StmtKind::Enum(e) => {
363                for member in e.members.iter() {
364                    if let EnumMemberKind::Method(m) = &member.kind
365                        && m.name == fn_name
366                        && let Some(body) = &m.body
367                    {
368                        calls_in_stmts(body, out);
369                        return;
370                    }
371                }
372            }
373            StmtKind::Namespace(ns) => {
374                if let NamespaceBody::Braced(inner) = &ns.body {
375                    collect_calls_for(fn_name, inner, out);
376                }
377            }
378            _ => {}
379        }
380    }
381}
382
383/// Collects all (callee_name, span) call sites reachable from a slice of statements,
384/// without descending into nested named declarations (functions, classes, etc.).
385fn calls_in_stmts(stmts: &[Stmt<'_, '_>], out: &mut Vec<(String, Span)>) {
386    let mut collector = CallCollector { out };
387    for stmt in stmts {
388        let _ = collector.visit_stmt(stmt);
389    }
390}
391
392struct CallCollector<'c> {
393    out: &'c mut Vec<(String, Span)>,
394}
395
396impl<'arena, 'src> Visitor<'arena, 'src> for CallCollector<'_> {
397    fn visit_expr(&mut self, expr: &php_ast::Expr<'arena, 'src>) -> ControlFlow<()> {
398        match &expr.kind {
399            ExprKind::FunctionCall(f) => {
400                if let ExprKind::Identifier(name) = &f.name.kind {
401                    self.out.push((name.to_string(), f.name.span));
402                }
403            }
404            ExprKind::MethodCall(m) | ExprKind::NullsafeMethodCall(m) => {
405                if let ExprKind::Identifier(name) = &m.method.kind {
406                    self.out.push((name.to_string(), m.method.span));
407                }
408            }
409            ExprKind::StaticMethodCall(s) => {
410                if let ExprKind::Identifier(name) = &s.method.kind {
411                    self.out.push((name.to_string(), s.method.span));
412                }
413            }
414            _ => {}
415        }
416        walk_expr(self, expr)
417    }
418
419    fn visit_stmt(&mut self, stmt: &php_ast::Stmt<'arena, 'src>) -> ControlFlow<()> {
420        // Skip nested named declarations — they are separate callable units with
421        // their own call hierarchy entries; their internals are not outgoing calls
422        // of the function currently being analysed.
423        match &stmt.kind {
424            StmtKind::Function(_)
425            | StmtKind::Class(_)
426            | StmtKind::Trait(_)
427            | StmtKind::Enum(_)
428            | StmtKind::Interface(_) => ControlFlow::Continue(()),
429            _ => walk_stmt(self, stmt),
430        }
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    // ── range_contains boundary regression tests ─────────────────────────────
439
440    #[test]
441    fn range_contains_excludes_exact_end_position() {
442        // LSP ranges are half-open [start, end).  A position exactly at
443        // range.end is OUTSIDE the range.  The old code used `>` instead of
444        // `>=`, which incorrectly included the end position.
445        let range = Range {
446            start: Position {
447                line: 1,
448                character: 0,
449            },
450            end: Position {
451                line: 3,
452                character: 5,
453            },
454        };
455        // One past the last character on the end line — clearly outside.
456        assert!(
457            !range_contains(
458                range,
459                Position {
460                    line: 3,
461                    character: 6
462                }
463            ),
464            "position after end must be outside"
465        );
466        // Exactly at end — outside per LSP half-open semantics.
467        assert!(
468            !range_contains(
469                range,
470                Position {
471                    line: 3,
472                    character: 5
473                }
474            ),
475            "position exactly at range.end must be outside (half-open range)"
476        );
477        // One before end — inside.
478        assert!(
479            range_contains(
480                range,
481                Position {
482                    line: 3,
483                    character: 4
484                }
485            ),
486            "position just before end must be inside"
487        );
488        // Start of range — inside.
489        assert!(
490            range_contains(
491                range,
492                Position {
493                    line: 1,
494                    character: 0
495                }
496            ),
497            "start position must be inside"
498        );
499    }
500}