use serde::{Deserialize, Serialize};
use socket2::TcpKeepalive;
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::time::{Duration, Instant};
#[allow(dead_code)]
#[derive(Deserialize, Serialize, Debug, Copy, Clone, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PoolMode {
#[default]
Transaction,
Session,
Statement,
}
impl std::fmt::Display for PoolMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolMode::Transaction => write!(f, "transaction"),
PoolMode::Session => write!(f, "session"),
PoolMode::Statement => write!(f, "statement"),
}
}
}
#[allow(dead_code)]
#[derive(Deserialize, Serialize, Debug, Copy, Clone, Default, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum ServerRole {
Primary,
#[default]
Replica,
Standby,
}
impl std::fmt::Display for ServerRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServerRole::Primary => write!(f, "primary"),
ServerRole::Replica => write!(f, "replica"),
ServerRole::Standby => write!(f, "standby"),
}
}
}
#[derive(Deserialize, Serialize, Debug, Copy, Clone, Default, PartialEq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum LogFormat {
#[default]
Plain,
Json,
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct Config {
#[serde(skip)]
#[serde(default = "std::time::Instant::now")]
pub created_at: Instant,
#[serde(skip)]
#[serde(default)]
pub reloads: usize,
#[serde(default = "General::default")]
pub general: General,
#[serde(default = "Admin::default")]
pub admin: Admin,
#[serde(default = "Config::databases_default")]
pub databases: BTreeMap<String, Database>,
#[serde(default = "General::default_network")]
pub network: BTreeMap<String, Network>,
}
impl Config {
pub fn clients_network(&self) -> Network {
match self.network.get("clients") {
Some(network) => network.clone(),
None => Network::default(),
}
}
pub fn servers_network(&self) -> Network {
match self.network.get("servers") {
Some(network) => network.clone(),
None => Network::default(),
}
}
fn databases_default() -> BTreeMap<String, Database> {
let mut databases = BTreeMap::new();
databases.insert("postgres".to_owned(), Database::default());
databases
}
}
impl Default for Config {
fn default() -> Config {
Config {
created_at: Instant::now(),
reloads: 0,
general: General::default(),
admin: Admin::default(),
databases: Config::databases_default(),
network: General::default_network(),
}
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct General {
#[serde(default = "General::default_host")]
pub host: String,
#[serde(default = "General::default_port")]
pub port: u16,
#[serde(default = "General::default_max_connections")]
pub max_connections: u64,
#[serde(default = "General::default_tls_certificate")]
pub tls_certificate: PathBuf,
#[serde(default = "General::default_tls_private_key")]
pub tls_private_key: PathBuf,
#[serde(default = "General::default_worker_threads")]
pub worker_threads: usize,
#[serde(default = "General::default_shutdown_timeout")]
pub shutdown_timeout: u64,
#[serde(default = "General::default_client_login_timeout")]
pub client_login_timeout: u64,
}
impl Default for General {
fn default() -> General {
General {
host: General::default_host(),
port: General::default_port(),
max_connections: General::default_max_connections(),
tls_certificate: General::default_tls_certificate(),
tls_private_key: General::default_tls_private_key(),
worker_threads: General::default_worker_threads(),
shutdown_timeout: General::default_shutdown_timeout(),
client_login_timeout: General::default_client_login_timeout(),
}
}
}
impl General {
pub fn host(&self) -> &str {
&self.host
}
pub fn port(&self) -> u16 {
self.port
}
pub fn max_connections(&self) -> u64 {
self.max_connections
}
pub fn default_max_connections() -> u64 {
32_768
}
pub fn default_host() -> String {
"0.0.0.0".to_owned()
}
pub fn default_port() -> u16 {
6432
}
pub fn default_network() -> BTreeMap<String, Network> {
let mut network = BTreeMap::new();
network.insert("clients".to_owned(), Network::default());
network.insert("servers".to_owned(), Network::default());
network
}
pub fn default_shutdown_timeout() -> u64 {
60_000
}
pub fn default_worker_threads() -> usize {
4
}
pub fn default_tls_certificate() -> PathBuf {
PathBuf::from("server.cert")
}
pub fn default_tls_private_key() -> PathBuf {
PathBuf::from("server.key")
}
pub fn default_client_login_timeout() -> u64 {
i64::MAX as u64
}
pub fn shutdown_timeout(&self) -> Duration {
Duration::from_millis(self.shutdown_timeout)
}
pub fn client_login_timeout(&self) -> Duration {
Duration::from_millis(self.client_login_timeout)
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct Database {
#[serde(default = "Database::default_database_name")]
pub database_name: String,
#[serde(default = "Database::default_sharding")]
pub sharding: Sharding,
#[serde(default = "Database::default_load_balancing")]
pub load_balancing: LoadBalancing,
#[serde(default = "Database::default_users")]
pub users: BTreeMap<String, User>,
#[serde(default = "Database::default_shards")]
pub shards: BTreeMap<String, Shard>,
}
impl Database {
pub fn default_database_name() -> String {
"postgres".to_owned()
}
fn default_sharding() -> Sharding {
Sharding::default()
}
fn default_load_balancing() -> LoadBalancing {
LoadBalancing::default()
}
fn default_users() -> BTreeMap<String, User> {
let mut users = BTreeMap::new();
users.insert("postgres".to_owned(), User::default());
users
}
fn default_shards() -> BTreeMap<String, Shard> {
let mut shards = BTreeMap::new();
shards.insert("0".to_owned(), Shard::default());
shards
}
pub fn from_url(url: &str) -> Result<Self, url::ParseError> {
let url = url::Url::parse(url)?;
let database_name = url.path().replace('/', "");
let username = url.username();
let password = url.password().unwrap_or(username);
let host = url.host_str().unwrap_or("localhost");
let port = url.port().unwrap_or(5432);
let shard = Shard {
servers: BTreeMap::from_iter(vec![(
database_name.to_owned(),
Server {
host: host.to_owned(),
port,
role: ServerRole::Primary,
..Server::default()
},
)]),
};
let database = Self {
database_name: database_name.clone(),
users: BTreeMap::from_iter(vec![(
username.to_owned(),
User {
password: password.to_owned(),
..Default::default()
},
)]),
shards: BTreeMap::from_iter(vec![("0".to_owned(), shard)]),
..Default::default()
};
Ok(database)
}
}
impl Default for Database {
fn default() -> Database {
Database {
database_name: Database::default_database_name(),
sharding: Database::default_sharding(),
load_balancing: Database::default_load_balancing(),
users: Database::default_users(),
shards: Database::default_shards(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ShardingFunction {
#[default]
PgBigintHash,
}
impl std::fmt::Display for ShardingFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShardingFunction::PgBigintHash => write!(f, "pg_bigint_hash"),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Sharding {
#[serde(default = "Sharding::default_sharding_function")]
pub sharding_function: ShardingFunction,
#[serde(default = "Sharding::default_table")]
pub table: String,
#[serde(default = "Sharding::default_column")]
pub column: String,
#[serde(default = "Sharding::default_foreign_key")]
pub foreign_key: String,
}
impl Sharding {
pub fn default_sharding_function() -> ShardingFunction {
ShardingFunction::PgBigintHash
}
pub fn default_table() -> String {
"users".to_owned()
}
pub fn default_column() -> String {
"id".to_owned()
}
pub fn default_foreign_key() -> String {
"user_id".to_owned()
}
}
impl Default for Sharding {
fn default() -> Self {
Sharding {
sharding_function: Sharding::default_sharding_function(),
table: Sharding::default_table(),
column: Sharding::default_column(),
foreign_key: Sharding::default_foreign_key(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum LoadBalancingAlgorithm {
#[default]
Random,
RoundRobin,
#[serde(alias = "lac")]
LeastActiveConnections,
}
impl std::fmt::Display for LoadBalancingAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LoadBalancingAlgorithm::Random => write!(f, "random"),
LoadBalancingAlgorithm::RoundRobin => write!(f, "round_robin"),
LoadBalancingAlgorithm::LeastActiveConnections => write!(f, "least_active_connections"),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LoadBalancing {
#[serde(default = "LoadBalancing::default_algorithm")]
pub algorithm: LoadBalancingAlgorithm,
#[serde(default = "LoadBalancing::default_write_functions")]
pub write_functions: Vec<String>,
}
impl LoadBalancing {
pub fn default_algorithm() -> LoadBalancingAlgorithm {
LoadBalancingAlgorithm::Random
}
fn default_write_functions() -> Vec<String> {
vec!["pgml.train".to_owned(), "pgml.deploy".to_owned()]
}
pub fn write_functions(&self) -> &[String] {
&self.write_functions
}
pub fn random(&self) -> bool {
self.algorithm == LoadBalancingAlgorithm::Random
}
pub fn round_robin(&self) -> bool {
self.algorithm == LoadBalancingAlgorithm::RoundRobin
}
pub fn least_active_connections(&self) -> bool {
self.algorithm == LoadBalancingAlgorithm::LeastActiveConnections
}
pub fn algorithm(&self) -> &LoadBalancingAlgorithm {
&self.algorithm
}
}
impl Default for LoadBalancing {
fn default() -> Self {
LoadBalancing {
algorithm: LoadBalancing::default_algorithm(),
write_functions: LoadBalancing::default_write_functions(),
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
#[serde(rename_all = "lowercase")]
pub enum AuthMethod {
Plain,
Md5,
Trust,
}
impl std::fmt::Display for AuthMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthMethod::Plain => write!(f, "plain"),
AuthMethod::Md5 => write!(f, "md5"),
AuthMethod::Trust => write!(f, "trust"),
}
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct User {
#[serde(default = "User::default_password")]
pub password: String,
#[serde(default = "User::default_auth_method")]
pub auth_method: AuthMethod,
pub server_user: Option<String>,
pub server_password: Option<String>,
#[serde(default = "User::default_pool_size")]
pub pool_size: usize,
#[serde(default = "User::default_min_pool_size")]
pub min_pool_size: usize,
#[serde(default = "User::default_pool_mode")]
pub pool_mode: PoolMode,
#[serde(default = "User::default_primary_reads")]
pub primary_reads: bool,
#[serde(default = "User::default_server_idle_timeout")]
pub server_idle_timeout: u64,
#[serde(default = "User::default_idle_client_timeout")]
pub idle_client_timeout: u64,
#[serde(default = "User::default_ban_timeout")]
pub ban_timeout: u64,
#[serde(default = "User::default_healthcheck_timeout")]
pub healthcheck_timeout: u64,
#[serde(default = "User::default_checkout_timeout")]
pub checkout_timeout: u64,
#[serde(default = "User::default_connect_timeout")]
pub connect_timeout: u64,
#[serde(default = "User::default_statement_timeout")]
pub statement_timeout: u64,
#[serde(default = "User::default_idle_transaction_timeout")]
pub idle_transaction_timeout: u64,
#[serde(default = "User::default_server_lifetime")]
pub server_lifetime: u64,
#[serde(default = "User::default_stats_interval")]
pub stats_interval: u64,
#[serde(default = "User::default_prepared_statements_cache")]
pub prepared_statements_cache: usize,
#[serde(default = "User::default_queue_priority_factor")]
pub queue_priority_factor: usize,
#[serde(default = "User::default_query_parser_enabled")]
pub query_parser_enabled: bool,
}
impl Default for User {
fn default() -> User {
User {
password: User::default_password(),
auth_method: User::default_auth_method(),
server_user: None,
server_password: None,
pool_size: User::default_pool_size(),
min_pool_size: User::default_min_pool_size(),
pool_mode: User::default_pool_mode(),
primary_reads: User::default_primary_reads(),
server_idle_timeout: User::default_server_idle_timeout(),
idle_client_timeout: User::default_idle_client_timeout(),
ban_timeout: User::default_ban_timeout(),
healthcheck_timeout: User::default_healthcheck_timeout(),
checkout_timeout: User::default_checkout_timeout(),
connect_timeout: User::default_connect_timeout(),
statement_timeout: User::default_statement_timeout(),
idle_transaction_timeout: User::default_idle_transaction_timeout(),
server_lifetime: User::default_server_lifetime(),
stats_interval: User::default_stats_interval(),
prepared_statements_cache: User::default_prepared_statements_cache(),
queue_priority_factor: User::default_queue_priority_factor(),
query_parser_enabled: User::default_query_parser_enabled(),
}
}
}
impl User {
fn default_password() -> String {
"postgres".to_owned()
}
pub fn default_auth_method() -> AuthMethod {
AuthMethod::Plain
}
pub fn default_pool_size() -> usize {
10
}
pub fn default_min_pool_size() -> usize {
0
}
pub fn default_pool_mode() -> PoolMode {
PoolMode::Transaction
}
pub fn default_primary_reads() -> bool {
true
}
pub fn default_server_idle_timeout() -> u64 {
60_000
}
pub fn default_idle_client_timeout() -> u64 {
i64::MAX as u64
}
pub fn default_ban_timeout() -> u64 {
60_000
}
pub fn default_healthcheck_timeout() -> u64 {
60_000
}
pub fn default_checkout_timeout() -> u64 {
60_000
}
pub fn default_connect_timeout() -> u64 {
60_000
}
pub fn default_statement_timeout() -> u64 {
i64::MAX as u64 }
pub fn default_idle_transaction_timeout() -> u64 {
i64::MAX as u64
}
pub fn default_server_lifetime() -> u64 {
3_600 * 24 * 1_000
}
pub fn default_stats_interval() -> u64 {
15_000
}
pub fn default_prepared_statements_cache() -> usize {
100
}
pub fn default_queue_priority_factor() -> usize {
10
}
pub fn default_query_parser_enabled() -> bool {
true
}
pub fn server_idle_timeout(&self) -> Duration {
Duration::from_millis(self.server_idle_timeout)
}
pub fn idle_client_timeout(&self) -> Duration {
Duration::from_millis(self.idle_client_timeout)
}
pub fn ban_timeout(&self) -> Duration {
Duration::from_millis(self.ban_timeout)
}
pub fn healthcheck_timeout(&self) -> Duration {
Duration::from_millis(self.healthcheck_timeout)
}
pub fn checkout_timeout(&self) -> Duration {
Duration::from_millis(self.checkout_timeout)
}
pub fn connect_timeout(&self) -> Duration {
Duration::from_millis(self.connect_timeout)
}
pub fn statement_timeout(&self) -> Duration {
Duration::from_millis(self.statement_timeout)
}
pub fn idle_transaction_timeout(&self) -> Duration {
Duration::from_millis(self.idle_transaction_timeout)
}
pub fn server_lifetime(&self) -> Duration {
Duration::from_millis(self.server_lifetime)
}
pub fn stats_interval(&self) -> Duration {
Duration::from_millis(self.stats_interval)
}
pub fn server_password(&self) -> &str {
if let Some(ref server_password) = self.server_password {
server_password
} else {
&self.password
}
}
pub fn query_parser_enabled(&self) -> bool {
self.query_parser_enabled
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct Shard {
#[serde(default = "Shard::default_servers")]
pub servers: BTreeMap<String, Server>,
}
impl Shard {
fn default_servers() -> BTreeMap<String, Server> {
let mut servers = BTreeMap::new();
servers.insert("postgres".to_owned(), Server::default());
servers
}
}
impl Default for Shard {
fn default() -> Shard {
Shard {
servers: Shard::default_servers(),
}
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct Server {
#[serde(default = "Server::default_role")]
pub role: ServerRole,
#[serde(default = "Server::default_host")]
pub host: String,
#[serde(default = "Server::default_port")]
pub port: u16,
pub database_name: Option<String>,
}
impl Server {
fn default_role() -> ServerRole {
ServerRole::Replica
}
fn default_host() -> String {
"localhost".to_owned()
}
fn default_port() -> u16 {
5432
}
}
impl Default for Server {
fn default() -> Server {
Server {
role: Server::default_role(),
host: Server::default_host(),
port: Server::default_port(),
database_name: None,
}
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Hash)]
pub struct Network {
#[serde(default = "Network::default_tcp_nodelay")]
pub tcp_nodelay: bool,
#[serde(default = "Network::default_tcp_keepalives_interval")]
pub tcp_keepalives_interval: u64,
#[serde(default = "Network::default_tcp_keepalives_count")]
pub tcp_keepalives_count: u32,
#[serde(default = "Network::default_tcp_keepalives_idle")]
pub tcp_keepalives_idle: u64,
#[serde(default = "Network::default_tcp_user_timeout")]
pub tcp_user_timeout: u64,
}
impl Network {
fn tcp_keepalives_interval(&self) -> Duration {
Duration::from_secs(self.tcp_keepalives_interval)
}
fn tcp_keepalives_count(&self) -> u32 {
self.tcp_keepalives_count
}
fn tcp_keepalives_idle(&self) -> Duration {
Duration::from_secs(self.tcp_keepalives_idle)
}
pub fn tcp_user_timeout(&self) -> Option<Duration> {
Some(Duration::from_millis(self.tcp_user_timeout))
}
pub fn keepalive(&self) -> TcpKeepalive {
TcpKeepalive::new()
.with_interval(self.tcp_keepalives_interval())
.with_retries(self.tcp_keepalives_count())
.with_time(self.tcp_keepalives_idle())
}
pub fn default_tcp_nodelay() -> bool {
true
}
pub fn default_tcp_keepalives_interval() -> u64 {
60
}
pub fn default_tcp_keepalives_count() -> u32 {
10
}
pub fn default_tcp_keepalives_idle() -> u64 {
60
}
pub fn default_tcp_user_timeout() -> u64 {
60_000
}
}
impl Default for Network {
fn default() -> Network {
Network {
tcp_nodelay: Network::default_tcp_nodelay(),
tcp_keepalives_interval: Network::default_tcp_keepalives_interval(),
tcp_keepalives_count: Network::default_tcp_keepalives_count(),
tcp_keepalives_idle: Network::default_tcp_keepalives_idle(),
tcp_user_timeout: Network::default_tcp_user_timeout(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Admin {
#[serde(default = "Admin::default_database_name")]
pub database_name: String,
#[serde(default = "Admin::default_users")]
pub users: BTreeMap<String, AdminUser>,
}
impl Admin {
pub fn database_name(&self) -> &str {
&self.database_name
}
pub fn exists(&self, user: &str, password: &str) -> bool {
Some(self.users.get(user).map(|user| &user.password))
.flatten()
.map(|admin_password| admin_password == password)
.unwrap_or(false)
}
pub fn password(&self, user: &str) -> Option<&str> {
self.users.get(user).map(|user| user.password.as_str())
}
fn default_database_name() -> String {
"pgcat".to_owned()
}
fn default_users() -> BTreeMap<String, AdminUser> {
let mut users = BTreeMap::new();
users.insert("pgcat".to_owned(), AdminUser::default());
users
}
}
impl Default for Admin {
fn default() -> Admin {
Admin {
database_name: Admin::default_database_name(),
users: Admin::default_users(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AdminUser {
#[serde(default = "AdminUser::default_password")]
pub password: String,
}
impl AdminUser {
fn default_password() -> String {
use rand::Rng;
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(|c| c as char)
.collect::<String>()
}
}
impl Default for AdminUser {
fn default() -> AdminUser {
AdminUser {
password: AdminUser::default_password(),
}
}
}
pub fn from_database_url(url: &str) -> Result<Config, url::ParseError> {
let mut config = Config::default();
let database = Database::from_url(url)?;
config.databases = BTreeMap::from_iter(vec![(database.database_name.clone(), database)]);
Ok(config)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_database_from_url() {
let url = "postgres://test_user:test_password@postgresml.org:6661/pgml_db";
let database = Database::from_url(url).unwrap();
assert_eq!(database.database_name, "pgml_db");
assert_eq!(database.users.len(), 1);
assert_eq!(
database.users.get("test_user").unwrap().password,
"test_password"
);
assert_eq!(database.shards.len(), 1);
assert_eq!(database.shards.get("0").unwrap().servers.len(), 1);
assert_eq!(
database
.shards
.get("0")
.unwrap()
.servers
.get("pgml_db")
.unwrap()
.host,
"postgresml.org"
);
assert_eq!(
database
.shards
.get("0")
.unwrap()
.servers
.get("pgml_db")
.unwrap()
.port,
6661
);
}
}