use crate::config::配置;
use crate::interfaces::{消息, 界面, 默认输入};
use crate::objectives::目标函数;
use crate::optimizers::优化结果;
use crate::{
原始可编码对象, 原始当量信息, 原始键位分布信息, 码表项, 错误
};
use chrono::Local;
use clap::{Parser, Subcommand};
use csv::{ReaderBuilder, WriterBuilder};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::fs::{create_dir_all, read_dir, read_to_string, write, File, OpenOptions};
use std::io::Write;
use std::iter::FromIterator;
use std::path::{Path, PathBuf};
pub trait 命令行参数: Clone {
fn 是否为多线程(&self) -> bool;
}
#[derive(Parser, Clone)]
#[command(name = "汉字自动拆分系统")]
#[command(author, version, about, long_about)]
#[command(propagate_version = true)]
pub struct 默认命令行参数 {
#[command(subcommand)]
pub command: 命令,
}
impl 命令行参数 for 默认命令行参数 {
fn 是否为多线程(&self) -> bool {
match &self.command {
命令::Optimize { threads, .. } => *threads != 1,
_ => false,
}
}
}
#[derive(Parser, Clone)]
pub struct 数据参数 {
pub config: Option<PathBuf>,
#[arg(short, long, value_name = "FILE")]
pub encodables: Option<PathBuf>,
#[arg(short, long, value_name = "FILE")]
pub key_distribution: Option<PathBuf>,
#[arg(short, long, value_name = "FILE")]
pub pair_equivalence: Option<PathBuf>,
}
#[derive(Subcommand, Clone)]
pub enum 命令 {
#[command(about = "使用方案文件和拆分表计算出字词编码并统计各类指标")]
Encode {
#[command(flatten)]
data: 数据参数,
},
#[command(about = "基于配置文件优化决策")]
Optimize {
#[command(flatten)]
data: 数据参数,
#[arg(short, long, default_value = "1")]
threads: usize,
#[arg(short, long)]
resume_from: Option<PathBuf>,
},
#[command(about = "启动 HTTP API 服务器")]
Server {
#[arg(short, long, default_value = "3200")]
port: u16,
},
}
#[derive(Debug, Clone)]
pub struct 命令行<P: 命令行参数> {
pub 参数: P,
pub 输出目录: PathBuf,
}
pub fn 读取文本文件<I, T>(path: PathBuf) -> T
where
I: for<'de> Deserialize<'de>,
T: FromIterator<I>,
{
let mut reader = ReaderBuilder::new()
.delimiter(b'\t')
.has_headers(false)
.flexible(true)
.from_path(path)
.unwrap();
reader.deserialize().map(|x| x.unwrap()).collect()
}
impl<P: 命令行参数> 命令行<P> {
pub fn 新建(args: P, maybe_output_dir: Option<PathBuf>) -> Self {
let output_dir = maybe_output_dir.unwrap_or_else(|| {
let time = Local::now().format("%m-%d+%H_%M_%S").to_string();
PathBuf::from(format!("output-{time}"))
});
create_dir_all(output_dir.clone()).unwrap();
Self {
参数: args,
输出目录: output_dir,
}
}
pub fn 输出编码结果(&self, entries: Vec<码表项>) -> PathBuf {
let path = self.输出目录.join("code.txt");
let mut writer = WriterBuilder::new()
.delimiter(b'\t')
.has_headers(false)
.from_path(&path)
.unwrap();
for 码表项 {
词: name,
全码: full,
全码排名: full_rank,
简码: short,
简码排名: short_rank,
} in entries
{
writer
.serialize((&name, &full, &full_rank, &short, &short_rank))
.unwrap();
}
writer.flush().unwrap();
return path;
}
pub fn 输出指标<M: Display + Serialize>(&self, metric: &M, score: f64) -> PathBuf {
let path = self.输出目录.join("metric.txt");
let metric_str = format!("分数:{score:.4e};{metric}");
write(&path, metric_str).unwrap();
return path;
}
pub fn 输出总结<O: 目标函数>(
&self,
results: &Vec<(usize, 优化结果<O>, Self)>,
) -> PathBuf {
let path = self.输出目录.join("summary.txt");
let mut f = File::create(&path).unwrap();
for (index, result, _) in results {
write!(
&mut f,
"线程 {index} 分数:{:.4e};{}",
result.分数, result.指标
)
.unwrap();
}
f.flush().unwrap();
return path;
}
pub fn 输出配置文件(&self, 配置内容: &str) -> PathBuf {
let path = self.输出目录.join("config.yaml");
write(&path, 配置内容).unwrap();
return path;
}
pub fn 生成子命令行(&self, index: usize) -> 命令行<P> {
if !self.参数.是否为多线程() {
return self.clone();
}
let child_dir = self.输出目录.join(format!("{index}"));
命令行::新建(self.参数.clone(), Some(child_dir))
}
}
pub fn 从命令行参数创建(参数: &默认命令行参数) -> 默认输入 {
let (config, encodables, key_distribution, pair_equivalence) = match &参数.command {
命令::Encode { data } | 命令::Optimize { data, .. } => (
data.config.clone(),
data.encodables.clone(),
data.key_distribution.clone(),
data.pair_equivalence.clone(),
),
命令::Server { .. } => {
panic!("Server 命令不需要数据准备");
}
};
let config_path = config.unwrap_or(PathBuf::from("config.yaml"));
let config_content = read_to_string(&config_path)
.unwrap_or_else(|_| panic!("文件 {} 不存在", config_path.display()));
let config: 配置 = serde_yaml::from_str(&config_content).unwrap();
let elements_path = encodables.unwrap_or(PathBuf::from("elements.yaml"));
let elements_content = read_to_string(&elements_path)
.unwrap_or_else(|_| panic!("文件 {} 不存在", elements_path.display()));
let encodables: Vec<原始可编码对象> = serde_yaml::from_str(&elements_content).unwrap();
let assets_dir = Path::new("assets");
let keq_path = key_distribution.unwrap_or(assets_dir.join("distribution.txt"));
let key_distribution: 原始键位分布信息 = 读取文本文件(keq_path);
let peq_path = pair_equivalence.unwrap_or(assets_dir.join("equivalence.txt"));
let pair_equivalence: 原始当量信息 = 读取文本文件(peq_path);
默认输入 {
配置: config,
原始键位分布信息: key_distribution,
原始当量信息: pair_equivalence,
词列表: encodables,
}
}
impl<P: 命令行参数> 界面 for 命令行<P> {
fn 发送(&self, message: 消息) {
let mut writer: Box<dyn Write> = if self.参数.是否为多线程() {
let log_path = self.输出目录.join("log.txt");
let file = OpenOptions::new()
.create(true) .append(true) .open(log_path)
.expect("Failed to open file");
Box::new(file)
} else {
Box::new(std::io::stdout())
};
let result = match message {
消息::TrialMax {
temperature,
accept_rate,
} => writeln!(
&mut writer,
"若温度为 {temperature:.2e},接受率为 {:.2}%",
accept_rate * 100.0
),
消息::TrialMin {
temperature,
improve_rate,
} => writeln!(
&mut writer,
"若温度为 {temperature:.2e},改进率为 {:.2}%",
improve_rate * 100.0
),
消息::Parameters { t_max, t_min } => writeln!(
&mut writer,
"参数寻找完成,从最高温 {t_max} 降到最低温 {t_min}……"
),
消息::Elapsed { time } => writeln!(&mut writer, "计算一次评测用时:{time} μs"),
消息::Progress {
steps,
temperature,
config,
metric,
score,
} => {
let 配置文件名 = format!("checkpoint-{steps}.yaml");
let 配置路径 = self.输出目录.join(&配置文件名);
write(&配置路径, config).unwrap();
writeln!(
&mut writer,
"已执行 {steps} 步,当前温度 {temperature:.2e},当前分数 {score:.4e},当前指标如下:\n{metric}",
)
}
消息::BetterSolution {
metric,
score,
config,
index,
} => {
if let Some(index) = index {
let 配置文件名 = format!("solution-{index}.yaml");
let 配置路径 = self.输出目录.join(&配置文件名);
let 指标文件名 = format!("solution-{index}.txt");
let 指标路径 = self.输出目录.join(&指标文件名);
writeln!(
&mut writer,
"方案文件保存于 {},指标保存于 {}",
配置路径.display(),
指标路径.display()
)
.unwrap();
write(指标路径, metric.clone()).unwrap();
write(配置路径, config).unwrap();
}
writeln!(
&mut writer,
"系统搜索到了一个更好的方案,分数为 {score:.4e},指标如下:\n{metric}"
)
}
};
result.unwrap()
}
}
pub fn 从目录恢复(目录: &PathBuf, 线程数: usize) -> Result<Vec<(usize, 配置)>, 错误> {
let mut 存档列表 = vec![None; 线程数];
let mut 目录列表 = vec![];
if 线程数 == 1 {
目录列表.push(目录.clone());
} else {
for i in 0..线程数 {
目录列表.push(目录.join(i.to_string()));
}
}
for (i, 子目录) in 目录列表.iter().enumerate() {
let 存档 = &mut 存档列表[i];
for entry in read_dir(子目录)? {
let entry = entry?;
let file_name_raw = entry.file_name();
let file_name = file_name_raw.to_str().ok_or("文件名不是有效的 UTF-8")?;
if file_name.starts_with("checkpoint-") && file_name.ends_with(".yaml") {
let step_str = &file_name["checkpoint-".len()..file_name.len() - ".yaml".len()];
if let Ok(step) = step_str.parse::<usize>() {
if let Some((current_step, _)) = 存档 {
if step <= *current_step {
continue; }
}
*存档 = Some((step, entry.path()));
}
}
}
}
let mut 结果 = vec![];
for (i, checkpoint) in 存档列表.iter().enumerate() {
if let Some((step, path)) = checkpoint {
let content = read_to_string(path)?;
let config: 配置 = serde_yaml::from_str(&content).unwrap();
结果.push((*step, config));
} else {
return Err(format!("线程 {i} 没有找到 checkpoint 文件").into());
}
}
Ok(结果)
}