1use std::collections::HashMap;
2use std::ops::ControlFlow;
3use std::sync::Arc;
4
5use php_ast::visitor::{Visitor, walk_expr, walk_stmt};
6use php_ast::{ClassMemberKind, EnumMemberKind, ExprKind, NamespaceBody, Span, Stmt, StmtKind};
7use tower_lsp::lsp_types::{
8 CallHierarchyIncomingCall, CallHierarchyItem, CallHierarchyOutgoingCall, Position, Range,
9 SymbolKind, Url,
10};
11
12use crate::ast::{ParsedDoc, SourceView, span_to_range};
13use crate::references::find_references;
14
15pub fn prepare_call_hierarchy(
17 name: &str,
18 all_docs: &[(Url, Arc<ParsedDoc>)],
19) -> Option<CallHierarchyItem> {
20 for (uri, doc) in all_docs {
21 let sv = doc.view();
22 if let Some(item) = find_declaration_item(name, &doc.program().stmts, sv, uri) {
23 return Some(item);
24 }
25 }
26 None
27}
28
29pub fn incoming_calls(
31 item: &CallHierarchyItem,
32 all_docs: &[(Url, Arc<ParsedDoc>)],
33) -> Vec<CallHierarchyIncomingCall> {
34 let call_sites = find_references(&item.name, all_docs, false, None);
35 let doc_map: HashMap<&Url, &Arc<ParsedDoc>> = all_docs.iter().map(|(u, d)| (u, d)).collect();
37 let mut result: Vec<CallHierarchyIncomingCall> = Vec::new();
38 let mut index: HashMap<(String, Url), usize> = HashMap::new();
40
41 for loc in call_sites {
42 let caller = doc_map.get(&loc.uri).and_then(|doc| {
43 enclosing_function(doc.view(), &doc.program().stmts, loc.range.start, &loc.uri)
44 });
45
46 let key = if let Some(ref ci) = caller {
47 (ci.name.clone(), ci.uri.clone())
48 } else {
49 ("<file scope>".to_string(), loc.uri.clone())
50 };
51
52 if let Some(&idx) = index.get(&key) {
53 result[idx].from_ranges.push(loc.range);
54 } else {
55 let from = caller.unwrap_or_else(|| CallHierarchyItem {
56 name: "<file scope>".to_string(),
57 kind: SymbolKind::FILE,
58 tags: None,
59 detail: None,
60 uri: loc.uri.clone(),
61 range: loc.range,
62 selection_range: loc.range,
63 data: None,
64 });
65 let idx = result.len();
66 index.insert(key, idx);
67 result.push(CallHierarchyIncomingCall {
68 from,
69 from_ranges: vec![loc.range],
70 });
71 }
72 }
73
74 result
75}
76
77pub fn outgoing_calls(
79 item: &CallHierarchyItem,
80 all_docs: &[(Url, Arc<ParsedDoc>)],
81) -> Vec<CallHierarchyOutgoingCall> {
82 let Some((_, doc)) = all_docs.iter().find(|(uri, _)| *uri == item.uri) else {
83 return Vec::new();
84 };
85 let item_source = doc.source();
87 let mut calls: Vec<(String, Span)> = Vec::new();
88 collect_calls_for(&item.name, &doc.program().stmts, &mut calls);
89
90 let mut result: Vec<CallHierarchyOutgoingCall> = Vec::new();
91 let mut index: HashMap<String, usize> = HashMap::new();
93 let item_line_starts = doc.line_starts();
94 for (callee_name, span) in calls {
95 let call_range = span_to_range(item_source, item_line_starts, span);
96 if let Some(&idx) = index.get(&callee_name) {
97 result[idx].from_ranges.push(call_range);
98 } else if let Some(callee_item) = prepare_call_hierarchy(&callee_name, all_docs) {
99 let idx = result.len();
100 index.insert(callee_name, idx);
101 result.push(CallHierarchyOutgoingCall {
102 to: callee_item,
103 from_ranges: vec![call_range],
104 });
105 }
106 }
107
108 result
109}
110
111fn find_declaration_item(
114 name: &str,
115 stmts: &[Stmt<'_, '_>],
116 sv: SourceView<'_>,
117 uri: &Url,
118) -> Option<CallHierarchyItem> {
119 for stmt in stmts {
120 match &stmt.kind {
121 StmtKind::Function(f) if f.name == name => {
122 let range = sv.range_of(stmt.span);
123 let sel = sv.name_range(f.name);
124 return Some(CallHierarchyItem {
125 name: name.to_string(),
126 kind: SymbolKind::FUNCTION,
127 tags: None,
128 detail: None,
129 uri: uri.clone(),
130 range,
131 selection_range: sel,
132 data: None,
133 });
134 }
135 StmtKind::Class(c) => {
136 for member in c.members.iter() {
137 if let ClassMemberKind::Method(m) = &member.kind
138 && m.name == name
139 {
140 let range = sv.range_of(member.span);
141 let sel = sv.name_range(m.name);
142 return Some(CallHierarchyItem {
143 name: name.to_string(),
144 kind: SymbolKind::METHOD,
145 tags: None,
146 detail: c.name.map(|n| n.to_string()),
147 uri: uri.clone(),
148 range,
149 selection_range: sel,
150 data: None,
151 });
152 }
153 }
154 }
155 StmtKind::Trait(t) => {
156 for member in t.members.iter() {
157 if let ClassMemberKind::Method(m) = &member.kind
158 && m.name == name
159 {
160 let range = sv.range_of(member.span);
161 let sel = sv.name_range(m.name);
162 return Some(CallHierarchyItem {
163 name: name.to_string(),
164 kind: SymbolKind::METHOD,
165 tags: None,
166 detail: Some(t.name.to_string()),
167 uri: uri.clone(),
168 range,
169 selection_range: sel,
170 data: None,
171 });
172 }
173 }
174 }
175 StmtKind::Enum(e) => {
176 for member in e.members.iter() {
177 if let EnumMemberKind::Method(m) = &member.kind
178 && m.name == name
179 {
180 let range = sv.range_of(member.span);
181 let sel = sv.name_range(m.name);
182 return Some(CallHierarchyItem {
183 name: name.to_string(),
184 kind: SymbolKind::METHOD,
185 tags: None,
186 detail: Some(e.name.to_string()),
187 uri: uri.clone(),
188 range,
189 selection_range: sel,
190 data: None,
191 });
192 }
193 }
194 }
195 StmtKind::Namespace(ns) => {
196 if let NamespaceBody::Braced(inner) = &ns.body
197 && let Some(item) = find_declaration_item(name, inner, sv, uri)
198 {
199 return Some(item);
200 }
201 }
202 _ => {}
203 }
204 }
205 None
206}
207
208fn enclosing_function(
209 sv: SourceView<'_>,
210 stmts: &[Stmt<'_, '_>],
211 pos: Position,
212 uri: &Url,
213) -> Option<CallHierarchyItem> {
214 for stmt in stmts {
215 if let Some(item) = enclosing_in_stmt(sv, stmt, pos, uri) {
216 return Some(item);
217 }
218 }
219 None
220}
221
222fn enclosing_in_stmt(
223 sv: SourceView<'_>,
224 stmt: &Stmt<'_, '_>,
225 pos: Position,
226 uri: &Url,
227) -> Option<CallHierarchyItem> {
228 let range = sv.range_of(stmt.span);
229 if !range_contains(range, pos) {
230 return None;
231 }
232 match &stmt.kind {
233 StmtKind::Function(f) => {
234 let sel = sv.name_range(f.name);
235 Some(CallHierarchyItem {
236 name: f.name.to_string(),
237 kind: SymbolKind::FUNCTION,
238 tags: None,
239 detail: None,
240 uri: uri.clone(),
241 range,
242 selection_range: sel,
243 data: None,
244 })
245 }
246 StmtKind::Class(c) => {
247 for member in c.members.iter() {
248 let m_range = sv.range_of(member.span);
249 if range_contains(m_range, pos)
250 && let ClassMemberKind::Method(m) = &member.kind
251 {
252 let sel = sv.name_range(m.name);
253 return Some(CallHierarchyItem {
254 name: m.name.to_string(),
255 kind: SymbolKind::METHOD,
256 tags: None,
257 detail: c.name.map(|n| n.to_string()),
258 uri: uri.clone(),
259 range: m_range,
260 selection_range: sel,
261 data: None,
262 });
263 }
264 }
265 None
266 }
267 StmtKind::Trait(t) => {
268 for member in t.members.iter() {
269 let m_range = sv.range_of(member.span);
270 if range_contains(m_range, pos)
271 && let ClassMemberKind::Method(m) = &member.kind
272 {
273 let sel = sv.name_range(m.name);
274 return Some(CallHierarchyItem {
275 name: m.name.to_string(),
276 kind: SymbolKind::METHOD,
277 tags: None,
278 detail: Some(t.name.to_string()),
279 uri: uri.clone(),
280 range: m_range,
281 selection_range: sel,
282 data: None,
283 });
284 }
285 }
286 None
287 }
288 StmtKind::Enum(e) => {
289 for member in e.members.iter() {
290 let m_range = sv.range_of(member.span);
291 if range_contains(m_range, pos)
292 && let EnumMemberKind::Method(m) = &member.kind
293 {
294 let sel = sv.name_range(m.name);
295 return Some(CallHierarchyItem {
296 name: m.name.to_string(),
297 kind: SymbolKind::METHOD,
298 tags: None,
299 detail: Some(e.name.to_string()),
300 uri: uri.clone(),
301 range: m_range,
302 selection_range: sel,
303 data: None,
304 });
305 }
306 }
307 None
308 }
309 StmtKind::Namespace(ns) => {
310 if let NamespaceBody::Braced(inner) = &ns.body {
311 return enclosing_function(sv, inner, pos, uri);
312 }
313 None
314 }
315 _ => None,
316 }
317}
318
319fn range_contains(range: Range, pos: Position) -> bool {
320 if pos.line < range.start.line || pos.line > range.end.line {
321 return false;
322 }
323 if pos.line == range.start.line && pos.character < range.start.character {
324 return false;
325 }
326 if pos.line == range.end.line && pos.character >= range.end.character {
327 return false;
328 }
329 true
330}
331
332fn collect_calls_for(fn_name: &str, stmts: &[Stmt<'_, '_>], out: &mut Vec<(String, Span)>) {
334 for stmt in stmts {
335 match &stmt.kind {
336 StmtKind::Function(f) if f.name == fn_name => {
337 calls_in_stmts(&f.body, out);
338 return;
339 }
340 StmtKind::Class(c) => {
341 for member in c.members.iter() {
342 if let ClassMemberKind::Method(m) = &member.kind
343 && m.name == fn_name
344 && let Some(body) = &m.body
345 {
346 calls_in_stmts(body, out);
347 return;
348 }
349 }
350 }
351 StmtKind::Trait(t) => {
352 for member in t.members.iter() {
353 if let ClassMemberKind::Method(m) = &member.kind
354 && m.name == fn_name
355 && let Some(body) = &m.body
356 {
357 calls_in_stmts(body, out);
358 return;
359 }
360 }
361 }
362 StmtKind::Enum(e) => {
363 for member in e.members.iter() {
364 if let EnumMemberKind::Method(m) = &member.kind
365 && m.name == fn_name
366 && let Some(body) = &m.body
367 {
368 calls_in_stmts(body, out);
369 return;
370 }
371 }
372 }
373 StmtKind::Namespace(ns) => {
374 if let NamespaceBody::Braced(inner) = &ns.body {
375 collect_calls_for(fn_name, inner, out);
376 }
377 }
378 _ => {}
379 }
380 }
381}
382
383fn calls_in_stmts(stmts: &[Stmt<'_, '_>], out: &mut Vec<(String, Span)>) {
386 let mut collector = CallCollector { out };
387 for stmt in stmts {
388 let _ = collector.visit_stmt(stmt);
389 }
390}
391
392struct CallCollector<'c> {
393 out: &'c mut Vec<(String, Span)>,
394}
395
396impl<'arena, 'src> Visitor<'arena, 'src> for CallCollector<'_> {
397 fn visit_expr(&mut self, expr: &php_ast::Expr<'arena, 'src>) -> ControlFlow<()> {
398 match &expr.kind {
399 ExprKind::FunctionCall(f) => {
400 if let ExprKind::Identifier(name) = &f.name.kind {
401 self.out.push((name.to_string(), f.name.span));
402 }
403 }
404 ExprKind::MethodCall(m) | ExprKind::NullsafeMethodCall(m) => {
405 if let ExprKind::Identifier(name) = &m.method.kind {
406 self.out.push((name.to_string(), m.method.span));
407 }
408 }
409 ExprKind::StaticMethodCall(s) => {
410 if let ExprKind::Identifier(name) = &s.method.kind {
411 self.out.push((name.to_string(), s.method.span));
412 }
413 }
414 _ => {}
415 }
416 walk_expr(self, expr)
417 }
418
419 fn visit_stmt(&mut self, stmt: &php_ast::Stmt<'arena, 'src>) -> ControlFlow<()> {
420 match &stmt.kind {
424 StmtKind::Function(_)
425 | StmtKind::Class(_)
426 | StmtKind::Trait(_)
427 | StmtKind::Enum(_)
428 | StmtKind::Interface(_) => ControlFlow::Continue(()),
429 _ => walk_stmt(self, stmt),
430 }
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
441 fn range_contains_excludes_exact_end_position() {
442 let range = Range {
446 start: Position {
447 line: 1,
448 character: 0,
449 },
450 end: Position {
451 line: 3,
452 character: 5,
453 },
454 };
455 assert!(
457 !range_contains(
458 range,
459 Position {
460 line: 3,
461 character: 6
462 }
463 ),
464 "position after end must be outside"
465 );
466 assert!(
468 !range_contains(
469 range,
470 Position {
471 line: 3,
472 character: 5
473 }
474 ),
475 "position exactly at range.end must be outside (half-open range)"
476 );
477 assert!(
479 range_contains(
480 range,
481 Position {
482 line: 3,
483 character: 4
484 }
485 ),
486 "position just before end must be inside"
487 );
488 assert!(
490 range_contains(
491 range,
492 Position {
493 line: 1,
494 character: 0
495 }
496 ),
497 "start position must be inside"
498 );
499 }
500}