chai/objectives/
default.rs1use rustc_hash::FxHashMap;
2
3use super::cache::缓存;
4use super::metric::默认指标;
5use super::目标函数;
6use crate::config::部分权重;
7use crate::contexts::default::{默认上下文, 默认决策, 默认决策空间};
8use crate::encoders::编码器;
9use crate::错误;
10use crate::{元素, 指法向量, 编码信息, 键位分布损失函数};
11
12#[derive(Clone)]
13pub struct 默认目标函数<E: 编码器> {
14 pub 决策空间: 默认决策空间,
15 pub 参数: 默认目标函数参数,
16 pub 编码器: E,
17 pub 编码结果: Vec<编码信息>,
18 pub 计数桶列表: Vec<[Option<缓存>; 2]>,
19}
20
21#[derive(Clone)]
22pub struct 默认目标函数参数 {
23 pub 键位分布信息: Vec<键位分布损失函数>,
24 pub 当量信息: Vec<f64>,
25 pub 指法计数: Vec<指法向量>,
26 pub 数字转键: FxHashMap<u64, char>,
27 pub 正则化强度: f64,
28}
29
30pub type Frequencies = Vec<f64>;
31
32pub enum PartialType {
33 CharactersFull,
34 CharactersShort,
35 WordsFull,
36 WordsShort,
37}
38
39impl PartialType {
40 pub fn is_characters(&self) -> bool {
41 matches!(self, Self::CharactersFull | Self::CharactersShort)
42 }
43}
44
45impl<E: 编码器<决策 = 默认决策>> 默认目标函数<E> {
47 pub fn 新建(上下文: &默认上下文, 编码器: E) -> Result<Self, 错误> {
49 let 键位分布信息 = 上下文.键位分布信息.clone();
50 let 当量信息 = 上下文.当量信息.clone();
51 let 指法计数 = 上下文.棱镜.预处理指法标记(上下文.get_space());
52 let config = 上下文
53 .配置
54 .optimization
55 .as_ref()
56 .ok_or("优化配置不存在")?
57 .objective
58 .clone();
59 let 最大编码 = 当量信息.len() as u64;
60 let 构造缓存 =
61 |x: &部分权重| 缓存::new(x, 上下文.棱镜.进制, 上下文.词列表.len(), 最大编码);
62 let 一字全码 = config.characters_full.as_ref().map(构造缓存);
63 let 一字简码 = config.characters_short.as_ref().map(构造缓存);
64 let 多字全码 = config.words_full.as_ref().map(构造缓存);
65 let 多字简码 = config.words_short.as_ref().map(构造缓存);
66 let 计数桶列表 = vec![[一字全码, 一字简码], [多字全码, 多字简码]];
67 let 参数 = 默认目标函数参数 {
68 键位分布信息,
69 当量信息,
70 指法计数,
71 数字转键: 上下文.棱镜.数字转键.clone(),
72 正则化强度: config.regularization_strength.unwrap_or(1.0),
73 };
74 let 编码结果: Vec<_> = 上下文.词列表.iter().map(编码信息::new).collect();
75 Ok(Self {
76 参数,
77 编码器,
78 编码结果: 编码结果.clone(),
79 计数桶列表: 计数桶列表.clone(),
80 决策空间: 上下文.决策空间.clone(),
81 })
82 }
83
84 pub fn 计算复杂度(&self, 决策: &默认决策) -> f64 {
85 let mut 复杂度 = 0.0;
86 for (序号, 安排列表) in self.决策空间.元素.iter().enumerate() {
87 let 安排 = &决策.元素[序号];
88 let mut 分值 = 0.0;
89 for 条件安排 in 安排列表 {
90 if &条件安排.安排 == 安排 {
91 分值 = 条件安排.分数;
92 break;
93 }
94 }
95 复杂度 += 分值;
96 }
97 复杂度
98 }
99}
100
101impl<E: 编码器<决策 = 默认决策>> 目标函数 for 默认目标函数<E> {
102 type 目标值 = 默认指标;
103 type 决策 = 默认决策;
104
105 fn 计算(
107 &mut self, 决策: &默认决策, 变化: &Option<Vec<元素>>
108 ) -> (默认指标, f64) {
109 let 参数 = &self.参数;
110 self.编码器.编码(决策, 变化, &mut self.编码结果);
111 let mut 桶序号列表: Vec<_> = self.计数桶列表.iter().map(|_| 0).collect();
112 for 编码信息 in self.编码结果.iter_mut() {
114 let 频率 = 编码信息.频率;
115 let 桶索引 = if 编码信息.词长 == 1 { 0 } else { 1 };
116 let 桶 = &mut self.计数桶列表[桶索引];
117 let 桶序号 = 桶序号列表[桶索引];
118 if let Some(缓存) = &mut 桶[0] {
119 缓存.处理(桶序号, 频率, &mut 编码信息.全码, 参数);
120 }
121 if let Some(缓存) = &mut 桶[1] {
122 缓存.处理(桶序号, 频率, &mut 编码信息.简码, 参数);
123 }
124 桶序号列表[桶索引] += 1;
125 }
126
127 let mut 目标函数 = 0.0;
128 let mut 指标 = 默认指标 {
129 characters_full: None,
130 words_full: None,
131 characters_short: None,
132 words_short: None,
133 complexity: None,
134 };
135 for (桶索引, 桶) in self.计数桶列表.iter().enumerate() {
136 let _ = &桶[0].as_ref().map(|x| {
137 let (分组指标, 分组目标函数) = x.汇总(参数);
138 目标函数 += 分组目标函数;
139 if 桶索引 == 0 {
140 指标.characters_full = Some(分组指标);
141 } else {
142 指标.words_full = Some(分组指标);
143 }
144 });
145 let _ = &桶[1].as_ref().map(|x| {
146 let (分组指标, 分组目标函数) = x.汇总(参数);
147 目标函数 += 分组目标函数;
148 if 桶索引 == 0 {
149 指标.characters_short = Some(分组指标);
150 } else {
151 指标.words_short = Some(分组指标);
152 }
153 });
154 }
155 let 复杂度 = self.计算复杂度(决策);
156 指标.complexity = Some(复杂度);
157 目标函数 += 参数.正则化强度 * 复杂度;
158 (指标, 目标函数)
159 }
160}