chai/objectives/
default.rs1use rustc_hash::FxHashMap;
2
3use super::cache::缓存;
4use super::metric::默认指标;
5use super::目标函数;
6use crate::config::PartialWeights;
7use crate::data::{
8    元素映射, 指法向量, 数据, 正则化, 编码信息, 键位分布损失函数
9};
10use crate::错误;
11
12#[derive(Clone)]
13pub struct 默认目标函数 {
14    pub 参数: 默认目标函数参数,
15    pub 计数桶列表: Vec<[Option<缓存>; 2]>,
16}
17
18#[derive(Clone)]
19pub struct 默认目标函数参数 {
20    pub 键位分布信息: Vec<键位分布损失函数>,
21    pub 当量信息: Vec<f64>,
22    pub 指法计数: Vec<指法向量>,
23    pub 数字转键: FxHashMap<u64, char>,
24    pub 正则化: 正则化,
25    pub 正则化强度: f64,
26}
27
28pub type Frequencies = Vec<f64>;
29
30pub enum PartialType {
31    CharactersFull,
32    CharactersShort,
33    WordsFull,
34    WordsShort,
35}
36
37impl PartialType {
38    pub fn is_characters(&self) -> bool {
39        matches!(self, Self::CharactersFull | Self::CharactersShort)
40    }
41}
42
43impl 默认目标函数 {
45    pub fn 新建(数据: &数据) -> Result<Self, 错误> {
47        let 键位分布信息 = 数据.键位分布信息.clone();
48        let 当量信息 = 数据.当量信息.clone();
49        let 正则化 = 数据.正则化.clone();
50        let 指法计数 = 数据.预处理指法标记();
51        let config = 数据
52            .配置
53            .optimization
54            .as_ref()
55            .ok_or("优化配置不存在")?
56            .objective
57            .clone();
58        let 最大编码 = 当量信息.len() as u64;
59        let 构造缓存 = |x: &PartialWeights| 缓存::new(x, 数据.进制, 数据.词列表.len(), 最大编码);
60        let 一字全码 = config.characters_full.as_ref().map(构造缓存);
61        let 一字简码 = config.characters_short.as_ref().map(构造缓存);
62        let 多字全码 = config.words_full.as_ref().map(构造缓存);
63        let 多字简码 = config.words_short.as_ref().map(构造缓存);
64        let 计数桶列表 = vec![[一字全码, 一字简码], [多字全码, 多字简码]];
65        let 参数 = 默认目标函数参数 {
66            键位分布信息,
67            当量信息,
68            指法计数,
69            数字转键: 数据.数字转键.clone(),
70            正则化,
71            正则化强度: config
72                .regularization
73                .and_then(|x| x.strength)
74                .unwrap_or(1.0),
75        };
76        Ok(Self {
77            参数, 计数桶列表
78        })
79    }
80}
81
82impl 目标函数 for 默认目标函数 {
83    type 目标值 = 默认指标;
84
85    fn 计算(
87        &mut self, 编码结果: &mut [编码信息], 映射: &元素映射
88    ) -> (默认指标, f64) {
89        let 参数 = &self.参数;
90
91        let mut 桶序号列表: Vec<_> = self.计数桶列表.iter().map(|_| 0).collect();
92        for 编码信息 in 编码结果.iter_mut() {
94            let 频率 = 编码信息.频率;
95            let 桶索引 = if 编码信息.词长 == 1 { 0 } else { 1 };
96            let 桶 = &mut self.计数桶列表[桶索引];
97            let 桶序号 = 桶序号列表[桶索引];
98            if let Some(缓存) = &mut 桶[0] {
99                缓存.处理(桶序号, 频率, &mut 编码信息.全码, 参数);
100            }
101            if let Some(缓存) = &mut 桶[1] {
102                缓存.处理(桶序号, 频率, &mut 编码信息.简码, 参数);
103            }
104            桶序号列表[桶索引] += 1;
105        }
106
107        let mut 目标函数 = 0.0;
108        let mut 指标 = 默认指标 {
109            characters_full: None,
110            words_full: None,
111            characters_short: None,
112            words_short: None,
113            memory: None,
114        };
115        for (桶索引, 桶) in self.计数桶列表.iter().enumerate() {
116            let _ = &桶[0].as_ref().map(|x| {
117                let (分组指标, 分组目标函数) = x.汇总(参数);
118                目标函数 += 分组目标函数;
119                if 桶索引 == 0 {
120                    指标.characters_full = Some(分组指标);
121                } else {
122                    指标.words_full = Some(分组指标);
123                }
124            });
125            let _ = &桶[1].as_ref().map(|x| {
126                let (分组指标, 分组目标函数) = x.汇总(参数);
127                目标函数 += 分组目标函数;
128                if 桶索引 == 0 {
129                    指标.characters_short = Some(分组指标);
130                } else {
131                    指标.words_short = Some(分组指标);
132                }
133            });
134        }
135
136        if !参数.正则化.is_empty() {
137            let mut 记忆量 = 映射.len() as f64;
138            for (元素, 键) in 映射.iter().enumerate() {
139                if 元素 as u64 == *键 {
140                    记忆量 -= 1.0;
141                    continue;
142                }
143                if let Some(归并列表) = 参数.正则化.get(&元素) {
144                    let mut 最大亲和度 = 0.0;
145                    for (目标元素, 亲和度) in 归并列表.iter() {
146                        if 映射[*目标元素] == *键 {
147                            最大亲和度 = 亲和度.max(最大亲和度);
148                        }
149                    }
150                    记忆量 -= 最大亲和度;
151                }
152            }
153            指标.memory = Some(记忆量);
154            let 归一化记忆量 = 记忆量 / 映射.len() as f64;
155            目标函数 += 归一化记忆量 * 参数.正则化强度;
156        }
157        (指标, 目标函数)
158    }
159}