1use crate::config::配置;
2use crate::interfaces::{消息, 界面, 默认输入};
3use crate::objectives::目标函数;
4use crate::optimizers::优化结果;
5use crate::{
6 原始可编码对象, 原始当量信息, 原始键位分布信息, 码表项, 错误
7};
8use chrono::Local;
9use clap::{Parser, Subcommand};
10use csv::{ReaderBuilder, WriterBuilder};
11use serde::{Deserialize, Serialize};
12use std::fmt::Display;
13use std::fs::{create_dir_all, read_dir, read_to_string, write, File, OpenOptions};
14use std::io::Write;
15use std::iter::FromIterator;
16use std::path::{Path, PathBuf};
17
18pub trait 命令行参数: Clone {
19 fn 是否为多线程(&self) -> bool;
20}
21
22#[derive(Parser, Clone)]
24#[command(name = "汉字自动拆分系统")]
25#[command(author, version, about, long_about)]
26#[command(propagate_version = true)]
27pub struct 默认命令行参数 {
28 #[command(subcommand)]
29 pub command: 命令,
30}
31
32impl 命令行参数 for 默认命令行参数 {
33 fn 是否为多线程(&self) -> bool {
34 match &self.command {
35 命令::Optimize { threads, .. } => *threads != 1,
36 _ => false,
37 }
38 }
39}
40
41#[derive(Parser, Clone)]
43pub struct 数据参数 {
44 pub config: Option<PathBuf>,
46 #[arg(short, long, value_name = "FILE")]
48 pub encodables: Option<PathBuf>,
49 #[arg(short, long, value_name = "FILE")]
51 pub key_distribution: Option<PathBuf>,
52 #[arg(short, long, value_name = "FILE")]
54 pub pair_equivalence: Option<PathBuf>,
55}
56
57#[derive(Subcommand, Clone)]
59pub enum 命令 {
60 #[command(about = "使用方案文件和拆分表计算出字词编码并统计各类指标")]
61 Encode {
62 #[command(flatten)]
63 data: 数据参数,
64 },
65 #[command(about = "基于配置文件优化决策")]
66 Optimize {
67 #[command(flatten)]
68 data: 数据参数,
69 #[arg(short, long, default_value = "1")]
71 threads: usize,
72 #[arg(short, long)]
75 resume_from: Option<PathBuf>,
76 },
77 #[command(about = "启动 HTTP API 服务器")]
79 Server {
80 #[arg(short, long, default_value = "3200")]
82 port: u16,
83 },
84}
85
86#[derive(Debug, Clone)]
88pub struct 命令行<P: 命令行参数> {
89 pub 参数: P,
90 pub 输出目录: PathBuf,
91}
92
93pub fn 读取文本文件<I, T>(path: PathBuf) -> T
94where
95 I: for<'de> Deserialize<'de>,
96 T: FromIterator<I>,
97{
98 let mut reader = ReaderBuilder::new()
99 .delimiter(b'\t')
100 .has_headers(false)
101 .flexible(true)
102 .from_path(path)
103 .unwrap();
104 reader.deserialize().map(|x| x.unwrap()).collect()
105}
106
107impl<P: 命令行参数> 命令行<P> {
108 pub fn 新建(args: P, maybe_output_dir: Option<PathBuf>) -> Self {
109 let output_dir = maybe_output_dir.unwrap_or_else(|| {
110 let time = Local::now().format("%m-%d+%H_%M_%S").to_string();
111 PathBuf::from(format!("output-{time}"))
112 });
113 create_dir_all(output_dir.clone()).unwrap();
114 Self {
115 参数: args,
116 输出目录: output_dir,
117 }
118 }
119
120 pub fn 输出编码结果(&self, entries: Vec<码表项>) -> PathBuf {
121 let path = self.输出目录.join("code.txt");
122 let mut writer = WriterBuilder::new()
123 .delimiter(b'\t')
124 .has_headers(false)
125 .from_path(&path)
126 .unwrap();
127 for 码表项 {
128 词: name,
129 全码: full,
130 全码排名: full_rank,
131 简码: short,
132 简码排名: short_rank,
133 } in entries
134 {
135 writer
136 .serialize((&name, &full, &full_rank, &short, &short_rank))
137 .unwrap();
138 }
139 writer.flush().unwrap();
140 return path;
141 }
142
143 pub fn 输出指标<M: Display + Serialize>(&self, metric: &M, score: f64) -> PathBuf {
144 let path = self.输出目录.join("metric.txt");
145 let metric_str = format!("分数:{score:.4e};{metric}");
146 write(&path, metric_str).unwrap();
147 return path;
148 }
149
150 pub fn 输出总结<O: 目标函数>(
151 &self,
152 results: &Vec<(usize, 优化结果<O>, Self)>,
153 ) -> PathBuf {
154 let path = self.输出目录.join("summary.txt");
155 let mut f = File::create(&path).unwrap();
156 for (index, result, _) in results {
157 write!(
158 &mut f,
159 "线程 {index} 分数:{:.4e};{}",
160 result.分数, result.指标
161 )
162 .unwrap();
163 }
164 f.flush().unwrap();
165 return path;
166 }
167
168 pub fn 输出配置文件(&self, 配置内容: &str) -> PathBuf {
169 let path = self.输出目录.join("config.yaml");
170 write(&path, 配置内容).unwrap();
171 return path;
172 }
173
174 pub fn 生成子命令行(&self, index: usize) -> 命令行<P> {
175 if !self.参数.是否为多线程() {
176 return self.clone();
177 }
178 let child_dir = self.输出目录.join(format!("{index}"));
179 命令行::新建(self.参数.clone(), Some(child_dir))
180 }
181}
182
183pub fn 从命令行参数创建(参数: &默认命令行参数) -> 默认输入 {
184 let (config, encodables, key_distribution, pair_equivalence) = match &参数.command {
185 命令::Encode { data } | 命令::Optimize { data, .. } => (
186 data.config.clone(),
187 data.encodables.clone(),
188 data.key_distribution.clone(),
189 data.pair_equivalence.clone(),
190 ),
191 命令::Server { .. } => {
192 panic!("Server 命令不需要数据准备");
193 }
194 };
195 let config_path = config.unwrap_or(PathBuf::from("config.yaml"));
196 let config_content = read_to_string(&config_path)
197 .unwrap_or_else(|_| panic!("文件 {} 不存在", config_path.display()));
198 let config: 配置 = serde_yaml::from_str(&config_content).unwrap();
199 let elements_path = encodables.unwrap_or(PathBuf::from("elements.yaml"));
200 let elements_content = read_to_string(&elements_path)
201 .unwrap_or_else(|_| panic!("文件 {} 不存在", elements_path.display()));
202 let encodables: Vec<原始可编码对象> = serde_yaml::from_str(&elements_content).unwrap();
203 let assets_dir = Path::new("assets");
204 let keq_path = key_distribution.unwrap_or(assets_dir.join("distribution.txt"));
205 let key_distribution: 原始键位分布信息 = 读取文本文件(keq_path);
206 let peq_path = pair_equivalence.unwrap_or(assets_dir.join("equivalence.txt"));
207 let pair_equivalence: 原始当量信息 = 读取文本文件(peq_path);
208 默认输入 {
209 配置: config,
210 原始键位分布信息: key_distribution,
211 原始当量信息: pair_equivalence,
212 词列表: encodables,
213 }
214}
215
216impl<P: 命令行参数> 界面 for 命令行<P> {
217 fn 发送(&self, message: 消息) {
218 let mut writer: Box<dyn Write> = if self.参数.是否为多线程() {
219 let log_path = self.输出目录.join("log.txt");
220 let file = OpenOptions::new()
221 .create(true) .append(true) .open(log_path)
224 .expect("Failed to open file");
225 Box::new(file)
226 } else {
227 Box::new(std::io::stdout())
228 };
229 let result = match message {
230 消息::TrialMax {
231 temperature,
232 accept_rate,
233 } => writeln!(
234 &mut writer,
235 "若温度为 {temperature:.2e},接受率为 {:.2}%",
236 accept_rate * 100.0
237 ),
238 消息::TrialMin {
239 temperature,
240 improve_rate,
241 } => writeln!(
242 &mut writer,
243 "若温度为 {temperature:.2e},改进率为 {:.2}%",
244 improve_rate * 100.0
245 ),
246 消息::Parameters { t_max, t_min } => writeln!(
247 &mut writer,
248 "参数寻找完成,从最高温 {t_max} 降到最低温 {t_min}……"
249 ),
250 消息::Elapsed { time } => writeln!(&mut writer, "计算一次评测用时:{time} μs"),
251 消息::Progress {
252 steps,
253 temperature,
254 config,
255 metric,
256 score,
257 } => {
258 let 配置文件名 = format!("checkpoint-{steps}.yaml");
259 let 配置路径 = self.输出目录.join(&配置文件名);
260 write(&配置路径, config).unwrap();
261 writeln!(
262 &mut writer,
263 "已执行 {steps} 步,当前温度 {temperature:.2e},当前分数 {score:.4e},当前指标如下:\n{metric}",
264 )
265 }
266 消息::BetterSolution {
267 metric,
268 score,
269 config,
270 index,
271 } => {
272 if let Some(index) = index {
273 let 配置文件名 = format!("solution-{index}.yaml");
274 let 配置路径 = self.输出目录.join(&配置文件名);
275 let 指标文件名 = format!("solution-{index}.txt");
276 let 指标路径 = self.输出目录.join(&指标文件名);
277 writeln!(
278 &mut writer,
279 "方案文件保存于 {},指标保存于 {}",
280 配置路径.display(),
281 指标路径.display()
282 )
283 .unwrap();
284 write(指标路径, metric.clone()).unwrap();
285 write(配置路径, config).unwrap();
286 }
287 writeln!(
288 &mut writer,
289 "系统搜索到了一个更好的方案,分数为 {score:.4e},指标如下:\n{metric}"
290 )
291 }
292 };
293 result.unwrap()
294 }
295}
296
297pub fn 从目录恢复(目录: &PathBuf, 线程数: usize) -> Result<Vec<(usize, 配置)>, 错误> {
298 let mut 存档列表 = vec![None; 线程数];
299 let mut 目录列表 = vec![];
300 if 线程数 == 1 {
301 目录列表.push(目录.clone());
302 } else {
303 for i in 0..线程数 {
304 目录列表.push(目录.join(i.to_string()));
305 }
306 }
307 for (i, 子目录) in 目录列表.iter().enumerate() {
308 let 存档 = &mut 存档列表[i];
309 for entry in read_dir(子目录)? {
310 let entry = entry?;
311 let file_name_raw = entry.file_name();
312 let file_name = file_name_raw.to_str().ok_or("文件名不是有效的 UTF-8")?;
313 if file_name.starts_with("checkpoint-") && file_name.ends_with(".yaml") {
314 let step_str = &file_name["checkpoint-".len()..file_name.len() - ".yaml".len()];
315 if let Ok(step) = step_str.parse::<usize>() {
316 if let Some((current_step, _)) = 存档 {
317 if step <= *current_step {
318 continue; }
320 }
321 *存档 = Some((step, entry.path()));
322 }
323 }
324 }
325 }
326 let mut 结果 = vec![];
327 for (i, checkpoint) in 存档列表.iter().enumerate() {
328 if let Some((step, path)) = checkpoint {
329 let content = read_to_string(path)?;
330 let config: 配置 = serde_yaml::from_str(&content).unwrap();
331 结果.push((*step, config));
332 } else {
333 return Err(format!("线程 {i} 没有找到 checkpoint 文件").into());
334 }
335 }
336 Ok(结果)
337}