1use crate::error::{Result, TextError};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum AlignmentMethod {
17 WordBaseline,
19 BpePair,
21 FastAlign,
23}
24
25pub type AlignmentPair = (usize, usize);
27
28pub fn word_alignment(
42 source_tokens: &[String],
43 target_tokens: &[String],
44 co_occurrence: &HashMap<(String, String), usize>,
45) -> Result<Vec<AlignmentPair>> {
46 if source_tokens.is_empty() {
47 return Err(TextError::InvalidInput(
48 "source_tokens must not be empty".to_string(),
49 ));
50 }
51 if target_tokens.is_empty() {
52 return Err(TextError::InvalidInput(
53 "target_tokens must not be empty".to_string(),
54 ));
55 }
56
57 let mut alignments: Vec<AlignmentPair> = Vec::new();
58
59 for (si, src) in source_tokens.iter().enumerate() {
60 let best = target_tokens
61 .iter()
62 .enumerate()
63 .filter_map(|(ti, tgt)| {
64 co_occurrence
65 .get(&(src.clone(), tgt.clone()))
66 .map(|&cnt| (ti, cnt))
67 })
68 .max_by_key(|&(_, cnt)| cnt);
69
70 if let Some((ti, _)) = best {
71 alignments.push((si, ti));
72 }
73 }
74
75 Ok(alignments)
76}
77
78pub fn ibm_model1(
94 sentence_pairs: &[(Vec<String>, Vec<String>)],
95 n_iter: usize,
96) -> Result<HashMap<(String, String), f64>> {
97 if sentence_pairs.is_empty() {
98 return Err(TextError::InvalidInput(
99 "sentence_pairs must not be empty".to_string(),
100 ));
101 }
102 if n_iter == 0 {
103 return Err(TextError::InvalidInput(
104 "n_iter must be at least 1".to_string(),
105 ));
106 }
107
108 const NULL: &str = "<NULL>";
109
110 let mut src_vocab: std::collections::HashSet<String> = std::collections::HashSet::new();
112 let mut tgt_vocab: std::collections::HashSet<String> = std::collections::HashSet::new();
113
114 for (src_sent, tgt_sent) in sentence_pairs {
115 for w in src_sent {
116 src_vocab.insert(w.clone());
117 }
118 for w in tgt_sent {
119 tgt_vocab.insert(w.clone());
120 }
121 }
122 src_vocab.insert(NULL.to_string());
123
124 let uniform = if tgt_vocab.is_empty() {
126 1.0
127 } else {
128 1.0 / tgt_vocab.len() as f64
129 };
130
131 let mut t: HashMap<(String, String), f64> = HashMap::new();
132 for s in &src_vocab {
133 for e in &tgt_vocab {
134 t.insert((s.clone(), e.clone()), uniform);
135 }
136 }
137
138 for _ in 0..n_iter {
140 let mut count: HashMap<(String, String), f64> = HashMap::new();
142 let mut total_s: HashMap<String, f64> = HashMap::new();
143
144 for (src_sent, tgt_sent) in sentence_pairs {
145 let augmented_src: Vec<&str> = std::iter::once(NULL)
147 .chain(src_sent.iter().map(|s| s.as_str()))
148 .collect();
149
150 for e in tgt_sent {
152 let s_total: f64 = augmented_src
153 .iter()
154 .map(|&s| {
155 t.get(&(s.to_string(), e.clone()))
156 .copied()
157 .unwrap_or(uniform)
158 })
159 .sum();
160
161 if s_total > 0.0 {
162 for &s in &augmented_src {
163 let prob = t
164 .get(&(s.to_string(), e.clone()))
165 .copied()
166 .unwrap_or(uniform);
167 let delta = prob / s_total;
168 *count.entry((s.to_string(), e.clone())).or_insert(0.0) += delta;
169 *total_s.entry(s.to_string()).or_insert(0.0) += delta;
170 }
171 }
172 }
173 }
174
175 for ((s, e), c) in &count {
177 let total = total_s.get(s).copied().unwrap_or(1.0);
178 t.insert((s.clone(), e.clone()), c / total);
179 }
180 }
181
182 t.retain(|(s, _), _| s != NULL);
184 Ok(t)
185}
186
187pub fn symmetrize_alignments(
203 src_to_tgt: &[AlignmentPair],
204 tgt_to_src: &[AlignmentPair],
205) -> Result<Vec<AlignmentPair>> {
206 if src_to_tgt.is_empty() && tgt_to_src.is_empty() {
207 return Err(TextError::ProcessingError(
208 "Both alignment sets are empty; cannot symmetrize".to_string(),
209 ));
210 }
211
212 let s2t_set: std::collections::HashSet<AlignmentPair> = src_to_tgt.iter().copied().collect();
214 let t2s_set: std::collections::HashSet<AlignmentPair> =
216 tgt_to_src.iter().map(|&(ti, si)| (si, ti)).collect();
217
218 let mut result: std::collections::HashSet<AlignmentPair> =
219 s2t_set.intersection(&t2s_set).copied().collect();
220
221 let aligned_src = |set: &std::collections::HashSet<AlignmentPair>, si: usize| {
223 set.iter().any(|&(s, _)| s == si)
224 };
225 let aligned_tgt = |set: &std::collections::HashSet<AlignmentPair>, ti: usize| {
226 set.iter().any(|&(_, t)| t == ti)
227 };
228
229 let union: std::collections::HashSet<AlignmentPair> =
231 s2t_set.union(&t2s_set).copied().collect();
232
233 let neighbors: [(i32, i32); 4] = [(-1, 0), (1, 0), (0, -1), (0, 1)];
236 let mut changed = true;
237 while changed {
238 changed = false;
239 let current: Vec<AlignmentPair> = result.iter().copied().collect();
240 for (si, ti) in ¤t {
241 for (ds, dt) in &neighbors {
242 let ns = (*si as i32 + ds) as usize;
243 let nt = (*ti as i32 + dt) as usize;
244 let candidate = (ns, nt);
245 if union.contains(&candidate) && !result.contains(&candidate) {
246 result.insert(candidate);
247 changed = true;
248 }
249 }
250 }
251 }
252
253 for &(si, ti) in &union {
255 if !aligned_src(&result, si) || !aligned_tgt(&result, ti) {
256 result.insert((si, ti));
257 }
258 }
259
260 let mut out: Vec<AlignmentPair> = result.into_iter().collect();
261 out.sort_unstable();
262 Ok(out)
263}
264
265pub fn alignment_f1(
279 pred_alignments: &[AlignmentPair],
280 gold_alignments: &[AlignmentPair],
281) -> Result<(f64, f64, f64)> {
282 if pred_alignments.is_empty() && gold_alignments.is_empty() {
283 return Err(TextError::InvalidInput(
284 "Both pred and gold alignment sets are empty".to_string(),
285 ));
286 }
287
288 let pred_set: std::collections::HashSet<AlignmentPair> =
289 pred_alignments.iter().copied().collect();
290 let gold_set: std::collections::HashSet<AlignmentPair> =
291 gold_alignments.iter().copied().collect();
292
293 let tp = pred_set.intersection(&gold_set).count() as f64;
294
295 let precision = if pred_set.is_empty() {
296 0.0
297 } else {
298 tp / pred_set.len() as f64
299 };
300
301 let recall = if gold_set.is_empty() {
302 0.0
303 } else {
304 tp / gold_set.len() as f64
305 };
306
307 let f1 = if precision + recall < f64::EPSILON {
308 0.0
309 } else {
310 2.0 * precision * recall / (precision + recall)
311 };
312
313 Ok((precision, recall, f1))
314}
315
316#[derive(Debug)]
323pub struct AlignedCorpus {
324 pub source: Vec<Vec<String>>,
326 pub target: Vec<Vec<String>>,
328 pub t_table: HashMap<(String, String), f64>,
330}
331
332impl AlignedCorpus {
333 pub fn train(sentence_pairs: Vec<(Vec<String>, Vec<String>)>, n_iter: usize) -> Result<Self> {
339 let t_table = ibm_model1(&sentence_pairs, n_iter)?;
340 let (source, target) = sentence_pairs.into_iter().unzip();
341 Ok(Self {
342 source,
343 target,
344 t_table,
345 })
346 }
347
348 pub fn viterbi_align(&self, idx: usize) -> Result<Vec<AlignmentPair>> {
356 if idx >= self.source.len() {
357 return Err(TextError::InvalidInput(format!(
358 "Sentence pair index {} is out of range (corpus has {} pairs)",
359 idx,
360 self.source.len()
361 )));
362 }
363
364 const NULL: &str = "<NULL>";
365 let src = &self.source[idx];
366 let tgt = &self.target[idx];
367
368 let mut alignments = Vec::new();
369
370 for (ti, tgt_word) in tgt.iter().enumerate() {
371 let null_prob = self
373 .t_table
374 .get(&(NULL.to_string(), tgt_word.clone()))
375 .copied()
376 .unwrap_or(0.0);
377
378 let best = src
379 .iter()
380 .enumerate()
381 .map(|(si, src_word)| {
382 let p = self
383 .t_table
384 .get(&(src_word.clone(), tgt_word.clone()))
385 .copied()
386 .unwrap_or(0.0);
387 (si, p)
388 })
389 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
390
391 if let Some((si, best_prob)) = best {
392 if best_prob >= null_prob {
393 alignments.push((si, ti));
394 }
395 }
396 }
397
398 Ok(alignments)
399 }
400}
401
402#[cfg(test)]
407mod tests {
408 use super::*;
409
410 fn tok(words: &[&str]) -> Vec<String> {
411 words.iter().map(|w| w.to_string()).collect()
412 }
413
414 #[test]
415 fn test_word_alignment_basic() {
416 let mut cooc: HashMap<(String, String), usize> = HashMap::new();
417 cooc.insert(("cat".to_string(), "gato".to_string()), 10);
418 cooc.insert(("dog".to_string(), "perro".to_string()), 8);
419
420 let src = tok(&["cat", "dog"]);
421 let tgt = tok(&["gato", "perro"]);
422
423 let aligns = word_alignment(&src, &tgt, &cooc).expect("alignment failed");
424 assert!(aligns.contains(&(0, 0)));
425 assert!(aligns.contains(&(1, 1)));
426 }
427
428 #[test]
429 fn test_word_alignment_empty_source() {
430 let cooc: HashMap<(String, String), usize> = HashMap::new();
431 let res = word_alignment(&[], &tok(&["a"]), &cooc);
432 assert!(res.is_err());
433 }
434
435 #[test]
436 fn test_ibm_model1_basic() {
437 let pairs = vec![
438 (tok(&["the", "cat"]), tok(&["le", "chat"])),
439 (tok(&["the", "dog"]), tok(&["le", "chien"])),
440 (tok(&["a", "cat"]), tok(&["un", "chat"])),
441 ];
442 let t = ibm_model1(&pairs, 5).expect("ibm_model1 failed");
443
444 let p_chat_cat = t
446 .get(&("cat".to_string(), "chat".to_string()))
447 .copied()
448 .unwrap_or(0.0);
449 assert!(
450 p_chat_cat > 0.0,
451 "Expected positive probability for (cat, chat)"
452 );
453 }
454
455 #[test]
456 fn test_ibm_model1_zero_iters() {
457 let pairs = vec![(tok(&["a"]), tok(&["b"]))];
458 assert!(ibm_model1(&pairs, 0).is_err());
459 }
460
461 #[test]
462 fn test_symmetrize_alignments() {
463 let s2t = vec![(0, 0), (1, 1)];
465 let t2s = vec![(0, 0), (1, 1)];
467 let sym = symmetrize_alignments(&s2t, &t2s).expect("symmetrize failed");
468 assert!(sym.contains(&(0, 0)));
469 assert!(sym.contains(&(1, 1)));
470 }
471
472 #[test]
473 fn test_alignment_f1_perfect() {
474 let aligns = vec![(0, 0), (1, 1), (2, 2)];
475 let (p, r, f1) = alignment_f1(&aligns, &aligns).expect("f1 failed");
476 assert!((p - 1.0).abs() < 1e-9);
477 assert!((r - 1.0).abs() < 1e-9);
478 assert!((f1 - 1.0).abs() < 1e-9);
479 }
480
481 #[test]
482 fn test_alignment_f1_no_overlap() {
483 let pred = vec![(0, 1)];
484 let gold = vec![(0, 0)];
485 let (p, r, f1) = alignment_f1(&pred, &gold).expect("f1 failed");
486 assert!((p - 0.0).abs() < 1e-9);
487 assert!((r - 0.0).abs() < 1e-9);
488 assert!((f1 - 0.0).abs() < 1e-9);
489 }
490
491 #[test]
492 fn test_aligned_corpus_train_viterbi() {
493 let pairs = vec![
494 (tok(&["the", "cat"]), tok(&["le", "chat"])),
495 (tok(&["the", "dog"]), tok(&["le", "chien"])),
496 (tok(&["a", "cat"]), tok(&["un", "chat"])),
497 ];
498 let corpus = AlignedCorpus::train(pairs, 10).expect("train failed");
499 let aligns = corpus.viterbi_align(0).expect("viterbi failed");
500 assert!(!aligns.is_empty());
502 }
503}