1use super::default::默认目标函数参数;
2use super::metric::FingeringMetric;
3use super::metric::FingeringMetricUniform;
4use super::metric::LevelMetricUniform;
5use super::metric::分组指标;
6use super::metric::层级指标;
7use super::metric::键长指标;
8use crate::config::PartialWeights;
9use crate::data::最大按键组合长度;
10use crate::data::编码;
11use crate::data::部分编码信息;
12use crate::data::键位分布损失函数;
13use std::collections::HashMap;
14use std::iter::zip;
15
16#[derive(Debug, Clone)]
18pub struct 缓存 {
19 partial_weights: PartialWeights,
20 total_count: usize,
21 total_frequency: i64,
22 total_pairs: i64,
23 total_extended_pairs: i64,
24 distribution: Vec<i64>,
25 total_pair_equivalence: f64,
26 total_extended_pair_equivalence: f64,
27 total_duplication: i64,
28 total_fingering: [i64; 8],
29 total_levels: Vec<i64>,
30 tiers_duplication: Vec<i64>,
31 tiers_levels: Vec<Vec<i64>>,
32 tiers_fingering: Vec<[i64; 8]>,
33 max_index: u64,
34 segment: u64,
35 length_breakpoints: Vec<u64>,
36 radix: u64,
37}
38
39impl 缓存 {
40 #[inline(always)]
41 pub fn 处理(
42 &mut self,
43 序号: usize,
44 频率: u64,
45 编码信息: &mut 部分编码信息,
46 参数: &默认目标函数参数,
47 ) {
48 if !编码信息.有变化 {
49 return;
50 }
51 编码信息.有变化 = false;
52 self.增减(序号, 频率, 编码信息.实际编码, 编码信息.选重标记, 参数, 1);
53 if 编码信息.上一个实际编码 == 0 {
54 return;
55 }
56 self.增减(
57 序号,
58 频率,
59 编码信息.上一个实际编码,
60 编码信息.上一个选重标记,
61 参数,
62 -1,
63 );
64 }
65
66 pub fn 汇总(&self, 参数: &默认目标函数参数) -> (分组指标, f64) {
67 let partial_weights = &self.partial_weights;
68 let 键位分布信息 = &参数.键位分布信息;
69 let mut 分组指标 = 分组指标 {
71 tiers: None,
72 key_distribution: None,
73 key_distribution_loss: None,
74 pair_equivalence: None,
75 extended_pair_equivalence: None,
76 fingering: None,
77 duplication: None,
78 levels: None,
79 };
80 let mut 损失函数 = 0.0;
81 if let Some(key_distribution_weight) = partial_weights.key_distribution {
84 let 总频率: i64 = self.distribution.iter().sum();
86 let 分布 = self
87 .distribution
88 .iter()
89 .map(|x| *x as f64 / 总频率 as f64)
90 .collect();
91 let 距离 = 缓存::计算键位分布距离(&分布, 键位分布信息);
92 let mut 分布映射 = HashMap::new();
93 for (i, x) in 分布.iter().enumerate() {
94 if let Some(键) = 参数.数字转键.get(&(i as u64)) {
95 分布映射.insert(*键, *x);
96 }
97 }
98 分组指标.key_distribution = Some(分布映射);
99 分组指标.key_distribution_loss = Some(距离);
100 损失函数 += 距离 * key_distribution_weight;
101 }
102 if let Some(equivalence_weight) = partial_weights.pair_equivalence {
104 let equivalence = self.total_pair_equivalence / self.total_pairs as f64;
105 分组指标.pair_equivalence = Some(equivalence);
106 损失函数 += equivalence * equivalence_weight;
107 }
108 if let Some(equivalence_weight) = partial_weights.extended_pair_equivalence {
110 let equivalence =
111 self.total_extended_pair_equivalence / self.total_extended_pairs as f64;
112 分组指标.extended_pair_equivalence = Some(equivalence);
113 损失函数 += equivalence * equivalence_weight;
114 }
115 if let Some(fingering_weight) = &partial_weights.fingering {
117 let mut fingering = FingeringMetric::default();
118 for (i, weight) in fingering_weight.iter().enumerate() {
119 if let Some(weight) = weight {
120 fingering[i] = Some(self.total_fingering[i] as f64 / self.total_pairs as f64);
121 损失函数 += self.total_fingering[i] as f64 * weight;
122 }
123 }
124 分组指标.fingering = Some(fingering);
125 }
126 if let Some(duplication_weight) = partial_weights.duplication {
128 let duplication = self.total_duplication as f64 / self.total_frequency as f64;
129 分组指标.duplication = Some(duplication);
130 损失函数 += duplication * duplication_weight;
131 }
132 if let Some(levels_weight) = &partial_weights.levels {
134 let mut levels: Vec<键长指标> = Vec::new();
135 for (ilevel, level) in levels_weight.iter().enumerate() {
136 let value = self.total_levels[ilevel] as f64 / self.total_frequency as f64;
137 损失函数 += value * level.frequency;
138 levels.push(键长指标 {
139 length: level.length,
140 frequency: value,
141 });
142 }
143 分组指标.levels = Some(levels);
144 }
145 if let Some(tiers_weight) = &partial_weights.tiers {
147 let mut tiers: Vec<层级指标> = tiers_weight
148 .iter()
149 .map(|x| 层级指标 {
150 top: x.top,
151 duplication: None,
152 levels: None,
153 fingering: None,
154 })
155 .collect();
156 for (itier, tier_weights) in tiers_weight.iter().enumerate() {
157 let count = tier_weights.top.unwrap_or(self.total_count) as f64;
158 if let Some(duplication_weight) = tier_weights.duplication {
160 let duplication = self.tiers_duplication[itier];
161 损失函数 += duplication as f64 / count * duplication_weight;
162 tiers[itier].duplication = Some(duplication as u64);
163 }
164 if let Some(level_weight) = &tier_weights.levels {
166 for (ilevel, level) in level_weight.iter().enumerate() {
167 损失函数 +=
168 self.tiers_levels[itier][ilevel] as f64 / count * level.frequency;
169 }
170 tiers[itier].levels = Some(
171 level_weight
172 .iter()
173 .enumerate()
174 .map(|(i, v)| LevelMetricUniform {
175 length: v.length,
176 frequency: self.tiers_levels[itier][i] as u64,
177 })
178 .collect(),
179 );
180 }
181 if let Some(fingering_weight) = &tier_weights.fingering {
183 let mut fingering = FingeringMetricUniform::default();
184 for (i, weight) in fingering_weight.iter().enumerate() {
185 if let Some(weight) = weight {
186 let value = self.tiers_fingering[itier][i];
187 fingering[i] = Some(value as u64);
188 损失函数 += value as f64 / count * weight;
189 }
190 }
191 tiers[itier].fingering = Some(fingering);
192 }
193 }
194 分组指标.tiers = Some(tiers);
195 }
196 (分组指标, 损失函数)
197 }
198}
199
200impl 缓存 {
201 pub fn new(
202 partial_weights: &PartialWeights,
203 radix: u64,
204 total_count: usize,
205 max_index: u64,
206 ) -> Self {
207 let total_frequency = 0;
208 let total_pairs = 0;
209 let total_extended_pairs = 0;
210 let distribution = vec![0; radix as usize];
213 let total_pair_equivalence = 0.0;
214 let total_extended_pair_equivalence = 0.0;
215 let total_duplication = 0;
217 let total_fingering = [0; 8];
218 let nlevel = partial_weights.levels.as_ref().map_or(0, |v| v.len());
219 let total_levels = vec![0; nlevel];
220 let ntier = partial_weights.tiers.as_ref().map_or(0, |v| v.len());
222 let tiers_duplication = vec![0; ntier];
223 let mut tiers_levels = vec![];
224 if let Some(tiers) = &partial_weights.tiers {
225 for tier in tiers {
226 let vec = vec![0; tier.levels.as_ref().map_or(0, |v| v.len())];
227 tiers_levels.push(vec);
228 }
229 }
230 let tiers_fingering = vec![[0; 8]; ntier];
231 let segment = radix.pow((最大按键组合长度 - 1) as u32);
232 let length_breakpoints: Vec<u64> = (0..=8).map(|x| radix.pow(x)).collect();
233
234 Self {
235 partial_weights: partial_weights.clone(),
236 total_count,
237 total_frequency,
238 total_pairs,
239 total_extended_pairs,
240 distribution,
241 total_pair_equivalence,
242 total_extended_pair_equivalence,
243 total_duplication,
244 total_fingering,
245 total_levels,
246 tiers_duplication,
247 tiers_levels,
248 tiers_fingering,
249 max_index,
250 segment,
251 length_breakpoints,
252 radix,
253 }
254 }
255
256 fn 计算键位分布距离(
259 distribution: &Vec<f64>,
260 ideal_distribution: &Vec<键位分布损失函数>,
261 ) -> f64 {
262 let mut distance = 0.0;
263 for (frequency, loss) in zip(distribution, ideal_distribution) {
264 let diff = frequency - loss.ideal;
265 if diff > 0.0 {
266 distance += loss.gt_penalty * diff;
267 } else {
268 distance -= loss.lt_penalty * diff;
269 }
270 }
271 distance
272 }
273
274 #[inline(always)]
275 pub fn 增减(
276 &mut self,
277 index: usize,
278 frequency: u64,
279 code: 编码,
280 duplicate: bool,
281 parameters: &默认目标函数参数,
282 sign: i64,
283 ) {
284 let frequency = frequency as i64 * sign;
285 let radix = self.radix;
286 let length = self
287 .length_breakpoints
288 .iter()
289 .position(|&x| code < x)
290 .unwrap() as u64;
291 self.total_frequency += frequency;
292 self.total_pairs += (length - 1) as i64 * frequency;
293 let partial_weights = &self.partial_weights;
294 if partial_weights.key_distribution.is_some() {
297 let mut current = code;
298 while current > 0 {
299 let key = current % self.radix;
300 if let Some(x) = self.distribution.get_mut(key as usize) {
301 *x += frequency;
302 }
303 current /= self.radix;
304 }
305 }
306 if partial_weights.pair_equivalence.is_some() {
308 let mut code = code;
309 while code > self.radix {
310 let partial_code = (code % self.max_index) as usize;
311 self.total_pair_equivalence += parameters.当量信息[partial_code] * frequency as f64;
312 code /= self.segment;
313 }
314 }
315 if let Some(fingering) = &partial_weights.fingering {
317 let mut code = code;
318 while code > radix {
319 let label = parameters.指法计数[(code % self.max_index) as usize];
320 for (i, weight) in fingering.iter().enumerate() {
321 if weight.is_some() {
322 self.total_fingering[i] += frequency * label[i] as i64;
323 }
324 }
325 code /= self.segment;
326 }
327 }
328 if duplicate {
330 self.total_duplication += frequency;
331 }
332 if let Some(levels) = &partial_weights.levels {
334 for (ilevel, level) in levels.iter().enumerate() {
335 if level.length == length as usize {
336 self.total_levels[ilevel] += frequency;
337 }
338 }
339 }
340 if let Some(tiers) = &partial_weights.tiers {
342 for (itier, tier) in tiers.iter().enumerate() {
343 if index >= tier.top.unwrap_or(self.total_count) {
344 continue;
345 }
346 if duplicate {
348 self.tiers_duplication[itier] += sign;
349 }
350 if let Some(levels) = &tier.levels {
352 for (ilevel, level) in levels.iter().enumerate() {
353 if level.length == length as usize {
354 self.tiers_levels[itier][ilevel] += sign;
355 }
356 }
357 }
358 if let Some(fingering) = &tier.fingering {
360 let mut code = code;
361 while code > radix {
362 let label = parameters.指法计数[(code % self.max_index) as usize];
363 for (i, weight) in fingering.iter().enumerate() {
364 if weight.is_some() {
365 self.tiers_fingering[itier][i] += sign * label[i] as i64;
366 }
367 }
368 code /= self.segment;
369 }
370 }
371 }
372 }
373 }
374}