diffr_lib/
best_projection.rs1use std::collections::hash_map::Entry::*;
2use std::collections::HashMap;
3use std::convert::TryFrom;
4
5use crate::HashedSlice;
6use crate::Tokenization;
7
8#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy, Hash)]
9struct Coord {
10 next_lcs: usize,
11 next_seq: usize,
12}
13
14#[derive(Debug)]
15struct Context<'a> {
16 seq_index: HashMap<HashedSlice<'a>, Vec<usize>>,
17}
18
19impl<'a> Context<'a> {
20 fn new(seq: &'a Tokenization<'a>, lcs: &'a Tokenization<'a>) -> Self {
21 let mut seq_index = HashMap::new();
22 for i in 0..to_isize(lcs.nb_tokens()) {
23 match seq_index.entry(lcs.nth_token(i)) {
24 Occupied(_) => (),
25 Vacant(e) => {
26 e.insert(vec![]);
27 }
28 }
29 }
30 for i in 0..seq.nb_tokens() {
31 match seq_index.entry(seq.nth_token(to_isize(i))) {
32 Occupied(e) => {
33 e.into_mut().push(i);
34 }
35 Vacant(_) => (),
36 }
37 }
38 Context { seq_index }
39 }
40
41 fn get_indexes(&self, tok: &HashedSlice<'a>, min_value: usize) -> &[usize] {
42 match self.seq_index.get(tok) {
43 Some(values) => {
44 let min_idx = match values.binary_search(&min_value) {
45 Ok(i) | Err(i) => i,
46 };
47 &values[min_idx..]
48 }
49 None => &[],
50 }
51 }
52}
53
54#[derive(Debug)]
56pub struct NormalizationResult {
57 pub path: Vec<isize>,
58 pub starts_with_shared: bool,
59}
60
61impl NormalizationResult {
62 pub fn shared_segments<'a>(&'a self, seq: &'a Tokenization) -> SharedSegments<'a> {
65 SharedSegments::new(self, seq)
66 }
67}
68
69fn snake_len(seq: &Tokenization, lcs: &Tokenization, start_lcs: usize, start_seq: usize) -> usize {
70 let lcs_len = lcs.nb_tokens() - start_lcs;
71 let seq_len = seq.nb_tokens() - start_seq;
72 let max_smake_len = lcs_len.min(seq_len);
73 let mut snake_len = 0;
74 while snake_len < max_smake_len {
75 let lcs_tok = lcs.nth_token(to_isize(start_lcs + snake_len));
76 let seq_tok = seq.nth_token(to_isize(start_seq + snake_len));
77 if lcs_tok != seq_tok {
78 break;
79 }
80 snake_len += 1
81 }
82 snake_len
83}
84
85pub fn optimize_partition(seq: &Tokenization, lcs: &Tokenization) -> NormalizationResult {
88 let context = Context::new(seq, lcs);
89 let root = Coord {
90 next_lcs: 0,
91 next_seq: 0,
92 };
93 let target = Coord {
94 next_lcs: lcs.nb_tokens(),
95 next_seq: seq.nb_tokens(),
96 };
97 let mut frontier = vec![root];
98 let mut new_frontier = vec![];
99 let mut prev = HashMap::new();
100 let mut found_seq = None;
101 while !frontier.is_empty() && found_seq == None {
102 new_frontier.clear();
103 for &coord in frontier.iter() {
104 if coord.next_lcs == target.next_lcs {
105 found_seq = Some(coord.next_seq);
106 if coord.next_seq == target.next_seq {
107 break;
108 } else {
109 continue;
111 }
112 }
113 let start_lcs = coord.next_lcs;
114 let lcs_len = lcs.nb_tokens() - start_lcs;
115 let mut last_enqueued_snake_len = 0;
116 for start_seq in
117 context.get_indexes(&lcs.nth_token(to_isize(coord.next_lcs)), coord.next_seq)
118 {
119 if start_seq + lcs_len > seq.nb_tokens() {
120 break;
121 }
122 let snake_len = 1 + snake_len(seq, lcs, start_lcs + 1, start_seq + 1);
123 let next_coord = Coord {
124 next_lcs: start_lcs + snake_len,
125 next_seq: start_seq + snake_len,
126 };
127 if last_enqueued_snake_len < snake_len || next_coord == target {
128 if next_coord.next_lcs == target.next_lcs
129 && (next_coord.next_seq == target.next_seq || found_seq == None)
130 {
131 found_seq = Some(next_coord.next_seq);
132 }
133 match prev.entry(next_coord) {
134 Occupied(_) => continue,
135 Vacant(e) => e.insert(coord),
136 };
137 new_frontier.push(next_coord);
138 last_enqueued_snake_len = snake_len;
139 }
140 }
141 }
142 std::mem::swap(&mut frontier, &mut new_frontier)
143 }
144
145 let target = found_seq.map(|next_seq| Coord {
146 next_lcs: lcs.nb_tokens(),
147 next_seq,
148 });
149 let mut path = vec![];
150 let mut starts_with_shared = false;
151 let mut coord = target.as_ref();
152 let mut seq = seq.nb_tokens();
153 let mut lcs = lcs.nb_tokens();
154 while let Some(&coord_content) = coord {
155 let next_seq = coord_content.next_seq;
156 let next_lcs = coord_content.next_lcs;
157 let snake_len = lcs - next_lcs;
158 push_if_not_last(&mut path, to_isize(seq - snake_len));
159 starts_with_shared = !push_if_not_last(&mut path, to_isize(next_seq));
160
161 coord = prev.get(&coord_content);
162
163 seq = next_seq;
164 lcs = next_lcs;
165 }
166 path.reverse();
167 NormalizationResult {
168 path,
169 starts_with_shared,
170 }
171}
172
173fn push_if_not_last(v: &mut Vec<isize>, val: isize) -> bool {
174 let should_push = v.last() != Some(&val);
175 if should_push {
176 v.push(val);
177 }
178 should_push
179}
180
181fn to_isize(input: usize) -> isize {
182 isize::try_from(input).unwrap()
183}
184
185pub struct SharedSegments<'a> {
187 index: usize,
188 normalization: &'a Vec<isize>,
189 seq: &'a Tokenization<'a>,
190}
191
192impl<'a> SharedSegments<'a> {
193 fn new(normalization: &'a NormalizationResult, seq: &'a Tokenization) -> Self {
194 SharedSegments {
195 index: if normalization.starts_with_shared {
196 0
197 } else {
198 1
199 },
200 normalization: &normalization.path,
201 seq,
202 }
203 }
204}
205
206impl<'a> Iterator for SharedSegments<'a> {
207 type Item = (usize, usize);
208 fn next(&mut self) -> Option<Self::Item> {
209 if self.index + 1 < self.normalization.len() {
210 let prev = self.normalization[self.index];
211 let curr = self.normalization[self.index + 1];
212 let from = self.seq.nth_span(prev).lo;
213 let to = self.seq.nth_span(curr - 1).hi;
214 self.index += 2;
215 Some((from, to))
216 } else {
217 None
218 }
219 }
220}