use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use similar::{ChangeTag, TextDiff};
use time::Date;
use crate::constants::STOP_WORDS;
use crate::uslm::{
BillAmendment, ElementData, TextContentField, USLMElement, bill_parser::AmendmentData,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct FieldChangeEvent {
pub field_name: TextContentField,
pub from_date: Date,
pub to_date: Date,
pub old_value: String,
pub new_value: String,
pub changes: Vec<TextChange>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TextChange {
pub value: String,
pub old_index: Option<i32>,
pub new_index: Option<i32>,
pub tag: TextChangeType,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TextChangeType {
Insert,
Delete,
Equal,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TreeDiff {
pub root_path: String,
pub changes: Vec<FieldChangeEvent>,
pub from_element: ElementData,
pub to_element: ElementData,
pub added: Vec<ElementData>,
pub removed: Vec<ElementData>,
pub child_diffs: Vec<TreeDiff>,
}
impl TreeDiff {
pub fn from_elements(from_element: &USLMElement, to_element: &USLMElement) -> TreeDiff {
assert!(from_element.data.path == to_element.data.path);
let root_path = from_element.data.path.clone();
let changes = diff_elements(from_element, to_element);
let children_a: HashMap<String, &USLMElement> = from_element
.children
.iter()
.map(|child| (child.data.path.clone(), child))
.collect();
let children_b: HashMap<String, &USLMElement> = to_element
.children
.iter()
.map(|child| (child.data.path.clone(), child))
.collect();
let mut added = vec![];
let mut removed = vec![];
let mut child_diffs = vec![];
for (path, child_a) in &children_a {
match children_b.get(path) {
Some(child_b) => {
let child_diff = TreeDiff::from_elements(child_a, child_b);
if !child_diff.child_diffs.is_empty() || !child_diff.changes.is_empty() {
child_diffs.push(child_diff);
}
}
None => {
removed.push(child_a.data.clone()); }
}
}
for (path, child_b) in &children_b {
if !children_a.contains_key(path) {
added.push(child_b.data.clone()); }
}
TreeDiff {
changes,
root_path,
from_element: from_element.data.clone(),
to_element: to_element.data.clone(),
added,
removed,
child_diffs,
}
}
pub fn find(&self, path: &str) -> Option<&TreeDiff> {
if path == self.root_path.as_str() {
return Some(self);
}
let remaining_path = path.strip_prefix(self.root_path.as_str())?;
let next_step: Vec<&str> = remaining_path.split("/").collect();
assert!(next_step.len() > 1);
let child_id = next_step[1];
let child_vec: Vec<&TreeDiff> = self
.child_diffs
.iter()
.filter(|c| c.root_path.ends_with(child_id))
.collect();
if child_vec.is_empty() {
None
} else {
assert!(child_vec.len() == 1);
child_vec[0].find(path)
}
}
pub fn calculate_amendment_similarities(
&self,
data: &AmendmentData,
) -> HashMap<String, AmendmentSimilarity> {
let mut result = HashMap::new();
self.calculate_similarities_recursive(&mut result, data);
result
}
fn calculate_similarities_recursive(
&self,
result: &mut HashMap<String, AmendmentSimilarity>,
data: &AmendmentData,
) {
if !self.changes.is_empty() {
for (amendment_id, amendment) in &data.amendments {
if amendment.changes.is_empty() {
continue;
}
let similarity = self.calculate_match_with_amendment(amendment_id, amendment);
if similarity.score > 0.0 {
let entry = result
.entry(self.root_path.clone())
.or_insert(similarity.clone());
if similarity.score > entry.score {
*entry = similarity;
}
}
}
}
for child_diff in &self.child_diffs {
child_diff.calculate_similarities_recursive(result, data);
}
}
fn calculate_match_with_amendment(
&self,
amendment_id: &str,
amendment: &BillAmendment,
) -> AmendmentSimilarity {
let tree_diff_words: HashSet<String> = self.collect_tree_diff_words();
let tree_diff_count = tree_diff_words.len();
let mut best_score = 0.0_f32;
let mut best_precision = 0.0_f32;
let mut best_recall = 0.0_f32;
let mut best_matched = 0_i32;
for bill_diff in &amendment.changes {
let mut bill_diff_words: HashSet<String> = HashSet::new();
for word in &bill_diff.removed {
let trimmed = word.trim();
if !trimmed.is_empty() && !is_stop_word(trimmed) {
bill_diff_words.insert(trimmed.to_lowercase());
}
}
for word in &bill_diff.added {
let trimmed = word.trim();
if !trimmed.is_empty() && !is_stop_word(trimmed) {
bill_diff_words.insert(trimmed.to_lowercase());
}
}
if bill_diff_words.is_empty() {
continue;
}
let matched_words: i32 = tree_diff_words
.iter()
.filter(|w| bill_diff_words.contains(*w))
.count() as i32;
let bill_diff_count = bill_diff_words.len();
let precision = if tree_diff_count > 0 {
matched_words as f32 / tree_diff_count as f32
} else {
0.0
};
let recall = if bill_diff_count > 0 {
matched_words as f32 / bill_diff_count as f32
} else {
0.0
};
let score = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
if score > best_score {
best_score = score;
best_precision = precision;
best_recall = recall;
best_matched = matched_words;
}
}
AmendmentSimilarity {
tree_diff_path: self.root_path.clone(),
amendment_id: amendment_id.to_string(),
score: best_score,
precision: best_precision,
recall: best_recall,
matched_words: best_matched,
tree_diff_words: tree_diff_count as i32,
}
}
fn collect_tree_diff_words(&self) -> HashSet<String> {
let mut words = HashSet::new();
for field_change in &self.changes {
for text_change in &field_change.changes {
let word = text_change.value.trim();
if word.is_empty() || is_stop_word(word) {
continue;
}
match text_change.tag {
TextChangeType::Delete | TextChangeType::Insert => {
words.insert(word.to_lowercase());
}
TextChangeType::Equal => {}
}
}
}
words
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AmendmentSimilarity {
pub tree_diff_path: String,
pub amendment_id: String,
pub score: f32,
pub precision: f32,
pub recall: f32,
pub matched_words: i32,
pub tree_diff_words: i32,
}
fn is_stop_word(word: &str) -> bool {
let lower = word.to_lowercase();
STOP_WORDS.contains(&lower.as_str())
}
pub fn diff_elements(element_a: &USLMElement, element_b: &USLMElement) -> Vec<FieldChangeEvent> {
assert!(element_a.data.path == element_b.data.path);
assert!(element_a.data.element_type == element_b.data.element_type);
let mut changes: Vec<FieldChangeEvent> = Vec::new();
for field_name in [
TextContentField::Heading,
TextContentField::Chapeau,
TextContentField::Proviso,
TextContentField::Content,
TextContentField::Continuation,
]
.into_iter()
{
let field_changes = diff_field(element_a, element_b, field_name);
if !field_changes.changes.is_empty() {
changes.push(field_changes);
}
}
changes
}
fn rewrap_usize(s: Option<usize>) -> Option<i32> {
s.map(|val| val as i32)
}
fn diff_field(
element_a: &USLMElement,
element_b: &USLMElement,
field_name: TextContentField,
) -> FieldChangeEvent {
let a = element_a
.data
.get_text_content(field_name)
.unwrap_or_default();
let b = element_b
.data
.get_text_content(field_name)
.unwrap_or_default();
let diff = TextDiff::from_words(a.as_str(), b.as_str());
let changes: Vec<TextChange> = diff
.iter_all_changes()
.filter(|c| c.tag() != ChangeTag::Equal)
.map(|c| {
let tag = match c.tag() {
ChangeTag::Delete => TextChangeType::Delete,
ChangeTag::Insert => TextChangeType::Insert,
ChangeTag::Equal => TextChangeType::Equal,
};
TextChange {
value: String::from(c.value()),
old_index: rewrap_usize(c.old_index()),
new_index: rewrap_usize(c.new_index()),
tag,
}
})
.collect();
FieldChangeEvent {
field_name,
from_date: element_a.data.date,
to_date: element_b.data.date,
old_value: a,
new_value: b,
changes,
}
}