chai/objectives/
default.rs1use super::cache::Cache;
2use super::metric::Metric;
3use super::目标函数;
4use crate::config::PartialWeights;
5use crate::data::{数据, 用指标记, 编码信息, 键位分布损失函数};
6use crate::错误;
7
8#[derive(Clone)]
9pub struct 默认目标函数 {
10 parameters: Parameters,
11 buckets: Vec<[Option<Cache>; 2]>,
12}
13
14#[derive(Clone)]
15pub struct Parameters {
16 pub ideal_distribution: Vec<键位分布损失函数>,
17 pub pair_equivalence: Vec<f64>,
18 pub fingering_types: Vec<用指标记>,
19}
20
21pub type Frequencies = Vec<f64>;
22
23pub enum PartialType {
24 CharactersFull,
25 CharactersShort,
26 WordsFull,
27 WordsShort,
28}
29
30impl PartialType {
31 pub fn is_characters(&self) -> bool {
32 matches!(self, Self::CharactersFull | Self::CharactersShort)
33 }
34}
35
36impl 默认目标函数 {
38 pub fn 新建(数据: &数据) -> Result<Self, 错误> {
40 let ideal_distribution = 数据.键位分布信息.clone();
41 let pair_equivalence = 数据.当量信息.clone();
42 let fingering_types = 数据.预处理指法标记();
43 let config = 数据
44 .配置
45 .optimization
46 .as_ref()
47 .ok_or("优化配置不存在")?
48 .objective
49 .clone();
50 let radix = 数据.进制;
51 let max_index = pair_equivalence.len() as u64;
52 let make_cache = |x: &PartialWeights| Cache::new(x, radix, 数据.词列表.len(), max_index);
53 let cf = config.characters_full.as_ref().map(make_cache);
54 let cs = config.characters_short.as_ref().map(make_cache);
55 let wf = config.words_full.as_ref().map(make_cache);
56 let ws = config.words_short.as_ref().map(make_cache);
57 let buckets = vec![[cf, cs], [wf, ws]];
58 let parameters = Parameters {
59 ideal_distribution,
60 pair_equivalence,
61 fingering_types,
62 };
63 let objective = Self {
64 parameters,
65 buckets,
66 };
67 Ok(objective)
68 }
69}
70
71impl 目标函数 for 默认目标函数 {
72 fn 计算(&mut self, 编码结果: &mut [编码信息]) -> (Metric, f64) {
74 let parameters = &self.parameters;
75
76 for (index, code_info) in 编码结果.iter_mut().enumerate() {
78 let frequency = code_info.频率;
79 let bucket = if code_info.词长 == 1 {
80 &mut self.buckets[0]
81 } else {
82 &mut self.buckets[1]
83 };
84 if let Some(cache) = &mut bucket[0] {
85 cache.process(index, frequency, &mut code_info.全码, parameters);
86 }
87 if let Some(cache) = &mut bucket[1] {
88 cache.process(index, frequency, &mut code_info.简码, parameters);
89 }
90 }
91
92 let mut loss = 0.0;
93 let mut metric = Metric {
94 characters_full: None,
95 words_full: None,
96 characters_short: None,
97 words_short: None,
98 };
99 for (index, bucket) in self.buckets.iter().enumerate() {
100 let _ = &bucket[0].as_ref().map(|x| {
101 let (partial, accum) = x.finalize(parameters);
102 loss += accum;
103 if index == 0 {
104 metric.characters_full = Some(partial);
105 } else {
106 metric.words_full = Some(partial);
107 }
108 });
109 let _ = &bucket[1].as_ref().map(|x| {
110 let (partial, accum) = x.finalize(parameters);
111 loss += accum;
112 if index == 0 {
113 metric.characters_short = Some(partial);
114 } else {
115 metric.words_short = Some(partial);
116 }
117 });
118 }
119
120 (metric, loss)
121 }
122}