use std::collections::HashMap;
use crate::types::{CstNode, ListOrdering, MatchPair, NodeId};
pub fn match_children(
left_parent: &CstNode,
right_parent: &CstNode,
ordering: ListOrdering,
) -> Vec<MatchPair> {
let left_children = left_parent.children();
let right_children = right_parent.children();
match ordering {
ListOrdering::Ordered => yang_match(left_children, right_children),
ListOrdering::Unordered => bipartite_match(left_children, right_children),
}
}
pub fn match_trees(left: &CstNode, right: &CstNode) -> Vec<MatchPair> {
let mut pairs = Vec::new();
match_trees_recursive(left, right, &mut pairs);
pairs
}
fn match_trees_recursive(left: &CstNode, right: &CstNode, pairs: &mut Vec<MatchPair>) {
if left.kind() != right.kind() {
return;
}
if left.is_leaf() && right.is_leaf() {
if left.leaf_value() == right.leaf_value() {
pairs.push(MatchPair {
left: left.id(),
right: right.id(),
score: 1,
});
}
return;
}
if left.is_leaf() != right.is_leaf() {
return;
}
let similarity = tree_similarity(left, right);
if similarity > 0 {
pairs.push(MatchPair {
left: left.id(),
right: right.id(),
score: similarity,
});
}
let ordering = match (left, right) {
(CstNode::List { ordering: lo, .. }, CstNode::List { .. }) => *lo,
_ => ListOrdering::Ordered,
};
let child_pairs = match_children(left, right, ordering);
let left_children = left.children();
let right_children = right.children();
let left_map: HashMap<NodeId, &CstNode> = left_children.iter().map(|c| (c.id(), c)).collect();
let right_map: HashMap<NodeId, &CstNode> = right_children.iter().map(|c| (c.id(), c)).collect();
for pair in child_pairs {
if let (Some(lc), Some(rc)) = (left_map.get(&pair.left), right_map.get(&pair.right)) {
match_trees_recursive(lc, rc, pairs);
}
}
}
fn yang_match(left: &[CstNode], right: &[CstNode]) -> Vec<MatchPair> {
let n = left.len();
let m = right.len();
if n == 0 || m == 0 {
return Vec::new();
}
let mut dp = vec![vec![0usize; m + 1]; n + 1];
let mut choice = vec![vec![0u8; m + 1]; n + 1];
for i in 1..=n {
for j in 1..=m {
let match_score = if can_match(&left[i - 1], &right[j - 1]) {
dp[i - 1][j - 1] + tree_similarity(&left[i - 1], &right[j - 1])
} else {
0
};
let skip_left = dp[i - 1][j];
let skip_right = dp[i][j - 1];
if match_score >= skip_left && match_score >= skip_right && match_score > 0 {
dp[i][j] = match_score;
choice[i][j] = 1;
} else if skip_left >= skip_right {
dp[i][j] = skip_left;
choice[i][j] = 2;
} else {
dp[i][j] = skip_right;
choice[i][j] = 3;
}
}
}
let mut pairs = Vec::new();
let mut i = n;
let mut j = m;
while i > 0 && j > 0 {
match choice[i][j] {
1 => {
pairs.push(MatchPair {
left: left[i - 1].id(),
right: right[j - 1].id(),
score: tree_similarity(&left[i - 1], &right[j - 1]),
});
i -= 1;
j -= 1;
}
2 => i -= 1,
3 => j -= 1,
_ => break,
}
}
pairs.reverse();
pairs
}
fn bipartite_match(left: &[CstNode], right: &[CstNode]) -> Vec<MatchPair> {
let n = left.len();
let m = right.len();
if n == 0 || m == 0 {
return Vec::new();
}
let size = n.max(m);
let mut weights = vec![vec![0i64; size]; size];
for (i, l) in left.iter().enumerate() {
for (j, r) in right.iter().enumerate() {
if can_match(l, r) {
weights[i][j] = tree_similarity(l, r) as i64;
}
}
}
let assignment = hungarian_max(&weights, size);
let mut pairs = Vec::new();
for (i, &j) in assignment.iter().enumerate() {
if i < n && j < m && weights[i][j] > 0 {
pairs.push(MatchPair {
left: left[i].id(),
right: right[j].id(),
score: weights[i][j] as usize,
});
}
}
pairs
}
fn can_match(left: &CstNode, right: &CstNode) -> bool {
if left.is_leaf() != right.is_leaf() {
return false;
}
left.kind() == right.kind()
}
pub fn tree_similarity(left: &CstNode, right: &CstNode) -> usize {
if !can_match(left, right) {
return 0;
}
match (left, right) {
(CstNode::Leaf { value: v1, .. }, CstNode::Leaf { value: v2, .. }) => {
if v1 == v2 {
1
} else {
0
}
}
_ => {
let left_leaves = left.collect_leaves();
let right_leaves = right.collect_leaves();
lcs_length(&left_leaves, &right_leaves)
}
}
}
fn lcs_length<T: PartialEq>(a: &[T], b: &[T]) -> usize {
let n = a.len();
let m = b.len();
let mut dp = vec![vec![0usize; m + 1]; n + 1];
for i in 1..=n {
for j in 1..=m {
dp[i][j] = if a[i - 1] == b[j - 1] {
dp[i - 1][j - 1] + 1
} else {
dp[i - 1][j].max(dp[i][j - 1])
};
}
}
dp[n][m]
}
fn hungarian_max(weights: &[Vec<i64>], n: usize) -> Vec<usize> {
if n == 0 {
return Vec::new();
}
let max_w = weights
.iter()
.flat_map(|row| row.iter())
.copied()
.max()
.unwrap_or(0);
let mut cost = vec![vec![0i64; n]; n];
for i in 0..n {
for j in 0..n {
cost[i][j] = max_w - weights[i][j];
}
}
let mut u = vec![0i64; n + 1];
let mut v = vec![0i64; n + 1];
let mut p = vec![0usize; n + 1]; let mut way = vec![0usize; n + 1];
for i in 1..=n {
p[0] = i;
let mut j0 = 0usize;
let mut minv = vec![i64::MAX; n + 1];
let mut used = vec![false; n + 1];
loop {
used[j0] = true;
let i0 = p[j0];
let mut delta = i64::MAX;
let mut j1 = 0usize;
for j in 1..=n {
if !used[j] {
let cur = cost[i0 - 1][j - 1] - u[i0] - v[j];
if cur < minv[j] {
minv[j] = cur;
way[j] = j0;
}
if minv[j] < delta {
delta = minv[j];
j1 = j;
}
}
}
for j in 0..=n {
if used[j] {
u[p[j]] += delta;
v[j] -= delta;
} else {
minv[j] -= delta;
}
}
j0 = j1;
if p[j0] == 0 {
break;
}
}
loop {
let j1 = way[j0];
p[j0] = p[j1];
j0 = j1;
if j0 == 0 {
break;
}
}
}
let mut result = vec![0usize; n];
for j in 1..=n {
if p[j] > 0 {
result[p[j] - 1] = j - 1;
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn leaf(id: usize, val: &str) -> CstNode {
CstNode::Leaf {
id,
kind: "identifier".into(),
value: val.into(),
}
}
#[test]
fn test_yang_match_identical() {
let left = vec![leaf(1, "a"), leaf(2, "b"), leaf(3, "c")];
let right = vec![leaf(4, "a"), leaf(5, "b"), leaf(6, "c")];
let pairs = yang_match(&left, &right);
assert_eq!(pairs.len(), 3);
}
#[test]
fn test_yang_match_partial() {
let left = vec![leaf(1, "a"), leaf(2, "b"), leaf(3, "c")];
let right = vec![leaf(4, "a"), leaf(5, "c")];
let pairs = yang_match(&left, &right);
assert_eq!(pairs.len(), 2);
}
#[test]
fn test_bipartite_match() {
let left = vec![leaf(1, "a"), leaf(2, "b")];
let right = vec![leaf(3, "b"), leaf(4, "a")];
let pairs = bipartite_match(&left, &right);
assert_eq!(pairs.len(), 2);
}
#[test]
fn test_hungarian_simple() {
let weights = vec![vec![3, 1], vec![1, 3]];
let assignment = hungarian_max(&weights, 2);
assert_eq!(assignment[0], 0);
assert_eq!(assignment[1], 1);
}
#[test]
fn test_tree_similarity() {
let a = leaf(1, "hello");
let b = leaf(2, "hello");
let c = leaf(3, "world");
assert_eq!(tree_similarity(&a, &b), 1);
assert_eq!(tree_similarity(&a, &c), 0);
}
}