pub use crate::checked_client::CheckMode;
use crate::error::OrmError;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct PgClientConfig {
pub check_mode: CheckMode,
pub sql_policy: SqlPolicy,
pub query_timeout: Option<Duration>,
pub slow_query_threshold: Option<Duration>,
pub statement_cache: StatementCacheConfig,
pub parse_cache_capacity: usize,
pub stats_enabled: bool,
pub logging_enabled: bool,
pub log_min_duration: Option<Duration>,
}
#[derive(Debug, Clone, Default)]
pub struct StatementCacheConfig {
pub enabled: bool,
pub capacity: usize,
}
impl Default for PgClientConfig {
fn default() -> Self {
Self {
check_mode: CheckMode::WarnOnly,
sql_policy: SqlPolicy::default(),
query_timeout: None,
slow_query_threshold: None,
statement_cache: StatementCacheConfig::default(),
parse_cache_capacity: 256,
stats_enabled: true,
logging_enabled: false,
log_min_duration: None,
}
}
}
impl PgClientConfig {
pub fn new() -> Self {
Self::default()
}
pub fn check_mode(mut self, mode: CheckMode) -> Self {
self.check_mode = mode;
self
}
pub fn sql_policy(mut self, policy: SqlPolicy) -> Self {
self.sql_policy = policy;
self
}
pub fn select_without_limit(mut self, policy: SelectWithoutLimitPolicy) -> Self {
self.sql_policy.select_without_limit = policy;
self
}
pub fn delete_without_where(mut self, policy: DangerousDmlPolicy) -> Self {
self.sql_policy.delete_without_where = policy;
self
}
pub fn update_without_where(mut self, policy: DangerousDmlPolicy) -> Self {
self.sql_policy.update_without_where = policy;
self
}
pub fn truncate_policy(mut self, policy: DangerousDmlPolicy) -> Self {
self.sql_policy.truncate = policy;
self
}
pub fn drop_table_policy(mut self, policy: DangerousDmlPolicy) -> Self {
self.sql_policy.drop_table = policy;
self
}
pub fn strict(mut self) -> Self {
self.check_mode = CheckMode::Strict;
self
}
pub fn no_check(mut self) -> Self {
self.check_mode = CheckMode::Disabled;
self
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.query_timeout = Some(duration);
self
}
pub fn slow_threshold(mut self, duration: Duration) -> Self {
self.slow_query_threshold = Some(duration);
self
}
pub fn statement_cache(mut self, cap: usize) -> Self {
self.statement_cache = StatementCacheConfig {
enabled: cap > 0,
capacity: cap,
};
self
}
pub fn no_statement_cache(mut self) -> Self {
self.statement_cache.enabled = false;
self
}
pub fn parse_cache_capacity(mut self, capacity: usize) -> Self {
self.parse_cache_capacity = capacity;
self
}
pub fn with_stats(mut self) -> Self {
self.stats_enabled = true;
self
}
pub fn no_stats(mut self) -> Self {
self.stats_enabled = false;
self
}
pub fn with_logging(mut self) -> Self {
self.logging_enabled = true;
self
}
pub fn log_slow_queries(mut self, min_duration: Duration) -> Self {
self.logging_enabled = true;
self.log_min_duration = Some(min_duration);
self
}
}
#[derive(Debug, Clone)]
pub struct SqlPolicy {
pub select_without_limit: SelectWithoutLimitPolicy,
pub delete_without_where: DangerousDmlPolicy,
pub update_without_where: DangerousDmlPolicy,
pub truncate: DangerousDmlPolicy,
pub drop_table: DangerousDmlPolicy,
}
impl Default for SqlPolicy {
fn default() -> Self {
Self {
select_without_limit: SelectWithoutLimitPolicy::Allow,
delete_without_where: DangerousDmlPolicy::Allow,
update_without_where: DangerousDmlPolicy::Allow,
truncate: DangerousDmlPolicy::Allow,
drop_table: DangerousDmlPolicy::Allow,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DangerousDmlPolicy {
Allow,
Warn,
Error,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SelectWithoutLimitPolicy {
Allow,
Warn,
Error,
AutoLimit(i32),
}
pub(crate) fn handle_dangerous_dml(
policy: DangerousDmlPolicy,
rule: &str,
sql: &str,
) -> Result<(), OrmError> {
match policy {
DangerousDmlPolicy::Allow => Ok(()),
DangerousDmlPolicy::Warn => {
crate::error::pgorm_warn(&format!("[pgorm warn] SQL policy: {rule}: {sql}"));
Ok(())
}
DangerousDmlPolicy::Error => Err(OrmError::validation(format!(
"SQL policy violation: {rule}: {sql}"
))),
}
}