Skip to main content

chai/objectives/
cache.rs

1use rustc_hash::FxHashMap;
2
3use super::default::默认目标函数参数;
4use super::metric::FingeringMetric;
5use super::metric::FingeringMetricUniform;
6use super::metric::LevelMetricUniform;
7use super::metric::分组指标;
8use super::metric::层级指标;
9use super::metric::键长指标;
10use crate::config::部分权重;
11use crate::{
12    最大按键组合长度, 编码, 部分编码信息, 键位分布损失函数
13};
14use std::iter::zip;
15
16// 用于缓存计算目标函数的中间结果,方便实现增量计算
17#[derive(Debug, Clone)]
18pub struct 缓存 {
19    partial_weights: 部分权重,
20    total_count: usize,
21    total_frequency: i64,
22    total_pairs: i64,
23    total_extended_pairs: i64,
24    distribution: Vec<i64>,
25    total_pair_equivalence: f64,
26    total_extended_pair_equivalence: f64,
27    total_duplication: i64,
28    total_fingering: [i64; 8],
29    total_levels: Vec<i64>,
30    tiers_duplication: Vec<i64>,
31    tiers_duplication_squared: Vec<i64>,
32    tiers_levels: Vec<Vec<i64>>,
33    tiers_fingering: Vec<[i64; 8]>,
34    max_index: u64,
35    segment: u64,
36    length_breakpoints: Vec<u64>,
37    radix: u64,
38}
39
40impl 缓存 {
41    #[inline(always)]
42    pub fn 处理(
43        &mut self,
44        序号: usize,
45        频率: u64,
46        编码信息: &mut 部分编码信息,
47        参数: &默认目标函数参数,
48    ) {
49        if !编码信息.有变化 {
50            return;
51        }
52        编码信息.有变化 = false;
53        self.增减(序号, 频率, 编码信息.实际编码, 编码信息.选重标记, 参数, 1);
54        if 编码信息.上一个实际编码 == 0 {
55            return;
56        }
57        self.增减(
58            序号,
59            频率,
60            编码信息.上一个实际编码,
61            编码信息.上一个选重标记,
62            参数,
63            -1,
64        );
65    }
66
67    pub fn 汇总(&self, 参数: &默认目标函数参数) -> (分组指标, f64) {
68        let partial_weights = &self.partial_weights;
69        let 键位分布信息 = &参数.键位分布信息;
70        // 初始化返回值和标量化的损失函数
71        let mut 分组指标 = 分组指标 {
72            tiers: None,
73            key_distribution: None,
74            key_distribution_loss: None,
75            pair_equivalence: None,
76            extended_pair_equivalence: None,
77            fingering: None,
78            duplication: None,
79            levels: None,
80        };
81        let mut 损失函数 = 0.0;
82        // 一、全局指标
83        // 1. 按键分布
84        if let Some(key_distribution_weight) = partial_weights.key_distribution {
85            // 首先归一化
86            let 总频率: i64 = self.distribution.iter().sum();
87            let 分布 = self
88                .distribution
89                .iter()
90                .map(|x| *x as f64 / 总频率 as f64)
91                .collect();
92            let 距离 = 缓存::计算键位分布距离(&分布, 键位分布信息);
93            let mut 分布映射 = FxHashMap::default();
94            for (i, x) in 分布.iter().enumerate() {
95                if let Some(键) = 参数.数字转键.get(&(i as u64)) {
96                    分布映射.insert(*键, *x);
97                }
98            }
99            分组指标.key_distribution = Some(分布映射);
100            分组指标.key_distribution_loss = Some(距离);
101            损失函数 += 距离 * key_distribution_weight;
102        }
103        // 2. 组合当量
104        if let Some(equivalence_weight) = partial_weights.pair_equivalence {
105            let equivalence = self.total_pair_equivalence / self.total_pairs as f64;
106            分组指标.pair_equivalence = Some(equivalence);
107            损失函数 += equivalence * equivalence_weight;
108        }
109        // 3. 词间当量
110        if let Some(equivalence_weight) = partial_weights.extended_pair_equivalence {
111            let equivalence =
112                self.total_extended_pair_equivalence / self.total_extended_pairs as f64;
113            分组指标.extended_pair_equivalence = Some(equivalence);
114            损失函数 += equivalence * equivalence_weight;
115        }
116        // 4. 差指法
117        if let Some(fingering_weight) = &partial_weights.fingering {
118            let mut fingering = FingeringMetric::default();
119            for (i, weight) in fingering_weight.iter().enumerate() {
120                if let Some(weight) = weight {
121                    fingering[i] = Some(self.total_fingering[i] as f64 / self.total_pairs as f64);
122                    损失函数 += self.total_fingering[i] as f64 * weight;
123                }
124            }
125            分组指标.fingering = Some(fingering);
126        }
127        // 5. 重码
128        if let Some(duplication_weight) = partial_weights.duplication {
129            let duplication = self.total_duplication as f64 / self.total_frequency as f64;
130            分组指标.duplication = Some(duplication);
131            损失函数 += duplication * duplication_weight;
132        }
133        // 6. 简码
134        if let Some(levels_weight) = &partial_weights.levels {
135            let mut levels: Vec<键长指标> = Vec::new();
136            for (ilevel, level) in levels_weight.iter().enumerate() {
137                let value = self.total_levels[ilevel] as f64 / self.total_frequency as f64;
138                损失函数 += value * level.frequency;
139                levels.push(键长指标 {
140                    length: level.length,
141                    frequency: value,
142                });
143            }
144            分组指标.levels = Some(levels);
145        }
146        // 二、分级指标
147        if let Some(tiers_weight) = &partial_weights.tiers {
148            let mut tiers: Vec<层级指标> = tiers_weight
149                .iter()
150                .map(|x| 层级指标 {
151                    top: x.top,
152                    duplication: None,
153                    duplication_squared: None,
154                    levels: None,
155                    fingering: None,
156                })
157                .collect();
158            for (itier, tier_weights) in tiers_weight.iter().enumerate() {
159                let count = tier_weights.top.unwrap_or(self.total_count) as f64;
160                // 1. 重码
161                if let Some(duplication_weight) = tier_weights.duplication {
162                    let duplication = self.tiers_duplication[itier];
163                    损失函数 += duplication as f64 / count * duplication_weight;
164                    tiers[itier].duplication = Some(duplication as u64);
165                }
166                // 1b. 选重平方
167                if let Some(duplication_squared_weight) = tier_weights.duplication_squared {
168                    let duplication_squared = self.tiers_duplication_squared[itier];
169                    损失函数 += duplication_squared as f64 / count * duplication_squared_weight;
170                    tiers[itier].duplication_squared = Some(duplication_squared as u64);
171                }
172                // 2. 简码
173                if let Some(level_weight) = &tier_weights.levels {
174                    for (ilevel, level) in level_weight.iter().enumerate() {
175                        损失函数 +=
176                            self.tiers_levels[itier][ilevel] as f64 / count * level.frequency;
177                    }
178                    tiers[itier].levels = Some(
179                        level_weight
180                            .iter()
181                            .enumerate()
182                            .map(|(i, v)| LevelMetricUniform {
183                                length: v.length,
184                                frequency: self.tiers_levels[itier][i] as u64,
185                            })
186                            .collect(),
187                    );
188                }
189                // 3. 差指法
190                if let Some(fingering_weight) = &tier_weights.fingering {
191                    let mut fingering = FingeringMetricUniform::default();
192                    for (i, weight) in fingering_weight.iter().enumerate() {
193                        if let Some(weight) = weight {
194                            let value = self.tiers_fingering[itier][i];
195                            fingering[i] = Some(value as u64);
196                            损失函数 += value as f64 / count * weight;
197                        }
198                    }
199                    tiers[itier].fingering = Some(fingering);
200                }
201            }
202            分组指标.tiers = Some(tiers);
203        }
204        (分组指标, 损失函数)
205    }
206}
207
208impl 缓存 {
209    pub fn new(
210        partial_weights: &部分权重,
211        radix: u64,
212        total_count: usize,
213        max_index: u64,
214    ) -> Self {
215        let total_frequency = 0;
216        let total_pairs = 0;
217        let total_extended_pairs = 0;
218        // 初始化全局指标的变量
219        // 1. 只有加权指标,没有计数指标
220        let distribution = vec![0; radix as usize];
221        let total_pair_equivalence = 0.0;
222        let total_extended_pair_equivalence = 0.0;
223        // 2. 有加权指标,也有计数指标
224        let total_duplication = 0;
225        let total_fingering = [0; 8];
226        let nlevel = partial_weights.levels.as_ref().map_or(0, |v| v.len());
227        let total_levels = vec![0; nlevel];
228        // 初始化分级指标的变量
229        let ntier = partial_weights.tiers.as_ref().map_or(0, |v| v.len());
230        let tiers_duplication = vec![0; ntier];
231        let tiers_duplication_squared = vec![0; ntier];
232        let mut tiers_levels = vec![];
233        if let Some(tiers) = &partial_weights.tiers {
234            for tier in tiers {
235                let vec = vec![0; tier.levels.as_ref().map_or(0, |v| v.len())];
236                tiers_levels.push(vec);
237            }
238        }
239        let tiers_fingering = vec![[0; 8]; ntier];
240        let segment = radix.pow((最大按键组合长度 - 1) as u32);
241        let length_breakpoints: Vec<u64> = (0..=8).map(|x| radix.pow(x)).collect();
242
243        Self {
244            partial_weights: partial_weights.clone(),
245            total_count,
246            total_frequency,
247            total_pairs,
248            total_extended_pairs,
249            distribution,
250            total_pair_equivalence,
251            total_extended_pair_equivalence,
252            total_duplication,
253            total_fingering,
254            total_levels,
255            tiers_duplication,
256            tiers_duplication_squared,
257            tiers_levels,
258            tiers_fingering,
259            max_index,
260            segment,
261            length_breakpoints,
262            radix,
263        }
264    }
265
266    /// 用指分布偏差
267    /// 计算按键使用率与理想使用率之间的偏差。对于每个按键,偏差是实际频率与理想频率之间的差值乘以一个惩罚系数。用户可以根据自己的喜好自定义理想频率和惩罚系数。
268    fn 计算键位分布距离(
269        distribution: &Vec<f64>,
270        ideal_distribution: &Vec<键位分布损失函数>,
271    ) -> f64 {
272        let mut distance = 0.0;
273        for (frequency, loss) in zip(distribution, ideal_distribution) {
274            let diff = frequency - loss.理想值;
275            if diff > 0.0 {
276                distance += loss.高于惩罚 * diff;
277            } else {
278                distance -= loss.低于惩罚 * diff;
279            }
280        }
281        distance
282    }
283
284    #[inline(always)]
285    pub fn 增减(
286        &mut self,
287        index: usize,
288        frequency: u64,
289        code: 编码,
290        duplicate: u8,
291        parameters: &默认目标函数参数,
292        sign: i64,
293    ) {
294        let frequency = frequency as i64 * sign;
295        let radix = self.radix;
296        let length = self
297            .length_breakpoints
298            .iter()
299            .position(|&x| code < x)
300            .unwrap() as u64;
301        self.total_frequency += frequency;
302        self.total_pairs += (length - 1) as i64 * frequency;
303        let partial_weights = &self.partial_weights;
304        // 一、全局指标
305        // 1. 按键分布
306        if partial_weights.key_distribution.is_some() {
307            let mut current = code;
308            while current > 0 {
309                let key = current % self.radix;
310                if let Some(x) = self.distribution.get_mut(key as usize) {
311                    *x += frequency;
312                }
313                current /= self.radix;
314            }
315        }
316        // 2. 组合当量
317        if partial_weights.pair_equivalence.is_some() {
318            let mut code = code;
319            while code > self.radix {
320                let partial_code = (code % self.max_index) as usize;
321                self.total_pair_equivalence += parameters.当量信息[partial_code] * frequency as f64;
322                code /= self.segment;
323            }
324        }
325        // 4. 差指法
326        if let Some(fingering) = &partial_weights.fingering {
327            let mut code = code;
328            while code > radix {
329                let label = parameters.指法计数[(code % self.max_index) as usize];
330                for (i, weight) in fingering.iter().enumerate() {
331                    if weight.is_some() {
332                        self.total_fingering[i] += frequency * label[i] as i64;
333                    }
334                }
335                code /= self.segment;
336            }
337        }
338        // 5. 重码
339        if duplicate > 0 {
340            self.total_duplication += frequency;
341        }
342        // 6. 简码
343        if let Some(levels) = &partial_weights.levels {
344            for (ilevel, level) in levels.iter().enumerate() {
345                if level.length == length as usize {
346                    self.total_levels[ilevel] += frequency;
347                }
348            }
349        }
350        // 二、分级指标
351        if let Some(tiers) = &partial_weights.tiers {
352            for (itier, tier) in tiers.iter().enumerate() {
353                if index >= tier.top.unwrap_or(self.total_count) {
354                    continue;
355                }
356                // 1. 重码
357                if duplicate > 0 {
358                    self.tiers_duplication[itier] += sign;
359                    self.tiers_duplication_squared[itier] += sign * (2 * duplicate as i64 - 1);
360                }
361                // 2. 简码
362                if let Some(levels) = &tier.levels {
363                    for (ilevel, level) in levels.iter().enumerate() {
364                        if level.length == length as usize {
365                            self.tiers_levels[itier][ilevel] += sign;
366                        }
367                    }
368                }
369                // 3. 差指法
370                if let Some(fingering) = &tier.fingering {
371                    let mut code = code;
372                    while code > radix {
373                        let label = parameters.指法计数[(code % self.max_index) as usize];
374                        for (i, weight) in fingering.iter().enumerate() {
375                            if weight.is_some() {
376                                self.tiers_fingering[itier][i] += sign * label[i] as i64;
377                            }
378                        }
379                        code /= self.segment;
380                    }
381                }
382            }
383        }
384    }
385}