use crate::{
data::{AlignmentStatus, CharInterval, Extraction},
exceptions::LangExtractResult,
};
use std::cmp::min;
#[derive(Debug, Clone)]
pub struct AlignmentConfig {
pub enable_fuzzy_alignment: bool,
pub fuzzy_alignment_threshold: f32,
pub accept_match_lesser: bool,
pub case_sensitive: bool,
pub max_search_window: usize,
}
impl Default for AlignmentConfig {
fn default() -> Self {
Self {
enable_fuzzy_alignment: true,
fuzzy_alignment_threshold: 0.4, accept_match_lesser: true,
case_sensitive: false,
max_search_window: 100,
}
}
}
pub struct TextAligner {
config: AlignmentConfig,
}
impl TextAligner {
pub fn new() -> Self {
Self {
config: AlignmentConfig::default(),
}
}
pub fn with_config(config: AlignmentConfig) -> Self {
Self { config }
}
#[tracing::instrument(skip_all, fields(num_extractions = extractions.len(), source_len = source_text.len(), char_offset))]
pub fn align_extractions(
&self,
extractions: &mut [Extraction],
source_text: &str,
char_offset: usize,
) -> LangExtractResult<usize> {
let search_text = if self.config.case_sensitive {
source_text.to_string()
} else {
source_text.to_lowercase()
};
let source_words: Vec<&str> = search_text.split_whitespace().collect();
let word_byte_offsets: Vec<(usize, usize)> = search_text
.split_whitespace()
.map(|word| {
let start = word.as_ptr() as usize - search_text.as_ptr() as usize;
(start, start + word.len())
})
.collect();
let mut aligned_count = 0;
for extraction in extractions.iter_mut() {
if let Some(interval) = self.align_single_extraction_with_cache(
extraction, &search_text, &source_words, &word_byte_offsets, char_offset,
)? {
extraction.char_interval = Some(interval);
aligned_count += 1;
}
}
Ok(aligned_count)
}
fn align_single_extraction_with_cache(
&self,
extraction: &mut Extraction,
search_text: &str,
source_words: &[&str],
word_byte_offsets: &[(usize, usize)],
char_offset: usize,
) -> LangExtractResult<Option<CharInterval>> {
let extraction_text = if self.config.case_sensitive {
extraction.extraction_text.clone()
} else {
extraction.extraction_text.to_lowercase()
};
if let Some((start, end, status)) = self.find_exact_match(&extraction_text, search_text) {
extraction.alignment_status = Some(status);
return Ok(Some(CharInterval::new(
Some(start + char_offset),
Some(end + char_offset),
)));
}
if self.config.enable_fuzzy_alignment {
if let Some((start, end, status)) = self.find_fuzzy_match_with_words(&extraction_text, search_text, source_words, word_byte_offsets) {
extraction.alignment_status = Some(status);
return Ok(Some(CharInterval::new(
Some(start + char_offset),
Some(end + char_offset),
)));
}
}
extraction.alignment_status = None;
Ok(None)
}
pub fn align_single_extraction(
&self,
extraction: &mut Extraction,
source_text: &str,
char_offset: usize,
) -> LangExtractResult<Option<CharInterval>> {
let search_text = if self.config.case_sensitive {
source_text.to_string()
} else {
source_text.to_lowercase()
};
let source_words: Vec<&str> = search_text.split_whitespace().collect();
let word_byte_offsets: Vec<(usize, usize)> = search_text
.split_whitespace()
.map(|word| {
let start = word.as_ptr() as usize - search_text.as_ptr() as usize;
(start, start + word.len())
})
.collect();
self.align_single_extraction_with_cache(extraction, &search_text, &source_words, &word_byte_offsets, char_offset)
}
fn find_exact_match(&self, extraction_text: &str, source_text: &str) -> Option<(usize, usize, AlignmentStatus)> {
if let Some(start) = source_text.find(extraction_text) {
let end = start + extraction_text.len();
return Some((start, end, AlignmentStatus::MatchExact));
}
if self.config.accept_match_lesser {
let extraction_words: Vec<&str> = extraction_text.split_whitespace().collect();
if extraction_words.len() > 1 {
if let (Some(first_word), Some(last_word)) = (extraction_words.first(), extraction_words.last()) {
if let Some(first_start) = source_text.find(first_word) {
if let Some(last_start) = source_text[first_start..].find(last_word) {
let last_absolute_start = first_start + last_start;
let last_end = last_absolute_start + last_word.len();
if last_end - first_start < extraction_text.len() * 2 {
return Some((first_start, last_end, AlignmentStatus::MatchLesser));
}
}
}
}
}
}
None
}
fn find_fuzzy_match_with_words(&self, extraction_text: &str, source_text: &str, source_words: &[&str], word_byte_offsets: &[(usize, usize)]) -> Option<(usize, usize, AlignmentStatus)> {
let extraction_words: Vec<&str> = extraction_text.split_whitespace().collect();
if extraction_words.is_empty() || source_words.is_empty() {
return None;
}
let mut best_match: Option<(usize, usize, f32)> = None;
let max_window = min(source_words.len(), self.config.max_search_window);
let min_window = extraction_words.len();
for window_size in min_window..=max_window {
for start_idx in 0..=source_words.len().saturating_sub(window_size) {
let end_idx = start_idx + window_size;
let window = &source_words[start_idx..end_idx];
let similarity = self.calculate_word_similarity_direct(&extraction_words, window);
if similarity >= self.config.fuzzy_alignment_threshold {
if let Some((_, _, current_best)) = best_match {
if similarity > current_best {
best_match = Some((start_idx, end_idx, similarity));
}
} else {
best_match = Some((start_idx, end_idx, similarity));
}
}
}
if best_match.is_some() {
break;
}
}
if let Some((start_word_idx, end_word_idx, _)) = best_match {
let char_start = word_byte_offsets[start_word_idx].0;
let char_end = if end_word_idx >= source_words.len() {
source_text.len()
} else {
word_byte_offsets[end_word_idx - 1].1
};
return Some((char_start, char_end, AlignmentStatus::MatchFuzzy));
}
None
}
fn calculate_word_similarity_direct(&self, words1: &[&str], words2: &[&str]) -> f32 {
if words1.is_empty() && words2.is_empty() {
return 1.0;
}
if words1.is_empty() || words2.is_empty() {
return 0.0;
}
let word_set2: std::collections::HashSet<&str> = words2.iter().copied().collect();
let found_count = words1.iter().filter(|w| word_set2.contains(**w)).count();
found_count as f32 / words1.len() as f32
}
pub fn align_chunk_extractions(
&self,
extractions: &mut [Extraction],
chunk_text: &str,
chunk_char_offset: usize,
) -> LangExtractResult<usize> {
self.align_extractions(extractions, chunk_text, chunk_char_offset)
}
pub fn get_alignment_stats(&self, extractions: &[Extraction]) -> AlignmentStats {
let total = extractions.len();
let mut exact = 0;
let mut fuzzy = 0;
let mut lesser = 0;
let mut greater = 0;
let mut unaligned = 0;
for extraction in extractions {
match extraction.alignment_status {
Some(AlignmentStatus::MatchExact) => exact += 1,
Some(AlignmentStatus::MatchFuzzy) => fuzzy += 1,
Some(AlignmentStatus::MatchLesser) => lesser += 1,
Some(AlignmentStatus::MatchGreater) => greater += 1,
None => unaligned += 1,
}
}
AlignmentStats {
total,
exact,
fuzzy,
lesser,
greater,
unaligned,
}
}
}
impl Default for TextAligner {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AlignmentStats {
pub total: usize,
pub exact: usize,
pub fuzzy: usize,
pub lesser: usize,
pub greater: usize,
pub unaligned: usize,
}
impl AlignmentStats {
pub fn success_rate(&self) -> f32 {
if self.total == 0 {
1.0
} else {
(self.total - self.unaligned) as f32 / self.total as f32
}
}
pub fn exact_match_rate(&self) -> f32 {
if self.total == 0 {
0.0
} else {
self.exact as f32 / self.total as f32
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_alignment() {
let aligner = TextAligner::new();
let mut extraction = Extraction::new("person".to_string(), "John Doe".to_string());
let source_text = "Hello, John Doe is a software engineer.";
let result = aligner.align_single_extraction(&mut extraction, source_text, 0).unwrap();
assert!(result.is_some());
let interval = result.unwrap();
assert_eq!(interval.start_pos, Some(7));
assert_eq!(interval.end_pos, Some(15));
assert_eq!(extraction.alignment_status, Some(AlignmentStatus::MatchExact));
}
#[test]
fn test_case_insensitive_alignment() {
let aligner = TextAligner::new();
let mut extraction = Extraction::new("person".to_string(), "JOHN DOE".to_string());
let source_text = "Hello, john doe is a software engineer.";
let result = aligner.align_single_extraction(&mut extraction, source_text, 0).unwrap();
assert!(result.is_some());
let interval = result.unwrap();
assert_eq!(interval.start_pos, Some(7));
assert_eq!(interval.end_pos, Some(15));
assert_eq!(extraction.alignment_status, Some(AlignmentStatus::MatchExact));
}
#[test]
fn test_fuzzy_alignment() {
let aligner = TextAligner::new();
let mut extraction = Extraction::new("person".to_string(), "John Smith".to_string());
let source_text = "Hello, John is a software engineer named Smith.";
let result = aligner.align_single_extraction(&mut extraction, source_text, 0).unwrap();
assert!(result.is_some());
assert_eq!(extraction.alignment_status, Some(AlignmentStatus::MatchFuzzy));
}
#[test]
fn test_no_alignment() {
let aligner = TextAligner::new();
let mut extraction = Extraction::new("person".to_string(), "Jane Doe".to_string());
let source_text = "Hello, John Smith is a software engineer.";
let result = aligner.align_single_extraction(&mut extraction, source_text, 0).unwrap();
assert!(result.is_none());
assert_eq!(extraction.alignment_status, None);
}
#[test]
fn test_chunk_offset() {
let aligner = TextAligner::new();
let mut extraction = Extraction::new("person".to_string(), "John Doe".to_string());
let chunk_text = "John Doe is here.";
let chunk_offset = 100;
let result = aligner.align_single_extraction(&mut extraction, chunk_text, chunk_offset).unwrap();
assert!(result.is_some());
let interval = result.unwrap();
assert_eq!(interval.start_pos, Some(100)); assert_eq!(interval.end_pos, Some(108)); }
#[test]
fn test_alignment_stats() {
let aligner = TextAligner::new();
let extractions = vec![
Extraction {
extraction_class: "test".to_string(),
extraction_text: "test".to_string(),
alignment_status: Some(AlignmentStatus::MatchExact),
..Default::default()
},
Extraction {
extraction_class: "test".to_string(),
extraction_text: "test".to_string(),
alignment_status: Some(AlignmentStatus::MatchFuzzy),
..Default::default()
},
Extraction {
extraction_class: "test".to_string(),
extraction_text: "test".to_string(),
alignment_status: None,
..Default::default()
},
];
let stats = aligner.get_alignment_stats(&extractions);
assert_eq!(stats.total, 3);
assert_eq!(stats.exact, 1);
assert_eq!(stats.fuzzy, 1);
assert_eq!(stats.unaligned, 1);
assert_eq!(stats.success_rate(), 2.0 / 3.0);
assert_eq!(stats.exact_match_rate(), 1.0 / 3.0);
}
}