Skip to main content

php_lsp/
call_hierarchy.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use php_ast::{ClassMemberKind, EnumMemberKind, ExprKind, NamespaceBody, Span, Stmt, StmtKind};
5use tower_lsp::lsp_types::{
6    CallHierarchyIncomingCall, CallHierarchyItem, CallHierarchyOutgoingCall, Position, Range,
7    SymbolKind, Url,
8};
9
10use crate::ast::{ParsedDoc, SourceView, span_to_range};
11use crate::references::find_references;
12
13/// Find the declaration matching `name` and return a `CallHierarchyItem`.
14pub fn prepare_call_hierarchy(
15    name: &str,
16    all_docs: &[(Url, Arc<ParsedDoc>)],
17) -> Option<CallHierarchyItem> {
18    for (uri, doc) in all_docs {
19        let sv = doc.view();
20        if let Some(item) = find_declaration_item(name, &doc.program().stmts, sv, uri) {
21            return Some(item);
22        }
23    }
24    None
25}
26
27/// Find all callers of `item.name` and return them grouped by enclosing function.
28pub fn incoming_calls(
29    item: &CallHierarchyItem,
30    all_docs: &[(Url, Arc<ParsedDoc>)],
31) -> Vec<CallHierarchyIncomingCall> {
32    let call_sites = find_references(&item.name, all_docs, false, None);
33    // Build O(1) URI → doc map to avoid scanning all_docs for each call site.
34    let doc_map: HashMap<&Url, &Arc<ParsedDoc>> = all_docs.iter().map(|(u, d)| (u, d)).collect();
35    let mut result: Vec<CallHierarchyIncomingCall> = Vec::new();
36    // Track (caller_name, caller_uri) → index in `result` for O(1) dedup.
37    let mut index: HashMap<(String, Url), usize> = HashMap::new();
38
39    for loc in call_sites {
40        let caller = doc_map.get(&loc.uri).and_then(|doc| {
41            enclosing_function(doc.view(), &doc.program().stmts, loc.range.start, &loc.uri)
42        });
43
44        let key = if let Some(ref ci) = caller {
45            (ci.name.clone(), ci.uri.clone())
46        } else {
47            ("<file scope>".to_string(), loc.uri.clone())
48        };
49
50        if let Some(&idx) = index.get(&key) {
51            result[idx].from_ranges.push(loc.range);
52        } else {
53            let from = caller.unwrap_or_else(|| CallHierarchyItem {
54                name: "<file scope>".to_string(),
55                kind: SymbolKind::FILE,
56                tags: None,
57                detail: None,
58                uri: loc.uri.clone(),
59                range: loc.range,
60                selection_range: loc.range,
61                data: None,
62            });
63            let idx = result.len();
64            index.insert(key, idx);
65            result.push(CallHierarchyIncomingCall {
66                from,
67                from_ranges: vec![loc.range],
68            });
69        }
70    }
71
72    result
73}
74
75/// Find all calls made by the body of `item.name`.
76pub fn outgoing_calls(
77    item: &CallHierarchyItem,
78    all_docs: &[(Url, Arc<ParsedDoc>)],
79) -> Vec<CallHierarchyOutgoingCall> {
80    let Some((_, doc)) = all_docs.iter().find(|(uri, _)| *uri == item.uri) else {
81        return Vec::new();
82    };
83    // Borrow sv.source() directly from the Arc to avoid cloning the whole file.
84    let item_source = doc.source();
85    let mut calls: Vec<(String, Span)> = Vec::new();
86    collect_calls_for(&item.name, &doc.program().stmts, &mut calls);
87
88    let mut result: Vec<CallHierarchyOutgoingCall> = Vec::new();
89    // Track callee_name → index in `result` for O(1) dedup.
90    let mut index: HashMap<String, usize> = HashMap::new();
91    let item_line_starts = doc.line_starts();
92    for (callee_name, span) in calls {
93        let call_range = span_to_range(item_source, item_line_starts, span);
94        if let Some(&idx) = index.get(&callee_name) {
95            result[idx].from_ranges.push(call_range);
96        } else if let Some(callee_item) = prepare_call_hierarchy(&callee_name, all_docs) {
97            let idx = result.len();
98            index.insert(callee_name, idx);
99            result.push(CallHierarchyOutgoingCall {
100                to: callee_item,
101                from_ranges: vec![call_range],
102            });
103        }
104    }
105
106    result
107}
108
109// === Internal helpers ===
110
111fn find_declaration_item(
112    name: &str,
113    stmts: &[Stmt<'_, '_>],
114    sv: SourceView<'_>,
115    uri: &Url,
116) -> Option<CallHierarchyItem> {
117    for stmt in stmts {
118        match &stmt.kind {
119            StmtKind::Function(f) if f.name == name => {
120                let range = sv.range_of(stmt.span);
121                let sel = sv.name_range(f.name);
122                return Some(CallHierarchyItem {
123                    name: name.to_string(),
124                    kind: SymbolKind::FUNCTION,
125                    tags: None,
126                    detail: None,
127                    uri: uri.clone(),
128                    range,
129                    selection_range: sel,
130                    data: None,
131                });
132            }
133            StmtKind::Class(c) => {
134                for member in c.members.iter() {
135                    if let ClassMemberKind::Method(m) = &member.kind
136                        && m.name == name
137                    {
138                        let range = sv.range_of(member.span);
139                        let sel = sv.name_range(m.name);
140                        return Some(CallHierarchyItem {
141                            name: name.to_string(),
142                            kind: SymbolKind::METHOD,
143                            tags: None,
144                            detail: c.name.map(|n| n.to_string()),
145                            uri: uri.clone(),
146                            range,
147                            selection_range: sel,
148                            data: None,
149                        });
150                    }
151                }
152            }
153            StmtKind::Trait(t) => {
154                for member in t.members.iter() {
155                    if let ClassMemberKind::Method(m) = &member.kind
156                        && m.name == name
157                    {
158                        let range = sv.range_of(member.span);
159                        let sel = sv.name_range(m.name);
160                        return Some(CallHierarchyItem {
161                            name: name.to_string(),
162                            kind: SymbolKind::METHOD,
163                            tags: None,
164                            detail: Some(t.name.to_string()),
165                            uri: uri.clone(),
166                            range,
167                            selection_range: sel,
168                            data: None,
169                        });
170                    }
171                }
172            }
173            StmtKind::Enum(e) => {
174                for member in e.members.iter() {
175                    if let EnumMemberKind::Method(m) = &member.kind
176                        && m.name == name
177                    {
178                        let range = sv.range_of(member.span);
179                        let sel = sv.name_range(m.name);
180                        return Some(CallHierarchyItem {
181                            name: name.to_string(),
182                            kind: SymbolKind::METHOD,
183                            tags: None,
184                            detail: Some(e.name.to_string()),
185                            uri: uri.clone(),
186                            range,
187                            selection_range: sel,
188                            data: None,
189                        });
190                    }
191                }
192            }
193            StmtKind::Namespace(ns) => {
194                if let NamespaceBody::Braced(inner) = &ns.body
195                    && let Some(item) = find_declaration_item(name, inner, sv, uri)
196                {
197                    return Some(item);
198                }
199            }
200            _ => {}
201        }
202    }
203    None
204}
205
206fn enclosing_function(
207    sv: SourceView<'_>,
208    stmts: &[Stmt<'_, '_>],
209    pos: Position,
210    uri: &Url,
211) -> Option<CallHierarchyItem> {
212    for stmt in stmts {
213        if let Some(item) = enclosing_in_stmt(sv, stmt, pos, uri) {
214            return Some(item);
215        }
216    }
217    None
218}
219
220fn enclosing_in_stmt(
221    sv: SourceView<'_>,
222    stmt: &Stmt<'_, '_>,
223    pos: Position,
224    uri: &Url,
225) -> Option<CallHierarchyItem> {
226    let range = sv.range_of(stmt.span);
227    if !range_contains(range, pos) {
228        return None;
229    }
230    match &stmt.kind {
231        StmtKind::Function(f) => {
232            let sel = sv.name_range(f.name);
233            Some(CallHierarchyItem {
234                name: f.name.to_string(),
235                kind: SymbolKind::FUNCTION,
236                tags: None,
237                detail: None,
238                uri: uri.clone(),
239                range,
240                selection_range: sel,
241                data: None,
242            })
243        }
244        StmtKind::Class(c) => {
245            for member in c.members.iter() {
246                let m_range = sv.range_of(member.span);
247                if range_contains(m_range, pos)
248                    && let ClassMemberKind::Method(m) = &member.kind
249                {
250                    let sel = sv.name_range(m.name);
251                    return Some(CallHierarchyItem {
252                        name: m.name.to_string(),
253                        kind: SymbolKind::METHOD,
254                        tags: None,
255                        detail: c.name.map(|n| n.to_string()),
256                        uri: uri.clone(),
257                        range: m_range,
258                        selection_range: sel,
259                        data: None,
260                    });
261                }
262            }
263            None
264        }
265        StmtKind::Trait(t) => {
266            for member in t.members.iter() {
267                let m_range = sv.range_of(member.span);
268                if range_contains(m_range, pos)
269                    && let ClassMemberKind::Method(m) = &member.kind
270                {
271                    let sel = sv.name_range(m.name);
272                    return Some(CallHierarchyItem {
273                        name: m.name.to_string(),
274                        kind: SymbolKind::METHOD,
275                        tags: None,
276                        detail: Some(t.name.to_string()),
277                        uri: uri.clone(),
278                        range: m_range,
279                        selection_range: sel,
280                        data: None,
281                    });
282                }
283            }
284            None
285        }
286        StmtKind::Enum(e) => {
287            for member in e.members.iter() {
288                let m_range = sv.range_of(member.span);
289                if range_contains(m_range, pos)
290                    && let EnumMemberKind::Method(m) = &member.kind
291                {
292                    let sel = sv.name_range(m.name);
293                    return Some(CallHierarchyItem {
294                        name: m.name.to_string(),
295                        kind: SymbolKind::METHOD,
296                        tags: None,
297                        detail: Some(e.name.to_string()),
298                        uri: uri.clone(),
299                        range: m_range,
300                        selection_range: sel,
301                        data: None,
302                    });
303                }
304            }
305            None
306        }
307        StmtKind::Namespace(ns) => {
308            if let NamespaceBody::Braced(inner) = &ns.body {
309                return enclosing_function(sv, inner, pos, uri);
310            }
311            None
312        }
313        _ => None,
314    }
315}
316
317fn range_contains(range: Range, pos: Position) -> bool {
318    if pos.line < range.start.line || pos.line > range.end.line {
319        return false;
320    }
321    if pos.line == range.start.line && pos.character < range.start.character {
322        return false;
323    }
324    if pos.line == range.end.line && pos.character >= range.end.character {
325        return false;
326    }
327    true
328}
329
330/// Collect all (callee_name, span) for calls made inside the body of `fn_name`.
331fn collect_calls_for(fn_name: &str, stmts: &[Stmt<'_, '_>], out: &mut Vec<(String, Span)>) {
332    for stmt in stmts {
333        match &stmt.kind {
334            StmtKind::Function(f) if f.name == fn_name => {
335                calls_in_stmts(&f.body, out);
336                return;
337            }
338            StmtKind::Class(c) => {
339                for member in c.members.iter() {
340                    if let ClassMemberKind::Method(m) = &member.kind
341                        && m.name == fn_name
342                        && let Some(body) = &m.body
343                    {
344                        calls_in_stmts(body, out);
345                        return;
346                    }
347                }
348            }
349            StmtKind::Trait(t) => {
350                for member in t.members.iter() {
351                    if let ClassMemberKind::Method(m) = &member.kind
352                        && m.name == fn_name
353                        && let Some(body) = &m.body
354                    {
355                        calls_in_stmts(body, out);
356                        return;
357                    }
358                }
359            }
360            StmtKind::Enum(e) => {
361                for member in e.members.iter() {
362                    if let EnumMemberKind::Method(m) = &member.kind
363                        && m.name == fn_name
364                        && let Some(body) = &m.body
365                    {
366                        calls_in_stmts(body, out);
367                        return;
368                    }
369                }
370            }
371            StmtKind::Namespace(ns) => {
372                if let NamespaceBody::Braced(inner) = &ns.body {
373                    collect_calls_for(fn_name, inner, out);
374                }
375            }
376            _ => {}
377        }
378    }
379}
380
381fn calls_in_stmts(stmts: &[Stmt<'_, '_>], out: &mut Vec<(String, Span)>) {
382    for stmt in stmts {
383        calls_in_stmt(stmt, out);
384    }
385}
386
387fn calls_in_stmt(stmt: &Stmt<'_, '_>, out: &mut Vec<(String, Span)>) {
388    match &stmt.kind {
389        StmtKind::Expression(e) => calls_in_expr(e, out),
390        StmtKind::Return(Some(v)) => calls_in_expr(v, out),
391        StmtKind::Echo(exprs) => {
392            for expr in exprs.iter() {
393                calls_in_expr(expr, out);
394            }
395        }
396        StmtKind::If(i) => {
397            calls_in_expr(&i.condition, out);
398            calls_in_stmt(i.then_branch, out);
399            for ei in i.elseif_branches.iter() {
400                calls_in_expr(&ei.condition, out);
401                calls_in_stmt(&ei.body, out);
402            }
403            if let Some(e) = &i.else_branch {
404                calls_in_stmt(e, out);
405            }
406        }
407        StmtKind::While(w) => {
408            calls_in_expr(&w.condition, out);
409            calls_in_stmt(w.body, out);
410        }
411        StmtKind::For(f) => {
412            for e in f.init.iter() {
413                calls_in_expr(e, out);
414            }
415            for cond in f.condition.iter() {
416                calls_in_expr(cond, out);
417            }
418            for e in f.update.iter() {
419                calls_in_expr(e, out);
420            }
421            calls_in_stmt(f.body, out);
422        }
423        StmtKind::Foreach(f) => {
424            calls_in_expr(&f.expr, out);
425            calls_in_stmt(f.body, out);
426        }
427        StmtKind::TryCatch(t) => {
428            calls_in_stmts(&t.body, out);
429            for catch in t.catches.iter() {
430                calls_in_stmts(&catch.body, out);
431            }
432            if let Some(finally) = &t.finally {
433                calls_in_stmts(finally, out);
434            }
435        }
436        StmtKind::Block(stmts) => calls_in_stmts(stmts, out),
437        _ => {}
438    }
439}
440
441fn calls_in_expr(expr: &php_ast::Expr<'_, '_>, out: &mut Vec<(String, Span)>) {
442    match &expr.kind {
443        ExprKind::FunctionCall(f) => {
444            if let ExprKind::Identifier(name) = &f.name.kind {
445                out.push((name.to_string(), f.name.span));
446            } else {
447                calls_in_expr(f.name, out);
448            }
449            for arg in f.args.iter() {
450                calls_in_expr(&arg.value, out);
451            }
452        }
453        ExprKind::MethodCall(m) => {
454            calls_in_expr(m.object, out);
455            if let ExprKind::Identifier(name) = &m.method.kind {
456                out.push((name.to_string(), m.method.span));
457            }
458            for arg in m.args.iter() {
459                calls_in_expr(&arg.value, out);
460            }
461        }
462        ExprKind::NullsafeMethodCall(m) => {
463            calls_in_expr(m.object, out);
464            if let ExprKind::Identifier(name) = &m.method.kind {
465                out.push((name.to_string(), m.method.span));
466            }
467            for arg in m.args.iter() {
468                calls_in_expr(&arg.value, out);
469            }
470        }
471        ExprKind::StaticMethodCall(s) => {
472            calls_in_expr(s.class, out);
473            for arg in s.args.iter() {
474                calls_in_expr(&arg.value, out);
475            }
476        }
477        ExprKind::Assign(a) => {
478            calls_in_expr(a.target, out);
479            calls_in_expr(a.value, out);
480        }
481        ExprKind::Ternary(t) => {
482            calls_in_expr(t.condition, out);
483            if let Some(then_expr) = t.then_expr {
484                calls_in_expr(then_expr, out);
485            }
486            calls_in_expr(t.else_expr, out);
487        }
488        ExprKind::NullCoalesce(n) => {
489            calls_in_expr(n.left, out);
490            calls_in_expr(n.right, out);
491        }
492        ExprKind::Binary(b) => {
493            calls_in_expr(b.left, out);
494            calls_in_expr(b.right, out);
495        }
496        ExprKind::Parenthesized(e) => calls_in_expr(e, out),
497        _ => {}
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    fn uri(path: &str) -> Url {
506        Url::parse(&format!("file://{path}")).unwrap()
507    }
508
509    fn doc(path: &str, src: &str) -> (Url, Arc<ParsedDoc>) {
510        (uri(path), Arc::new(ParsedDoc::parse(src.to_string())))
511    }
512
513    #[test]
514    fn prepare_finds_function_declaration() {
515        let docs = vec![doc("/a.php", "<?php\nfunction greet() {}")];
516        let item = prepare_call_hierarchy("greet", &docs);
517        assert!(item.is_some(), "should find greet");
518        let item = item.unwrap();
519        assert_eq!(item.name, "greet");
520        assert_eq!(item.kind, SymbolKind::FUNCTION);
521    }
522
523    #[test]
524    fn prepare_finds_method_declaration() {
525        let docs = vec![doc(
526            "/a.php",
527            "<?php\nclass Foo { public function run() {} }",
528        )];
529        let item = prepare_call_hierarchy("run", &docs);
530        assert!(item.is_some(), "should find run");
531        let item = item.unwrap();
532        assert_eq!(item.name, "run");
533        assert_eq!(item.kind, SymbolKind::METHOD);
534    }
535
536    #[test]
537    fn prepare_returns_none_for_unknown() {
538        let docs = vec![doc("/a.php", "<?php\nfunction greet() {}")];
539        assert!(prepare_call_hierarchy("nonexistent", &docs).is_none());
540    }
541
542    #[test]
543    fn prepare_returns_none_for_empty_docs() {
544        let docs: Vec<(Url, Arc<ParsedDoc>)> = vec![];
545        assert!(prepare_call_hierarchy("anything", &docs).is_none());
546    }
547
548    #[test]
549    fn incoming_calls_finds_callers() {
550        let docs = vec![doc(
551            "/a.php",
552            "<?php\nfunction greet() {}\nfunction main() { greet(); }",
553        )];
554        let item = prepare_call_hierarchy("greet", &docs).unwrap();
555        let incoming = incoming_calls(&item, &docs);
556        assert!(!incoming.is_empty(), "should find at least one caller");
557        assert!(
558            incoming.iter().any(|c| c.from.name == "main"),
559            "main should be a caller"
560        );
561    }
562
563    #[test]
564    fn incoming_calls_empty_when_no_callers() {
565        let docs = vec![doc("/a.php", "<?php\nfunction unused() {}")];
566        let item = prepare_call_hierarchy("unused", &docs).unwrap();
567        let incoming = incoming_calls(&item, &docs);
568        assert!(incoming.is_empty(), "no callers expected");
569    }
570
571    #[test]
572    fn outgoing_calls_finds_callees() {
573        let docs = vec![doc(
574            "/a.php",
575            "<?php\nfunction helper() {}\nfunction main() { helper(); }",
576        )];
577        let item = prepare_call_hierarchy("main", &docs).unwrap();
578        let outgoing = outgoing_calls(&item, &docs);
579        assert!(!outgoing.is_empty(), "should find at least one callee");
580        assert!(
581            outgoing.iter().any(|c| c.to.name == "helper"),
582            "helper should be a callee"
583        );
584    }
585
586    #[test]
587    fn outgoing_calls_empty_for_function_with_no_calls() {
588        let docs = vec![doc("/a.php", "<?php\nfunction noop() { $x = 1; }")];
589        let item = prepare_call_hierarchy("noop", &docs).unwrap();
590        let outgoing = outgoing_calls(&item, &docs);
591        assert!(outgoing.is_empty(), "no outgoing calls expected");
592    }
593
594    #[test]
595    fn outgoing_calls_cross_file() {
596        let a = doc("/a.php", "<?php\nfunction helper() {}");
597        let b = doc("/b.php", "<?php\nfunction main() { helper(); }");
598        let docs = vec![a, b];
599        let item = prepare_call_hierarchy("main", &docs).unwrap();
600        let outgoing = outgoing_calls(&item, &docs);
601        assert!(
602            outgoing.iter().any(|c| c.to.name == "helper"),
603            "cross-file callee not found"
604        );
605    }
606
607    #[test]
608    fn incoming_calls_cross_file() {
609        let a = doc("/a.php", "<?php\nfunction greet() {}");
610        let b = doc("/b.php", "<?php\nfunction run() { greet(); }");
611        let docs = vec![a, b];
612        let item = prepare_call_hierarchy("greet", &docs).unwrap();
613        let incoming = incoming_calls(&item, &docs);
614        assert!(
615            incoming.iter().any(|c| c.from.name == "run"),
616            "cross-file caller not found"
617        );
618    }
619
620    #[test]
621    fn prepare_finds_enum_method_declaration() {
622        let docs = vec![doc(
623            "/a.php",
624            "<?php\nenum Suit { public function label(): string { return 'x'; } }",
625        )];
626        let item = prepare_call_hierarchy("label", &docs);
627        assert!(item.is_some(), "should find enum method 'label'");
628        let item = item.unwrap();
629        assert_eq!(item.name, "label");
630        assert_eq!(item.kind, SymbolKind::METHOD);
631    }
632
633    #[test]
634    fn outgoing_calls_from_enum_method() {
635        let docs = vec![doc(
636            "/a.php",
637            "<?php\nfunction fmt(): string { return ''; }\nenum Suit { public function label(): string { return fmt(); } }",
638        )];
639        let item = prepare_call_hierarchy("label", &docs).unwrap();
640        let outgoing = outgoing_calls(&item, &docs);
641        assert!(
642            outgoing.iter().any(|c| c.to.name == "fmt"),
643            "should find outgoing call to fmt from enum method"
644        );
645    }
646
647    #[test]
648    fn outgoing_calls_from_for_init_and_update() {
649        let docs = vec![doc(
650            "/a.php",
651            "<?php\nfunction start(): int { return 0; }\nfunction step(): void {}\nfunction main(): void { for ($i = start(); $i < 10; step()) {} }",
652        )];
653        let item = prepare_call_hierarchy("main", &docs).unwrap();
654        let outgoing = outgoing_calls(&item, &docs);
655        assert!(
656            outgoing.iter().any(|c| c.to.name == "start"),
657            "should find call to start() in for-init"
658        );
659        assert!(
660            outgoing.iter().any(|c| c.to.name == "step"),
661            "should find call to step() in for-update"
662        );
663    }
664
665    #[test]
666    fn outgoing_calls_deduplicates_same_callee() {
667        let docs = vec![doc(
668            "/a.php",
669            "<?php\nfunction helper() {}\nfunction main() { helper(); helper(); }",
670        )];
671        let item = prepare_call_hierarchy("main", &docs).unwrap();
672        let outgoing = outgoing_calls(&item, &docs);
673        let helper_entries: Vec<_> = outgoing.iter().filter(|c| c.to.name == "helper").collect();
674        assert_eq!(
675            helper_entries.len(),
676            1,
677            "helper should appear once (with two from_ranges)"
678        );
679        assert_eq!(
680            helper_entries[0].from_ranges.len(),
681            2,
682            "should have two call-site ranges"
683        );
684    }
685
686    // ── range_contains boundary regression tests ─────────────────────────────
687
688    #[test]
689    fn range_contains_excludes_exact_end_position() {
690        // LSP ranges are half-open [start, end).  A position exactly at
691        // range.end is OUTSIDE the range.  The old code used `>` instead of
692        // `>=`, which incorrectly included the end position.
693        let range = Range {
694            start: Position {
695                line: 1,
696                character: 0,
697            },
698            end: Position {
699                line: 3,
700                character: 5,
701            },
702        };
703        // One past the last character on the end line — clearly outside.
704        assert!(
705            !range_contains(
706                range,
707                Position {
708                    line: 3,
709                    character: 6
710                }
711            ),
712            "position after end must be outside"
713        );
714        // Exactly at end — outside per LSP half-open semantics.
715        assert!(
716            !range_contains(
717                range,
718                Position {
719                    line: 3,
720                    character: 5
721                }
722            ),
723            "position exactly at range.end must be outside (half-open range)"
724        );
725        // One before end — inside.
726        assert!(
727            range_contains(
728                range,
729                Position {
730                    line: 3,
731                    character: 4
732                }
733            ),
734            "position just before end must be inside"
735        );
736        // Start of range — inside.
737        assert!(
738            range_contains(
739                range,
740                Position {
741                    line: 1,
742                    character: 0
743                }
744            ),
745            "start position must be inside"
746        );
747    }
748}