1use super::default::Parameters;
2use super::metric::FingeringMetric;
3use super::metric::FingeringMetricUniform;
4use super::metric::LevelMetric;
5use super::metric::LevelMetricUniform;
6use super::metric::PartialMetric;
7use super::metric::TierMetric;
8use crate::config::PartialWeights;
9use crate::data::最大按键组合长度;
10use crate::data::编码;
11use crate::data::部分编码信息;
12use crate::data::键位分布损失函数;
13use std::iter::zip;
14
15#[derive(Debug, Clone)]
16pub struct Cache {
17 partial_weights: PartialWeights,
18 total_count: usize,
19 total_frequency: i64,
20 total_pairs: i64,
21 total_extended_pairs: i64,
22 distribution: Vec<i64>,
23 total_pair_equivalence: f64,
24 total_extended_pair_equivalence: f64,
25 total_duplication: i64,
26 total_fingering: [i64; 8],
27 total_levels: Vec<i64>,
28 tiers_duplication: Vec<i64>,
29 tiers_levels: Vec<Vec<i64>>,
30 tiers_fingering: Vec<[i64; 8]>,
31 max_index: u64,
32 segment: u64,
33 length_breakpoints: Vec<u64>,
34 radix: u64,
35}
36
37impl Cache {
38 #[inline(always)]
39 pub fn process(
40 &mut self,
41 index: usize,
42 frequency: u64,
43 c: &mut 部分编码信息,
44 parameters: &Parameters,
45 ) {
46 if !c.有变化 {
47 return;
48 }
49 c.有变化 = false;
50 self.accumulate(index, frequency, c.实际编码, c.选重标记, parameters, 1);
51 if c.上一个实际编码 == 0 {
52 return;
53 }
54 self.accumulate(
55 index,
56 frequency,
57 c.上一个实际编码,
58 c.上一个选重标记,
59 parameters,
60 -1,
61 );
62 }
63
64 pub fn finalize(&self, parameters: &Parameters) -> (PartialMetric, f64) {
65 let partial_weights = &self.partial_weights;
66 let ideal_distribution = ¶meters.ideal_distribution;
67 let mut partial_metric = PartialMetric {
69 tiers: None,
70 key_distribution: None,
71 pair_equivalence: None,
72 extended_pair_equivalence: None,
73 fingering: None,
74 duplication: None,
75 levels: None,
76 };
77 let mut loss = 0.0;
78 if let Some(key_distribution_weight) = partial_weights.key_distribution {
81 let total: i64 = self.distribution.iter().sum();
83 let distribution = self
84 .distribution
85 .iter()
86 .map(|x| *x as f64 / total as f64)
87 .collect();
88 let distance = Cache::get_distribution_distance(&distribution, ideal_distribution);
89 partial_metric.key_distribution = Some(distance);
90 loss += distance * key_distribution_weight;
91 }
92 if let Some(equivalence_weight) = partial_weights.pair_equivalence {
94 let equivalence = self.total_pair_equivalence / self.total_pairs as f64;
95 partial_metric.pair_equivalence = Some(equivalence);
96 loss += equivalence * equivalence_weight;
97 }
98 if let Some(equivalence_weight) = partial_weights.extended_pair_equivalence {
100 let equivalence =
101 self.total_extended_pair_equivalence / self.total_extended_pairs as f64;
102 partial_metric.extended_pair_equivalence = Some(equivalence);
103 loss += equivalence * equivalence_weight;
104 }
105 if let Some(fingering_weight) = &partial_weights.fingering {
107 let mut fingering = FingeringMetric::default();
108 for (i, weight) in fingering_weight.iter().enumerate() {
109 if let Some(weight) = weight {
110 fingering[i] = Some(self.total_fingering[i] as f64 / self.total_pairs as f64);
111 loss += self.total_fingering[i] as f64 * weight;
112 }
113 }
114 partial_metric.fingering = Some(fingering);
115 }
116 if let Some(duplication_weight) = partial_weights.duplication {
118 let duplication = self.total_duplication as f64 / self.total_frequency as f64;
119 partial_metric.duplication = Some(duplication);
120 loss += duplication * duplication_weight;
121 }
122 if let Some(levels_weight) = &partial_weights.levels {
124 let mut levels: Vec<LevelMetric> = Vec::new();
125 for (ilevel, level) in levels_weight.iter().enumerate() {
126 let value = self.total_levels[ilevel] as f64 / self.total_frequency as f64;
127 loss += value * level.frequency;
128 levels.push(LevelMetric {
129 length: level.length,
130 frequency: value,
131 });
132 }
133 partial_metric.levels = Some(levels);
134 }
135 if let Some(tiers_weight) = &partial_weights.tiers {
137 let mut tiers: Vec<TierMetric> = tiers_weight
138 .iter()
139 .map(|x| TierMetric {
140 top: x.top,
141 duplication: None,
142 levels: None,
143 fingering: None,
144 })
145 .collect();
146 for (itier, tier_weights) in tiers_weight.iter().enumerate() {
147 let count = tier_weights.top.unwrap_or(self.total_count) as f64;
148 if let Some(duplication_weight) = tier_weights.duplication {
150 let duplication = self.tiers_duplication[itier];
151 loss += duplication as f64 / count * duplication_weight;
152 tiers[itier].duplication = Some(duplication as u64);
153 }
154 if let Some(level_weight) = &tier_weights.levels {
156 for (ilevel, level) in level_weight.iter().enumerate() {
157 loss += self.tiers_levels[itier][ilevel] as f64 / count * level.frequency;
158 }
159 tiers[itier].levels = Some(
160 level_weight
161 .iter()
162 .enumerate()
163 .map(|(i, v)| LevelMetricUniform {
164 length: v.length,
165 frequency: self.tiers_levels[itier][i] as u64,
166 })
167 .collect(),
168 );
169 }
170 if let Some(fingering_weight) = &tier_weights.fingering {
172 let mut fingering = FingeringMetricUniform::default();
173 for (i, weight) in fingering_weight.iter().enumerate() {
174 if let Some(weight) = weight {
175 let value = self.tiers_fingering[itier][i];
176 fingering[i] = Some(value as u64);
177 loss += value as f64 / count * weight;
178 }
179 }
180 tiers[itier].fingering = Some(fingering);
181 }
182 }
183 partial_metric.tiers = Some(tiers);
184 }
185 (partial_metric, loss)
186 }
187}
188
189impl Cache {
190 pub fn new(
191 partial_weights: &PartialWeights,
192 radix: u64,
193 total_count: usize,
194 max_index: u64,
195 ) -> Self {
196 let total_frequency = 0;
197 let total_pairs = 0;
198 let total_extended_pairs = 0;
199 let distribution = vec![0; radix as usize];
202 let total_pair_equivalence = 0.0;
203 let total_extended_pair_equivalence = 0.0;
204 let total_duplication = 0;
206 let total_fingering = [0; 8];
207 let nlevel = partial_weights.levels.as_ref().map_or(0, |v| v.len());
208 let total_levels = vec![0; nlevel];
209 let ntier = partial_weights.tiers.as_ref().map_or(0, |v| v.len());
211 let tiers_duplication = vec![0; ntier];
212 let mut tiers_levels = vec![];
213 if let Some(tiers) = &partial_weights.tiers {
214 for tier in tiers {
215 let vec = vec![0; tier.levels.as_ref().map_or(0, |v| v.len())];
216 tiers_levels.push(vec);
217 }
218 }
219 let tiers_fingering = vec![[0; 8]; ntier];
220 let segment = radix.pow((最大按键组合长度 - 1) as u32);
221 let length_breakpoints: Vec<u64> = (0..=8).map(|x| radix.pow(x)).collect();
222
223 Self {
224 partial_weights: partial_weights.clone(),
225 total_count,
226 total_frequency,
227 total_pairs,
228 total_extended_pairs,
229 distribution,
230 total_pair_equivalence,
231 total_extended_pair_equivalence,
232 total_duplication,
233 total_fingering,
234 total_levels,
235 tiers_duplication,
236 tiers_levels,
237 tiers_fingering,
238 max_index,
239 segment,
240 length_breakpoints,
241 radix,
242 }
243 }
244
245 fn get_distribution_distance(
248 distribution: &Vec<f64>,
249 ideal_distribution: &Vec<键位分布损失函数>,
250 ) -> f64 {
251 let mut distance = 0.0;
252 for (frequency, loss) in zip(distribution, ideal_distribution) {
253 let diff = frequency - loss.ideal;
254 if diff > 0.0 {
255 distance += loss.gt_penalty * diff;
256 } else {
257 distance -= loss.lt_penalty * diff;
258 }
259 }
260 distance
261 }
262
263 #[inline(always)]
264 pub fn accumulate(
265 &mut self,
266 index: usize,
267 frequency: u64,
268 code: 编码,
269 duplicate: bool,
270 parameters: &Parameters,
271 sign: i64,
272 ) {
273 let frequency = frequency as i64 * sign;
274 let radix = self.radix;
275 let length = self
276 .length_breakpoints
277 .iter()
278 .position(|&x| code < x)
279 .unwrap() as u64;
280 self.total_frequency += frequency;
281 self.total_pairs += (length - 1) as i64 * frequency;
282 let partial_weights = &self.partial_weights;
283 if partial_weights.key_distribution.is_some() {
286 let mut current = code;
287 while current > 0 {
288 let key = current % self.radix;
289 if let Some(x) = self.distribution.get_mut(key as usize) {
290 *x += frequency;
291 }
292 current /= self.radix;
293 }
294 }
295 if partial_weights.pair_equivalence.is_some() {
297 let mut code = code;
298 while code > self.radix {
299 let partial_code = (code % self.max_index) as usize;
300 self.total_pair_equivalence +=
301 parameters.pair_equivalence[partial_code] * frequency as f64;
302 code /= self.segment;
303 }
304 }
305 if let Some(fingering) = &partial_weights.fingering {
307 let mut code = code;
308 while code > radix {
309 let label = parameters.fingering_types[(code % self.max_index) as usize];
310 for (i, weight) in fingering.iter().enumerate() {
311 if weight.is_some() {
312 self.total_fingering[i] += frequency * label[i] as i64;
313 }
314 }
315 code /= self.segment;
316 }
317 }
318 if duplicate {
320 self.total_duplication += frequency;
321 }
322 if let Some(levels) = &partial_weights.levels {
324 for (ilevel, level) in levels.iter().enumerate() {
325 if level.length == length as usize {
326 self.total_levels[ilevel] += frequency;
327 }
328 }
329 }
330 if let Some(tiers) = &partial_weights.tiers {
332 for (itier, tier) in tiers.iter().enumerate() {
333 if index >= tier.top.unwrap_or(self.total_count) {
334 continue;
335 }
336 if duplicate {
338 self.tiers_duplication[itier] += sign;
339 }
340 if let Some(levels) = &tier.levels {
342 for (ilevel, level) in levels.iter().enumerate() {
343 if level.length == length as usize {
344 self.tiers_levels[itier][ilevel] += sign;
345 }
346 }
347 }
348 if let Some(fingering) = &tier.fingering {
350 let mut code = code;
351 while code > radix {
352 let label = parameters.fingering_types[(code % self.max_index) as usize];
353 for (i, weight) in fingering.iter().enumerate() {
354 if weight.is_some() {
355 self.tiers_fingering[itier][i] += sign * label[i] as i64;
356 }
357 }
358 code /= self.segment;
359 }
360 }
361 }
362 }
363 }
364}