use json::{object, JsonValue};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fs;
use std::path::PathBuf;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Config {
pub default: String,
pub connections: BTreeMap<String, Connection>,
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
impl Config {
pub fn new() -> Config {
let mut connections = BTreeMap::new();
connections.insert("sqlite".to_string(), Connection::new("sqlite"));
connections.insert("mysql".to_string(), Connection::new("mysql"));
connections.insert("pgsql".to_string(), Connection::new("pgsql"));
connections.insert("mssql".to_string(), Connection::new("mssql"));
Self {
default: "sqlite".to_string(),
connections,
}
}
pub fn from(data: JsonValue) -> Config {
let default = data["default"].to_string();
let mut connections = BTreeMap::new();
for (key, value) in data["connections"].entries() {
let connection = Connection::from(value.clone()).clone();
connections.insert(key.to_string(), connection.clone());
}
Self {
default,
connections,
}
}
pub fn create(config_file: PathBuf, pkg_name: bool) -> Config {
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ConfigPkg {
pub br_db: Config,
}
impl ConfigPkg {
pub fn new() -> ConfigPkg {
let mut connections = BTreeMap::new();
connections.insert("sqlite".to_string(), Connection::new("sqlite"));
connections.insert("mysql".to_string(), Connection::new("mysql"));
connections.insert("pgsql".to_string(), Connection::new("pgsql"));
connections.insert("mssql".to_string(), Connection::new("mssql"));
Self {
br_db: Config {
default: "sqlite".to_string(),
connections,
},
}
}
}
match fs::read_to_string(config_file.clone()) {
Ok(e) => {
if pkg_name {
toml::from_str::<ConfigPkg>(&e)
.map(|c| c.br_db)
.unwrap_or_else(|_| Config::new())
} else {
toml::from_str::<Config>(&e).unwrap_or_else(|_| Config::new())
}
}
Err(_) => {
if pkg_name {
let data = ConfigPkg::new();
if let Some(parent) = config_file.parent() {
let _ = fs::create_dir_all(parent);
}
if let Ok(toml) = toml::to_string(&data) {
if let Some(path) = config_file.to_str() {
let _ = fs::write(path, toml);
}
}
data.br_db
} else {
let data = Config::new();
if let Some(parent) = config_file.parent() {
let _ = fs::create_dir_all(parent);
}
if let Ok(toml) = toml::to_string(&data) {
if let Some(path) = config_file.to_str() {
let _ = fs::write(path, toml);
}
}
data
}
}
}
}
pub fn set_connection(&mut self, name: &str, connection: JsonValue) {
let connection = Connection::from(connection);
self.connections.insert(name.to_string(), connection);
}
pub fn set_default(&mut self, name: &str) {
self.default = name.to_string();
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub enum Mode {
Mysql,
Mssql,
Sqlite,
Pgsql,
None,
}
impl Mode {
pub fn str(&mut self) -> String {
match self {
Mode::Mysql => "mysql",
Mode::Sqlite => "sqlite",
Mode::Mssql => "mssql",
Mode::Pgsql => "pgsql",
Mode::None => "",
}
.to_string()
}
pub fn from(name: &str) -> Self {
match name.to_lowercase().as_str() {
"mysql" => Mode::Mysql,
"sqlite" => Mode::Sqlite,
"mssql" => Mode::Mssql,
"pgsql" => Mode::Pgsql,
_ => Mode::None,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct PoolConfig {
pub min_connections: u32,
pub max_connections: u32,
pub connect_timeout_secs: u64,
pub read_timeout_secs: u64,
pub write_timeout_secs: u64,
pub keepalive_ms: u64,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_connections: 0,
max_connections: 400,
connect_timeout_secs: 5,
read_timeout_secs: 15,
write_timeout_secs: 20,
keepalive_ms: 5000,
}
}
}
impl PoolConfig {
pub fn from(data: &JsonValue) -> Self {
Self {
min_connections: data["min_connections"].as_u32().unwrap_or(0),
max_connections: data["max_connections"].as_u32().unwrap_or(400),
connect_timeout_secs: data["connect_timeout_secs"].as_u64().unwrap_or(5),
read_timeout_secs: data["read_timeout_secs"].as_u64().unwrap_or(15),
write_timeout_secs: data["write_timeout_secs"].as_u64().unwrap_or(20),
keepalive_ms: data["keepalive_ms"].as_u64().unwrap_or(5000),
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Connection {
pub mode: Mode,
pub hostname: String,
pub hostport: String,
pub database: String,
pub username: String,
pub userpass: String,
pub params: Vec<String>,
pub charset: Charset,
pub prefix: String,
pub debug: bool,
#[serde(default)]
pub pool: PoolConfig,
}
impl Default for Connection {
fn default() -> Self {
Self::new("sqlite")
}
}
impl Connection {
pub fn new(mode: &str) -> Connection {
let mut that = Self {
mode: Mode::from(mode),
hostname: "".to_string(),
hostport: "".to_string(),
database: "".to_string(),
username: "".to_string(),
userpass: "".to_string(),
params: vec![],
charset: Charset::Utf8mb4,
prefix: "".to_string(),
debug: false,
pool: PoolConfig::default(),
};
match Mode::from(mode) {
Mode::Mysql => {
that.hostname = "127.0.0.1".to_string();
that.hostport = "3306".to_string();
that.database = "test".to_string();
that.username = "test".to_string();
that.userpass = "test".to_string();
}
Mode::Mssql => {}
Mode::Sqlite => {
that.database = "db/app.db".to_string();
}
Mode::Pgsql => {
that.hostname = "127.0.0.1".to_string();
that.hostport = "5432".to_string();
that.username = "test".to_string();
that.userpass = "test".to_string();
}
Mode::None => {}
}
that
}
pub fn json(&mut self) -> JsonValue {
object! {
mode: self.mode.str(),
hostname: self.hostname.clone(),
hostport: self.hostport.clone(),
database: self.database.clone(),
username: self.username.clone(),
userpass:self.userpass.clone(),
params: self.params.clone(),
charset: self.charset.str(),
prefix: self.prefix.clone(),
debug: self.debug
}
}
pub fn from(data: JsonValue) -> Connection {
Self {
mode: Mode::from(data["mode"].as_str().unwrap_or("none")),
hostname: data["hostname"].to_string(),
hostport: data["hostport"].to_string(),
database: data["database"].to_string(),
username: data["username"].to_string(),
userpass: data["userpass"].to_string(),
params: data["params"].members().map(|x| x.to_string()).collect(),
charset: Charset::from(data["charset"].as_str().unwrap_or("utf8mb4")),
prefix: data["prefix"].as_str().unwrap_or("").to_string(),
debug: data["debug"].to_string().parse::<bool>().unwrap_or(false),
pool: PoolConfig::from(&data["pool"]),
}
}
pub fn get_dsn(self) -> String {
match self.mode {
Mode::Mysql => {
format!(
"mysql://{}:{}@{}:{}/{}",
self.username, self.userpass, self.hostname, self.hostport, self.database
)
}
Mode::Sqlite => {
let db_path = self.database.as_str();
let path_buf = PathBuf::from(db_path);
if !path_buf.is_file() {
if let Some(file_name) = path_buf.file_name() {
if let Some(file_name_str) = file_name.to_str() {
let dir_path = db_path.trim_end_matches(file_name_str);
let _ = fs::create_dir_all(dir_path);
}
}
}
path_buf.to_str().unwrap_or(db_path).to_string()
}
Mode::Mssql => format!(
"sqlsrv://{}:{}@{}:{}/{}",
self.username, self.userpass, self.hostname, self.hostport, self.database
),
Mode::Pgsql => format!(
"host={} user={} password={} dbname={}",
self.hostname, self.username, self.userpass, self.database
),
Mode::None => "".to_string(),
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub enum Charset {
Utf8mb4,
Utf8,
None,
}
impl Charset {
pub fn from(str: &str) -> Charset {
match str.to_lowercase().as_str() {
"utf8" => Charset::Utf8,
"utf8mb4" => Charset::Utf8mb4,
_ => Charset::None,
}
}
pub fn str(&self) -> String {
match self {
Charset::Utf8 => "utf8",
Charset::Utf8mb4 => "utf8mb4",
Charset::None => "",
}
.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use json::object;
#[test]
fn config_new_has_four_connections() {
let cfg = Config::new();
assert_eq!(cfg.connections.len(), 4);
assert!(cfg.connections.contains_key("sqlite"));
assert!(cfg.connections.contains_key("mysql"));
assert!(cfg.connections.contains_key("pgsql"));
assert!(cfg.connections.contains_key("mssql"));
}
#[test]
fn config_new_default_is_sqlite() {
let cfg = Config::new();
assert_eq!(cfg.default, "sqlite");
}
#[test]
fn config_default_equals_new() {
let a = Config::new();
let b = Config::default();
assert_eq!(a.default, b.default);
assert_eq!(a.connections.len(), b.connections.len());
for key in a.connections.keys() {
assert!(b.connections.contains_key(key));
}
}
#[test]
fn config_from_valid_json() {
let data = object! {
default: "mysql",
connections: {
myconn: {
mode: "mysql",
hostname: "10.0.0.1",
hostport: "3307",
database: "mydb",
username: "admin",
userpass: "secret",
params: [],
charset: "utf8",
prefix: "app_",
debug: true
}
}
};
let cfg = Config::from(data);
assert_eq!(cfg.default, "mysql");
assert_eq!(cfg.connections.len(), 1);
let conn = cfg.connections.get("myconn").expect("myconn should exist");
assert_eq!(conn.hostname, "10.0.0.1");
assert_eq!(conn.hostport, "3307");
assert_eq!(conn.database, "mydb");
assert_eq!(conn.username, "admin");
assert_eq!(conn.userpass, "secret");
assert_eq!(conn.prefix, "app_");
assert!(conn.debug);
}
#[test]
fn config_from_empty_json() {
let data = object! {};
let cfg = Config::from(data);
assert_eq!(cfg.default, "null");
assert_eq!(cfg.connections.len(), 0);
}
#[test]
fn config_from_missing_connections() {
let data = object! { default: "pgsql" };
let cfg = Config::from(data);
assert_eq!(cfg.default, "pgsql");
assert_eq!(cfg.connections.len(), 0);
}
#[test]
fn config_create_nonexistent_file_no_pkg_name() {
let dir = std::env::temp_dir().join("br_db_test_create_no_pkg");
let _ = fs::remove_dir_all(&dir);
let file = dir.join("config.toml");
let cfg = Config::create(file.clone(), false);
assert_eq!(cfg.default, "sqlite");
assert_eq!(cfg.connections.len(), 4);
assert!(file.is_file());
let cfg2 = Config::create(file.clone(), false);
assert_eq!(cfg2.default, "sqlite");
assert_eq!(cfg2.connections.len(), 4);
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn config_create_nonexistent_file_with_pkg_name() {
let dir = std::env::temp_dir().join("br_db_test_create_pkg");
let _ = fs::remove_dir_all(&dir);
let file = dir.join("config.toml");
let cfg = Config::create(file.clone(), true);
assert_eq!(cfg.default, "sqlite");
assert_eq!(cfg.connections.len(), 4);
assert!(file.is_file());
let cfg2 = Config::create(file.clone(), true);
assert_eq!(cfg2.default, "sqlite");
assert_eq!(cfg2.connections.len(), 4);
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn config_create_existing_file_without_pkg_name() {
let dir = std::env::temp_dir().join("br_db_test_create_existing");
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).expect("create temp dir");
let file = dir.join("config.toml");
let content = r#"
default = "mysql"
[connections.only]
mode = "Mysql"
hostname = "1.2.3.4"
hostport = "3306"
database = "custom"
username = "u"
userpass = "p"
params = []
charset = "Utf8mb4"
prefix = ""
debug = false
[connections.only.pool]
min_connections = 0
max_connections = 400
connect_timeout_secs = 5
read_timeout_secs = 15
write_timeout_secs = 20
keepalive_ms = 5000
"#;
fs::write(&file, content).expect("write temp config");
let cfg = Config::create(file.clone(), false);
assert_eq!(cfg.default, "mysql");
assert_eq!(cfg.connections.len(), 1);
assert!(cfg.connections.contains_key("only"));
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn config_set_connection_adds_new() {
let mut cfg = Config::new();
let conn_json = object! {
mode: "pgsql",
hostname: "db.example.com",
hostport: "5433",
database: "prod",
username: "admin",
userpass: "pw",
params: [],
charset: "utf8",
prefix: "v2_",
debug: false
};
cfg.set_connection("production", conn_json);
assert_eq!(cfg.connections.len(), 5);
let conn = cfg
.connections
.get("production")
.expect("production should exist");
assert_eq!(conn.hostname, "db.example.com");
assert_eq!(conn.hostport, "5433");
assert_eq!(conn.prefix, "v2_");
}
#[test]
fn config_set_connection_overwrites_existing() {
let mut cfg = Config::new();
let conn_json = object! {
mode: "mysql",
hostname: "new-host",
hostport: "3307",
database: "newdb",
username: "newuser",
userpass: "newpass",
params: [],
charset: "utf8mb4",
prefix: "",
debug: true
};
cfg.set_connection("mysql", conn_json);
assert_eq!(cfg.connections.len(), 4);
let conn = cfg.connections.get("mysql").expect("mysql should exist");
assert_eq!(conn.hostname, "new-host");
assert!(conn.debug);
}
#[test]
fn config_set_default_changes_default() {
let mut cfg = Config::new();
assert_eq!(cfg.default, "sqlite");
cfg.set_default("mysql");
assert_eq!(cfg.default, "mysql");
cfg.set_default("pgsql");
assert_eq!(cfg.default, "pgsql");
}
#[test]
fn mode_str_all_variants() {
assert_eq!(Mode::Mysql.str(), "mysql");
assert_eq!(Mode::Sqlite.str(), "sqlite");
assert_eq!(Mode::Mssql.str(), "mssql");
assert_eq!(Mode::Pgsql.str(), "pgsql");
assert_eq!(Mode::None.str(), "");
}
#[test]
fn mode_from_all_variants() {
assert!(matches!(Mode::from("mysql"), Mode::Mysql));
assert!(matches!(Mode::from("sqlite"), Mode::Sqlite));
assert!(matches!(Mode::from("mssql"), Mode::Mssql));
assert!(matches!(Mode::from("pgsql"), Mode::Pgsql));
}
#[test]
fn mode_from_case_insensitive() {
assert!(matches!(Mode::from("MYSQL"), Mode::Mysql));
assert!(matches!(Mode::from("Sqlite"), Mode::Sqlite));
assert!(matches!(Mode::from("PGSQL"), Mode::Pgsql));
}
#[test]
fn mode_from_unknown_returns_none() {
assert!(matches!(Mode::from("oracle"), Mode::None));
assert!(matches!(Mode::from(""), Mode::None));
assert!(matches!(Mode::from("redis"), Mode::None));
}
#[test]
fn pool_config_default_values() {
let pc = PoolConfig::default();
assert_eq!(pc.min_connections, 0);
assert_eq!(pc.max_connections, 400);
assert_eq!(pc.connect_timeout_secs, 5);
assert_eq!(pc.read_timeout_secs, 15);
assert_eq!(pc.write_timeout_secs, 20);
assert_eq!(pc.keepalive_ms, 5000);
}
#[test]
fn pool_config_from_full_data() {
let data = object! {
min_connections: 5,
max_connections: 100,
connect_timeout_secs: 10,
read_timeout_secs: 30,
write_timeout_secs: 60,
keepalive_ms: 10000
};
let pc = PoolConfig::from(&data);
assert_eq!(pc.min_connections, 5);
assert_eq!(pc.max_connections, 100);
assert_eq!(pc.connect_timeout_secs, 10);
assert_eq!(pc.read_timeout_secs, 30);
assert_eq!(pc.write_timeout_secs, 60);
assert_eq!(pc.keepalive_ms, 10000);
}
#[test]
fn pool_config_from_partial_data() {
let data = object! {
max_connections: 50
};
let pc = PoolConfig::from(&data);
assert_eq!(pc.min_connections, 0);
assert_eq!(pc.max_connections, 50);
assert_eq!(pc.connect_timeout_secs, 5);
assert_eq!(pc.read_timeout_secs, 15);
assert_eq!(pc.write_timeout_secs, 20);
assert_eq!(pc.keepalive_ms, 5000);
}
#[test]
fn pool_config_from_empty_data() {
let data = object! {};
let pc = PoolConfig::from(&data);
assert_eq!(pc.min_connections, 0);
assert_eq!(pc.max_connections, 400);
assert_eq!(pc.connect_timeout_secs, 5);
assert_eq!(pc.read_timeout_secs, 15);
assert_eq!(pc.write_timeout_secs, 20);
assert_eq!(pc.keepalive_ms, 5000);
}
#[test]
fn connection_new_mysql() {
let conn = Connection::new("mysql");
assert!(matches!(conn.mode, Mode::Mysql));
assert_eq!(conn.hostname, "127.0.0.1");
assert_eq!(conn.hostport, "3306");
assert_eq!(conn.database, "test");
assert_eq!(conn.username, "test");
assert_eq!(conn.userpass, "test");
assert!(!conn.debug);
}
#[test]
fn connection_new_sqlite() {
let conn = Connection::new("sqlite");
assert!(matches!(conn.mode, Mode::Sqlite));
assert_eq!(conn.database, "db/app.db");
assert_eq!(conn.hostname, "");
assert_eq!(conn.hostport, "");
}
#[test]
fn connection_new_pgsql() {
let conn = Connection::new("pgsql");
assert!(matches!(conn.mode, Mode::Pgsql));
assert_eq!(conn.hostname, "127.0.0.1");
assert_eq!(conn.hostport, "5432");
assert_eq!(conn.username, "test");
assert_eq!(conn.userpass, "test");
assert_eq!(conn.database, "");
}
#[test]
fn connection_new_mssql() {
let conn = Connection::new("mssql");
assert!(matches!(conn.mode, Mode::Mssql));
assert_eq!(conn.hostname, "");
assert_eq!(conn.hostport, "");
}
#[test]
fn connection_new_unknown() {
let conn = Connection::new("oracle");
assert!(matches!(conn.mode, Mode::None));
assert_eq!(conn.hostname, "");
assert_eq!(conn.database, "");
}
#[test]
fn connection_default_is_sqlite() {
let conn = Connection::default();
assert!(matches!(conn.mode, Mode::Sqlite));
assert_eq!(conn.database, "db/app.db");
}
#[test]
fn connection_from_full_json() {
let data = object! {
mode: "pgsql",
hostname: "pg.local",
hostport: "5433",
database: "appdb",
username: "pguser",
userpass: "pgpass",
params: ["sslmode=require"],
charset: "utf8",
prefix: "t_",
debug: true,
pool: {
min_connections: 2,
max_connections: 50,
connect_timeout_secs: 3,
read_timeout_secs: 10,
write_timeout_secs: 10,
keepalive_ms: 3000
}
};
let conn = Connection::from(data);
assert!(matches!(conn.mode, Mode::Pgsql));
assert_eq!(conn.hostname, "pg.local");
assert_eq!(conn.hostport, "5433");
assert_eq!(conn.database, "appdb");
assert_eq!(conn.username, "pguser");
assert_eq!(conn.userpass, "pgpass");
assert_eq!(conn.params, vec!["sslmode=require".to_string()]);
assert!(matches!(conn.charset, Charset::Utf8));
assert_eq!(conn.prefix, "t_");
assert!(conn.debug);
assert_eq!(conn.pool.min_connections, 2);
assert_eq!(conn.pool.max_connections, 50);
}
#[test]
fn connection_from_partial_json() {
let data = object! {
mode: "mysql",
hostname: "db.host"
};
let conn = Connection::from(data);
assert!(matches!(conn.mode, Mode::Mysql));
assert_eq!(conn.hostname, "db.host");
assert_eq!(conn.hostport, "null");
assert_eq!(conn.prefix, "");
assert!(!conn.debug);
}
#[test]
fn connection_json_roundtrip() {
let mut original = Connection::new("mysql");
original.prefix = "pre_".to_string();
original.debug = true;
let json_val = original.json();
let mut restored = Connection::from(json_val);
assert_eq!(original.hostname, restored.hostname);
assert_eq!(original.hostport, restored.hostport);
assert_eq!(original.database, restored.database);
assert_eq!(original.username, restored.username);
assert_eq!(original.userpass, restored.userpass);
assert_eq!(original.prefix, restored.prefix);
assert_eq!(original.debug, restored.debug);
assert_eq!(original.charset.str(), restored.charset.str());
assert_eq!(original.mode.str(), restored.mode.str());
}
#[test]
fn connection_get_dsn_mysql() {
let conn = Connection::new("mysql");
let dsn = conn.get_dsn();
assert_eq!(dsn, "mysql://test:test@127.0.0.1:3306/test");
}
#[test]
fn connection_get_dsn_sqlite() {
let mut conn = Connection::new("sqlite");
conn.database = "/tmp/br_db_test_dsn.db".to_string();
let dsn = conn.get_dsn();
assert_eq!(dsn, "/tmp/br_db_test_dsn.db");
}
#[test]
fn connection_get_dsn_mssql() {
let mut conn = Connection::new("mssql");
conn.hostname = "mssql.local".to_string();
conn.hostport = "1433".to_string();
conn.username = "sa".to_string();
conn.userpass = "pass".to_string();
conn.database = "master".to_string();
let dsn = conn.get_dsn();
assert_eq!(dsn, "sqlsrv://sa:pass@mssql.local:1433/master");
}
#[test]
fn connection_get_dsn_pgsql() {
let mut conn = Connection::new("pgsql");
conn.database = "mydb".to_string();
let dsn = conn.get_dsn();
assert_eq!(dsn, "host=127.0.0.1 user=test password=test dbname=mydb");
}
#[test]
fn connection_get_dsn_none() {
let conn = Connection::new("unknown");
let dsn = conn.get_dsn();
assert_eq!(dsn, "");
}
#[test]
fn charset_from_utf8() {
assert!(matches!(Charset::from("utf8"), Charset::Utf8));
}
#[test]
fn charset_from_utf8mb4() {
assert!(matches!(Charset::from("utf8mb4"), Charset::Utf8mb4));
}
#[test]
fn charset_from_case_insensitive() {
assert!(matches!(Charset::from("UTF8"), Charset::Utf8));
assert!(matches!(Charset::from("UTF8MB4"), Charset::Utf8mb4));
}
#[test]
fn charset_from_unknown() {
assert!(matches!(Charset::from("latin1"), Charset::None));
assert!(matches!(Charset::from(""), Charset::None));
}
#[test]
fn charset_str_all_variants() {
assert_eq!(Charset::Utf8.str(), "utf8");
assert_eq!(Charset::Utf8mb4.str(), "utf8mb4");
assert_eq!(Charset::None.str(), "");
}
}