use std::collections::HashMap;
use lsp_types::{
CallHierarchyIncomingCall, CallHierarchyIncomingCallsParams, CallHierarchyItem,
CallHierarchyOutgoingCall, CallHierarchyOutgoingCallsParams, CallHierarchyPrepareParams, Range,
SymbolKind, Uri,
};
use crate::ir::ast::{ClassDefinition, ClassType, ComponentReference, Expression};
use crate::ir::visitor::{Visitable, Visitor};
use crate::lsp::utils::{get_word_at_position, parse_document, token_to_range};
struct CallRangeFinder<'a> {
target: &'a str,
ranges: Vec<Range>,
}
impl<'a> CallRangeFinder<'a> {
fn new(target: &'a str) -> Self {
Self {
target,
ranges: Vec::new(),
}
}
fn check_function_call(&mut self, comp: &ComponentReference) {
if get_function_name(comp) == self.target
&& let Some(ident) = comp.parts.first().map(|p| &p.ident)
{
self.ranges.push(token_to_range(ident));
}
}
}
impl Visitor for CallRangeFinder<'_> {
fn enter_expression(&mut self, node: &Expression) {
if let Expression::FunctionCall { comp, .. } = node {
self.check_function_call(comp);
}
}
}
struct FunctionCallCollector {
calls: HashMap<String, Vec<Range>>,
}
impl FunctionCallCollector {
fn new() -> Self {
Self {
calls: HashMap::new(),
}
}
fn record_call(&mut self, comp: &ComponentReference) {
let name = get_function_name(comp);
if let Some(ident) = comp.parts.first().map(|p| &p.ident) {
let range = token_to_range(ident);
self.calls.entry(name).or_default().push(range);
}
}
}
impl Visitor for FunctionCallCollector {
fn enter_expression(&mut self, node: &Expression) {
if let Expression::FunctionCall { comp, .. } = node {
self.record_call(comp);
}
}
}
pub fn handle_prepare_call_hierarchy(
documents: &HashMap<Uri, String>,
params: CallHierarchyPrepareParams,
) -> Option<Vec<CallHierarchyItem>> {
let uri = ¶ms.text_document_position_params.text_document.uri;
let position = params.text_document_position_params.position;
let text = documents.get(uri)?;
let path = uri.path().as_str();
let word = get_word_at_position(text, position)?;
let ast = parse_document(text, path)?;
for class in ast.class_list.values() {
if let Some(item) = find_call_hierarchy_item(class, &word, uri) {
return Some(vec![item]);
}
}
None
}
pub fn handle_incoming_calls(
documents: &HashMap<Uri, String>,
params: CallHierarchyIncomingCallsParams,
) -> Option<Vec<CallHierarchyIncomingCall>> {
let target_name = ¶ms.item.name;
let mut calls = Vec::new();
for (uri, text) in documents {
let path = uri.path().as_str();
if let Some(ast) = parse_document(text, path) {
for class in ast.class_list.values() {
collect_incoming_calls(class, target_name, uri, &mut calls);
}
}
}
if calls.is_empty() { None } else { Some(calls) }
}
pub fn handle_outgoing_calls(
documents: &HashMap<Uri, String>,
params: CallHierarchyOutgoingCallsParams,
) -> Option<Vec<CallHierarchyOutgoingCall>> {
let uri = ¶ms.item.uri;
let text = documents.get(uri)?;
let path = uri.path().as_str();
let source_name = ¶ms.item.name;
let ast = parse_document(text, path)?;
let mut calls = Vec::new();
for class in ast.class_list.values() {
if class.name.text == *source_name {
collect_outgoing_calls(class, uri, documents, &mut calls);
}
for nested in class.classes.values() {
if nested.name.text == *source_name {
collect_outgoing_calls(nested, uri, documents, &mut calls);
}
}
}
if calls.is_empty() { None } else { Some(calls) }
}
fn find_call_hierarchy_item(
class: &ClassDefinition,
name: &str,
uri: &Uri,
) -> Option<CallHierarchyItem> {
if class.name.text == name {
let kind = match class.class_type {
ClassType::Function => SymbolKind::FUNCTION,
ClassType::Model => SymbolKind::CLASS,
ClassType::Class => SymbolKind::CLASS,
ClassType::Record => SymbolKind::STRUCT,
ClassType::Connector => SymbolKind::INTERFACE,
ClassType::Package => SymbolKind::MODULE,
ClassType::Block => SymbolKind::CLASS,
ClassType::Type => SymbolKind::TYPE_PARAMETER,
ClassType::Operator => SymbolKind::OPERATOR,
};
let range = token_to_range(&class.name);
return Some(CallHierarchyItem {
name: class.name.text.clone(),
kind,
tags: None,
detail: Some(format!("{:?}", class.class_type)),
uri: uri.clone(),
range,
selection_range: range,
data: None,
});
}
for nested in class.classes.values() {
if let Some(item) = find_call_hierarchy_item(nested, name, uri) {
return Some(item);
}
}
None
}
fn collect_incoming_calls(
class: &ClassDefinition,
target_name: &str,
uri: &Uri,
calls: &mut Vec<CallHierarchyIncomingCall>,
) {
let mut finder = CallRangeFinder::new(target_name);
class.accept(&mut finder);
if !finder.ranges.is_empty() {
let kind = match class.class_type {
ClassType::Function => SymbolKind::FUNCTION,
_ => SymbolKind::CLASS,
};
let range = token_to_range(&class.name);
calls.push(CallHierarchyIncomingCall {
from: CallHierarchyItem {
name: class.name.text.clone(),
kind,
tags: None,
detail: Some(format!("{:?}", class.class_type)),
uri: uri.clone(),
range,
selection_range: range,
data: None,
},
from_ranges: finder.ranges,
});
}
for nested in class.classes.values() {
collect_incoming_calls(nested, target_name, uri, calls);
}
}
fn collect_outgoing_calls(
class: &ClassDefinition,
uri: &Uri,
documents: &HashMap<Uri, String>,
calls: &mut Vec<CallHierarchyOutgoingCall>,
) {
let mut collector = FunctionCallCollector::new();
class.accept(&mut collector);
for (func_name, from_ranges) in collector.calls {
if let Some(item) = find_function_definition(&func_name, uri, documents) {
calls.push(CallHierarchyOutgoingCall {
to: item,
from_ranges,
});
}
}
}
fn get_function_name(comp: &ComponentReference) -> String {
comp.parts
.iter()
.map(|p| p.ident.text.as_str())
.collect::<Vec<_>>()
.join(".")
}
fn find_function_definition(
name: &str,
_current_uri: &Uri,
documents: &HashMap<Uri, String>,
) -> Option<CallHierarchyItem> {
for (uri, text) in documents {
let path = uri.path().as_str();
if let Some(ast) = parse_document(text, path) {
for class in ast.class_list.values() {
if let Some(item) = find_call_hierarchy_item(class, name, uri) {
return Some(item);
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::get_function_name;
use crate::ir::ast::{ComponentRefPart, ComponentReference, Location, Token};
#[test]
fn test_get_function_name() {
let comp_ref = ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: "sin".to_string(),
location: Location::default(),
token_number: 0,
token_type: 0,
},
subs: None,
}],
};
assert_eq!(get_function_name(&comp_ref), "sin".to_string());
}
}