Skip to main content

chai/interfaces/
command_line.rs

1use crate::config::配置;
2use crate::interfaces::{消息, 界面, 默认输入};
3use crate::{原始可编码对象, 原始当量信息, 原始键位分布信息, 码表项};
4use chrono::Local;
5use clap::{Parser, Subcommand};
6use csv::{ReaderBuilder, WriterBuilder};
7use serde::{Deserialize, Serialize};
8use std::fmt::Display;
9use std::fs::{create_dir_all, read_to_string, write, OpenOptions};
10use std::io::Write;
11use std::iter::FromIterator;
12use std::path::{Path, PathBuf};
13
14pub trait 命令行参数: Clone {
15    fn 是否为多线程(&self) -> bool;
16}
17
18/// 命令行参数的定义
19#[derive(Parser, Clone)]
20#[command(name = "汉字自动拆分系统")]
21#[command(author, version, about, long_about)]
22#[command(propagate_version = true)]
23pub struct 默认命令行参数 {
24    #[command(subcommand)]
25    pub command: 命令,
26}
27
28impl 命令行参数 for 默认命令行参数 {
29    fn 是否为多线程(&self) -> bool {
30        match &self.command {
31            命令::Optimize { threads, .. } => *threads != 1,
32            _ => false,
33        }
34    }
35}
36
37/// 编码和优化共用的数据参数
38#[derive(Parser, Clone)]
39pub struct 数据参数 {
40    /// 方案文件,默认为 config.yaml
41    pub config: Option<PathBuf>,
42    /// 频率序列表,默认为 elements.txt
43    #[arg(short, long, value_name = "FILE")]
44    pub encodables: Option<PathBuf>,
45    /// 单键用指分布表,默认为 assets 目录下的 key_distribution.txt
46    #[arg(short, long, value_name = "FILE")]
47    pub key_distribution: Option<PathBuf>,
48    /// 双键速度当量表,默认为 assets 目录下的 pair_equivalence.txt
49    #[arg(short, long, value_name = "FILE")]
50    pub pair_equivalence: Option<PathBuf>,
51}
52
53/// 命令行中所有可用的子命令
54#[derive(Subcommand, Clone)]
55pub enum 命令 {
56    #[command(about = "使用方案文件和拆分表计算出字词编码并统计各类评测指标")]
57    Encode {
58        #[command(flatten)]
59        data: 数据参数,
60    },
61    #[command(about = "基于配置文件优化决策")]
62    Optimize {
63        #[command(flatten)]
64        data: 数据参数,
65        /// 优化时使用的线程数
66        #[arg(short, long, default_value = "1")]
67        threads: usize,
68    },
69    /// 启动 Web API 服务器
70    #[command(about = "启动 HTTP API 服务器")]
71    Server {
72        /// 服务器端口号
73        #[arg(short, long, default_value = "3200")]
74        port: u16,
75    },
76}
77
78/// 通过命令行来使用 libchai 的入口,实现了界面特征
79pub struct 命令行<P: 命令行参数> {
80    pub 参数: P,
81    pub 输出目录: PathBuf,
82}
83
84pub fn 读取文本文件<I, T>(path: PathBuf) -> T
85where
86    I: for<'de> Deserialize<'de>,
87    T: FromIterator<I>,
88{
89    let mut reader = ReaderBuilder::new()
90        .delimiter(b'\t')
91        .has_headers(false)
92        .flexible(true)
93        .from_path(path)
94        .unwrap();
95    reader.deserialize().map(|x| x.unwrap()).collect()
96}
97
98impl<P: 命令行参数> 命令行<P> {
99    pub fn 新建(args: P, maybe_output_dir: Option<PathBuf>) -> Self {
100        let output_dir = maybe_output_dir.unwrap_or_else(|| {
101            let time = Local::now().format("%m-%d+%H_%M_%S").to_string();
102            PathBuf::from(format!("output-{time}"))
103        });
104        create_dir_all(output_dir.clone()).unwrap();
105        Self {
106            参数: args,
107            输出目录: output_dir,
108        }
109    }
110
111    pub fn 输出编码结果(&self, entries: Vec<码表项>) {
112        let path = self.输出目录.join("编码.txt");
113        let mut writer = WriterBuilder::new()
114            .delimiter(b'\t')
115            .has_headers(false)
116            .from_path(&path)
117            .unwrap();
118        for 码表项 {
119            词: name,
120            全码: full,
121            全码排名: full_rank,
122            简码: short,
123            简码排名: short_rank,
124        } in entries
125        {
126            writer
127                .serialize((&name, &full, &full_rank, &short, &short_rank))
128                .unwrap();
129        }
130        writer.flush().unwrap();
131        println!("已完成编码,结果保存在 {} 中", path.clone().display());
132    }
133
134    pub fn 输出评测指标<M: Display + Serialize>(&self, metric: M) {
135        let path = self.输出目录.join("评测指标.yaml");
136        print!("{metric}");
137        let metric_str = serde_yaml::to_string(&metric).unwrap();
138        write(&path, metric_str).unwrap();
139    }
140
141    pub fn 生成子命令行(&self, index: usize) -> 命令行<P> {
142        let child_dir = self.输出目录.join(format!("{index}"));
143        命令行::新建(self.参数.clone(), Some(child_dir))
144    }
145}
146
147pub fn 从命令行参数创建(参数: &默认命令行参数) -> 默认输入 {
148    let (config, encodables, key_distribution, pair_equivalence) = match &参数.command {
149        命令::Encode { data } | 命令::Optimize { data, .. } => (
150            data.config.clone(),
151            data.encodables.clone(),
152            data.key_distribution.clone(),
153            data.pair_equivalence.clone(),
154        ),
155        命令::Server { .. } => {
156            panic!("Server 命令不需要数据准备");
157        }
158    };
159    let config_path = config.unwrap_or(PathBuf::from("config.yaml"));
160    let config_content = read_to_string(&config_path)
161        .unwrap_or_else(|_| panic!("文件 {} 不存在", config_path.display()));
162    let config: 配置 = serde_yaml::from_str(&config_content).unwrap();
163    let elements_path = encodables.unwrap_or(PathBuf::from("elements.txt"));
164    let encodables: Vec<原始可编码对象> = 读取文本文件(elements_path);
165    let assets_dir = Path::new("assets");
166    let keq_path = key_distribution.unwrap_or(assets_dir.join("key_distribution.txt"));
167    let key_distribution: 原始键位分布信息 = 读取文本文件(keq_path);
168    let peq_path = pair_equivalence.unwrap_or(assets_dir.join("pair_equivalence.txt"));
169    let pair_equivalence: 原始当量信息 = 读取文本文件(peq_path);
170    默认输入 {
171        配置: config,
172        原始键位分布信息: key_distribution,
173        原始当量信息: pair_equivalence,
174        词列表: encodables,
175    }
176}
177
178impl<P: 命令行参数> 界面 for 命令行<P> {
179    fn 发送(&self, message: 消息) {
180        let mut writer: Box<dyn Write> = if self.参数.是否为多线程() {
181            let log_path = self.输出目录.join("log.txt");
182            let file = OpenOptions::new()
183                .create(true) // 如果文件不存在,则创建
184                .append(true) // 追加写入,不覆盖原有内容
185                .open(log_path)
186                .expect("Failed to open file");
187            Box::new(file)
188        } else {
189            Box::new(std::io::stdout())
190        };
191        let result = match message {
192            消息::TrialMax {
193                temperature,
194                accept_rate,
195            } => writeln!(
196                &mut writer,
197                "若温度为 {temperature:.2e},接受率为 {:.2}%",
198                accept_rate * 100.0
199            ),
200            消息::TrialMin {
201                temperature,
202                improve_rate,
203            } => writeln!(
204                &mut writer,
205                "若温度为 {temperature:.2e},改进率为 {:.2}%",
206                improve_rate * 100.0
207            ),
208            消息::Parameters { t_max, t_min } => writeln!(
209                &mut writer,
210                "参数寻找完成,从最高温 {t_max} 降到最低温 {t_min}……"
211            ),
212            消息::Elapsed { time } => writeln!(&mut writer, "计算一次评测用时:{time} μs"),
213            消息::Progress {
214                steps,
215                temperature,
216                metric,
217            } => writeln!(
218                &mut writer,
219                "已执行 {steps} 步,当前温度为 {temperature:.2e},当前评测指标如下:\n{metric}",
220            ),
221            消息::BetterSolution {
222                metric,
223                config,
224                save,
225            } => {
226                let 时刻 = Local::now();
227                let 时间戳 = 时刻.format("%m-%d+%H_%M_%S_%3f").to_string();
228                let 配置路径 = self.输出目录.join(format!("{时间戳}.yaml"));
229                let 指标路径 = self.输出目录.join(format!("{时间戳}.txt"));
230                if save {
231                    write(指标路径, metric.clone()).unwrap();
232                    write(配置路径, config).unwrap();
233                    writeln!(
234                        &mut writer,
235                        "方案文件保存于 {时间戳}.yaml 中,评测指标保存于 {时间戳}.metric.yaml 中",
236                    )
237                    .unwrap();
238                }
239                writeln!(
240                    &mut writer,
241                    "{} 系统搜索到了一个更好的方案,评测指标如下:\n{}",
242                    时刻.format("%H:%M:%S"),
243                    metric
244                )
245            }
246        };
247        result.unwrap()
248    }
249}