use std::fs;
use std::io;
use std::path::Path;
use tree_sitter::Node;
use super::super::cross_file_types::{CallType, ImportDef};
pub fn normalize_path(path: &Path, root: Option<&Path>) -> String {
let path_str = if let Some(root) = root {
path.strip_prefix(root)
.unwrap_or(path)
.to_string_lossy()
.to_string()
} else {
path.to_string_lossy().to_string()
};
path_str.replace('\\', "/")
}
pub fn read_source_safely(path: &Path) -> Result<String, io::Error> {
let bytes = fs::read(path)?;
match String::from_utf8(bytes.clone()) {
Ok(s) => Ok(s),
Err(_) => {
Ok(bytes.iter().map(|&b| b as char).collect())
}
}
}
pub fn get_node_text<'a>(node: &Node, source: &'a [u8]) -> &'a str {
node.utf8_text(source).unwrap_or("")
}
pub fn get_node_text_owned(node: &Node, source: &[u8]) -> String {
get_node_text(node, source).to_string()
}
pub fn extract_call_name(node: &Node, source: &str) -> Option<String> {
let source_bytes = source.as_bytes();
for child_name in &["function", "callee", "receiver"] {
if let Some(func_node) = node.child_by_field_name(child_name) {
return Some(get_node_text_owned(&func_node, source_bytes));
}
}
for i in 0..node.named_child_count() {
if let Some(child) = node.named_child(i) {
if child.kind().contains("argument") {
continue;
}
return Some(get_node_text_owned(&child, source_bytes));
}
}
None
}
pub fn determine_call_type(
target: &str,
defined_funcs: &std::collections::HashSet<String>,
) -> CallType {
if target.contains('.') {
return CallType::Attr;
}
if target.contains("::") {
return CallType::Static;
}
if defined_funcs.contains(target) {
return CallType::Intra;
}
CallType::Direct
}
pub fn make_import(module: &str, names: &[&str], is_from: bool, level: u8) -> ImportDef {
ImportDef {
module: module.to_string(),
is_from,
names: names.iter().map(|s| s.to_string()).collect(),
alias: None,
aliases: None,
resolved_module: None,
is_default: false,
is_namespace: false,
is_mod: false,
level,
is_type_checking: false,
}
}
pub fn make_import_with_alias(module: &str, alias: &str, level: u8) -> ImportDef {
ImportDef {
module: module.to_string(),
is_from: false,
names: vec![],
alias: Some(alias.to_string()),
aliases: None,
resolved_module: None,
is_default: false,
is_namespace: false,
is_mod: false,
level,
is_type_checking: false,
}
}
pub struct TreeWalker<'a> {
cursor: tree_sitter::TreeCursor<'a>,
done: bool,
}
impl<'a> TreeWalker<'a> {
pub fn new(node: Node<'a>) -> Self {
Self {
cursor: node.walk(),
done: false,
}
}
}
impl<'a> Iterator for TreeWalker<'a> {
type Item = Node<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let node = self.cursor.node();
if self.cursor.goto_first_child() {
return Some(node);
}
if self.cursor.goto_next_sibling() {
return Some(node);
}
loop {
if !self.cursor.goto_parent() {
self.done = true;
return Some(node);
}
if self.cursor.goto_next_sibling() {
return Some(node);
}
}
}
}
pub fn walk_tree(node: Node<'_>) -> TreeWalker<'_> {
TreeWalker::new(node)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
mod normalize_path_tests {
use super::*;
#[test]
fn test_normalize_forward_slashes() {
assert_eq!(
normalize_path(Path::new("src/main.py"), None),
"src/main.py"
);
}
#[test]
fn test_normalize_backslashes() {
let path = Path::new("src\\sub\\main.py");
let normalized = normalize_path(path, None);
assert!(!normalized.contains('\\') || cfg!(not(windows)));
}
#[test]
fn test_normalize_with_root() {
let path = Path::new("/project/src/main.py");
let root = Path::new("/project");
assert_eq!(normalize_path(path, Some(root)), "src/main.py");
}
#[test]
fn test_normalize_root_not_prefix() {
let path = Path::new("/other/src/main.py");
let root = Path::new("/project");
assert_eq!(normalize_path(path, Some(root)), "/other/src/main.py");
}
}
mod read_source_tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_read_utf8_file() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "def hello(): pass").unwrap();
let content = read_source_safely(file.path()).unwrap();
assert!(content.contains("def hello()"));
}
#[test]
fn test_read_latin1_fallback() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"caf\xe9").unwrap();
let content = read_source_safely(file.path()).unwrap();
assert_eq!(content, "caf\u{00e9}");
}
#[test]
fn test_read_nonexistent_file() {
let result = read_source_safely(Path::new("/nonexistent/file.py"));
assert!(result.is_err());
}
}
mod call_type_tests {
use super::*;
#[test]
fn test_determine_call_type_intra() {
let mut defined = HashSet::new();
defined.insert("local_func".to_string());
assert_eq!(determine_call_type("local_func", &defined), CallType::Intra);
}
#[test]
fn test_determine_call_type_attr() {
let defined = HashSet::new();
assert_eq!(determine_call_type("obj.method", &defined), CallType::Attr);
assert_eq!(determine_call_type("a.b.c", &defined), CallType::Attr);
}
#[test]
fn test_determine_call_type_static() {
let defined = HashSet::new();
assert_eq!(
determine_call_type("Class::method", &defined),
CallType::Static
);
}
#[test]
fn test_determine_call_type_direct() {
let defined = HashSet::new();
assert_eq!(
determine_call_type("external_func", &defined),
CallType::Direct
);
}
}
mod import_helper_tests {
use super::*;
#[test]
fn test_make_import_simple() {
let imp = make_import("os", &[], false, 0);
assert_eq!(imp.module, "os");
assert!(!imp.is_from);
assert!(imp.names.is_empty());
assert_eq!(imp.level, 0);
}
#[test]
fn test_make_import_from() {
let imp = make_import("os.path", &["join", "exists"], true, 0);
assert_eq!(imp.module, "os.path");
assert!(imp.is_from);
assert_eq!(imp.names, vec!["join", "exists"]);
}
#[test]
fn test_make_import_relative() {
let imp = make_import("", &["utils"], true, 1);
assert!(imp.is_from);
assert_eq!(imp.level, 1);
assert!(imp.is_relative());
}
#[test]
fn test_make_import_with_alias() {
let imp = make_import_with_alias("numpy", "np", 0);
assert_eq!(imp.module, "numpy");
assert_eq!(imp.alias, Some("np".to_string()));
assert!(!imp.is_from);
}
}
mod tree_walker_tests {
use super::*;
#[test]
fn test_walk_tree_simple() {
use tree_sitter::Parser;
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_python::LANGUAGE.into())
.unwrap();
let source = "def hello(): pass";
let tree = parser.parse(source, None).unwrap();
let nodes: Vec<_> = walk_tree(tree.root_node()).collect();
assert!(!nodes.is_empty());
let kinds: Vec<_> = nodes.iter().map(|n| n.kind()).collect();
assert!(kinds.contains(&"function_definition"));
}
}
}