semantic_search_cli/
config.rs

1//! Configuration file parser.
2
3use anyhow::Result as AnyResult;
4use std::path::Path;
5
6use semantic_search::Model;
7use serde::Deserialize;
8
9/// Structure of the configuration file.
10#[derive(Deserialize, Debug)]
11pub struct Config {
12    /// Server configuration.
13    #[serde(default)]
14    pub server: Server,
15    /// API configuration.
16    pub api: ApiConfig,
17    /// Telegram bot configuration.
18    #[serde(default)]
19    pub bot: BotConfig,
20}
21
22/// Server configuration.
23#[derive(Deserialize, Debug)]
24pub struct Server {
25    /// Port for the server. Default is 8080.
26    #[serde(default = "defaults::server_port")]
27    pub port: u16,
28}
29
30impl Default for Server {
31    fn default() -> Self {
32        Self {
33            port: defaults::server_port(),
34        }
35    }
36}
37
38/// API configuration.
39#[derive(Deserialize, Debug)]
40pub struct ApiConfig {
41    /// API key for Silicon Cloud.
42    pub key: String,
43    /// Model to use for embedding.
44    #[serde(default)]
45    pub model: Model,
46}
47
48/// Telegram bot configuration.
49#[derive(Deserialize, Debug)]
50pub struct BotConfig {
51    /// Token for the Telegram bot.
52    #[serde(default)]
53    pub token: String,
54    /// Telegram user ID of the bot owner.
55    #[serde(default)]
56    pub owner: u64,
57    /// White list of user IDs that can use the bot.
58    #[serde(default)]
59    pub whitelist: Vec<u64>,
60    /// Sticker set id for the bot (Optional).
61    #[serde(default = "defaults::sticker_set")]
62    pub sticker_set: String,
63    /// Number of results to return.
64    #[serde(default = "defaults::num_results")]
65    pub num_results: usize,
66}
67
68impl Default for BotConfig {
69    fn default() -> Self {
70        Self {
71            token: String::new(),
72            owner: 0,
73            whitelist: Vec::new(),
74            num_results: defaults::num_results(),
75            sticker_set: defaults::sticker_set(),
76        }
77    }
78}
79
80/// Parse the configuration into a `Config` structure.
81///
82/// # Errors
83///
84/// Returns an [`Error`](toml::de::Error) if the configuration file is not valid, like missing fields.
85fn parse_config_from_str(content: &str) -> Result<Config, toml::de::Error> {
86    toml::from_str(content)
87}
88
89/// Parse the configuration file into a `Config` structure.
90///
91/// # Errors
92///
93/// Returns an [IO error](std::io::Error) if reading fails, or a [TOML error](toml::de::Error) if parsing fails.
94pub fn parse_config<T>(path: T) -> AnyResult<Config>
95where
96    T: AsRef<Path>,
97{
98    let content = std::fs::read_to_string(path)?;
99    Ok(parse_config_from_str(&content)?)
100}
101
102/// Default values for the configuration.
103mod defaults {
104    /// Default port for the server.
105    pub const fn server_port() -> u16 {
106        8080
107    }
108    /// Number of results to return.
109    pub const fn num_results() -> usize {
110        8
111    }
112    /// Sticker set id for the bot.
113    pub fn sticker_set() -> String {
114        "meme".to_string()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    fn test(content: &str, port: u16, key: &str, model: Model, bot_token: &str) {
123        let config = parse_config_from_str(content).unwrap();
124        assert_eq!(config.server.port, port);
125        assert_eq!(config.api.key, key);
126        assert_eq!(config.api.model, model);
127        assert_eq!(config.bot.token, bot_token);
128    }
129
130    #[test]
131    fn parse_config_1() {
132        let content = r#"
133            [server]
134            port = 8081
135
136            [api]
137            key = "test_key"
138
139            [bot]
140            token = "test_token"
141        "#;
142        test(
143            content,
144            8081,
145            "test_key",
146            Model::BgeLargeZhV1_5,
147            "test_token",
148        );
149    }
150
151    #[test]
152    fn parse_config_2() {
153        let content = r#"
154            [server]
155            port = 8080
156
157            [api]
158            key = "test_key"
159            model = "BAAI/bge-large-zh-v1.5"
160        "#;
161        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
162    }
163
164    #[test]
165    fn parse_config_3() {
166        let content = r#"
167            [server]
168
169            [api]
170            key = "test_key"
171            model = "BAAI/bge-large-en-v1.5"
172        "#;
173        test(content, 8080, "test_key", Model::BgeLargeEnV1_5, "");
174    }
175
176    #[test]
177    fn parse_config_4() {
178        let content = r#"
179            [api]
180            key = "test_key"
181        "#;
182        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
183    }
184
185    #[test]
186    fn parse_config_5() {
187        let content = r#"
188            [server]
189            port = 8081
190
191            [api]
192            key = "test_key"
193
194            [bot]
195        "#;
196        test(content, 8081, "test_key", Model::BgeLargeZhV1_5, "");
197    }
198
199    #[test]
200    #[should_panic(expected = "missing field `api`")]
201    fn parse_config_fail_1() {
202        let content = r"
203            [server]
204            port = 8080
205        ";
206        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
207    }
208
209    #[test]
210    #[should_panic(expected = "missing field `key`")]
211    fn parse_config_fail_2() {
212        let content = r"
213            [api]
214        ";
215        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
216    }
217}