use std::collections::HashSet;
use std::fs;
use std::path::{Path, PathBuf};
use anyhow::Result;
use clap::Args;
use tree_sitter::{Node, Parser};
use tree_sitter_python::LANGUAGE as PYTHON_LANGUAGE;
use super::error::{RemainingError, RemainingResult};
use super::types::{DefinitionResult, Location, SymbolInfo, SymbolKind};
use crate::output::OutputWriter;
use tldr_core::Language;
const MAX_IMPORT_DEPTH: usize = 10;
const PYTHON_BUILTINS: &[&str] = &[
"abs",
"aiter",
"all",
"any",
"anext",
"ascii",
"bin",
"bool",
"breakpoint",
"bytearray",
"bytes",
"callable",
"chr",
"classmethod",
"compile",
"complex",
"delattr",
"dict",
"dir",
"divmod",
"enumerate",
"eval",
"exec",
"filter",
"float",
"format",
"frozenset",
"getattr",
"globals",
"hasattr",
"hash",
"help",
"hex",
"id",
"input",
"int",
"isinstance",
"issubclass",
"iter",
"len",
"list",
"locals",
"map",
"max",
"memoryview",
"min",
"next",
"object",
"oct",
"open",
"ord",
"pow",
"print",
"property",
"range",
"repr",
"reversed",
"round",
"set",
"setattr",
"slice",
"sorted",
"staticmethod",
"str",
"sum",
"super",
"tuple",
"type",
"vars",
"zip",
"__import__",
];
pub struct DefinitionCycleDetector {
visited: HashSet<(PathBuf, String)>,
}
impl DefinitionCycleDetector {
pub fn new() -> Self {
Self {
visited: HashSet::new(),
}
}
pub fn visit(&mut self, file: &Path, symbol: &str) -> bool {
let key = (file.to_path_buf(), symbol.to_string());
!self.visited.insert(key)
}
}
impl Default for DefinitionCycleDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Args)]
pub struct DefinitionArgs {
pub file: Option<PathBuf>,
pub line: Option<u32>,
pub column: Option<u32>,
#[arg(long)]
pub symbol: Option<String>,
#[arg(long = "file", name = "target_file")]
pub target_file: Option<PathBuf>,
#[arg(long)]
pub project: Option<PathBuf>,
#[arg(long, short = 'O')]
pub output: Option<PathBuf>,
}
impl DefinitionArgs {
pub fn run(
&self,
format: crate::output::OutputFormat,
quiet: bool,
lang: Option<Language>,
) -> Result<()> {
let writer = OutputWriter::new(format, quiet);
let lang_hint = match lang {
Some(l) => format!("{:?}", l).to_lowercase(),
None => "auto".to_string(),
};
let result = if let Some(ref symbol_name) = self.symbol {
let file = self.target_file.as_ref().ok_or_else(|| {
RemainingError::invalid_argument("--file is required with --symbol")
})?;
writer.progress(&format!(
"Finding definition of '{}' in {}...",
symbol_name,
file.display()
));
find_definition_by_name(symbol_name, file, self.project.as_deref(), &lang_hint)?
} else {
let file = self
.file
.as_ref()
.ok_or_else(|| RemainingError::invalid_argument("file argument is required"))?;
let line = self
.line
.ok_or_else(|| RemainingError::invalid_argument("line argument is required"))?;
let column = self
.column
.ok_or_else(|| RemainingError::invalid_argument("column argument is required"))?;
writer.progress(&format!(
"Finding definition at {}:{}:{}...",
file.display(),
line,
column
));
match find_definition_by_position(
file,
line,
column,
self.project.as_deref(),
&lang_hint,
) {
Ok(result) => result,
Err(_) => {
DefinitionResult {
symbol: SymbolInfo {
name: format!("<unknown at {}:{}:{}>", file.display(), line, column),
kind: SymbolKind::Variable,
location: Some(Location::with_column(
file.display().to_string(),
line,
column,
)),
type_annotation: None,
docstring: None,
is_builtin: false,
module: None,
},
definition: None,
type_definition: None,
}
}
}
};
let use_text = format == crate::output::OutputFormat::Text;
if let Some(ref output_path) = self.output {
if use_text {
let text = format_definition_text(&result);
fs::write(output_path, text)?;
} else {
let json = serde_json::to_string_pretty(&result)?;
fs::write(output_path, json)?;
}
} else if use_text {
let text = format_definition_text(&result);
writer.write_text(&text)?;
} else {
writer.write(&result)?;
}
Ok(())
}
}
pub fn find_definition_by_name(
symbol: &str,
file: &Path,
project: Option<&Path>,
lang_hint: &str,
) -> RemainingResult<DefinitionResult> {
if !file.exists() {
return Err(RemainingError::file_not_found(file));
}
let language = detect_language(file, lang_hint)?;
if language != Language::Python {
return Err(RemainingError::unsupported_language(format!(
"{:?}",
language
)));
}
if is_builtin(symbol, &language) {
return Ok(DefinitionResult {
symbol: SymbolInfo {
name: symbol.to_string(),
kind: SymbolKind::Function,
location: None,
type_annotation: None,
docstring: None,
is_builtin: true,
module: Some("builtins".to_string()),
},
definition: None,
type_definition: None,
});
}
let source = fs::read_to_string(file).map_err(RemainingError::Io)?;
if let Some(result) = find_symbol_in_file(symbol, file, &source)? {
return Ok(result);
}
if let Some(project_root) = project {
let mut detector = DefinitionCycleDetector::new();
if let Some(result) = resolve_cross_file(symbol, file, project_root, &mut detector, 0)? {
return Ok(result);
}
}
Err(RemainingError::symbol_not_found(symbol, file))
}
pub fn find_definition_by_position(
file: &Path,
line: u32,
column: u32,
project: Option<&Path>,
lang_hint: &str,
) -> RemainingResult<DefinitionResult> {
if !file.exists() {
return Err(RemainingError::file_not_found(file));
}
let language = detect_language(file, lang_hint)?;
if language != Language::Python {
return Err(RemainingError::unsupported_language(format!(
"{:?}",
language
)));
}
let source = fs::read_to_string(file).map_err(RemainingError::Io)?;
let symbol_name = find_symbol_at_position(&source, line, column)?;
find_definition_by_name(&symbol_name, file, project, lang_hint)
}
fn find_symbol_at_position(source: &str, line: u32, column: u32) -> RemainingResult<String> {
let mut parser = Parser::new();
parser
.set_language(&PYTHON_LANGUAGE.into())
.map_err(|e| RemainingError::parse_error(PathBuf::from("<input>"), e.to_string()))?;
let tree = parser.parse(source, None).ok_or_else(|| {
RemainingError::parse_error(PathBuf::from("<input>"), "Failed to parse".to_string())
})?;
let target_line = line.saturating_sub(1) as usize;
let target_col = column as usize;
let root = tree.root_node();
let point = tree_sitter::Point::new(target_line, target_col);
let node = root
.descendant_for_point_range(point, point)
.ok_or_else(|| {
RemainingError::invalid_argument(format!(
"No symbol found at line {}, column {}",
line, column
))
})?;
let text = node.utf8_text(source.as_bytes()).map_err(|_| {
RemainingError::parse_error(PathBuf::from("<input>"), "Invalid UTF-8".to_string())
})?;
if node.kind() == "identifier" || node.kind() == "property_identifier" {
return Ok(text.to_string());
}
let mut current = Some(node);
while let Some(n) = current {
if n.kind() == "identifier" || n.kind() == "property_identifier" {
let text = n.utf8_text(source.as_bytes()).map_err(|_| {
RemainingError::parse_error(PathBuf::from("<input>"), "Invalid UTF-8".to_string())
})?;
return Ok(text.to_string());
}
current = n.parent();
}
Ok(text.to_string())
}
fn find_symbol_in_file(
symbol: &str,
file: &Path,
source: &str,
) -> RemainingResult<Option<DefinitionResult>> {
let mut parser = Parser::new();
parser
.set_language(&PYTHON_LANGUAGE.into())
.map_err(|e| RemainingError::parse_error(file.to_path_buf(), e.to_string()))?;
let tree = parser.parse(source, None).ok_or_else(|| {
RemainingError::parse_error(file.to_path_buf(), "Failed to parse".to_string())
})?;
let root = tree.root_node();
if let Some((kind, location)) = find_definition_recursive(root, source, symbol, file) {
return Ok(Some(DefinitionResult {
symbol: SymbolInfo {
name: symbol.to_string(),
kind,
location: Some(location.clone()),
type_annotation: None,
docstring: None,
is_builtin: false,
module: None,
},
definition: Some(location),
type_definition: None,
}));
}
Ok(None)
}
fn find_definition_recursive(
node: Node,
source: &str,
target_name: &str,
file: &Path,
) -> Option<(SymbolKind, Location)> {
match node.kind() {
"function_definition" => {
if let Some(name_node) = node.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
if name == target_name {
let in_class = is_inside_class(node);
let kind = if in_class {
SymbolKind::Method
} else {
SymbolKind::Function
};
let location = Location::with_column(
file.display().to_string(),
name_node.start_position().row as u32 + 1,
name_node.start_position().column as u32,
);
return Some((kind, location));
}
}
}
}
"class_definition" => {
if let Some(name_node) = node.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source.as_bytes()) {
if name == target_name {
let location = Location::with_column(
file.display().to_string(),
name_node.start_position().row as u32 + 1,
name_node.start_position().column as u32,
);
return Some((SymbolKind::Class, location));
}
}
}
}
"assignment" => {
if let Some(left) = node.child_by_field_name("left") {
if left.kind() == "identifier" {
if let Ok(name) = left.utf8_text(source.as_bytes()) {
if name == target_name {
let location = Location::with_column(
file.display().to_string(),
left.start_position().row as u32 + 1,
left.start_position().column as u32,
);
return Some((SymbolKind::Variable, location));
}
}
}
}
}
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(result) = find_definition_recursive(child, source, target_name, file) {
return Some(result);
}
}
}
None
}
fn is_inside_class(node: Node) -> bool {
let mut current = node.parent();
while let Some(n) = current {
if n.kind() == "class_definition" {
return true;
}
current = n.parent();
}
false
}
fn resolve_cross_file(
symbol: &str,
current_file: &Path,
project_root: &Path,
detector: &mut DefinitionCycleDetector,
depth: usize,
) -> RemainingResult<Option<DefinitionResult>> {
if depth >= MAX_IMPORT_DEPTH {
return Ok(None);
}
if detector.visit(current_file, symbol) {
return Ok(None);
}
let source = fs::read_to_string(current_file).map_err(RemainingError::Io)?;
let imports = extract_imports(&source);
for (module_path, imported_names) in imports {
let is_imported = imported_names.is_empty() || imported_names.contains(&symbol.to_string());
if is_imported {
if let Some(resolved_path) =
resolve_module_path(&module_path, current_file, project_root)
{
if resolved_path.exists() {
let module_source =
fs::read_to_string(&resolved_path).map_err(RemainingError::Io)?;
if let Some(result) =
find_symbol_in_file(symbol, &resolved_path, &module_source)?
{
return Ok(Some(result));
}
if let Some(result) = resolve_cross_file(
symbol,
&resolved_path,
project_root,
detector,
depth + 1,
)? {
return Ok(Some(result));
}
}
}
}
}
Ok(None)
}
fn extract_imports(source: &str) -> Vec<(String, Vec<String>)> {
let mut imports = Vec::new();
for line in source.lines() {
let line = line.trim();
if line.starts_with("from ") {
if let Some(import_idx) = line.find(" import ") {
let module = &line[5..import_idx];
let names_str = &line[import_idx + 8..];
let names: Vec<String> = names_str
.split(',')
.map(|s| {
s.trim()
.split(" as ")
.next()
.unwrap_or("")
.trim()
.to_string()
})
.filter(|s| !s.is_empty() && s != "*")
.collect();
imports.push((module.trim().to_string(), names));
}
} else if let Some(module) = line.strip_prefix("import ") {
let module = module.split(" as ").next().unwrap_or(module).trim();
imports.push((module.to_string(), Vec::new()));
}
}
imports
}
fn resolve_module_path(module: &str, current_file: &Path, project_root: &Path) -> Option<PathBuf> {
let current_dir = current_file.parent()?;
let dot_count = module.chars().take_while(|&c| c == '.').count();
if dot_count > 0 {
let remainder = &module[dot_count..];
let mut base = current_dir.to_path_buf();
for _ in 1..dot_count {
base = base.parent()?.to_path_buf();
}
if remainder.is_empty() {
let pkg_candidate = base.join("__init__.py");
if pkg_candidate.exists() {
return Some(pkg_candidate);
}
return None;
}
let rel_path = remainder.replace('.', "/");
let candidate = base.join(&rel_path).with_extension("py");
if candidate.exists() {
return Some(candidate);
}
let pkg_candidate = base.join(&rel_path).join("__init__.py");
if pkg_candidate.exists() {
return Some(pkg_candidate);
}
return None;
}
let rel_path = module.replace('.', "/");
let candidate = current_dir.join(&rel_path).with_extension("py");
if candidate.exists() {
return Some(candidate);
}
let pkg_candidate = current_dir.join(&rel_path).join("__init__.py");
if pkg_candidate.exists() {
return Some(pkg_candidate);
}
let candidate = project_root.join(&rel_path).with_extension("py");
if candidate.exists() {
return Some(candidate);
}
let pkg_candidate = project_root.join(&rel_path).join("__init__.py");
if pkg_candidate.exists() {
return Some(pkg_candidate);
}
None
}
pub fn is_builtin(name: &str, language: &Language) -> bool {
match language {
Language::Python => PYTHON_BUILTINS.contains(&name),
_ => false,
}
}
fn detect_language(file: &Path, hint: &str) -> RemainingResult<Language> {
if hint != "auto" {
return match hint.to_lowercase().as_str() {
"python" | "py" => Ok(Language::Python),
"typescript" | "ts" => Ok(Language::TypeScript),
"javascript" | "js" => Ok(Language::JavaScript),
"rust" | "rs" => Ok(Language::Rust),
"go" | "golang" => Ok(Language::Go),
_ => Err(RemainingError::unsupported_language(hint)),
};
}
let ext = file.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext {
"py" => Ok(Language::Python),
"ts" | "tsx" => Ok(Language::TypeScript),
"js" | "jsx" => Ok(Language::JavaScript),
"rs" => Ok(Language::Rust),
"go" => Ok(Language::Go),
_ => Err(RemainingError::unsupported_language(ext)),
}
}
fn format_definition_text(result: &DefinitionResult) -> String {
let mut output = String::new();
output.push_str("=== Definition Result ===\n\n");
output.push_str(&format!("Symbol: {}\n", result.symbol.name));
output.push_str(&format!("Kind: {:?}\n", result.symbol.kind));
if result.symbol.is_builtin {
output.push_str("Type: Built-in\n");
if let Some(ref module) = result.symbol.module {
output.push_str(&format!("Module: {}\n", module));
}
} else if let Some(ref location) = result.definition {
output.push_str("\nDefinition Location:\n");
output.push_str(&format!(" File: {}\n", location.file));
output.push_str(&format!(" Line: {}\n", location.line));
if location.column > 0 {
output.push_str(&format!(" Column: {}\n", location.column));
}
} else {
output.push_str("\nDefinition: Not found\n");
}
if let Some(ref type_def) = result.type_definition {
output.push_str("\nType Definition:\n");
output.push_str(&format!(" File: {}\n", type_def.file));
output.push_str(&format!(" Line: {}\n", type_def.line));
}
if let Some(ref docstring) = result.symbol.docstring {
output.push_str(&format!("\nDocstring:\n {}\n", docstring));
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_builtin_python() {
assert!(is_builtin("len", &Language::Python));
assert!(is_builtin("print", &Language::Python));
assert!(is_builtin("range", &Language::Python));
assert!(!is_builtin("my_func", &Language::Python));
}
#[test]
fn test_cycle_detector() {
let mut detector = DefinitionCycleDetector::new();
assert!(!detector.visit(Path::new("file.py"), "symbol"));
assert!(detector.visit(Path::new("file.py"), "symbol"));
assert!(!detector.visit(Path::new("other.py"), "symbol"));
}
#[test]
fn test_detect_language() {
assert_eq!(
detect_language(Path::new("test.py"), "auto").unwrap(),
Language::Python
);
}
#[test]
fn test_detect_language_with_hint() {
assert_eq!(
detect_language(Path::new("test.txt"), "python").unwrap(),
Language::Python
);
}
#[test]
fn test_extract_imports() {
let source = r#"
from os import path, getcwd
from sys import argv
import json
import re as regex
"#;
let imports = extract_imports(source);
assert_eq!(imports.len(), 4);
assert_eq!(imports[0].0, "os");
assert!(imports[0].1.contains(&"path".to_string()));
assert!(imports[0].1.contains(&"getcwd".to_string()));
assert_eq!(imports[1].0, "sys");
assert!(imports[1].1.contains(&"argv".to_string()));
assert_eq!(imports[2].0, "json");
assert_eq!(imports[3].0, "re");
}
#[test]
fn test_extract_imports_relative() {
let source = r#"
from .utils import echo, make_str
from .exceptions import Abort
from ._utils import FLAG_NEEDS_VALUE
from . import types
"#;
let imports = extract_imports(source);
assert_eq!(imports.len(), 4);
assert_eq!(imports[0].0, ".utils");
assert!(imports[0].1.contains(&"echo".to_string()));
assert!(imports[0].1.contains(&"make_str".to_string()));
assert_eq!(imports[1].0, ".exceptions");
assert!(imports[1].1.contains(&"Abort".to_string()));
assert_eq!(imports[2].0, "._utils");
assert!(imports[2].1.contains(&"FLAG_NEEDS_VALUE".to_string()));
assert_eq!(imports[3].0, ".");
assert!(imports[3].1.contains(&"types".to_string()));
}
#[test]
fn test_resolve_module_path_relative_import() {
let dir = tempfile::tempdir().unwrap();
let pkg = dir.path().join("mypkg");
fs::create_dir_all(&pkg).unwrap();
fs::write(pkg.join("__init__.py"), "").unwrap();
fs::write(pkg.join("core.py"), "from .utils import helper\n").unwrap();
fs::write(pkg.join("utils.py"), "def helper(): pass\n").unwrap();
let current_file = pkg.join("core.py");
let project_root = dir.path();
let resolved = resolve_module_path(".utils", ¤t_file, project_root);
assert!(
resolved.is_some(),
"resolve_module_path should find .utils relative to core.py"
);
assert_eq!(
resolved.unwrap(),
pkg.join("utils.py"),
"Should resolve to sibling utils.py"
);
}
#[test]
fn test_resolve_module_path_relative_import_subpackage() {
let dir = tempfile::tempdir().unwrap();
let pkg = dir.path().join("mypkg");
let sub = pkg.join("sub");
fs::create_dir_all(&sub).unwrap();
fs::write(pkg.join("__init__.py"), "").unwrap();
fs::write(sub.join("__init__.py"), "").unwrap();
fs::write(pkg.join("core.py"), "").unwrap();
fs::write(sub.join("helpers.py"), "def helper(): pass\n").unwrap();
let current_file = pkg.join("core.py");
let project_root = dir.path();
let resolved = resolve_module_path(".sub.helpers", ¤t_file, project_root);
assert!(
resolved.is_some(),
"resolve_module_path should find .sub.helpers relative to core.py"
);
assert_eq!(
resolved.unwrap(),
sub.join("helpers.py"),
"Should resolve to sub/helpers.py"
);
}
#[test]
fn test_cross_file_definition_via_relative_import() {
let dir = tempfile::tempdir().unwrap();
let pkg = dir.path().join("mypkg");
fs::create_dir_all(&pkg).unwrap();
fs::write(pkg.join("__init__.py"), "").unwrap();
fs::write(
pkg.join("core.py"),
"from .utils import echo\n\ndef main():\n echo('hello')\n",
)
.unwrap();
fs::write(pkg.join("utils.py"), "def echo(msg):\n print(msg)\n").unwrap();
let result =
find_definition_by_name("echo", &pkg.join("core.py"), Some(dir.path()), "python");
assert!(
result.is_ok(),
"Should find echo via cross-file resolution: {:?}",
result.err()
);
let result = result.unwrap();
assert_eq!(result.symbol.name, "echo");
assert_eq!(result.symbol.kind, SymbolKind::Function);
assert!(
result.definition.is_some(),
"Should have a definition location"
);
let def_loc = result.definition.unwrap();
assert!(
def_loc.file.contains("utils.py"),
"Definition should be in utils.py, got: {}",
def_loc.file
);
assert_eq!(def_loc.line, 1, "echo is defined on line 1 of utils.py");
}
}