chai/objectives/
cache.rs

1use super::default::默认目标函数参数;
2use super::metric::FingeringMetric;
3use super::metric::FingeringMetricUniform;
4use super::metric::LevelMetricUniform;
5use super::metric::分组指标;
6use super::metric::层级指标;
7use super::metric::键长指标;
8use crate::config::PartialWeights;
9use crate::data::最大按键组合长度;
10use crate::data::编码;
11use crate::data::部分编码信息;
12use crate::data::键位分布损失函数;
13use std::collections::HashMap;
14use std::iter::zip;
15
16// 用于缓存计算目标函数的中间结果,方便实现增量计算
17#[derive(Debug, Clone)]
18pub struct 缓存 {
19    partial_weights: PartialWeights,
20    total_count: usize,
21    total_frequency: i64,
22    total_pairs: i64,
23    total_extended_pairs: i64,
24    distribution: Vec<i64>,
25    total_pair_equivalence: f64,
26    total_extended_pair_equivalence: f64,
27    total_duplication: i64,
28    total_fingering: [i64; 8],
29    total_levels: Vec<i64>,
30    tiers_duplication: Vec<i64>,
31    tiers_levels: Vec<Vec<i64>>,
32    tiers_fingering: Vec<[i64; 8]>,
33    max_index: u64,
34    segment: u64,
35    length_breakpoints: Vec<u64>,
36    radix: u64,
37}
38
39impl 缓存 {
40    #[inline(always)]
41    pub fn 处理(
42        &mut self,
43        序号: usize,
44        频率: u64,
45        编码信息: &mut 部分编码信息,
46        参数: &默认目标函数参数,
47    ) {
48        if !编码信息.有变化 {
49            return;
50        }
51        编码信息.有变化 = false;
52        self.增减(序号, 频率, 编码信息.实际编码, 编码信息.选重标记, 参数, 1);
53        if 编码信息.上一个实际编码 == 0 {
54            return;
55        }
56        self.增减(
57            序号,
58            频率,
59            编码信息.上一个实际编码,
60            编码信息.上一个选重标记,
61            参数,
62            -1,
63        );
64    }
65
66    pub fn 汇总(&self, 参数: &默认目标函数参数) -> (分组指标, f64) {
67        let partial_weights = &self.partial_weights;
68        let 键位分布信息 = &参数.键位分布信息;
69        // 初始化返回值和标量化的损失函数
70        let mut 分组指标 = 分组指标 {
71            tiers: None,
72            key_distribution: None,
73            key_distribution_loss: None,
74            pair_equivalence: None,
75            extended_pair_equivalence: None,
76            fingering: None,
77            duplication: None,
78            levels: None,
79        };
80        let mut 损失函数 = 0.0;
81        // 一、全局指标
82        // 1. 按键分布
83        if let Some(key_distribution_weight) = partial_weights.key_distribution {
84            // 首先归一化
85            let 总频率: i64 = self.distribution.iter().sum();
86            let 分布 = self
87                .distribution
88                .iter()
89                .map(|x| *x as f64 / 总频率 as f64)
90                .collect();
91            let 距离 = 缓存::计算键位分布距离(&分布, 键位分布信息);
92            let mut 分布映射 = HashMap::new();
93            for (i, x) in 分布.iter().enumerate() {
94                if let Some(键) = 参数.数字转键.get(&(i as u64)) {
95                    分布映射.insert(*键, *x);
96                }
97            }
98            分组指标.key_distribution = Some(分布映射);
99            分组指标.key_distribution_loss = Some(距离);
100            损失函数 += 距离 * key_distribution_weight;
101        }
102        // 2. 组合当量
103        if let Some(equivalence_weight) = partial_weights.pair_equivalence {
104            let equivalence = self.total_pair_equivalence / self.total_pairs as f64;
105            分组指标.pair_equivalence = Some(equivalence);
106            损失函数 += equivalence * equivalence_weight;
107        }
108        // 3. 词间当量
109        if let Some(equivalence_weight) = partial_weights.extended_pair_equivalence {
110            let equivalence =
111                self.total_extended_pair_equivalence / self.total_extended_pairs as f64;
112            分组指标.extended_pair_equivalence = Some(equivalence);
113            损失函数 += equivalence * equivalence_weight;
114        }
115        // 4. 差指法
116        if let Some(fingering_weight) = &partial_weights.fingering {
117            let mut fingering = FingeringMetric::default();
118            for (i, weight) in fingering_weight.iter().enumerate() {
119                if let Some(weight) = weight {
120                    fingering[i] = Some(self.total_fingering[i] as f64 / self.total_pairs as f64);
121                    损失函数 += self.total_fingering[i] as f64 * weight;
122                }
123            }
124            分组指标.fingering = Some(fingering);
125        }
126        // 5. 重码
127        if let Some(duplication_weight) = partial_weights.duplication {
128            let duplication = self.total_duplication as f64 / self.total_frequency as f64;
129            分组指标.duplication = Some(duplication);
130            损失函数 += duplication * duplication_weight;
131        }
132        // 6. 简码
133        if let Some(levels_weight) = &partial_weights.levels {
134            let mut levels: Vec<键长指标> = Vec::new();
135            for (ilevel, level) in levels_weight.iter().enumerate() {
136                let value = self.total_levels[ilevel] as f64 / self.total_frequency as f64;
137                损失函数 += value * level.frequency;
138                levels.push(键长指标 {
139                    length: level.length,
140                    frequency: value,
141                });
142            }
143            分组指标.levels = Some(levels);
144        }
145        // 二、分级指标
146        if let Some(tiers_weight) = &partial_weights.tiers {
147            let mut tiers: Vec<层级指标> = tiers_weight
148                .iter()
149                .map(|x| 层级指标 {
150                    top: x.top,
151                    duplication: None,
152                    levels: None,
153                    fingering: None,
154                })
155                .collect();
156            for (itier, tier_weights) in tiers_weight.iter().enumerate() {
157                let count = tier_weights.top.unwrap_or(self.total_count) as f64;
158                // 1. 重码
159                if let Some(duplication_weight) = tier_weights.duplication {
160                    let duplication = self.tiers_duplication[itier];
161                    损失函数 += duplication as f64 / count * duplication_weight;
162                    tiers[itier].duplication = Some(duplication as u64);
163                }
164                // 2. 简码
165                if let Some(level_weight) = &tier_weights.levels {
166                    for (ilevel, level) in level_weight.iter().enumerate() {
167                        损失函数 +=
168                            self.tiers_levels[itier][ilevel] as f64 / count * level.frequency;
169                    }
170                    tiers[itier].levels = Some(
171                        level_weight
172                            .iter()
173                            .enumerate()
174                            .map(|(i, v)| LevelMetricUniform {
175                                length: v.length,
176                                frequency: self.tiers_levels[itier][i] as u64,
177                            })
178                            .collect(),
179                    );
180                }
181                // 3. 差指法
182                if let Some(fingering_weight) = &tier_weights.fingering {
183                    let mut fingering = FingeringMetricUniform::default();
184                    for (i, weight) in fingering_weight.iter().enumerate() {
185                        if let Some(weight) = weight {
186                            let value = self.tiers_fingering[itier][i];
187                            fingering[i] = Some(value as u64);
188                            损失函数 += value as f64 / count * weight;
189                        }
190                    }
191                    tiers[itier].fingering = Some(fingering);
192                }
193            }
194            分组指标.tiers = Some(tiers);
195        }
196        (分组指标, 损失函数)
197    }
198}
199
200impl 缓存 {
201    pub fn new(
202        partial_weights: &PartialWeights,
203        radix: u64,
204        total_count: usize,
205        max_index: u64,
206    ) -> Self {
207        let total_frequency = 0;
208        let total_pairs = 0;
209        let total_extended_pairs = 0;
210        // 初始化全局指标的变量
211        // 1. 只有加权指标,没有计数指标
212        let distribution = vec![0; radix as usize];
213        let total_pair_equivalence = 0.0;
214        let total_extended_pair_equivalence = 0.0;
215        // 2. 有加权指标,也有计数指标
216        let total_duplication = 0;
217        let total_fingering = [0; 8];
218        let nlevel = partial_weights.levels.as_ref().map_or(0, |v| v.len());
219        let total_levels = vec![0; nlevel];
220        // 初始化分级指标的变量
221        let ntier = partial_weights.tiers.as_ref().map_or(0, |v| v.len());
222        let tiers_duplication = vec![0; ntier];
223        let mut tiers_levels = vec![];
224        if let Some(tiers) = &partial_weights.tiers {
225            for tier in tiers {
226                let vec = vec![0; tier.levels.as_ref().map_or(0, |v| v.len())];
227                tiers_levels.push(vec);
228            }
229        }
230        let tiers_fingering = vec![[0; 8]; ntier];
231        let segment = radix.pow((最大按键组合长度 - 1) as u32);
232        let length_breakpoints: Vec<u64> = (0..=8).map(|x| radix.pow(x)).collect();
233
234        Self {
235            partial_weights: partial_weights.clone(),
236            total_count,
237            total_frequency,
238            total_pairs,
239            total_extended_pairs,
240            distribution,
241            total_pair_equivalence,
242            total_extended_pair_equivalence,
243            total_duplication,
244            total_fingering,
245            total_levels,
246            tiers_duplication,
247            tiers_levels,
248            tiers_fingering,
249            max_index,
250            segment,
251            length_breakpoints,
252            radix,
253        }
254    }
255
256    /// 用指分布偏差
257    /// 计算按键使用率与理想使用率之间的偏差。对于每个按键,偏差是实际频率与理想频率之间的差值乘以一个惩罚系数。用户可以根据自己的喜好自定义理想频率和惩罚系数。
258    fn 计算键位分布距离(
259        distribution: &Vec<f64>,
260        ideal_distribution: &Vec<键位分布损失函数>,
261    ) -> f64 {
262        let mut distance = 0.0;
263        for (frequency, loss) in zip(distribution, ideal_distribution) {
264            let diff = frequency - loss.ideal;
265            if diff > 0.0 {
266                distance += loss.gt_penalty * diff;
267            } else {
268                distance -= loss.lt_penalty * diff;
269            }
270        }
271        distance
272    }
273
274    #[inline(always)]
275    pub fn 增减(
276        &mut self,
277        index: usize,
278        frequency: u64,
279        code: 编码,
280        duplicate: bool,
281        parameters: &默认目标函数参数,
282        sign: i64,
283    ) {
284        let frequency = frequency as i64 * sign;
285        let radix = self.radix;
286        let length = self
287            .length_breakpoints
288            .iter()
289            .position(|&x| code < x)
290            .unwrap() as u64;
291        self.total_frequency += frequency;
292        self.total_pairs += (length - 1) as i64 * frequency;
293        let partial_weights = &self.partial_weights;
294        // 一、全局指标
295        // 1. 按键分布
296        if partial_weights.key_distribution.is_some() {
297            let mut current = code;
298            while current > 0 {
299                let key = current % self.radix;
300                if let Some(x) = self.distribution.get_mut(key as usize) {
301                    *x += frequency;
302                }
303                current /= self.radix;
304            }
305        }
306        // 2. 组合当量
307        if partial_weights.pair_equivalence.is_some() {
308            let mut code = code;
309            while code > self.radix {
310                let partial_code = (code % self.max_index) as usize;
311                self.total_pair_equivalence += parameters.当量信息[partial_code] * frequency as f64;
312                code /= self.segment;
313            }
314        }
315        // 4. 差指法
316        if let Some(fingering) = &partial_weights.fingering {
317            let mut code = code;
318            while code > radix {
319                let label = parameters.指法计数[(code % self.max_index) as usize];
320                for (i, weight) in fingering.iter().enumerate() {
321                    if weight.is_some() {
322                        self.total_fingering[i] += frequency * label[i] as i64;
323                    }
324                }
325                code /= self.segment;
326            }
327        }
328        // 5. 重码
329        if duplicate {
330            self.total_duplication += frequency;
331        }
332        // 6. 简码
333        if let Some(levels) = &partial_weights.levels {
334            for (ilevel, level) in levels.iter().enumerate() {
335                if level.length == length as usize {
336                    self.total_levels[ilevel] += frequency;
337                }
338            }
339        }
340        // 二、分级指标
341        if let Some(tiers) = &partial_weights.tiers {
342            for (itier, tier) in tiers.iter().enumerate() {
343                if index >= tier.top.unwrap_or(self.total_count) {
344                    continue;
345                }
346                // 1. 重码
347                if duplicate {
348                    self.tiers_duplication[itier] += sign;
349                }
350                // 2. 简码
351                if let Some(levels) = &tier.levels {
352                    for (ilevel, level) in levels.iter().enumerate() {
353                        if level.length == length as usize {
354                            self.tiers_levels[itier][ilevel] += sign;
355                        }
356                    }
357                }
358                // 3. 差指法
359                if let Some(fingering) = &tier.fingering {
360                    let mut code = code;
361                    while code > radix {
362                        let label = parameters.指法计数[(code % self.max_index) as usize];
363                        for (i, weight) in fingering.iter().enumerate() {
364                            if weight.is_some() {
365                                self.tiers_fingering[itier][i] += sign * label[i] as i64;
366                            }
367                        }
368                        code /= self.segment;
369                    }
370                }
371            }
372        }
373    }
374}