use crate::error::{KernelError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tensorlogic_ir::TLExpr;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TreeNode {
pub label: String,
pub children: Vec<TreeNode>,
}
impl TreeNode {
pub fn new(label: impl Into<String>) -> Self {
Self {
label: label.into(),
children: Vec::new(),
}
}
pub fn with_children(label: impl Into<String>, children: Vec<TreeNode>) -> Self {
Self {
label: label.into(),
children,
}
}
pub fn height(&self) -> usize {
if self.children.is_empty() {
1
} else {
1 + self.children.iter().map(|c| c.height()).max().unwrap_or(0)
}
}
pub fn num_nodes(&self) -> usize {
1 + self.children.iter().map(|c| c.num_nodes()).sum::<usize>()
}
pub fn is_leaf(&self) -> bool {
self.children.is_empty()
}
pub fn from_tlexpr(expr: &TLExpr) -> Self {
match expr {
TLExpr::Pred { name, .. } => TreeNode::new(format!("Pred({})", name)),
TLExpr::And(left, right) => TreeNode::with_children(
"And",
vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
),
TLExpr::Or(left, right) => TreeNode::with_children(
"Or",
vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
),
TLExpr::Not(expr) => TreeNode::with_children("Not", vec![TreeNode::from_tlexpr(expr)]),
TLExpr::Imply(left, right) => TreeNode::with_children(
"Imply",
vec![TreeNode::from_tlexpr(left), TreeNode::from_tlexpr(right)],
),
TLExpr::Exists { var, domain, body } => TreeNode::with_children(
format!("Exists({}, {})", var, domain),
vec![TreeNode::from_tlexpr(body)],
),
TLExpr::ForAll { var, domain, body } => TreeNode::with_children(
format!("ForAll({}, {})", var, domain),
vec![TreeNode::from_tlexpr(body)],
),
_ => TreeNode::new("Expr"),
}
}
fn get_all_subtrees(&self) -> Vec<TreeNode> {
let mut subtrees = vec![self.clone()];
for child in &self.children {
subtrees.extend(child.get_all_subtrees());
}
subtrees
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubtreeKernelConfig {
pub normalize: bool,
}
impl SubtreeKernelConfig {
pub fn new() -> Self {
Self { normalize: true }
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
}
impl Default for SubtreeKernelConfig {
fn default() -> Self {
Self::new()
}
}
pub struct SubtreeKernel {
config: SubtreeKernelConfig,
}
impl SubtreeKernel {
pub fn new(config: SubtreeKernelConfig) -> Self {
Self { config }
}
pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
let subtrees1 = tree1.get_all_subtrees();
let subtrees2 = tree2.get_all_subtrees();
let mut count = 0;
for st1 in &subtrees1 {
for st2 in &subtrees2 {
if st1 == st2 {
count += 1;
}
}
}
let similarity = count as f64;
if self.config.normalize {
let self_sim1 = subtrees1.len() as f64;
let self_sim2 = subtrees2.len() as f64;
let norm = (self_sim1 * self_sim2).sqrt();
if norm > 0.0 {
Ok(similarity / norm)
} else {
Ok(0.0)
}
} else {
Ok(similarity)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubsetTreeKernelConfig {
pub normalize: bool,
pub decay: f64,
}
impl SubsetTreeKernelConfig {
pub fn new() -> Result<Self> {
Ok(Self {
normalize: true,
decay: 1.0,
})
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn with_decay(mut self, decay: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&decay) {
return Err(KernelError::InvalidParameter {
parameter: "decay".to_string(),
value: decay.to_string(),
reason: "must be between 0.0 and 1.0".to_string(),
});
}
self.decay = decay;
Ok(self)
}
}
impl Default for SubsetTreeKernelConfig {
fn default() -> Self {
Self::new().expect("default SubsetTreeKernelConfig parameters are valid")
}
}
pub struct SubsetTreeKernel {
config: SubsetTreeKernelConfig,
}
impl SubsetTreeKernel {
pub fn new(config: SubsetTreeKernelConfig) -> Self {
Self { config }
}
pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
let similarity = self.compute_recursive(tree1, tree2, &mut HashMap::new());
if self.config.normalize {
let self_sim1 = self.compute_recursive(tree1, tree1, &mut HashMap::new());
let self_sim2 = self.compute_recursive(tree2, tree2, &mut HashMap::new());
let norm = (self_sim1 * self_sim2).sqrt();
if norm > 0.0 {
Ok(similarity / norm)
} else {
Ok(0.0)
}
} else {
Ok(similarity)
}
}
fn compute_recursive(
&self,
n1: &TreeNode,
n2: &TreeNode,
cache: &mut HashMap<(usize, usize), f64>,
) -> f64 {
let key = (n1.num_nodes(), n2.num_nodes());
if let Some(&cached) = cache.get(&key) {
return cached;
}
let mut result = 0.0;
if n1.label == n2.label {
result += self.config.decay;
if !n1.children.is_empty() && !n2.children.is_empty() {
for c1 in &n1.children {
for c2 in &n2.children {
result += self.config.decay * self.compute_recursive(c1, c2, cache);
}
}
}
}
cache.insert(key, result);
result
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartialTreeKernelConfig {
pub normalize: bool,
pub decay: f64,
pub threshold: f64,
}
impl PartialTreeKernelConfig {
pub fn new() -> Result<Self> {
Ok(Self {
normalize: true,
decay: 0.8,
threshold: 0.0,
})
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn with_decay(mut self, decay: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&decay) {
return Err(KernelError::InvalidParameter {
parameter: "decay".to_string(),
value: decay.to_string(),
reason: "must be between 0.0 and 1.0".to_string(),
});
}
self.decay = decay;
Ok(self)
}
pub fn with_threshold(mut self, threshold: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&threshold) {
return Err(KernelError::InvalidParameter {
parameter: "threshold".to_string(),
value: threshold.to_string(),
reason: "must be between 0.0 and 1.0".to_string(),
});
}
self.threshold = threshold;
Ok(self)
}
}
impl Default for PartialTreeKernelConfig {
fn default() -> Self {
Self::new().expect("default PartialTreeKernelConfig parameters are valid")
}
}
pub struct PartialTreeKernel {
config: PartialTreeKernelConfig,
}
impl PartialTreeKernel {
pub fn new(config: PartialTreeKernelConfig) -> Self {
Self { config }
}
pub fn compute_trees(&self, tree1: &TreeNode, tree2: &TreeNode) -> Result<f64> {
let similarity = self.compute_partial_match(tree1, tree2, 1.0);
if similarity < self.config.threshold {
return Ok(0.0);
}
if self.config.normalize {
let self_sim1 = self.compute_partial_match(tree1, tree1, 1.0);
let self_sim2 = self.compute_partial_match(tree2, tree2, 1.0);
let norm = (self_sim1 * self_sim2).sqrt();
if norm > 0.0 {
Ok(similarity / norm)
} else {
Ok(0.0)
}
} else {
Ok(similarity)
}
}
fn compute_partial_match(&self, n1: &TreeNode, n2: &TreeNode, weight: f64) -> f64 {
let mut score = 0.0;
if n1.label == n2.label {
score += weight;
let min_children = n1.children.len().min(n2.children.len());
for i in 0..min_children {
score += self.compute_partial_match(
&n1.children[i],
&n2.children[i],
weight * self.config.decay,
);
}
} else {
let label_sim = self.label_similarity(&n1.label, &n2.label);
score += weight * label_sim * 0.5;
let min_children = n1.children.len().min(n2.children.len());
for i in 0..min_children {
score += self.compute_partial_match(
&n1.children[i],
&n2.children[i],
weight * self.config.decay * 0.5,
);
}
}
score
}
fn label_similarity(&self, label1: &str, label2: &str) -> f64 {
if label1 == label2 {
1.0
} else {
let chars1: std::collections::HashSet<char> = label1.chars().collect();
let chars2: std::collections::HashSet<char> = label2.chars().collect();
let intersection = chars1.intersection(&chars2).count();
let union = chars1.union(&chars2).count();
if union > 0 {
intersection as f64 / union as f64
} else {
0.0
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree_node_creation() {
let node = TreeNode::new("root");
assert_eq!(node.label, "root");
assert!(node.children.is_empty());
assert!(node.is_leaf());
}
#[test]
fn test_tree_node_with_children() {
let child1 = TreeNode::new("child1");
let child2 = TreeNode::new("child2");
let parent = TreeNode::with_children("parent", vec![child1, child2]);
assert_eq!(parent.label, "parent");
assert_eq!(parent.children.len(), 2);
assert!(!parent.is_leaf());
}
#[test]
fn test_tree_height() {
let leaf = TreeNode::new("leaf");
assert_eq!(leaf.height(), 1);
let tree = TreeNode::with_children(
"root",
vec![
TreeNode::new("child1"),
TreeNode::with_children("child2", vec![TreeNode::new("grandchild")]),
],
);
assert_eq!(tree.height(), 3);
}
#[test]
fn test_tree_num_nodes() {
let tree = TreeNode::with_children(
"root",
vec![
TreeNode::new("child1"),
TreeNode::with_children("child2", vec![TreeNode::new("grandchild")]),
],
);
assert_eq!(tree.num_nodes(), 4);
}
#[test]
fn test_tree_from_tlexpr() {
let expr = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
let tree = TreeNode::from_tlexpr(&expr);
assert_eq!(tree.label, "And");
assert_eq!(tree.children.len(), 2);
}
#[test]
fn test_subtree_kernel_identical() {
let tree1 = TreeNode::with_children(
"root",
vec![TreeNode::new("child1"), TreeNode::new("child2")],
);
let tree2 = tree1.clone();
let config = SubtreeKernelConfig::new().with_normalize(false);
let kernel = SubtreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim > 0.0);
}
#[test]
fn test_subtree_kernel_different() {
let tree1 = TreeNode::with_children("root", vec![TreeNode::new("child1")]);
let tree2 = TreeNode::with_children("root", vec![TreeNode::new("child2")]);
let config = SubtreeKernelConfig::new().with_normalize(false);
let kernel = SubtreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim >= 0.0); }
#[test]
fn test_subtree_kernel_partial_match() {
let tree1 = TreeNode::with_children(
"root",
vec![TreeNode::new("child1"), TreeNode::new("child2")],
);
let tree2 = TreeNode::with_children(
"root",
vec![TreeNode::new("child1"), TreeNode::new("child3")],
);
let config = SubtreeKernelConfig::new().with_normalize(false);
let kernel = SubtreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim > 0.0);
}
#[test]
fn test_subtree_kernel_normalized() {
let tree1 = TreeNode::with_children(
"root",
vec![TreeNode::new("child1"), TreeNode::new("child2")],
);
let tree2 = tree1.clone();
let config = SubtreeKernelConfig::new().with_normalize(true);
let kernel = SubtreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!((sim - 1.0).abs() < 1e-6); }
#[test]
fn test_subset_tree_kernel() {
let tree1 = TreeNode::with_children(
"root",
vec![TreeNode::new("child1"), TreeNode::new("child2")],
);
let tree2 = TreeNode::with_children("root", vec![TreeNode::new("child1")]);
let config = SubsetTreeKernelConfig::new().expect("unwrap");
let kernel = SubsetTreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim > 0.0);
}
#[test]
fn test_subset_tree_kernel_decay() {
let tree1 = TreeNode::with_children("root", vec![TreeNode::new("child")]);
let tree2 = tree1.clone();
let config1 = SubsetTreeKernelConfig::new()
.expect("unwrap")
.with_decay(1.0)
.expect("unwrap")
.with_normalize(false);
let kernel1 = SubsetTreeKernel::new(config1);
let config2 = SubsetTreeKernelConfig::new()
.expect("unwrap")
.with_decay(0.5)
.expect("unwrap")
.with_normalize(false);
let kernel2 = SubsetTreeKernel::new(config2);
let sim1 = kernel1.compute_trees(&tree1, &tree2).expect("unwrap");
let sim2 = kernel2.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim2 < sim1);
}
#[test]
fn test_partial_tree_kernel() {
let tree1 = TreeNode::with_children(
"root",
vec![TreeNode::new("child1"), TreeNode::new("child2")],
);
let tree2 = TreeNode::with_children(
"root",
vec![TreeNode::new("child1"), TreeNode::new("child3")],
);
let config = PartialTreeKernelConfig::new().expect("unwrap");
let kernel = PartialTreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim > 0.0); }
#[test]
fn test_partial_tree_kernel_threshold() {
let tree1 = TreeNode::with_children("root1", vec![TreeNode::new("child")]);
let tree2 = TreeNode::with_children("root2", vec![TreeNode::new("child")]);
let config = PartialTreeKernelConfig::new()
.expect("unwrap")
.with_threshold(0.9)
.expect("unwrap");
let kernel = PartialTreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim < 0.5);
}
#[test]
fn test_partial_tree_kernel_config_invalid_decay() {
let result = PartialTreeKernelConfig::new()
.expect("unwrap")
.with_decay(1.5);
assert!(result.is_err());
}
#[test]
fn test_partial_tree_kernel_config_invalid_threshold() {
let result = PartialTreeKernelConfig::new()
.expect("unwrap")
.with_threshold(-0.1);
assert!(result.is_err());
}
#[test]
fn test_tree_kernel_with_tlexpr() {
let expr1 = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p2", vec![]));
let expr2 = TLExpr::and(TLExpr::pred("p1", vec![]), TLExpr::pred("p3", vec![]));
let tree1 = TreeNode::from_tlexpr(&expr1);
let tree2 = TreeNode::from_tlexpr(&expr2);
let config = SubtreeKernelConfig::new();
let kernel = SubtreeKernel::new(config);
let sim = kernel.compute_trees(&tree1, &tree2).expect("unwrap");
assert!(sim > 0.0); }
}