use crate::{
LazyLock,
application::{self, Agent, Application, ServerTag},
crypto,
encoding::base64,
error::Error,
extension::TomlTableExt,
helper,
};
use convert_case::{Case, Casing};
use serde::de::DeserializeOwned;
use std::{
borrow::Cow,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
use toml::value::{Table, Value};
mod config;
mod data;
mod env;
pub use data::{Data, SharedData};
pub use env::Env;
#[derive(Debug, Clone)]
pub struct State<T = ()> {
env: Env,
config: Table,
data: T,
}
impl<T> State<T> {
#[inline]
pub fn new(env: Env, data: T) -> Self {
Self {
env,
config: Table::new(),
data,
}
}
pub fn load_config(&mut self) {
let env = self.env.as_str();
let mut config_table = if let Ok(config_url) = std::env::var("ZINO_APP_CONFIG_URL") {
#[cfg(feature = "http-client")]
{
config::fetch_config_url(&config_url, env).unwrap_or_else(|err| {
tracing::error!("fail to fetch the config url `{config_url}`: {err}");
Table::new()
})
}
#[cfg(not(feature = "http-client"))]
{
tracing::error!("cannot fetch the config url `{config_url}`");
Table::new()
}
} else {
let format = std::env::var("ZINO_APP_CONFIG_FORMAT")
.map(|s| s.to_ascii_lowercase())
.unwrap_or_else(|_| "toml".to_owned());
let config_dir = Agent::config_dir();
if config_dir.exists() {
let config_file = format!("config.{env}.{format}");
let config_file_path = config_dir.join(&config_file);
config::read_config_file(&config_file_path, env).unwrap_or_else(|err| {
tracing::error!("fail to read the config file `{config_file}`: {err}");
Table::new()
})
} else {
tracing::warn!("no config file found in `{}`", config_dir.display());
Table::new()
}
};
for (key, value) in config_table.iter_mut() {
let Value::Table(table) = value else {
continue;
};
let prefix = key.to_ascii_uppercase();
for (key, value) in table.iter_mut() {
let name = format!("ZINO_{}_{}", &prefix, key.to_case(Case::Constant));
let Ok(s) = std::env::var(&name) else {
continue;
};
let result = match value.type_str() {
"string" => Ok(Value::String(s)),
"integer" => s.parse().map(Value::Integer).map_err(Error::from),
"float" => s.parse().map(Value::Float).map_err(Error::from),
"boolean" => s.parse().map(Value::Boolean).map_err(Error::from),
"datetime" => s.parse().map(Value::Datetime).map_err(Error::from),
_ => Err(Error::new(format!("cannot parse non-primitive value: {s}"))),
};
match result {
Ok(v) => *value = v,
Err(err) => tracing::error!("invalid environment variable `{name}`: {err}"),
}
}
}
self.config = config_table;
}
#[inline]
pub fn set_data(&mut self, data: T) {
self.data = data;
}
#[inline]
pub fn env(&self) -> &Env {
&self.env
}
#[inline]
pub fn config(&self) -> &Table {
&self.config
}
#[inline]
pub fn get_config(&self, key: &str) -> Option<&Table> {
self.config().get_table(key)
}
#[inline]
pub fn get_extension_config(&self, extension: &str) -> Option<&Table> {
self.config().get_table("extensions")?.get_table(extension)
}
#[inline]
pub fn parse_config<C: DeserializeOwned>(&self, key: &str) -> Option<Result<C, Error>> {
self.get_config(key)
.map(|t| serde_json::from_value(t.to_map().into()).map_err(Error::from))
}
#[inline]
pub fn parse_extension_config<C: DeserializeOwned>(
&self,
extension: &str,
) -> Option<Result<C, Error>> {
self.get_extension_config(extension)
.map(|t| serde_json::from_value(t.to_map().into()).map_err(Error::from))
}
#[inline]
pub fn data(&self) -> &T {
&self.data
}
#[inline]
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
}
pub fn listeners(&self) -> Vec<(ServerTag, SocketAddr)> {
let config = self.config();
let mut listeners = Vec::new();
if let Some(debug_server) = config.get_table("debug") {
let debug_host = debug_server
.get_str("host")
.and_then(|s| s.parse::<IpAddr>().ok())
.expect("field `debug.host` should be a str");
let debug_port = debug_server
.get_u16("port")
.expect("field `debug.port` should be an integer");
listeners.push((ServerTag::Debug, (debug_host, debug_port).into()));
}
if let Some(main_server) = config.get_table("main") {
let main_host = main_server
.get_str("host")
.and_then(|s| s.parse::<IpAddr>().ok())
.expect("field `main.host` should be a str");
let main_port = main_server
.get_u16("port")
.expect("field `main.port` should be an integer");
listeners.push((ServerTag::Main, (main_host, main_port).into()));
}
if config.contains_key("standby") {
let standbys = config
.get_array("standby")
.expect("field `standby` should be an array of tables");
for standby in standbys.iter().filter_map(|v| v.as_table()) {
let server_tag = standby.get_str("tag").unwrap_or("standby");
let standby_host = standby
.get_str("host")
.and_then(|s| s.parse::<IpAddr>().ok())
.expect("field `standby.host` should be a str");
let standby_port = standby
.get_u16("port")
.expect("field `standby.port` should be an integer");
listeners.push((server_tag.into(), (standby_host, standby_port).into()));
}
}
if listeners.is_empty() {
listeners.push((ServerTag::Main, (Ipv4Addr::LOCALHOST, 6080).into()));
}
listeners
}
}
impl State {
#[inline]
pub fn shared() -> &'static Self {
&SHARED_STATE
}
pub fn encrypt_password(config: &Table) -> Option<Cow<'_, str>> {
let password = config.get_str("password")?;
application::SECRET_KEY.get().and_then(|key| {
if base64::decode(password).is_ok_and(|data| crypto::decrypt(&data, key).is_ok()) {
Some(password.into())
} else {
crypto::encrypt(password.as_bytes(), key)
.ok()
.map(|bytes| base64::encode(bytes).into())
}
})
}
pub fn decrypt_password(config: &Table) -> Option<Cow<'_, str>> {
let password = config.get_str("password")?;
if let Ok(data) = base64::decode(password)
&& let Some(key) = application::SECRET_KEY.get()
&& let Ok(plaintext) = crypto::decrypt(&data, key)
{
return Some(String::from_utf8_lossy(&plaintext).into_owned().into());
}
if let Some(encrypted_password) = Self::encrypt_password(config).as_deref() {
let num_chars = password.len() / 4;
let masked_password = helper::mask_text(password, num_chars, num_chars);
tracing::warn!(
encrypted_password,
"raw password `{masked_password}` should be encypted"
);
}
Some(password.into())
}
pub fn format_authority(config: &Table, default_port: Option<u16>) -> String {
let mut authority = String::new();
let username = config.get_str("username").unwrap_or_default();
authority += username;
if let Some(password) = Self::decrypt_password(config) {
authority += &format!(":{password}@");
}
let host = config.get_str("host").unwrap_or("localhost");
authority += host;
if let Some(port) = config.get_u16("port").or(default_port) {
authority += &format!(":{port}");
}
authority
}
}
impl<T: Default> Default for State<T> {
#[inline]
fn default() -> Self {
State::new(*DEFAULT_ENV, T::default())
}
}
static DEFAULT_ENV: LazyLock<Env> = LazyLock::new(|| {
for arg in std::env::args().skip(1) {
if let Some(value) = arg.strip_prefix("--env=") {
let env: &'static str = value.to_owned().leak();
return env.into();
}
}
if let Ok(value) = std::env::var("ZINO_APP_ENV") {
let env: &'static str = value.to_owned().leak();
return env.into();
}
if cfg!(debug_assertions) {
Env::Dev
} else {
Env::Prod
}
});
static SHARED_STATE: LazyLock<State> = LazyLock::new(|| {
let mut state = State::default();
state.load_config();
state
});