use anyhow::Result as AnyResult;
use std::path::Path;
use semantic_search::Model;
use serde::Deserialize;
#[derive(Deserialize, Debug)]
pub struct Config {
#[serde(default)]
pub server: Server,
pub api: ApiConfig,
#[serde(default)]
pub bot: BotConfig,
}
#[derive(Deserialize, Debug)]
pub struct Server {
#[serde(default = "defaults::server_port")]
pub port: u16,
}
impl Default for Server {
fn default() -> Self {
Self {
port: defaults::server_port(),
}
}
}
#[derive(Deserialize, Debug)]
pub struct ApiConfig {
pub key: String,
#[serde(default)]
pub model: Model,
}
#[derive(Deserialize, Debug)]
pub struct BotConfig {
#[serde(default)]
pub token: String,
#[serde(default)]
pub owner: u64,
#[serde(default)]
pub whitelist: Vec<u64>,
#[serde(default = "defaults::sticker_set")]
pub sticker_set: String,
#[serde(default = "defaults::num_results")]
pub num_results: usize,
}
impl Default for BotConfig {
fn default() -> Self {
Self {
token: String::new(),
owner: 0,
whitelist: Vec::new(),
num_results: defaults::num_results(),
sticker_set: defaults::sticker_set(),
}
}
}
fn parse_config_from_str(content: &str) -> Result<Config, toml::de::Error> {
toml::from_str(content)
}
pub fn parse_config<T>(path: T) -> AnyResult<Config>
where
T: AsRef<Path>,
{
let content = std::fs::read_to_string(path)?;
Ok(parse_config_from_str(&content)?)
}
mod defaults {
pub const fn server_port() -> u16 {
8080
}
pub const fn num_results() -> usize {
8
}
pub fn sticker_set() -> String {
"meme".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test(content: &str, port: u16, key: &str, model: Model, bot_token: &str) {
let config = parse_config_from_str(content).unwrap();
assert_eq!(config.server.port, port);
assert_eq!(config.api.key, key);
assert_eq!(config.api.model, model);
assert_eq!(config.bot.token, bot_token);
}
#[test]
fn parse_config_1() {
let content = r#"
[server]
port = 8081
[api]
key = "test_key"
[bot]
token = "test_token"
"#;
test(
content,
8081,
"test_key",
Model::BgeLargeZhV1_5,
"test_token",
);
}
#[test]
fn parse_config_2() {
let content = r#"
[server]
port = 8080
[api]
key = "test_key"
model = "BAAI/bge-large-zh-v1.5"
"#;
test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
}
#[test]
fn parse_config_3() {
let content = r#"
[server]
[api]
key = "test_key"
model = "BAAI/bge-large-en-v1.5"
"#;
test(content, 8080, "test_key", Model::BgeLargeEnV1_5, "");
}
#[test]
fn parse_config_4() {
let content = r#"
[api]
key = "test_key"
"#;
test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
}
#[test]
fn parse_config_5() {
let content = r#"
[server]
port = 8081
[api]
key = "test_key"
[bot]
"#;
test(content, 8081, "test_key", Model::BgeLargeZhV1_5, "");
}
#[test]
#[should_panic(expected = "missing field `api`")]
fn parse_config_fail_1() {
let content = r"
[server]
port = 8080
";
test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
}
#[test]
#[should_panic(expected = "missing field `key`")]
fn parse_config_fail_2() {
let content = r"
[api]
";
test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
}
}