chai/objectives/
cache.rs

1use super::default::Parameters;
2use super::metric::FingeringMetric;
3use super::metric::FingeringMetricUniform;
4use super::metric::LevelMetric;
5use super::metric::LevelMetricUniform;
6use super::metric::PartialMetric;
7use super::metric::TierMetric;
8use crate::config::PartialWeights;
9use crate::data::最大按键组合长度;
10use crate::data::编码;
11use crate::data::部分编码信息;
12use crate::data::键位分布损失函数;
13use std::iter::zip;
14
15#[derive(Debug, Clone)]
16pub struct Cache {
17    partial_weights: PartialWeights,
18    total_count: usize,
19    total_frequency: i64,
20    total_pairs: i64,
21    total_extended_pairs: i64,
22    distribution: Vec<i64>,
23    total_pair_equivalence: f64,
24    total_extended_pair_equivalence: f64,
25    total_duplication: i64,
26    total_fingering: [i64; 8],
27    total_levels: Vec<i64>,
28    tiers_duplication: Vec<i64>,
29    tiers_levels: Vec<Vec<i64>>,
30    tiers_fingering: Vec<[i64; 8]>,
31    max_index: u64,
32    segment: u64,
33    length_breakpoints: Vec<u64>,
34    radix: u64,
35}
36
37impl Cache {
38    #[inline(always)]
39    pub fn process(
40        &mut self,
41        index: usize,
42        frequency: u64,
43        c: &mut 部分编码信息,
44        parameters: &Parameters,
45    ) {
46        if !c.有变化 {
47            return;
48        }
49        c.有变化 = false;
50        self.accumulate(index, frequency, c.实际编码, c.选重标记, parameters, 1);
51        if c.上一个实际编码 == 0 {
52            return;
53        }
54        self.accumulate(
55            index,
56            frequency,
57            c.上一个实际编码,
58            c.上一个选重标记,
59            parameters,
60            -1,
61        );
62    }
63
64    pub fn finalize(&self, parameters: &Parameters) -> (PartialMetric, f64) {
65        let partial_weights = &self.partial_weights;
66        let ideal_distribution = &parameters.ideal_distribution;
67        // 初始化返回值和标量化的损失函数
68        let mut partial_metric = PartialMetric {
69            tiers: None,
70            key_distribution: None,
71            pair_equivalence: None,
72            extended_pair_equivalence: None,
73            fingering: None,
74            duplication: None,
75            levels: None,
76        };
77        let mut loss = 0.0;
78        // 一、全局指标
79        // 1. 按键分布
80        if let Some(key_distribution_weight) = partial_weights.key_distribution {
81            // 首先归一化
82            let total: i64 = self.distribution.iter().sum();
83            let distribution = self
84                .distribution
85                .iter()
86                .map(|x| *x as f64 / total as f64)
87                .collect();
88            let distance = Cache::get_distribution_distance(&distribution, ideal_distribution);
89            partial_metric.key_distribution = Some(distance);
90            loss += distance * key_distribution_weight;
91        }
92        // 2. 组合当量
93        if let Some(equivalence_weight) = partial_weights.pair_equivalence {
94            let equivalence = self.total_pair_equivalence / self.total_pairs as f64;
95            partial_metric.pair_equivalence = Some(equivalence);
96            loss += equivalence * equivalence_weight;
97        }
98        // 3. 词间当量
99        if let Some(equivalence_weight) = partial_weights.extended_pair_equivalence {
100            let equivalence =
101                self.total_extended_pair_equivalence / self.total_extended_pairs as f64;
102            partial_metric.extended_pair_equivalence = Some(equivalence);
103            loss += equivalence * equivalence_weight;
104        }
105        // 4. 差指法
106        if let Some(fingering_weight) = &partial_weights.fingering {
107            let mut fingering = FingeringMetric::default();
108            for (i, weight) in fingering_weight.iter().enumerate() {
109                if let Some(weight) = weight {
110                    fingering[i] = Some(self.total_fingering[i] as f64 / self.total_pairs as f64);
111                    loss += self.total_fingering[i] as f64 * weight;
112                }
113            }
114            partial_metric.fingering = Some(fingering);
115        }
116        // 5. 重码
117        if let Some(duplication_weight) = partial_weights.duplication {
118            let duplication = self.total_duplication as f64 / self.total_frequency as f64;
119            partial_metric.duplication = Some(duplication);
120            loss += duplication * duplication_weight;
121        }
122        // 6. 简码
123        if let Some(levels_weight) = &partial_weights.levels {
124            let mut levels: Vec<LevelMetric> = Vec::new();
125            for (ilevel, level) in levels_weight.iter().enumerate() {
126                let value = self.total_levels[ilevel] as f64 / self.total_frequency as f64;
127                loss += value * level.frequency;
128                levels.push(LevelMetric {
129                    length: level.length,
130                    frequency: value,
131                });
132            }
133            partial_metric.levels = Some(levels);
134        }
135        // 二、分级指标
136        if let Some(tiers_weight) = &partial_weights.tiers {
137            let mut tiers: Vec<TierMetric> = tiers_weight
138                .iter()
139                .map(|x| TierMetric {
140                    top: x.top,
141                    duplication: None,
142                    levels: None,
143                    fingering: None,
144                })
145                .collect();
146            for (itier, tier_weights) in tiers_weight.iter().enumerate() {
147                let count = tier_weights.top.unwrap_or(self.total_count) as f64;
148                // 1. 重码
149                if let Some(duplication_weight) = tier_weights.duplication {
150                    let duplication = self.tiers_duplication[itier];
151                    loss += duplication as f64 / count * duplication_weight;
152                    tiers[itier].duplication = Some(duplication as u64);
153                }
154                // 2. 简码
155                if let Some(level_weight) = &tier_weights.levels {
156                    for (ilevel, level) in level_weight.iter().enumerate() {
157                        loss += self.tiers_levels[itier][ilevel] as f64 / count * level.frequency;
158                    }
159                    tiers[itier].levels = Some(
160                        level_weight
161                            .iter()
162                            .enumerate()
163                            .map(|(i, v)| LevelMetricUniform {
164                                length: v.length,
165                                frequency: self.tiers_levels[itier][i] as u64,
166                            })
167                            .collect(),
168                    );
169                }
170                // 3. 差指法
171                if let Some(fingering_weight) = &tier_weights.fingering {
172                    let mut fingering = FingeringMetricUniform::default();
173                    for (i, weight) in fingering_weight.iter().enumerate() {
174                        if let Some(weight) = weight {
175                            let value = self.tiers_fingering[itier][i];
176                            fingering[i] = Some(value as u64);
177                            loss += value as f64 / count * weight;
178                        }
179                    }
180                    tiers[itier].fingering = Some(fingering);
181                }
182            }
183            partial_metric.tiers = Some(tiers);
184        }
185        (partial_metric, loss)
186    }
187}
188
189impl Cache {
190    pub fn new(
191        partial_weights: &PartialWeights,
192        radix: u64,
193        total_count: usize,
194        max_index: u64,
195    ) -> Self {
196        let total_frequency = 0;
197        let total_pairs = 0;
198        let total_extended_pairs = 0;
199        // 初始化全局指标的变量
200        // 1. 只有加权指标,没有计数指标
201        let distribution = vec![0; radix as usize];
202        let total_pair_equivalence = 0.0;
203        let total_extended_pair_equivalence = 0.0;
204        // 2. 有加权指标,也有计数指标
205        let total_duplication = 0;
206        let total_fingering = [0; 8];
207        let nlevel = partial_weights.levels.as_ref().map_or(0, |v| v.len());
208        let total_levels = vec![0; nlevel];
209        // 初始化分级指标的变量
210        let ntier = partial_weights.tiers.as_ref().map_or(0, |v| v.len());
211        let tiers_duplication = vec![0; ntier];
212        let mut tiers_levels = vec![];
213        if let Some(tiers) = &partial_weights.tiers {
214            for tier in tiers {
215                let vec = vec![0; tier.levels.as_ref().map_or(0, |v| v.len())];
216                tiers_levels.push(vec);
217            }
218        }
219        let tiers_fingering = vec![[0; 8]; ntier];
220        let segment = radix.pow((最大按键组合长度 - 1) as u32);
221        let length_breakpoints: Vec<u64> = (0..=8).map(|x| radix.pow(x)).collect();
222
223        Self {
224            partial_weights: partial_weights.clone(),
225            total_count,
226            total_frequency,
227            total_pairs,
228            total_extended_pairs,
229            distribution,
230            total_pair_equivalence,
231            total_extended_pair_equivalence,
232            total_duplication,
233            total_fingering,
234            total_levels,
235            tiers_duplication,
236            tiers_levels,
237            tiers_fingering,
238            max_index,
239            segment,
240            length_breakpoints,
241            radix,
242        }
243    }
244
245    /// 用指分布偏差
246    /// 计算按键使用率与理想使用率之间的偏差。对于每个按键,偏差是实际频率与理想频率之间的差值乘以一个惩罚系数。用户可以根据自己的喜好自定义理想频率和惩罚系数。
247    fn get_distribution_distance(
248        distribution: &Vec<f64>,
249        ideal_distribution: &Vec<键位分布损失函数>,
250    ) -> f64 {
251        let mut distance = 0.0;
252        for (frequency, loss) in zip(distribution, ideal_distribution) {
253            let diff = frequency - loss.ideal;
254            if diff > 0.0 {
255                distance += loss.gt_penalty * diff;
256            } else {
257                distance -= loss.lt_penalty * diff;
258            }
259        }
260        distance
261    }
262
263    #[inline(always)]
264    pub fn accumulate(
265        &mut self,
266        index: usize,
267        frequency: u64,
268        code: 编码,
269        duplicate: bool,
270        parameters: &Parameters,
271        sign: i64,
272    ) {
273        let frequency = frequency as i64 * sign;
274        let radix = self.radix;
275        let length = self
276            .length_breakpoints
277            .iter()
278            .position(|&x| code < x)
279            .unwrap() as u64;
280        self.total_frequency += frequency;
281        self.total_pairs += (length - 1) as i64 * frequency;
282        let partial_weights = &self.partial_weights;
283        // 一、全局指标
284        // 1. 按键分布
285        if partial_weights.key_distribution.is_some() {
286            let mut current = code;
287            while current > 0 {
288                let key = current % self.radix;
289                if let Some(x) = self.distribution.get_mut(key as usize) {
290                    *x += frequency;
291                }
292                current /= self.radix;
293            }
294        }
295        // 2. 组合当量
296        if partial_weights.pair_equivalence.is_some() {
297            let mut code = code;
298            while code > self.radix {
299                let partial_code = (code % self.max_index) as usize;
300                self.total_pair_equivalence +=
301                    parameters.pair_equivalence[partial_code] * frequency as f64;
302                code /= self.segment;
303            }
304        }
305        // 4. 差指法
306        if let Some(fingering) = &partial_weights.fingering {
307            let mut code = code;
308            while code > radix {
309                let label = parameters.fingering_types[(code % self.max_index) as usize];
310                for (i, weight) in fingering.iter().enumerate() {
311                    if weight.is_some() {
312                        self.total_fingering[i] += frequency * label[i] as i64;
313                    }
314                }
315                code /= self.segment;
316            }
317        }
318        // 5. 重码
319        if duplicate {
320            self.total_duplication += frequency;
321        }
322        // 6. 简码
323        if let Some(levels) = &partial_weights.levels {
324            for (ilevel, level) in levels.iter().enumerate() {
325                if level.length == length as usize {
326                    self.total_levels[ilevel] += frequency;
327                }
328            }
329        }
330        // 二、分级指标
331        if let Some(tiers) = &partial_weights.tiers {
332            for (itier, tier) in tiers.iter().enumerate() {
333                if index >= tier.top.unwrap_or(self.total_count) {
334                    continue;
335                }
336                // 1. 重码
337                if duplicate {
338                    self.tiers_duplication[itier] += sign;
339                }
340                // 2. 简码
341                if let Some(levels) = &tier.levels {
342                    for (ilevel, level) in levels.iter().enumerate() {
343                        if level.length == length as usize {
344                            self.tiers_levels[itier][ilevel] += sign;
345                        }
346                    }
347                }
348                // 3. 差指法
349                if let Some(fingering) = &tier.fingering {
350                    let mut code = code;
351                    while code > radix {
352                        let label = parameters.fingering_types[(code % self.max_index) as usize];
353                        for (i, weight) in fingering.iter().enumerate() {
354                            if weight.is_some() {
355                                self.tiers_fingering[itier][i] += sign * label[i] as i64;
356                            }
357                        }
358                        code /= self.segment;
359                    }
360                }
361            }
362        }
363    }
364}