1use rustc_hash::FxHashMap;
2
3use super::default::默认目标函数参数;
4use super::metric::FingeringMetric;
5use super::metric::FingeringMetricUniform;
6use super::metric::LevelMetricUniform;
7use super::metric::分组指标;
8use super::metric::层级指标;
9use super::metric::键长指标;
10use crate::config::部分权重;
11use crate::{
12 最大按键组合长度, 编码, 部分编码信息, 键位分布损失函数
13};
14use std::iter::zip;
15
16#[derive(Debug, Clone)]
18pub struct 缓存 {
19 partial_weights: 部分权重,
20 total_count: usize,
21 total_frequency: i64,
22 total_pairs: i64,
23 total_extended_pairs: i64,
24 distribution: Vec<i64>,
25 total_pair_equivalence: f64,
26 total_extended_pair_equivalence: f64,
27 total_duplication: i64,
28 total_fingering: [i64; 8],
29 total_levels: Vec<i64>,
30 tiers_duplication: Vec<i64>,
31 tiers_duplication_squared: Vec<i64>,
32 tiers_levels: Vec<Vec<i64>>,
33 tiers_fingering: Vec<[i64; 8]>,
34 max_index: u64,
35 segment: u64,
36 length_breakpoints: Vec<u64>,
37 radix: u64,
38}
39
40impl 缓存 {
41 #[inline(always)]
42 pub fn 处理(
43 &mut self,
44 序号: usize,
45 频率: u64,
46 编码信息: &mut 部分编码信息,
47 参数: &默认目标函数参数,
48 ) {
49 if !编码信息.有变化 {
50 return;
51 }
52 编码信息.有变化 = false;
53 self.增减(序号, 频率, 编码信息.实际编码, 编码信息.选重标记, 参数, 1);
54 if 编码信息.上一个实际编码 == 0 {
55 return;
56 }
57 self.增减(
58 序号,
59 频率,
60 编码信息.上一个实际编码,
61 编码信息.上一个选重标记,
62 参数,
63 -1,
64 );
65 }
66
67 pub fn 汇总(&self, 参数: &默认目标函数参数) -> (分组指标, f64) {
68 let partial_weights = &self.partial_weights;
69 let 键位分布信息 = &参数.键位分布信息;
70 let mut 分组指标 = 分组指标 {
72 tiers: None,
73 key_distribution: None,
74 key_distribution_loss: None,
75 pair_equivalence: None,
76 extended_pair_equivalence: None,
77 fingering: None,
78 duplication: None,
79 levels: None,
80 };
81 let mut 损失函数 = 0.0;
82 if let Some(key_distribution_weight) = partial_weights.key_distribution {
85 let 总频率: i64 = self.distribution.iter().sum();
87 let 分布 = self
88 .distribution
89 .iter()
90 .map(|x| *x as f64 / 总频率 as f64)
91 .collect();
92 let 距离 = 缓存::计算键位分布距离(&分布, 键位分布信息);
93 let mut 分布映射 = FxHashMap::default();
94 for (i, x) in 分布.iter().enumerate() {
95 if let Some(键) = 参数.数字转键.get(&(i as u64)) {
96 分布映射.insert(*键, *x);
97 }
98 }
99 分组指标.key_distribution = Some(分布映射);
100 分组指标.key_distribution_loss = Some(距离);
101 损失函数 += 距离 * key_distribution_weight;
102 }
103 if let Some(equivalence_weight) = partial_weights.pair_equivalence {
105 let equivalence = self.total_pair_equivalence / self.total_pairs as f64;
106 分组指标.pair_equivalence = Some(equivalence);
107 损失函数 += equivalence * equivalence_weight;
108 }
109 if let Some(equivalence_weight) = partial_weights.extended_pair_equivalence {
111 let equivalence =
112 self.total_extended_pair_equivalence / self.total_extended_pairs as f64;
113 分组指标.extended_pair_equivalence = Some(equivalence);
114 损失函数 += equivalence * equivalence_weight;
115 }
116 if let Some(fingering_weight) = &partial_weights.fingering {
118 let mut fingering = FingeringMetric::default();
119 for (i, weight) in fingering_weight.iter().enumerate() {
120 if let Some(weight) = weight {
121 fingering[i] = Some(self.total_fingering[i] as f64 / self.total_pairs as f64);
122 损失函数 += self.total_fingering[i] as f64 * weight;
123 }
124 }
125 分组指标.fingering = Some(fingering);
126 }
127 if let Some(duplication_weight) = partial_weights.duplication {
129 let duplication = self.total_duplication as f64 / self.total_frequency as f64;
130 分组指标.duplication = Some(duplication);
131 损失函数 += duplication * duplication_weight;
132 }
133 if let Some(levels_weight) = &partial_weights.levels {
135 let mut levels: Vec<键长指标> = Vec::new();
136 for (ilevel, level) in levels_weight.iter().enumerate() {
137 let value = self.total_levels[ilevel] as f64 / self.total_frequency as f64;
138 损失函数 += value * level.frequency;
139 levels.push(键长指标 {
140 length: level.length,
141 frequency: value,
142 });
143 }
144 分组指标.levels = Some(levels);
145 }
146 if let Some(tiers_weight) = &partial_weights.tiers {
148 let mut tiers: Vec<层级指标> = tiers_weight
149 .iter()
150 .map(|x| 层级指标 {
151 top: x.top,
152 duplication: None,
153 duplication_squared: None,
154 levels: None,
155 fingering: None,
156 })
157 .collect();
158 for (itier, tier_weights) in tiers_weight.iter().enumerate() {
159 let count = tier_weights.top.unwrap_or(self.total_count) as f64;
160 if let Some(duplication_weight) = tier_weights.duplication {
162 let duplication = self.tiers_duplication[itier];
163 损失函数 += duplication as f64 / count * duplication_weight;
164 tiers[itier].duplication = Some(duplication as u64);
165 }
166 if let Some(duplication_squared_weight) = tier_weights.duplication_squared {
168 let duplication_squared = self.tiers_duplication_squared[itier];
169 损失函数 += duplication_squared as f64 / count * duplication_squared_weight;
170 tiers[itier].duplication_squared = Some(duplication_squared as u64);
171 }
172 if let Some(level_weight) = &tier_weights.levels {
174 for (ilevel, level) in level_weight.iter().enumerate() {
175 损失函数 +=
176 self.tiers_levels[itier][ilevel] as f64 / count * level.frequency;
177 }
178 tiers[itier].levels = Some(
179 level_weight
180 .iter()
181 .enumerate()
182 .map(|(i, v)| LevelMetricUniform {
183 length: v.length,
184 frequency: self.tiers_levels[itier][i] as u64,
185 })
186 .collect(),
187 );
188 }
189 if let Some(fingering_weight) = &tier_weights.fingering {
191 let mut fingering = FingeringMetricUniform::default();
192 for (i, weight) in fingering_weight.iter().enumerate() {
193 if let Some(weight) = weight {
194 let value = self.tiers_fingering[itier][i];
195 fingering[i] = Some(value as u64);
196 损失函数 += value as f64 / count * weight;
197 }
198 }
199 tiers[itier].fingering = Some(fingering);
200 }
201 }
202 分组指标.tiers = Some(tiers);
203 }
204 (分组指标, 损失函数)
205 }
206}
207
208impl 缓存 {
209 pub fn new(
210 partial_weights: &部分权重,
211 radix: u64,
212 total_count: usize,
213 max_index: u64,
214 ) -> Self {
215 let total_frequency = 0;
216 let total_pairs = 0;
217 let total_extended_pairs = 0;
218 let distribution = vec![0; radix as usize];
221 let total_pair_equivalence = 0.0;
222 let total_extended_pair_equivalence = 0.0;
223 let total_duplication = 0;
225 let total_fingering = [0; 8];
226 let nlevel = partial_weights.levels.as_ref().map_or(0, |v| v.len());
227 let total_levels = vec![0; nlevel];
228 let ntier = partial_weights.tiers.as_ref().map_or(0, |v| v.len());
230 let tiers_duplication = vec![0; ntier];
231 let tiers_duplication_squared = vec![0; ntier];
232 let mut tiers_levels = vec![];
233 if let Some(tiers) = &partial_weights.tiers {
234 for tier in tiers {
235 let vec = vec![0; tier.levels.as_ref().map_or(0, |v| v.len())];
236 tiers_levels.push(vec);
237 }
238 }
239 let tiers_fingering = vec![[0; 8]; ntier];
240 let segment = radix.pow((最大按键组合长度 - 1) as u32);
241 let length_breakpoints: Vec<u64> = (0..=8).map(|x| radix.pow(x)).collect();
242
243 Self {
244 partial_weights: partial_weights.clone(),
245 total_count,
246 total_frequency,
247 total_pairs,
248 total_extended_pairs,
249 distribution,
250 total_pair_equivalence,
251 total_extended_pair_equivalence,
252 total_duplication,
253 total_fingering,
254 total_levels,
255 tiers_duplication,
256 tiers_duplication_squared,
257 tiers_levels,
258 tiers_fingering,
259 max_index,
260 segment,
261 length_breakpoints,
262 radix,
263 }
264 }
265
266 fn 计算键位分布距离(
269 distribution: &Vec<f64>,
270 ideal_distribution: &Vec<键位分布损失函数>,
271 ) -> f64 {
272 let mut distance = 0.0;
273 for (frequency, loss) in zip(distribution, ideal_distribution) {
274 let diff = frequency - loss.理想值;
275 if diff > 0.0 {
276 distance += loss.高于惩罚 * diff;
277 } else {
278 distance -= loss.低于惩罚 * diff;
279 }
280 }
281 distance
282 }
283
284 #[inline(always)]
285 pub fn 增减(
286 &mut self,
287 index: usize,
288 frequency: u64,
289 code: 编码,
290 duplicate: u8,
291 parameters: &默认目标函数参数,
292 sign: i64,
293 ) {
294 let frequency = frequency as i64 * sign;
295 let radix = self.radix;
296 let length = self
297 .length_breakpoints
298 .iter()
299 .position(|&x| code < x)
300 .unwrap() as u64;
301 self.total_frequency += frequency;
302 self.total_pairs += (length - 1) as i64 * frequency;
303 let partial_weights = &self.partial_weights;
304 if partial_weights.key_distribution.is_some() {
307 let mut current = code;
308 while current > 0 {
309 let key = current % self.radix;
310 if let Some(x) = self.distribution.get_mut(key as usize) {
311 *x += frequency;
312 }
313 current /= self.radix;
314 }
315 }
316 if partial_weights.pair_equivalence.is_some() {
318 let mut code = code;
319 while code > self.radix {
320 let partial_code = (code % self.max_index) as usize;
321 self.total_pair_equivalence += parameters.当量信息[partial_code] * frequency as f64;
322 code /= self.segment;
323 }
324 }
325 if let Some(fingering) = &partial_weights.fingering {
327 let mut code = code;
328 while code > radix {
329 let label = parameters.指法计数[(code % self.max_index) as usize];
330 for (i, weight) in fingering.iter().enumerate() {
331 if weight.is_some() {
332 self.total_fingering[i] += frequency * label[i] as i64;
333 }
334 }
335 code /= self.segment;
336 }
337 }
338 if duplicate > 0 {
340 self.total_duplication += frequency;
341 }
342 if let Some(levels) = &partial_weights.levels {
344 for (ilevel, level) in levels.iter().enumerate() {
345 if level.length == length as usize {
346 self.total_levels[ilevel] += frequency;
347 }
348 }
349 }
350 if let Some(tiers) = &partial_weights.tiers {
352 for (itier, tier) in tiers.iter().enumerate() {
353 if index >= tier.top.unwrap_or(self.total_count) {
354 continue;
355 }
356 if duplicate > 0 {
358 self.tiers_duplication[itier] += sign;
359 self.tiers_duplication_squared[itier] += sign * (2 * duplicate as i64 - 1);
360 }
361 if let Some(levels) = &tier.levels {
363 for (ilevel, level) in levels.iter().enumerate() {
364 if level.length == length as usize {
365 self.tiers_levels[itier][ilevel] += sign;
366 }
367 }
368 }
369 if let Some(fingering) = &tier.fingering {
371 let mut code = code;
372 while code > radix {
373 let label = parameters.指法计数[(code % self.max_index) as usize];
374 for (i, weight) in fingering.iter().enumerate() {
375 if weight.is_some() {
376 self.tiers_fingering[itier][i] += sign * label[i] as i64;
377 }
378 }
379 code /= self.segment;
380 }
381 }
382 }
383 }
384 }
385}