Skip to main content

chai/interfaces/
command_line.rs

1use crate::config::配置;
2use crate::interfaces::{消息, 界面, 默认输入};
3use crate::objectives::目标函数;
4use crate::optimizers::优化结果;
5use crate::{
6    原始可编码对象, 原始当量信息, 原始键位分布信息, 码表项, 错误
7};
8use chrono::Local;
9use clap::{Parser, Subcommand};
10use csv::{ReaderBuilder, WriterBuilder};
11use serde::{Deserialize, Serialize};
12use std::fmt::Display;
13use std::fs::{create_dir_all, read_dir, read_to_string, write, File, OpenOptions};
14use std::io::Write;
15use std::iter::FromIterator;
16use std::path::{Path, PathBuf};
17
18pub trait 命令行参数: Clone {
19    fn 是否为多线程(&self) -> bool;
20}
21
22/// 命令行参数的定义
23#[derive(Parser, Clone)]
24#[command(name = "汉字自动拆分系统")]
25#[command(author, version, about, long_about)]
26#[command(propagate_version = true)]
27pub struct 默认命令行参数 {
28    #[command(subcommand)]
29    pub command: 命令,
30}
31
32impl 命令行参数 for 默认命令行参数 {
33    fn 是否为多线程(&self) -> bool {
34        match &self.command {
35            命令::Optimize { threads, .. } => *threads != 1,
36            _ => false,
37        }
38    }
39}
40
41/// 编码和优化共用的数据参数
42#[derive(Parser, Clone)]
43pub struct 数据参数 {
44    /// 方案文件,默认为 config.yaml
45    pub config: Option<PathBuf>,
46    /// 频率序列表,默认为 elements.txt
47    #[arg(short, long, value_name = "FILE")]
48    pub encodables: Option<PathBuf>,
49    /// 单键用指分布表,默认为 assets 目录下的 distribution.txt
50    #[arg(short, long, value_name = "FILE")]
51    pub key_distribution: Option<PathBuf>,
52    /// 双键速度当量表,默认为 assets 目录下的 equivalence.txt
53    #[arg(short, long, value_name = "FILE")]
54    pub pair_equivalence: Option<PathBuf>,
55}
56
57/// 命令行中所有可用的子命令
58#[derive(Subcommand, Clone)]
59pub enum 命令 {
60    #[command(about = "使用方案文件和拆分表计算出字词编码并统计各类指标")]
61    Encode {
62        #[command(flatten)]
63        data: 数据参数,
64    },
65    #[command(about = "基于配置文件优化决策")]
66    Optimize {
67        #[command(flatten)]
68        data: 数据参数,
69        /// 优化时使用的线程数
70        #[arg(short, long, default_value = "1")]
71        threads: usize,
72        /// 是否要从某个输出目录恢复
73        /// 如果指定了这个参数,程序会在该目录寻找 checkpoint-*.yaml 来恢复优化进度
74        #[arg(short, long)]
75        resume_from: Option<PathBuf>,
76    },
77    /// 启动 Web API 服务器
78    #[command(about = "启动 HTTP API 服务器")]
79    Server {
80        /// 服务器端口号
81        #[arg(short, long, default_value = "3200")]
82        port: u16,
83    },
84}
85
86/// 通过命令行来使用 libchai 的入口,实现了界面特征
87#[derive(Debug, Clone)]
88pub struct 命令行<P: 命令行参数> {
89    pub 参数: P,
90    pub 输出目录: PathBuf,
91}
92
93pub fn 读取文本文件<I, T>(path: PathBuf) -> T
94where
95    I: for<'de> Deserialize<'de>,
96    T: FromIterator<I>,
97{
98    let mut reader = ReaderBuilder::new()
99        .delimiter(b'\t')
100        .has_headers(false)
101        .flexible(true)
102        .from_path(path)
103        .unwrap();
104    reader.deserialize().map(|x| x.unwrap()).collect()
105}
106
107impl<P: 命令行参数> 命令行<P> {
108    pub fn 新建(args: P, maybe_output_dir: Option<PathBuf>) -> Self {
109        let output_dir = maybe_output_dir.unwrap_or_else(|| {
110            let time = Local::now().format("%m-%d+%H_%M_%S").to_string();
111            PathBuf::from(format!("output-{time}"))
112        });
113        create_dir_all(output_dir.clone()).unwrap();
114        Self {
115            参数: args,
116            输出目录: output_dir,
117        }
118    }
119
120    pub fn 输出编码结果(&self, entries: Vec<码表项>) -> PathBuf {
121        let path = self.输出目录.join("code.txt");
122        let mut writer = WriterBuilder::new()
123            .delimiter(b'\t')
124            .has_headers(false)
125            .from_path(&path)
126            .unwrap();
127        for 码表项 {
128            词: name,
129            全码: full,
130            全码排名: full_rank,
131            简码: short,
132            简码排名: short_rank,
133        } in entries
134        {
135            writer
136                .serialize((&name, &full, &full_rank, &short, &short_rank))
137                .unwrap();
138        }
139        writer.flush().unwrap();
140        return path;
141    }
142
143    pub fn 输出指标<M: Display + Serialize>(&self, metric: &M, score: f64) -> PathBuf {
144        let path = self.输出目录.join("metric.txt");
145        let metric_str = format!("分数:{score:.4e};{metric}");
146        write(&path, metric_str).unwrap();
147        return path;
148    }
149
150    pub fn 输出总结<O: 目标函数>(
151        &self,
152        results: &Vec<(usize, 优化结果<O>, Self)>,
153    ) -> PathBuf {
154        let path = self.输出目录.join("summary.txt");
155        let mut f = File::create(&path).unwrap();
156        for (index, result, _) in results {
157            write!(
158                &mut f,
159                "线程 {index} 分数:{:.4e};{}",
160                result.分数, result.指标
161            )
162            .unwrap();
163        }
164        f.flush().unwrap();
165        return path;
166    }
167
168    pub fn 输出配置文件(&self, 配置内容: &str) -> PathBuf {
169        let path = self.输出目录.join("config.yaml");
170        write(&path, 配置内容).unwrap();
171        return path;
172    }
173
174    pub fn 生成子命令行(&self, index: usize) -> 命令行<P> {
175        if !self.参数.是否为多线程() {
176            return self.clone();
177        }
178        let child_dir = self.输出目录.join(format!("{index}"));
179        命令行::新建(self.参数.clone(), Some(child_dir))
180    }
181}
182
183pub fn 从命令行参数创建(参数: &默认命令行参数) -> 默认输入 {
184    let (config, encodables, key_distribution, pair_equivalence) = match &参数.command {
185        命令::Encode { data } | 命令::Optimize { data, .. } => (
186            data.config.clone(),
187            data.encodables.clone(),
188            data.key_distribution.clone(),
189            data.pair_equivalence.clone(),
190        ),
191        命令::Server { .. } => {
192            panic!("Server 命令不需要数据准备");
193        }
194    };
195    let config_path = config.unwrap_or(PathBuf::from("config.yaml"));
196    let config_content = read_to_string(&config_path)
197        .unwrap_or_else(|_| panic!("文件 {} 不存在", config_path.display()));
198    let config: 配置 = serde_yaml::from_str(&config_content).unwrap();
199    let elements_path = encodables.unwrap_or(PathBuf::from("elements.yaml"));
200    let elements_content = read_to_string(&elements_path)
201        .unwrap_or_else(|_| panic!("文件 {} 不存在", elements_path.display()));
202    let encodables: Vec<原始可编码对象> = serde_yaml::from_str(&elements_content).unwrap();
203    let assets_dir = Path::new("assets");
204    let keq_path = key_distribution.unwrap_or(assets_dir.join("distribution.txt"));
205    let key_distribution: 原始键位分布信息 = 读取文本文件(keq_path);
206    let peq_path = pair_equivalence.unwrap_or(assets_dir.join("equivalence.txt"));
207    let pair_equivalence: 原始当量信息 = 读取文本文件(peq_path);
208    默认输入 {
209        配置: config,
210        原始键位分布信息: key_distribution,
211        原始当量信息: pair_equivalence,
212        词列表: encodables,
213    }
214}
215
216impl<P: 命令行参数> 界面 for 命令行<P> {
217    fn 发送(&self, message: 消息) {
218        let mut writer: Box<dyn Write> = if self.参数.是否为多线程() {
219            let log_path = self.输出目录.join("log.txt");
220            let file = OpenOptions::new()
221                .create(true) // 如果文件不存在,则创建
222                .append(true) // 追加写入,不覆盖原有内容
223                .open(log_path)
224                .expect("Failed to open file");
225            Box::new(file)
226        } else {
227            Box::new(std::io::stdout())
228        };
229        let result = match message {
230            消息::TrialMax {
231                temperature,
232                accept_rate,
233            } => writeln!(
234                &mut writer,
235                "若温度为 {temperature:.2e},接受率为 {:.2}%",
236                accept_rate * 100.0
237            ),
238            消息::TrialMin {
239                temperature,
240                improve_rate,
241            } => writeln!(
242                &mut writer,
243                "若温度为 {temperature:.2e},改进率为 {:.2}%",
244                improve_rate * 100.0
245            ),
246            消息::Parameters { t_max, t_min } => writeln!(
247                &mut writer,
248                "参数寻找完成,从最高温 {t_max} 降到最低温 {t_min}……"
249            ),
250            消息::Elapsed { time } => writeln!(&mut writer, "计算一次评测用时:{time} μs"),
251            消息::Progress {
252                steps,
253                temperature,
254                config,
255                metric,
256                score,
257            } => {
258                let 配置文件名 = format!("checkpoint-{steps}.yaml");
259                let 配置路径 = self.输出目录.join(&配置文件名);
260                write(&配置路径, config).unwrap();
261                writeln!(
262                    &mut writer,
263                    "已执行 {steps} 步,当前温度 {temperature:.2e},当前分数 {score:.4e},当前指标如下:\n{metric}",
264                )
265            }
266            消息::BetterSolution {
267                metric,
268                score,
269                config,
270                index,
271            } => {
272                if let Some(index) = index {
273                    let 配置文件名 = format!("solution-{index}.yaml");
274                    let 配置路径 = self.输出目录.join(&配置文件名);
275                    let 指标文件名 = format!("solution-{index}.txt");
276                    let 指标路径 = self.输出目录.join(&指标文件名);
277                    writeln!(
278                        &mut writer,
279                        "方案文件保存于 {},指标保存于 {}",
280                        配置路径.display(),
281                        指标路径.display()
282                    )
283                    .unwrap();
284                    write(指标路径, metric.clone()).unwrap();
285                    write(配置路径, config).unwrap();
286                }
287                writeln!(
288                    &mut writer,
289                    "系统搜索到了一个更好的方案,分数为 {score:.4e},指标如下:\n{metric}"
290                )
291            }
292        };
293        result.unwrap()
294    }
295}
296
297pub fn 从目录恢复(目录: &PathBuf, 线程数: usize) -> Result<Vec<(usize, 配置)>, 错误> {
298    let mut 存档列表 = vec![None; 线程数];
299    let mut 目录列表 = vec![];
300    if 线程数 == 1 {
301        目录列表.push(目录.clone());
302    } else {
303        for i in 0..线程数 {
304            目录列表.push(目录.join(i.to_string()));
305        }
306    }
307    for (i, 子目录) in 目录列表.iter().enumerate() {
308        let 存档 = &mut 存档列表[i];
309        for entry in read_dir(子目录)? {
310            let entry = entry?;
311            let file_name_raw = entry.file_name();
312            let file_name = file_name_raw.to_str().ok_or("文件名不是有效的 UTF-8")?;
313            if file_name.starts_with("checkpoint-") && file_name.ends_with(".yaml") {
314                let step_str = &file_name["checkpoint-".len()..file_name.len() - ".yaml".len()];
315                if let Ok(step) = step_str.parse::<usize>() {
316                    if let Some((current_step, _)) = 存档 {
317                        if step <= *current_step {
318                            continue; // 已经有更大的 step 了,跳过这个文件
319                        }
320                    }
321                    *存档 = Some((step, entry.path()));
322                }
323            }
324        }
325    }
326    let mut 结果 = vec![];
327    for (i, checkpoint) in 存档列表.iter().enumerate() {
328        if let Some((step, path)) = checkpoint {
329            let content = read_to_string(path)?;
330            let config: 配置 = serde_yaml::from_str(&content).unwrap();
331            结果.push((*step, config));
332        } else {
333            return Err(format!("线程 {i} 没有找到 checkpoint 文件").into());
334        }
335    }
336    Ok(结果)
337}