use std::collections::HashMap;
use tree_sitter::Tree;
use crate::types::Language;
pub fn identifier_node_types(language: Language) -> &'static [&'static str] {
match language {
Language::Python => &["identifier"],
Language::TypeScript | Language::JavaScript => &[
"identifier",
"property_identifier",
"shorthand_property_identifier",
"type_identifier",
],
Language::Go => &["identifier", "field_identifier", "type_identifier"],
Language::Rust => &["identifier", "field_identifier", "type_identifier"],
Language::Java => &["identifier", "type_identifier"],
Language::C | Language::Cpp => &["identifier", "field_identifier", "type_identifier"],
Language::Ruby => &["identifier", "constant"],
Language::Php => &["name"],
Language::Kotlin => &["identifier"],
Language::Swift => &["simple_identifier", "type_identifier"],
Language::CSharp => &["identifier"],
Language::Scala => &["identifier"],
Language::Elixir => &["identifier"],
Language::Lua | Language::Luau => &["identifier"],
Language::Ocaml => &["value_name", "type_constructor"],
}
}
pub fn count_identifiers_in_tree(
tree: &Tree,
source: &[u8],
language: Language,
) -> HashMap<String, usize> {
let id_types = identifier_node_types(language);
let mut counts: HashMap<String, usize> = HashMap::new();
let mut cursor = tree.walk();
let mut reached_root = false;
loop {
let node = cursor.node();
if id_types.contains(&node.kind()) {
let start = node.start_byte();
let end = node.end_byte();
if start <= end && end <= source.len() {
if let Ok(text) = std::str::from_utf8(&source[start..end]) {
if !text.is_empty() {
*counts.entry(text.to_string()).or_insert(0) += 1;
}
}
}
}
if cursor.goto_first_child() {
continue;
}
if cursor.goto_next_sibling() {
continue;
}
loop {
if !cursor.goto_parent() {
reached_root = true;
break;
}
if cursor.goto_next_sibling() {
break;
}
}
if reached_root {
break;
}
}
counts
}
pub fn is_rescued_by_refcount(name: &str, ref_counts: &HashMap<String, usize>) -> bool {
let bare_name = if name.contains('.') {
name.rsplit('.').next().unwrap_or(name)
} else if name.contains(':') {
name.rsplit(':').next().unwrap_or(name)
} else {
name
};
let min_refs = if bare_name.len() < 3 { 5 } else { 2 };
if let Some(&count) = ref_counts.get(bare_name) {
if count >= min_refs {
return true;
}
}
if bare_name != name {
if let Some(&count) = ref_counts.get(name) {
if count >= min_refs {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::parser::parse;
#[test]
fn test_count_identifiers_python() {
let source = "def hello():\n x = 1\n return hello()\n";
let tree = parse(source, Language::Python).unwrap();
let counts = count_identifiers_in_tree(&tree, source.as_bytes(), Language::Python);
assert!(
counts.get("hello").copied().unwrap_or(0) >= 2,
"Expected 'hello' count >= 2, got {:?}",
counts.get("hello")
);
assert!(
counts.get("x").copied().unwrap_or(0) >= 1,
"Expected 'x' count >= 1, got {:?}",
counts.get("x")
);
}
#[test]
fn test_is_rescued_short_name_low_count() {
let mut ref_counts = HashMap::new();
ref_counts.insert("fn".to_string(), 3);
assert!(
!is_rescued_by_refcount("fn", &ref_counts),
"Short name 'fn' (2 chars) with count=3 should not be rescued (needs >= 5)"
);
}
#[test]
fn test_is_rescued_short_name_high_count() {
let mut ref_counts = HashMap::new();
ref_counts.insert("cn".to_string(), 50);
assert!(
is_rescued_by_refcount("cn", &ref_counts),
"Short name 'cn' (2 chars) with count=50 should be rescued (>= 5)"
);
}
#[test]
fn test_is_rescued_referenced() {
let mut ref_counts = HashMap::new();
ref_counts.insert("handle_signal".to_string(), 3);
assert!(
is_rescued_by_refcount("handle_signal", &ref_counts),
"handle_signal with count=3 should be rescued"
);
}
#[test]
fn test_is_rescued_only_definition() {
let mut ref_counts = HashMap::new();
ref_counts.insert("unused_helper".to_string(), 1);
assert!(
!is_rescued_by_refcount("unused_helper", &ref_counts),
"unused_helper with count=1 should not be rescued"
);
}
#[test]
fn test_class_method_bare_name() {
let mut ref_counts = HashMap::new();
ref_counts.insert("process".to_string(), 5);
assert!(
is_rescued_by_refcount("MyClass.process", &ref_counts),
"MyClass.process should be rescued via bare name 'process' with count=5"
);
}
#[test]
fn test_identifier_node_types_all_languages() {
let languages = [
Language::Python,
Language::TypeScript,
Language::JavaScript,
Language::Go,
Language::Rust,
Language::Java,
Language::C,
Language::Cpp,
Language::Ruby,
Language::Kotlin,
Language::Swift,
Language::CSharp,
Language::Scala,
Language::Php,
Language::Lua,
Language::Luau,
Language::Elixir,
Language::Ocaml,
];
for lang in &languages {
let types = identifier_node_types(*lang);
assert!(
!types.is_empty(),
"identifier_node_types({:?}) returned empty slice",
lang
);
}
}
#[test]
fn test_jsx_element_name_refcount() {
let source = r#"
function Comp(props) {
return <div>hello</div>;
}
function App() {
return <Comp foo="bar" />;
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let counts = count_identifiers_in_tree(&tree, source.as_bytes(), Language::TypeScript);
let comp_count = counts.get("Comp").copied().unwrap_or(0);
assert!(
comp_count >= 2,
"Expected Comp refcount >= 2 (definition + JSX usage), got {}",
comp_count
);
}
#[test]
fn test_java_type_identifier_refcount() {
let source = r#"
class MyService {
public void run() {}
}
class App {
private MyService svc;
public void start(MyService service) {}
}
"#;
let tree = parse(source, Language::Java).unwrap();
let counts = count_identifiers_in_tree(&tree, source.as_bytes(), Language::Java);
let svc_count = counts.get("MyService").copied().unwrap_or(0);
assert!(
svc_count >= 3,
"Expected MyService refcount >= 3 (class def + field type + param type), got {}",
svc_count
);
}
#[test]
fn test_kotlin_type_identifier_refcount() {
let source = r#"
class MyHelper {
fun run() {}
}
fun main() {
val helper: MyHelper = MyHelper()
}
"#;
let tree = parse(source, Language::Kotlin).unwrap();
let counts = count_identifiers_in_tree(&tree, source.as_bytes(), Language::Kotlin);
let helper_count = counts.get("MyHelper").copied().unwrap_or(0);
assert!(
helper_count >= 2,
"Expected MyHelper refcount >= 2 (class def + type annotation or constructor), got {}",
helper_count
);
}
#[test]
fn test_swift_type_identifier_refcount() {
let source = r#"
class MyManager {
func run() {}
}
func setup() {
let mgr: MyManager = MyManager()
}
"#;
let tree = parse(source, Language::Swift).unwrap();
let counts = count_identifiers_in_tree(&tree, source.as_bytes(), Language::Swift);
let mgr_count = counts.get("MyManager").copied().unwrap_or(0);
assert!(
mgr_count >= 2,
"Expected MyManager refcount >= 2 (class def + type annotation or constructor), got {}",
mgr_count
);
}
}