use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use rayon::prelude::*;
use regex::Regex;
use serde::{Deserialize, Serialize};
use similar::{ChangeTag, TextDiff};
use time::Date;
use crate::constants::STOP_WORDS;
use crate::uslm::{BillAmendment, ElementData, TextContentField, USLMElement, bill_parser::Bill};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct FieldChangeEvent {
pub field_name: TextContentField,
pub from_date: Date,
pub to_date: Date,
pub old_value: String,
pub new_value: String,
pub changes: Vec<TextChange>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TextChange {
pub value: String,
pub old_index: Option<i32>,
pub new_index: Option<i32>,
pub tag: TextChangeType,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TextChangeType {
Insert,
Delete,
Equal,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TreeDiff {
pub root_path: String,
pub changes: Vec<FieldChangeEvent>,
pub from_element: ElementData,
pub to_element: ElementData,
pub added: Vec<ElementData>,
pub removed: Vec<ElementData>,
pub child_diffs: Vec<TreeDiff>,
}
impl TreeDiff {
pub fn mention_regex(&self) -> Option<Regex> {
if self.root_path.contains("section") {
let mut mreg = String::from(self.section_regex().unwrap().as_str());
mreg.truncate(mreg.len() - 2);
let split: Vec<_> = self.root_path.split("/").collect();
let mut started = false;
for part in split {
let Some((part_name, part_num)) = part.split_once("_") else {
continue;
};
if started {
mreg += r"\(";
mreg += part_num;
mreg += r"\)\s*"
}
if part_name == "section" {
started = true;
}
}
Some(Regex::from_str(mreg.as_str()).unwrap())
} else {
None
}
}
pub fn section_regex(&self) -> Option<Regex> {
if self.root_path.contains("section") {
let mut regex = String::from(r"[Ss]ection\s*");
let split: Vec<_> = self.root_path.split("/").collect();
for part in split {
let Some((part_name, part_num)) = part.split_once("_") else {
continue;
};
if part_name == "section" {
regex += part_num;
regex += r"\D";
return Some(Regex::from_str(regex.as_str()).unwrap());
}
}
}
None
}
pub fn all_regexes(&self) -> Vec<Regex> {
let mut res = Vec::new();
if let Some(sreg) = self.section_regex() {
res.push(sreg.clone());
if let Some(mreg) = self.mention_regex()
&& mreg.as_str() != sreg.as_str()
{
res.push(mreg);
}
}
res
}
pub fn from_elements(from_element: &USLMElement, to_element: &USLMElement) -> TreeDiff {
assert!(from_element.data.path == to_element.data.path);
let root_path = from_element.data.path.clone();
let changes = diff_elements(from_element, to_element);
let children_a: HashMap<String, &USLMElement> = from_element
.children
.iter()
.map(|child| (child.data.path.clone(), child))
.collect();
let children_b: HashMap<String, &USLMElement> = to_element
.children
.iter()
.map(|child| (child.data.path.clone(), child))
.collect();
let mut added = vec![];
let mut removed = vec![];
let mut child_diffs = vec![];
for (path, child_a) in &children_a {
match children_b.get(path) {
Some(child_b) => {
let child_diff = TreeDiff::from_elements(child_a, child_b);
if !child_diff.child_diffs.is_empty() || !child_diff.changes.is_empty() {
child_diffs.push(child_diff);
}
}
None => {
removed.push(child_a.data.clone()); }
}
}
for (path, child_b) in &children_b {
if !children_a.contains_key(path) {
added.push(child_b.data.clone()); }
}
TreeDiff {
changes,
root_path,
from_element: from_element.data.clone(),
to_element: to_element.data.clone(),
added,
removed,
child_diffs,
}
}
pub fn find(&self, path: &str) -> Option<&TreeDiff> {
if path == self.root_path.as_str() {
return Some(self);
}
let remaining_path = path.strip_prefix(self.root_path.as_str())?;
let next_step: Vec<&str> = remaining_path.split("/").collect();
assert!(next_step.len() > 1);
let child_id = next_step[1];
let child_vec: Vec<&TreeDiff> = self
.child_diffs
.iter()
.filter(|c| c.root_path.ends_with(child_id))
.collect();
if child_vec.is_empty() {
None
} else {
assert!(child_vec.len() == 1);
child_vec[0].find(path)
}
}
pub fn calculate_amendment_similarities(
&self,
data: &Bill,
) -> HashMap<String, AmendmentSimilarity> {
let mut result = HashMap::new();
self.calculate_similarities_recursive(&mut result, data);
result
}
fn calculate_similarities_recursive(
&self,
result: &mut HashMap<String, AmendmentSimilarity>,
data: &Bill,
) {
if !self.changes.is_empty() {
for (amendment_id, amendment) in &data.amendments {
if amendment.changes.is_empty() {
continue;
}
let similarity = self.calculate_match_with_amendment(amendment_id, amendment);
if similarity.score > 0.0 {
let entry = result
.entry(self.root_path.clone())
.or_insert(similarity.clone());
if similarity.score > entry.score {
*entry = similarity;
}
}
}
}
for child_diff in &self.child_diffs {
child_diff.calculate_similarities_recursive(result, data);
}
}
fn calculate_match_with_amendment(
&self,
amendment_id: &str,
amendment: &BillAmendment,
) -> AmendmentSimilarity {
let tree_diff_words: HashSet<String> = self.collect_tree_diff_words();
let tree_diff_count = tree_diff_words.len();
let mut best_score = 0.0_f32;
let mut best_precision = 0.0_f32;
let mut best_recall = 0.0_f32;
let mut best_matched = 0_i32;
for bill_diff in &amendment.changes {
let mut bill_diff_words: HashSet<String> = HashSet::new();
for word in &bill_diff.removed {
let trimmed = word.trim();
if !trimmed.is_empty() && !is_stop_word(trimmed) {
bill_diff_words.insert(trimmed.to_lowercase());
}
}
for word in &bill_diff.added {
let trimmed = word.trim();
if !trimmed.is_empty() && !is_stop_word(trimmed) {
bill_diff_words.insert(trimmed.to_lowercase());
}
}
if bill_diff_words.is_empty() {
continue;
}
let matched_words: i32 = tree_diff_words
.iter()
.filter(|w| bill_diff_words.contains(*w))
.count() as i32;
let bill_diff_count = bill_diff_words.len();
let precision = if tree_diff_count > 0 {
matched_words as f32 / tree_diff_count as f32
} else {
0.0
};
let recall = if bill_diff_count > 0 {
matched_words as f32 / bill_diff_count as f32
} else {
0.0
};
let score = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
if score > best_score {
best_score = score;
best_precision = precision;
best_recall = recall;
best_matched = matched_words;
}
}
AmendmentSimilarity {
tree_diff_path: self.root_path.clone(),
amendment_id: amendment_id.to_string(),
score: best_score,
precision: best_precision,
recall: best_recall,
matched_words: best_matched,
tree_diff_words: tree_diff_count as i32,
}
}
fn collect_tree_diff_words(&self) -> HashSet<String> {
let mut words = HashSet::new();
for field_change in &self.changes {
for text_change in &field_change.changes {
let word = text_change.value.trim();
if word.is_empty() || is_stop_word(word) {
continue;
}
match text_change.tag {
TextChangeType::Delete | TextChangeType::Insert => {
words.insert(word.to_lowercase());
}
TextChangeType::Equal => {}
}
}
}
words
}
pub fn scan_for_mentions(&self, data: &Bill) -> HashMap<String, Vec<MentionMatch>> {
let regex_with_paths = self.collect_regexes_with_paths();
let mut results: HashMap<String, Vec<MentionMatch>> = HashMap::new();
for (amendment_id, amendment) in &data.amendments {
let text = &amendment.amending_text;
let all_matches: Vec<MentionMatch> = regex_with_paths
.par_iter()
.filter_map(|(path, reg)| {
reg.find(text).map(|mat| MentionMatch {
tree_diff_path: path.clone(),
matched_text: mat.as_str().to_string(),
})
})
.collect();
let mut best_by_path: HashMap<&str, &MentionMatch> = HashMap::new();
for m in &all_matches {
let dominated = best_by_path
.get(m.tree_diff_path.as_str())
.is_some_and(|existing| existing.matched_text.len() >= m.matched_text.len());
if !dominated {
best_by_path.insert(&m.tree_diff_path, m);
}
}
let matches: Vec<MentionMatch> = best_by_path.into_values().cloned().collect();
if !matches.is_empty() {
results.insert(amendment_id.clone(), matches);
}
}
results
}
pub fn shallow(&self) -> TreeDiff {
TreeDiff {
root_path: self.root_path.clone(),
changes: self.changes.clone(),
from_element: self.from_element.clone(),
to_element: self.to_element.clone(),
added: self.added.clone(),
removed: self.removed.clone(),
child_diffs: vec![],
}
}
fn collect_regexes_with_paths(&self) -> Vec<(String, Regex)> {
let mut result = Vec::new();
for reg in self.all_regexes() {
result.push((self.root_path.clone(), reg));
}
for child in &self.child_diffs {
result.extend(child.collect_regexes_with_paths());
}
result
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AmendmentSimilarity {
pub tree_diff_path: String,
pub amendment_id: String,
pub score: f32,
pub precision: f32,
pub recall: f32,
pub matched_words: i32,
pub tree_diff_words: i32,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MentionMatch {
pub tree_diff_path: String,
pub matched_text: String,
}
fn is_stop_word(word: &str) -> bool {
let lower = word.to_lowercase();
STOP_WORDS.contains(&lower.as_str())
}
pub fn diff_elements(element_a: &USLMElement, element_b: &USLMElement) -> Vec<FieldChangeEvent> {
assert!(element_a.data.path == element_b.data.path);
assert!(element_a.data.element_type == element_b.data.element_type);
let mut changes: Vec<FieldChangeEvent> = Vec::new();
for field_name in [
TextContentField::Heading,
TextContentField::Chapeau,
TextContentField::Proviso,
TextContentField::Content,
TextContentField::Continuation,
]
.into_iter()
{
let field_changes = diff_field(element_a, element_b, field_name);
let has_real_changes = field_changes
.changes
.iter()
.any(|c| c.tag != TextChangeType::Equal);
if has_real_changes {
changes.push(field_changes);
}
}
changes
}
fn rewrap_usize(s: Option<usize>) -> Option<i32> {
s.map(|val| val as i32)
}
fn diff_field(
element_a: &USLMElement,
element_b: &USLMElement,
field_name: TextContentField,
) -> FieldChangeEvent {
let a = element_a
.data
.get_text_content(field_name)
.unwrap_or_default();
let b = element_b
.data
.get_text_content(field_name)
.unwrap_or_default();
let diff = TextDiff::from_words(a.as_str(), b.as_str());
let changes: Vec<TextChange> = diff
.iter_all_changes()
.map(|c| {
let tag = match c.tag() {
ChangeTag::Delete => TextChangeType::Delete,
ChangeTag::Insert => TextChangeType::Insert,
ChangeTag::Equal => TextChangeType::Equal,
};
TextChange {
value: String::from(c.value()),
old_index: rewrap_usize(c.old_index()),
new_index: rewrap_usize(c.new_index()),
tag,
}
})
.collect();
FieldChangeEvent {
field_name,
from_date: element_a.data.date,
to_date: element_b.data.date,
old_value: a,
new_value: b,
changes,
}
}