use crate::indexer::language::Language;
#[derive(Debug, Clone)]
pub struct SemanticUnit {
pub content: String,
pub start_line: usize,
pub end_line: usize,
pub unit_type: SemanticUnitType,
pub symbol_name: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum SemanticUnitType {
Function,
Class,
Method,
Struct,
Enum,
Interface,
Module,
Impl,
Trait,
Import,
Constant,
Variable,
Type,
Other,
}
impl SemanticUnitType {
pub fn from_node_kind(kind: &str, language: Language) -> Self {
match language {
Language::Rust => match kind {
"function_item" => SemanticUnitType::Function,
"impl_item" => SemanticUnitType::Impl,
"struct_item" => SemanticUnitType::Struct,
"enum_item" => SemanticUnitType::Enum,
"mod_item" => SemanticUnitType::Module,
"trait_item" => SemanticUnitType::Trait,
"type_item" => SemanticUnitType::Type,
"const_item" | "static_item" => SemanticUnitType::Constant,
"macro_definition" => SemanticUnitType::Function,
_ => SemanticUnitType::Other,
},
Language::TypeScript | Language::Tsx | Language::JavaScript | Language::Jsx => {
match kind {
"function_declaration" => SemanticUnitType::Function,
"class_declaration" => SemanticUnitType::Class,
"method_definition" => SemanticUnitType::Method,
"arrow_function" => SemanticUnitType::Function,
"interface_declaration" => SemanticUnitType::Interface,
"type_alias_declaration" => SemanticUnitType::Type,
"enum_declaration" => SemanticUnitType::Enum,
"export_statement" => SemanticUnitType::Other,
_ => SemanticUnitType::Other,
}
}
Language::Python => match kind {
"function_definition" => SemanticUnitType::Function,
"class_definition" => SemanticUnitType::Class,
"decorated_definition" => SemanticUnitType::Function,
_ => SemanticUnitType::Other,
},
Language::Go => match kind {
"function_declaration" => SemanticUnitType::Function,
"method_declaration" => SemanticUnitType::Method,
"type_declaration" => SemanticUnitType::Type,
"const_declaration" => SemanticUnitType::Constant,
"var_declaration" => SemanticUnitType::Variable,
_ => SemanticUnitType::Other,
},
Language::Java => match kind {
"class_declaration" => SemanticUnitType::Class,
"method_declaration" => SemanticUnitType::Method,
"interface_declaration" => SemanticUnitType::Interface,
"enum_declaration" => SemanticUnitType::Enum,
"constructor_declaration" => SemanticUnitType::Method,
_ => SemanticUnitType::Other,
},
Language::C => match kind {
"function_definition" => SemanticUnitType::Function,
"struct_specifier" => SemanticUnitType::Struct,
"enum_specifier" => SemanticUnitType::Enum,
"type_definition" => SemanticUnitType::Type,
_ => SemanticUnitType::Other,
},
Language::Cpp => match kind {
"function_definition" => SemanticUnitType::Function,
"class_specifier" => SemanticUnitType::Class,
"struct_specifier" => SemanticUnitType::Struct,
"enum_specifier" => SemanticUnitType::Enum,
"namespace_definition" => SemanticUnitType::Module,
"template_declaration" => SemanticUnitType::Other,
_ => SemanticUnitType::Other,
},
_ => SemanticUnitType::Other,
}
}
}
#[cfg(feature = "semantic-chunking")]
pub struct AstChunker {
max_chunk_size: usize,
overlap: usize,
}
#[cfg(feature = "semantic-chunking")]
impl AstChunker {
pub fn new(max_chunk_size: usize, overlap: usize) -> Self {
Self {
max_chunk_size,
overlap,
}
}
pub fn extract_semantic_units(
&self,
content: &str,
language: Language,
) -> Result<Vec<SemanticUnit>, AstChunkError> {
let ts_language = language
.tree_sitter_language()
.ok_or(AstChunkError::UnsupportedLanguage)?;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&ts_language)
.map_err(|e| AstChunkError::ParserError(e.to_string()))?;
let tree = parser
.parse(content, None)
.ok_or(AstChunkError::ParseFailed)?;
let semantic_types = language.semantic_node_types();
if semantic_types.is_empty() {
return Ok(vec![SemanticUnit {
content: content.to_string(),
start_line: 1,
end_line: content.lines().count().max(1),
unit_type: SemanticUnitType::Other,
symbol_name: None,
}]);
}
let mut units = Vec::new();
self.collect_semantic_units(
tree.root_node(),
content,
semantic_types,
language,
&mut units,
);
units.sort_by_key(|u| u.start_line);
let units_with_gaps = self.fill_gaps(content, units);
Ok(units_with_gaps)
}
fn collect_semantic_units(
&self,
node: tree_sitter::Node,
content: &str,
semantic_types: &[&str],
language: Language,
units: &mut Vec<SemanticUnit>,
) {
let kind = node.kind();
if semantic_types.contains(&kind) {
let start_byte = node.start_byte();
let end_byte = node.end_byte();
let node_content = &content[start_byte..end_byte];
let symbol_name = self.extract_symbol_name(node, content, language);
units.push(SemanticUnit {
content: node_content.to_string(),
start_line: node.start_position().row + 1,
end_line: node.end_position().row + 1,
unit_type: SemanticUnitType::from_node_kind(kind, language),
symbol_name,
});
} else {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_semantic_units(child, content, semantic_types, language, units);
}
}
}
fn extract_symbol_name(
&self,
node: tree_sitter::Node,
content: &str,
language: Language,
) -> Option<String> {
let name_field = match language {
Language::Rust => "name",
Language::Python => "name",
Language::Go => "name",
Language::Java => "name",
Language::TypeScript | Language::Tsx | Language::JavaScript | Language::Jsx => "name",
Language::C | Language::Cpp => "declarator",
_ => "name",
};
if let Some(name_node) = node.child_by_field_name(name_field) {
let name = &content[name_node.start_byte()..name_node.end_byte()];
if name_node.kind() == "function_declarator" || name_node.kind() == "declarator" {
if let Some(id) = name_node.child_by_field_name("declarator") {
return Some(content[id.start_byte()..id.end_byte()].to_string());
}
}
return Some(name.to_string());
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" || child.kind() == "type_identifier" {
return Some(content[child.start_byte()..child.end_byte()].to_string());
}
}
None
}
fn fill_gaps(&self, content: &str, units: Vec<SemanticUnit>) -> Vec<SemanticUnit> {
if units.is_empty() {
return vec![SemanticUnit {
content: content.to_string(),
start_line: 1,
end_line: content.lines().count().max(1),
unit_type: SemanticUnitType::Other,
symbol_name: None,
}];
}
let lines: Vec<&str> = content.lines().collect();
let total_lines = lines.len();
let mut result = Vec::new();
let mut current_line = 1;
for unit in units {
if unit.start_line > current_line {
let gap_content: String = lines[current_line - 1..unit.start_line - 1]
.to_vec()
.join("\n");
if !gap_content.trim().is_empty() {
result.push(SemanticUnit {
content: gap_content,
start_line: current_line,
end_line: unit.start_line - 1,
unit_type: SemanticUnitType::Other,
symbol_name: None,
});
}
}
result.push(unit.clone());
current_line = unit.end_line + 1;
}
if current_line <= total_lines {
let trailing_content: String = lines[current_line - 1..].to_vec().join("\n");
if !trailing_content.trim().is_empty() {
result.push(SemanticUnit {
content: trailing_content,
start_line: current_line,
end_line: total_lines,
unit_type: SemanticUnitType::Other,
symbol_name: None,
});
}
}
result
}
pub fn split_large_unit(&self, unit: &SemanticUnit) -> Vec<SemanticUnit> {
if unit.content.len() <= self.max_chunk_size {
return vec![unit.clone()];
}
let lines: Vec<&str> = unit.content.lines().collect();
if lines.is_empty() {
return vec![unit.clone()];
}
let signature = self.extract_signature(&unit.content);
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut chunk_start_line = unit.start_line;
let mut current_line_in_unit = 0;
let mut is_first_chunk = true;
for (i, line) in lines.iter().enumerate() {
let line_with_newline = if i < lines.len() - 1 {
format!("{}\n", line)
} else {
line.to_string()
};
let would_exceed = if is_first_chunk {
current_chunk.len() + line_with_newline.len() > self.max_chunk_size
} else {
signature.len() + 1 + current_chunk.len() + line_with_newline.len()
> self.max_chunk_size
};
if would_exceed && !current_chunk.is_empty() {
let chunk_content = if is_first_chunk {
current_chunk.clone()
} else {
format!("{}\n{}", signature, current_chunk)
};
chunks.push(SemanticUnit {
content: chunk_content,
start_line: chunk_start_line,
end_line: unit.start_line + current_line_in_unit - 1,
unit_type: unit.unit_type,
symbol_name: unit.symbol_name.clone(),
});
let overlap_start = self.find_overlap_start(¤t_chunk);
current_chunk = current_chunk[overlap_start..].to_string();
let overlap_lines = current_chunk.lines().count();
chunk_start_line = unit.start_line + current_line_in_unit - overlap_lines;
is_first_chunk = false;
}
current_chunk.push_str(&line_with_newline);
current_line_in_unit = i + 1;
}
if !current_chunk.trim().is_empty() {
let chunk_content = if is_first_chunk {
current_chunk
} else {
format!("{}\n{}", signature, current_chunk)
};
let adjusted_start = chunk_start_line;
chunks.push(SemanticUnit {
content: chunk_content,
start_line: adjusted_start,
end_line: unit.end_line,
unit_type: unit.unit_type,
symbol_name: unit.symbol_name.clone(),
});
}
chunks
}
fn extract_signature(&self, content: &str) -> String {
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return String::new();
}
let mut signature_lines = Vec::new();
for line in &lines {
signature_lines.push(*line);
let trimmed = line.trim();
if trimmed.ends_with('{') || trimmed.ends_with(':') || trimmed.ends_with("->") {
break;
}
if signature_lines.len() >= 3 {
break;
}
}
signature_lines.join("\n")
}
fn find_overlap_start(&self, content: &str) -> usize {
if content.len() <= self.overlap {
return 0;
}
let target_start = content.len() - self.overlap;
let mut start = target_start;
while start > 0 && !content.is_char_boundary(start) {
start -= 1;
}
if let Some(pos) = content[..start].rfind('\n') {
return pos + 1;
}
start
}
}
#[derive(Debug, Clone)]
pub enum AstChunkError {
UnsupportedLanguage,
ParserError(String),
ParseFailed,
}
impl std::fmt::Display for AstChunkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AstChunkError::UnsupportedLanguage => {
write!(f, "Language not supported for AST parsing")
}
AstChunkError::ParserError(msg) => write!(f, "Parser error: {}", msg),
AstChunkError::ParseFailed => write!(f, "Failed to parse content"),
}
}
}
impl std::error::Error for AstChunkError {}
#[cfg(all(test, feature = "semantic-chunking"))]
mod tests {
use super::*;
#[test]
fn test_extract_rust_functions() {
let content = r#"
fn foo() {
println!("hello");
}
fn bar(x: i32) -> i32 {
x + 1
}
"#;
let chunker = AstChunker::new(1000, 100);
let units = chunker
.extract_semantic_units(content, Language::Rust)
.unwrap();
assert!(units.len() >= 2);
let functions: Vec<_> = units
.iter()
.filter(|u| u.unit_type == SemanticUnitType::Function)
.collect();
assert_eq!(functions.len(), 2);
assert_eq!(functions[0].symbol_name.as_deref(), Some("foo"));
assert_eq!(functions[1].symbol_name.as_deref(), Some("bar"));
}
#[test]
fn test_extract_python_classes() {
let content = r#"
class MyClass:
def __init__(self):
pass
def method(self):
return 42
def standalone():
pass
"#;
let chunker = AstChunker::new(1000, 100);
let units = chunker
.extract_semantic_units(content, Language::Python)
.unwrap();
let classes: Vec<_> = units
.iter()
.filter(|u| u.unit_type == SemanticUnitType::Class)
.collect();
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].symbol_name.as_deref(), Some("MyClass"));
let functions: Vec<_> = units
.iter()
.filter(|u| u.unit_type == SemanticUnitType::Function)
.collect();
assert_eq!(functions.len(), 1);
assert_eq!(functions[0].symbol_name.as_deref(), Some("standalone"));
}
#[test]
fn test_split_large_unit() {
let chunker = AstChunker::new(100, 20);
let content = (0..50)
.map(|i| format!(" line{};", i))
.collect::<Vec<_>>()
.join("\n");
let large_unit = SemanticUnit {
content: format!("fn large_function() {{\n{}\n}}", content),
start_line: 1,
end_line: 52,
unit_type: SemanticUnitType::Function,
symbol_name: Some("large_function".to_string()),
};
let chunks = chunker.split_large_unit(&large_unit);
assert!(chunks.len() > 1, "Should split into multiple chunks");
assert!(chunks[0].content.starts_with("fn large_function()"));
for chunk in &chunks[1..] {
assert!(
chunk.content.contains("fn large_function()"),
"Subsequent chunks should include signature context"
);
}
}
#[test]
fn test_extract_signature() {
let chunker = AstChunker::new(100, 20);
let rust_fn = "fn process_data(input: &str) -> Result<Output, Error> {\n // body\n}";
let sig = chunker.extract_signature(rust_fn);
assert!(sig.contains("fn process_data"));
assert!(sig.contains("{"));
let python_fn = "def process_data(input):\n # body\n pass";
let sig = chunker.extract_signature(python_fn);
assert!(sig.contains("def process_data"));
assert!(sig.contains(":"));
}
}