use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FlatNode {
pub label: Arc<str>,
pub depth: usize,
pub scope: usize,
}
impl FlatNode {
pub fn new(label: impl Into<Arc<str>>, depth: usize, scope: usize) -> Self {
Self {
label: label.into(),
depth,
scope,
}
}
}
#[derive(Debug, Clone)]
pub struct FlatTree {
pub nodes: Vec<FlatNode>,
pub tree_id: u64,
pub metadata: Option<TreeMetadata>,
}
#[derive(Debug, Clone)]
pub struct TreeMetadata {
pub path: Option<String>,
pub language: Option<String>,
pub source: Option<String>,
}
impl FlatTree {
pub fn new(nodes: Vec<FlatNode>, tree_id: u64) -> Self {
Self {
nodes,
tree_id,
metadata: None,
}
}
pub fn with_metadata(nodes: Vec<FlatNode>, tree_id: u64, metadata: TreeMetadata) -> Self {
Self {
nodes,
tree_id,
metadata: Some(metadata),
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn from_ast_node(node: &super::super::ast::AstNode, tree_id: u64) -> Self {
let mut nodes = Vec::new();
Self::flatten_recursive(node, 0, &mut nodes);
Self::new(nodes, tree_id)
}
fn flatten_recursive(
node: &super::super::ast::AstNode,
depth: usize,
nodes: &mut Vec<FlatNode>,
) {
let scope = nodes.len();
nodes.push(FlatNode::new(node.kind.as_str(), depth, scope));
for child in &node.children {
Self::flatten_recursive(child, depth + 1, nodes);
}
}
pub fn label_positions(&self) -> HashMap<Arc<str>, Vec<usize>> {
let mut positions: HashMap<Arc<str>, Vec<usize>> = HashMap::new();
for (i, node) in self.nodes.iter().enumerate() {
positions
.entry(Arc::clone(&node.label))
.or_default()
.push(i);
}
positions
}
pub fn extract_subtree(&self, start: usize) -> Option<Vec<FlatNode>> {
if start >= self.nodes.len() {
return None;
}
let start_depth = self.nodes[start].depth;
let mut end = start + 1;
while end < self.nodes.len() && self.nodes[end].depth > start_depth {
end += 1;
}
Some(self.nodes[start..end].to_vec())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PatternNode {
pub label: Arc<str>,
pub depth: usize,
}
impl PatternNode {
pub fn new(label: impl Into<Arc<str>>, depth: usize) -> Self {
Self {
label: label.into(),
depth,
}
}
pub fn from_flat(node: &FlatNode, base_depth: usize) -> Self {
Self {
label: Arc::clone(&node.label),
depth: node.depth.saturating_sub(base_depth),
}
}
}
#[derive(Debug, Clone)]
pub struct SubtreePattern {
pub nodes: Vec<PatternNode>,
pub support: usize,
pub support_ratio: f64,
pub occurrences: Vec<u64>,
pub pattern_id: u64,
}
impl SubtreePattern {
pub fn new(
nodes: Vec<PatternNode>,
support: usize,
total_trees: usize,
occurrences: Vec<u64>,
pattern_id: u64,
) -> Self {
let support_ratio = if total_trees > 0 {
support as f64 / total_trees as f64
} else {
0.0
};
Self {
nodes,
support,
support_ratio,
occurrences,
pattern_id,
}
}
pub fn size(&self) -> usize {
self.nodes.len()
}
pub fn max_depth(&self) -> usize {
self.nodes.iter().map(|n| n.depth).max().unwrap_or(0)
}
pub fn contains(&self, other: &SubtreePattern) -> bool {
if self.nodes.len() < other.nodes.len() {
return false;
}
let mut other_idx = 0;
for self_node in &self.nodes {
if other_idx < other.nodes.len() && self_node == &other.nodes[other_idx] {
other_idx += 1;
}
}
other_idx == other.nodes.len()
}
pub fn to_string_repr(&self) -> String {
let mut parts = Vec::new();
for node in &self.nodes {
let indent = " ".repeat(node.depth);
parts.push(format!("{}{}", indent, node.label));
}
parts.join("\n")
}
pub fn root_label(&self) -> Option<&str> {
self.nodes.first().map(|n| n.label.as_ref())
}
}
impl PartialEq for SubtreePattern {
fn eq(&self, other: &Self) -> bool {
self.nodes == other.nodes
}
}
impl Eq for SubtreePattern {}
impl std::hash::Hash for SubtreePattern {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.nodes.hash(state);
}
}
pub mod encoding {
use super::*;
pub fn encode_pattern(nodes: &[PatternNode]) -> String {
let mut parts = Vec::with_capacity(nodes.len() * 2);
for node in nodes {
parts.push(format!("{}:{}", node.depth, node.label));
}
parts.join("|")
}
pub fn decode_pattern(encoded: &str) -> Vec<PatternNode> {
encoded
.split('|')
.filter_map(|part| {
let mut split = part.splitn(2, ':');
let depth = split.next()?.parse().ok()?;
let label = split.next()?;
Some(PatternNode::new(label, depth))
})
.collect()
}
pub fn pattern_hash(nodes: &[PatternNode]) -> u64 {
crate::util::hash::safe_hash(encode_pattern(nodes).as_bytes())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flat_node_creation() {
let node = FlatNode::new("function_definition", 0, 0);
assert_eq!(node.label.as_ref(), "function_definition");
assert_eq!(node.depth, 0);
}
#[test]
fn test_flat_tree_creation() {
let nodes = vec![
FlatNode::new("root", 0, 0),
FlatNode::new("child1", 1, 1),
FlatNode::new("child2", 1, 2),
];
let tree = FlatTree::new(nodes, 1);
assert_eq!(tree.len(), 3);
}
#[test]
fn test_extract_subtree() {
let nodes = vec![
FlatNode::new("root", 0, 0),
FlatNode::new("child1", 1, 1),
FlatNode::new("grandchild", 2, 2),
FlatNode::new("child2", 1, 3),
];
let tree = FlatTree::new(nodes, 1);
let subtree = tree.extract_subtree(1).unwrap();
assert_eq!(subtree.len(), 2);
assert_eq!(subtree[0].label.as_ref(), "child1");
assert_eq!(subtree[1].label.as_ref(), "grandchild");
}
#[test]
fn test_pattern_encoding() {
let nodes = vec![
PatternNode::new("A", 0),
PatternNode::new("B", 1),
PatternNode::new("C", 1),
];
let encoded = encoding::encode_pattern(&nodes);
assert_eq!(encoded, "0:A|1:B|1:C");
let decoded = encoding::decode_pattern(&encoded);
assert_eq!(decoded, nodes);
}
#[test]
fn test_subtree_pattern() {
let nodes = vec![
PatternNode::new("function", 0),
PatternNode::new("params", 1),
PatternNode::new("body", 1),
];
let pattern = SubtreePattern::new(nodes, 10, 100, vec![1, 2, 3], 42);
assert_eq!(pattern.size(), 3);
assert_eq!(pattern.support, 10);
assert!((pattern.support_ratio - 0.1).abs() < 1e-6);
assert_eq!(pattern.root_label(), Some("function"));
}
}