use rustc_hash::FxHashMap;
use super::default::默认目标函数参数;
use super::metric::FingeringMetric;
use super::metric::FingeringMetricUniform;
use super::metric::LevelMetricUniform;
use super::metric::分组指标;
use super::metric::层级指标;
use super::metric::键长指标;
use crate::config::部分权重;
use crate::{
最大按键组合长度, 编码, 部分编码信息, 键位分布损失函数
};
use std::iter::zip;
#[derive(Debug, Clone)]
pub struct 缓存 {
partial_weights: 部分权重,
total_count: usize,
total_frequency: i64,
total_pairs: i64,
total_extended_pairs: i64,
distribution: Vec<i64>,
total_pair_equivalence: f64,
total_extended_pair_equivalence: f64,
total_duplication: i64,
total_fingering: [i64; 8],
total_levels: Vec<i64>,
tiers_duplication: Vec<i64>,
tiers_duplication_squared: Vec<i64>,
tiers_levels: Vec<Vec<i64>>,
tiers_fingering: Vec<[i64; 8]>,
max_index: u64,
segment: u64,
length_breakpoints: Vec<u64>,
radix: u64,
}
impl 缓存 {
#[inline(always)]
pub fn 处理(
&mut self,
序号: usize,
频率: u64,
编码信息: &mut 部分编码信息,
参数: &默认目标函数参数,
) {
if !编码信息.有变化 {
return;
}
编码信息.有变化 = false;
self.增减(序号, 频率, 编码信息.实际编码, 编码信息.选重标记, 参数, 1);
if 编码信息.上一个实际编码 == 0 {
return;
}
self.增减(
序号,
频率,
编码信息.上一个实际编码,
编码信息.上一个选重标记,
参数,
-1,
);
}
pub fn 汇总(&self, 参数: &默认目标函数参数) -> (分组指标, f64) {
let partial_weights = &self.partial_weights;
let 键位分布信息 = &参数.键位分布信息;
let mut 分组指标 = 分组指标 {
tiers: None,
key_distribution: None,
key_distribution_loss: None,
pair_equivalence: None,
extended_pair_equivalence: None,
fingering: None,
duplication: None,
levels: None,
};
let mut 损失函数 = 0.0;
if let Some(key_distribution_weight) = partial_weights.key_distribution {
let 总频率: i64 = self.distribution.iter().sum();
let 分布 = self
.distribution
.iter()
.map(|x| *x as f64 / 总频率 as f64)
.collect();
let 距离 = 缓存::计算键位分布距离(&分布, 键位分布信息);
let mut 分布映射 = FxHashMap::default();
for (i, x) in 分布.iter().enumerate() {
if let Some(键) = 参数.数字转键.get(&(i as u64)) {
分布映射.insert(*键, *x);
}
}
分组指标.key_distribution = Some(分布映射);
分组指标.key_distribution_loss = Some(距离);
损失函数 += 距离 * key_distribution_weight;
}
if let Some(equivalence_weight) = partial_weights.pair_equivalence {
let equivalence = self.total_pair_equivalence / self.total_pairs as f64;
分组指标.pair_equivalence = Some(equivalence);
损失函数 += equivalence * equivalence_weight;
}
if let Some(equivalence_weight) = partial_weights.extended_pair_equivalence {
let equivalence =
self.total_extended_pair_equivalence / self.total_extended_pairs as f64;
分组指标.extended_pair_equivalence = Some(equivalence);
损失函数 += equivalence * equivalence_weight;
}
if let Some(fingering_weight) = &partial_weights.fingering {
let mut fingering = FingeringMetric::default();
for (i, weight) in fingering_weight.iter().enumerate() {
if let Some(weight) = weight {
fingering[i] = Some(self.total_fingering[i] as f64 / self.total_pairs as f64);
损失函数 += self.total_fingering[i] as f64 * weight;
}
}
分组指标.fingering = Some(fingering);
}
if let Some(duplication_weight) = partial_weights.duplication {
let duplication = self.total_duplication as f64 / self.total_frequency as f64;
分组指标.duplication = Some(duplication);
损失函数 += duplication * duplication_weight;
}
if let Some(levels_weight) = &partial_weights.levels {
let mut levels: Vec<键长指标> = Vec::new();
for (ilevel, level) in levels_weight.iter().enumerate() {
let value = self.total_levels[ilevel] as f64 / self.total_frequency as f64;
损失函数 += value * level.frequency;
levels.push(键长指标 {
length: level.length,
frequency: value,
});
}
分组指标.levels = Some(levels);
}
if let Some(tiers_weight) = &partial_weights.tiers {
let mut tiers: Vec<层级指标> = tiers_weight
.iter()
.map(|x| 层级指标 {
top: x.top,
duplication: None,
duplication_squared: None,
levels: None,
fingering: None,
})
.collect();
for (itier, tier_weights) in tiers_weight.iter().enumerate() {
let count = tier_weights.top.unwrap_or(self.total_count) as f64;
if let Some(duplication_weight) = tier_weights.duplication {
let duplication = self.tiers_duplication[itier];
损失函数 += duplication as f64 / count * duplication_weight;
tiers[itier].duplication = Some(duplication as u64);
}
if let Some(duplication_squared_weight) = tier_weights.duplication_squared {
let duplication_squared = self.tiers_duplication_squared[itier];
损失函数 += duplication_squared as f64 / count * duplication_squared_weight;
tiers[itier].duplication_squared = Some(duplication_squared as u64);
}
if let Some(level_weight) = &tier_weights.levels {
for (ilevel, level) in level_weight.iter().enumerate() {
损失函数 +=
self.tiers_levels[itier][ilevel] as f64 / count * level.frequency;
}
tiers[itier].levels = Some(
level_weight
.iter()
.enumerate()
.map(|(i, v)| LevelMetricUniform {
length: v.length,
frequency: self.tiers_levels[itier][i] as u64,
})
.collect(),
);
}
if let Some(fingering_weight) = &tier_weights.fingering {
let mut fingering = FingeringMetricUniform::default();
for (i, weight) in fingering_weight.iter().enumerate() {
if let Some(weight) = weight {
let value = self.tiers_fingering[itier][i];
fingering[i] = Some(value as u64);
损失函数 += value as f64 / count * weight;
}
}
tiers[itier].fingering = Some(fingering);
}
}
分组指标.tiers = Some(tiers);
}
(分组指标, 损失函数)
}
}
impl 缓存 {
pub fn new(
partial_weights: &部分权重,
radix: u64,
total_count: usize,
max_index: u64,
) -> Self {
let total_frequency = 0;
let total_pairs = 0;
let total_extended_pairs = 0;
let distribution = vec![0; radix as usize];
let total_pair_equivalence = 0.0;
let total_extended_pair_equivalence = 0.0;
let total_duplication = 0;
let total_fingering = [0; 8];
let nlevel = partial_weights.levels.as_ref().map_or(0, |v| v.len());
let total_levels = vec![0; nlevel];
let ntier = partial_weights.tiers.as_ref().map_or(0, |v| v.len());
let tiers_duplication = vec![0; ntier];
let tiers_duplication_squared = vec![0; ntier];
let mut tiers_levels = vec![];
if let Some(tiers) = &partial_weights.tiers {
for tier in tiers {
let vec = vec![0; tier.levels.as_ref().map_or(0, |v| v.len())];
tiers_levels.push(vec);
}
}
let tiers_fingering = vec![[0; 8]; ntier];
let segment = radix.pow((最大按键组合长度 - 1) as u32);
let length_breakpoints: Vec<u64> = (0..=8).map(|x| radix.pow(x)).collect();
Self {
partial_weights: partial_weights.clone(),
total_count,
total_frequency,
total_pairs,
total_extended_pairs,
distribution,
total_pair_equivalence,
total_extended_pair_equivalence,
total_duplication,
total_fingering,
total_levels,
tiers_duplication,
tiers_duplication_squared,
tiers_levels,
tiers_fingering,
max_index,
segment,
length_breakpoints,
radix,
}
}
fn 计算键位分布距离(
distribution: &Vec<f64>,
ideal_distribution: &Vec<键位分布损失函数>,
) -> f64 {
let mut distance = 0.0;
for (frequency, loss) in zip(distribution, ideal_distribution) {
let diff = frequency - loss.理想值;
if diff > 0.0 {
distance += loss.高于惩罚 * diff;
} else {
distance -= loss.低于惩罚 * diff;
}
}
distance
}
#[inline(always)]
pub fn 增减(
&mut self,
index: usize,
frequency: u64,
code: 编码,
duplicate: u8,
parameters: &默认目标函数参数,
sign: i64,
) {
let frequency = frequency as i64 * sign;
let radix = self.radix;
let length = self
.length_breakpoints
.iter()
.position(|&x| code < x)
.unwrap() as u64;
self.total_frequency += frequency;
self.total_pairs += (length - 1) as i64 * frequency;
let partial_weights = &self.partial_weights;
if partial_weights.key_distribution.is_some() {
let mut current = code;
while current > 0 {
let key = current % self.radix;
if let Some(x) = self.distribution.get_mut(key as usize) {
*x += frequency;
}
current /= self.radix;
}
}
if partial_weights.pair_equivalence.is_some() {
let mut code = code;
while code > self.radix {
let partial_code = (code % self.max_index) as usize;
self.total_pair_equivalence += parameters.当量信息[partial_code] * frequency as f64;
code /= self.segment;
}
}
if let Some(fingering) = &partial_weights.fingering {
let mut code = code;
while code > radix {
let label = parameters.指法计数[(code % self.max_index) as usize];
for (i, weight) in fingering.iter().enumerate() {
if weight.is_some() {
self.total_fingering[i] += frequency * label[i] as i64;
}
}
code /= self.segment;
}
}
if duplicate > 0 {
self.total_duplication += frequency;
}
if let Some(levels) = &partial_weights.levels {
for (ilevel, level) in levels.iter().enumerate() {
if level.length == length as usize {
self.total_levels[ilevel] += frequency;
}
}
}
if let Some(tiers) = &partial_weights.tiers {
for (itier, tier) in tiers.iter().enumerate() {
if index >= tier.top.unwrap_or(self.total_count) {
continue;
}
if duplicate > 0 {
self.tiers_duplication[itier] += sign;
self.tiers_duplication_squared[itier] += sign * (2 * duplicate as i64 - 1);
}
if let Some(levels) = &tier.levels {
for (ilevel, level) in levels.iter().enumerate() {
if level.length == length as usize {
self.tiers_levels[itier][ilevel] += sign;
}
}
}
if let Some(fingering) = &tier.fingering {
let mut code = code;
while code > radix {
let label = parameters.指法计数[(code % self.max_index) as usize];
for (i, weight) in fingering.iter().enumerate() {
if weight.is_some() {
self.tiers_fingering[itier][i] += sign * label[i] as i64;
}
}
code /= self.segment;
}
}
}
}
}
}