use super::AstNode;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AstContext {
pub nodes: Vec<ContextNode>,
pub visible_lines: HashSet<usize>,
pub headers: Vec<HeaderInfo>,
pub summary: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextNode {
pub node: AstNode,
pub display_text: String,
pub indent: usize,
pub is_match_container: bool,
pub is_collapsed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeaderInfo {
pub start_line: usize,
pub end_line: usize,
pub node_index: usize,
pub text: String,
}
impl AstContext {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
visible_lines: HashSet::new(),
headers: Vec::new(),
summary: String::new(),
}
}
pub fn add_node(&mut self, node: &AstNode, all_nodes: &[AstNode]) {
let mut chain = vec![node.clone()];
let mut current = node.parent;
while let Some(parent_idx) = current {
if let Some(parent) = all_nodes.get(parent_idx) {
if self.is_significant_node(parent) {
chain.push(parent.clone());
}
current = parent.parent;
} else {
break;
}
}
chain.reverse();
let chain_len = chain.len();
for (indent, node) in chain.into_iter().enumerate() {
let is_match_container = indent == chain_len - 1;
let context_node = ContextNode {
display_text: self.format_node(&node),
indent,
is_match_container,
is_collapsed: false,
node,
};
if !self.nodes.iter().any(|cn| {
cn.node.range == context_node.node.range &&
cn.node.kind == context_node.node.kind
}) {
self.nodes.push(context_node);
}
}
}
pub fn build(&mut self, source: &str) {
let lines: Vec<&str> = source.lines().collect();
self.nodes.sort_by_key(|n| n.indent);
for (idx, context_node) in self.nodes.iter().enumerate() {
let node = &context_node.node;
let header_lines = self.calculate_header_lines(node, &lines);
for line in header_lines {
self.visible_lines.insert(line);
}
let header_text = self.extract_header_text(node, &lines);
self.headers.push(HeaderInfo {
start_line: node.start.line,
end_line: node.end.line.min(node.start.line + 3),
node_index: idx,
text: header_text,
});
}
self.summary = self.generate_summary();
}
fn is_significant_node(&self, node: &AstNode) -> bool {
matches!(
node.kind.as_str(),
"function_item" | "function_definition" | "function_declaration" |
"method_definition" | "method_declaration" |
"class_definition" | "class_declaration" |
"struct_item" | "struct_specifier" |
"enum_item" | "enum_declaration" |
"trait_item" | "trait_declaration" |
"impl_item" | "implementation" |
"mod_item" | "module" |
"interface_declaration" | "namespace_declaration"
) || node.metadata.is_scope || node.metadata.is_definition
}
fn format_node(&self, node: &AstNode) -> String {
let name = node.name.as_deref().unwrap_or("<anonymous>");
let visibility = node.metadata.visibility
.as_ref()
.map(|v| format!("{} ", v))
.unwrap_or_default();
match node.kind.as_str() {
"function_item" | "function_definition" | "function_declaration" => {
format!("{}fn {}", visibility, name)
},
"method_definition" | "method_declaration" => {
format!("{}method {}", visibility, name)
},
"class_definition" | "class_declaration" => {
format!("{}class {}", visibility, name)
},
"struct_item" | "struct_specifier" => {
format!("{}struct {}", visibility, name)
},
"enum_item" | "enum_declaration" => {
format!("{}enum {}", visibility, name)
},
"trait_item" | "trait_declaration" => {
format!("{}trait {}", visibility, name)
},
"impl_item" => {
format!("{}impl {}", visibility, name)
},
"interface_declaration" => {
format!("{}interface {}", visibility, name)
},
_ => format!("{} {}", node.kind, name),
}
}
fn calculate_header_lines(&self, node: &AstNode, _lines: &[&str]) -> Vec<usize> {
let start_line = node.start.line;
let end_line = node.end.line;
if start_line == end_line {
vec![start_line]
} else if end_line - start_line <= 5 {
(start_line..=end_line).collect()
} else {
let mut lines_to_show = vec![start_line];
for i in 1..=3 {
if start_line + i < end_line {
lines_to_show.push(start_line + i);
}
}
if end_line > start_line + 3 {
lines_to_show.push(end_line);
}
lines_to_show
}
}
fn extract_header_text(&self, node: &AstNode, lines: &[&str]) -> String {
if node.start.line < lines.len() {
let line = lines[node.start.line];
if line.len() > 100 {
format!("{}...", &line[..97])
} else {
line.to_string()
}
} else {
self.format_node(node)
}
}
fn generate_summary(&self) -> String {
if self.nodes.is_empty() {
return "No context available".to_string();
}
let mut parts = Vec::new();
for node in &self.nodes {
if node.indent < 3 { parts.push(self.format_node(&node.node));
}
}
if parts.is_empty() {
"Global scope".to_string()
} else {
parts.join(" > ")
}
}
pub fn merge(&mut self, other: &AstContext) {
let mut existing_keys = HashSet::new();
for node in &self.nodes {
existing_keys.insert((node.node.start, node.node.end, node.node.kind.clone()));
}
for context_node in &other.nodes {
let key = (context_node.node.start, context_node.node.end, context_node.node.kind.clone());
if !existing_keys.contains(&key) {
self.nodes.push(context_node.clone());
existing_keys.insert(key);
}
}
self.visible_lines.extend(&other.visible_lines);
self.headers.extend(other.headers.clone());
self.summary = self.generate_summary();
}
pub fn get_path(&self) -> String {
self.nodes
.iter()
.take(3) .map(|n| {
n.node.name.as_deref()
.unwrap_or(&n.node.kind)
})
.collect::<Vec<_>>()
.join("::")
}
pub fn get_deepest_node(&self) -> Option<&ContextNode> {
self.nodes.iter().max_by_key(|n| n.indent)
}
pub fn get_nodes_at_level(&self, level: usize) -> Vec<&ContextNode> {
self.nodes.iter()
.filter(|n| n.indent == level)
.collect()
}
pub fn contains_node_type(&self, node_type: &str) -> bool {
self.nodes.iter().any(|n| n.node.kind == node_type)
}
}
impl Default for AstContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{NodeMetadata, Position};
use std::ops::Range;
fn create_test_node(kind: &str, name: Option<&str>, start_line: usize, end_line: usize) -> AstNode {
AstNode {
kind: kind.to_string(),
name: name.map(|s| s.to_string()),
start: Position { line: start_line, column: 0, offset: 0 },
end: Position { line: end_line, column: 0, offset: 0 },
range: Range { start: 0, end: 100 },
depth: 0,
parent: None,
children: Vec::new(),
metadata: NodeMetadata::default(),
}
}
#[test]
fn test_ast_context_creation() {
let mut context = AstContext::new();
assert!(context.nodes.is_empty());
assert!(context.visible_lines.is_empty());
assert!(context.headers.is_empty());
}
#[test]
fn test_format_node() {
let context = AstContext::new();
let node = create_test_node("function_item", Some("test_fn"), 0, 5);
let formatted = context.format_node(&node);
assert_eq!(formatted, "fn test_fn");
}
#[test]
fn test_context_path() {
let mut context = AstContext::new();
let node1 = create_test_node("class_definition", Some("MyClass"), 0, 10);
let node2 = create_test_node("function_definition", Some("my_method"), 2, 8);
context.add_node(&node1, &[]);
context.add_node(&node2, &[]);
let path = context.get_path();
assert!(path.contains("MyClass"));
}
#[test]
fn test_significant_node() {
let context = AstContext::new();
let func_node = create_test_node("function_item", Some("test"), 0, 5);
assert!(context.is_significant_node(&func_node));
let expr_node = create_test_node("expression", None, 1, 1);
assert!(!context.is_significant_node(&expr_node));
}
}