chai/objectives/
default.rs

1use super::cache::Cache;
2use super::metric::Metric;
3use super::目标函数;
4use crate::config::PartialWeights;
5use crate::data::{数据, 用指标记, 编码信息, 键位分布损失函数};
6use crate::错误;
7
8#[derive(Clone)]
9pub struct 默认目标函数 {
10    parameters: Parameters,
11    buckets: Vec<[Option<Cache>; 2]>,
12}
13
14#[derive(Clone)]
15pub struct Parameters {
16    pub ideal_distribution: Vec<键位分布损失函数>,
17    pub pair_equivalence: Vec<f64>,
18    pub fingering_types: Vec<用指标记>,
19}
20
21pub type Frequencies = Vec<f64>;
22
23pub enum PartialType {
24    CharactersFull,
25    CharactersShort,
26    WordsFull,
27    WordsShort,
28}
29
30impl PartialType {
31    pub fn is_characters(&self) -> bool {
32        matches!(self, Self::CharactersFull | Self::CharactersShort)
33    }
34}
35
36/// 目标函数
37impl 默认目标函数 {
38    /// 通过传入配置表示、编码器和共用资源来构造一个目标函数
39    pub fn 新建(数据: &数据) -> Result<Self, 错误> {
40        let ideal_distribution = 数据.键位分布信息.clone();
41        let pair_equivalence = 数据.当量信息.clone();
42        let fingering_types = 数据.预处理指法标记();
43        let config = 数据
44            .配置
45            .optimization
46            .as_ref()
47            .ok_or("优化配置不存在")?
48            .objective
49            .clone();
50        let radix = 数据.进制;
51        let max_index = pair_equivalence.len() as u64;
52        let make_cache = |x: &PartialWeights| Cache::new(x, radix, 数据.词列表.len(), max_index);
53        let cf = config.characters_full.as_ref().map(make_cache);
54        let cs = config.characters_short.as_ref().map(make_cache);
55        let wf = config.words_full.as_ref().map(make_cache);
56        let ws = config.words_short.as_ref().map(make_cache);
57        let buckets = vec![[cf, cs], [wf, ws]];
58        let parameters = Parameters {
59            ideal_distribution,
60            pair_equivalence,
61            fingering_types,
62        };
63        let objective = Self {
64            parameters,
65            buckets,
66        };
67        Ok(objective)
68    }
69}
70
71impl 目标函数 for 默认目标函数 {
72    /// 计算各个部分编码的指标,然后将它们合并成一个指标输出
73    fn 计算(&mut self, 编码结果: &mut [编码信息]) -> (Metric, f64) {
74        let parameters = &self.parameters;
75
76        // 开始计算指标
77        for (index, code_info) in 编码结果.iter_mut().enumerate() {
78            let frequency = code_info.频率;
79            let bucket = if code_info.词长 == 1 {
80                &mut self.buckets[0]
81            } else {
82                &mut self.buckets[1]
83            };
84            if let Some(cache) = &mut bucket[0] {
85                cache.process(index, frequency, &mut code_info.全码, parameters);
86            }
87            if let Some(cache) = &mut bucket[1] {
88                cache.process(index, frequency, &mut code_info.简码, parameters);
89            }
90        }
91
92        let mut loss = 0.0;
93        let mut metric = Metric {
94            characters_full: None,
95            words_full: None,
96            characters_short: None,
97            words_short: None,
98        };
99        for (index, bucket) in self.buckets.iter().enumerate() {
100            let _ = &bucket[0].as_ref().map(|x| {
101                let (partial, accum) = x.finalize(parameters);
102                loss += accum;
103                if index == 0 {
104                    metric.characters_full = Some(partial);
105                } else {
106                    metric.words_full = Some(partial);
107                }
108            });
109            let _ = &bucket[1].as_ref().map(|x| {
110                let (partial, accum) = x.finalize(parameters);
111                loss += accum;
112                if index == 0 {
113                    metric.characters_short = Some(partial);
114                } else {
115                    metric.words_short = Some(partial);
116                }
117            });
118        }
119
120        (metric, loss)
121    }
122}