use crate::error::{Result, SpliceError};
use ropey::Rope;
use std::path::Path;
#[derive(Debug, Clone, PartialEq)]
pub struct PythonSymbol {
pub name: String,
pub kind: PythonSymbolKind,
pub byte_start: usize,
pub byte_end: usize,
pub line_start: usize,
pub line_end: usize,
pub col_start: usize,
pub col_end: usize,
pub parameters: Vec<String>,
pub module_path: String,
pub fully_qualified: String,
pub is_async: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PythonSymbolKind {
Function,
Class,
Method,
Variable,
}
impl PythonSymbolKind {
pub fn as_str(&self) -> &'static str {
match self {
PythonSymbolKind::Function => "function",
PythonSymbolKind::Class => "class",
PythonSymbolKind::Method => "method",
PythonSymbolKind::Variable => "variable",
}
}
}
pub fn extract_python_symbols(path: &Path, source: &[u8]) -> Result<Vec<PythonSymbol>> {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_python::language())
.map_err(|e| SpliceError::Parse {
file: path.to_path_buf(),
message: format!("Failed to set Python language: {:?}", e),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| SpliceError::Parse {
file: path.to_path_buf(),
message: "Parse failed - no tree returned".to_string(),
})?;
let rope = Rope::from_str(std::str::from_utf8(source)?);
let mut symbols = Vec::new();
extract_symbols(tree.root_node(), source, &rope, &mut symbols, "module");
Ok(symbols)
}
fn extract_symbols(
node: tree_sitter::Node,
source: &[u8],
rope: &Rope,
symbols: &mut Vec<PythonSymbol>,
module_path: &str,
) {
let kind = node.kind();
let symbol_kind = match kind {
"function_definition" => Some(PythonSymbolKind::Function),
"class_definition" => Some(PythonSymbolKind::Class),
_ => None,
};
let has_async = node.children(&mut node.walk()).any(|c| c.kind() == "async");
if let Some(kind) = symbol_kind {
if let Some(symbol) = extract_symbol(node, source, rope, kind, module_path, has_async) {
let name = symbol.name.clone();
symbols.push(symbol);
if kind == PythonSymbolKind::Class {
let new_module_path = format!("{}::{}", module_path, name);
if let Some(block) = node.child_by_field_name("body") {
extract_symbols(block, source, rope, symbols, &new_module_path);
}
return;
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if kind == "class_definition" && child.kind() == "block" {
continue;
}
extract_symbols(child, source, rope, symbols, module_path);
}
}
fn extract_symbol(
node: tree_sitter::Node,
source: &[u8],
rope: &Rope,
kind: PythonSymbolKind,
module_path: &str,
is_async: bool,
) -> Option<PythonSymbol> {
let name = node
.children(&mut node.walk())
.find(|c| c.kind() == "identifier")?
.utf8_text(source)
.ok()?
.to_string();
let byte_start = node.start_byte();
let byte_end = node.end_byte();
let start_char = rope.byte_to_char(byte_start);
let end_char = rope.byte_to_char(byte_end);
let line_start = rope.char_to_line(start_char);
let line_end = rope.char_to_line(end_char);
let line_start_byte = rope.line_to_byte(line_start);
let line_end_byte = rope.line_to_byte(line_end);
let col_start = byte_start - line_start_byte;
let col_end = byte_end - line_end_byte;
let parameters = extract_parameters(node, source);
let fully_qualified = format!("{}::{}", module_path, name);
Some(PythonSymbol {
name,
kind,
byte_start,
byte_end,
line_start: line_start + 1, line_end: line_end + 1, col_start,
col_end,
parameters,
module_path: module_path.to_string(),
fully_qualified,
is_async,
})
}
fn extract_parameters(node: tree_sitter::Node, source: &[u8]) -> Vec<String> {
let mut parameters = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "parameters" {
let mut param_cursor = child.walk();
for param in child.children(&mut param_cursor) {
match param.kind() {
"identifier" => {
if let Ok(name) = param.utf8_text(source) {
if name != "self" {
parameters.push(name.to_string());
}
}
}
"typed_parameter" | "default_parameter" | "typed_default_parameter" => {
let mut sub_cursor = param.walk();
for sub_child in param.children(&mut sub_cursor) {
if sub_child.kind() == "identifier" {
if let Ok(name) = sub_child.utf8_text(source) {
if name != "self" {
parameters.push(name.to_string());
}
}
}
}
}
"," | "(" | ")" => continue,
_ => {}
}
}
}
}
parameters
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_simple_function_basic() {
let source = b"def foo():\n pass\n";
let path = Path::new("test.py");
let result = extract_python_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].name, "foo");
assert_eq!(symbols[0].kind, PythonSymbolKind::Function);
}
}