use crate::GenericClient;
use crate::error::{OrmError, OrmResult};
#[cfg(feature = "check")]
use crate::{RowStream, StreamingClient};
use std::sync::Arc;
use tokio_postgres::Row;
use tokio_postgres::types::ToSql;
#[cfg(feature = "check")]
use crate::check::SchemaRegistry;
pub struct ModelRegistration {
pub register_fn: fn(&mut crate::check::SchemaRegistry),
}
inventory::collect!(ModelRegistration);
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum CheckMode {
Disabled,
#[default]
WarnOnly,
Strict,
}
#[cfg(feature = "check")]
pub struct CheckedClient<C> {
client: C,
registry: Arc<SchemaRegistry>,
check_mode: CheckMode,
}
#[cfg(feature = "check")]
impl<C> CheckedClient<C> {
pub fn new(client: C) -> Self {
let mut registry = SchemaRegistry::new();
for reg in inventory::iter::<ModelRegistration> {
(reg.register_fn)(&mut registry);
}
Self {
client,
registry: Arc::new(registry),
check_mode: CheckMode::WarnOnly,
}
}
pub fn new_empty(client: C) -> Self {
Self {
client,
registry: Arc::new(SchemaRegistry::new()),
check_mode: CheckMode::WarnOnly,
}
}
pub fn with_registry(client: C, registry: SchemaRegistry) -> Self {
Self {
client,
registry: Arc::new(registry),
check_mode: CheckMode::WarnOnly,
}
}
pub fn check_mode(mut self, mode: CheckMode) -> Self {
self.check_mode = mode;
self
}
pub fn strict(self) -> Self {
self.check_mode(CheckMode::Strict)
}
pub fn disabled(self) -> Self {
self.check_mode(CheckMode::Disabled)
}
pub fn registry(&self) -> &SchemaRegistry {
&self.registry
}
pub fn inner(&self) -> &C {
&self.client
}
pub fn into_inner(self) -> C {
self.client
}
fn check_sql(&self, sql: &str) -> OrmResult<()> {
match self.check_mode {
CheckMode::Disabled => Ok(()),
CheckMode::WarnOnly => {
let issues = self.registry.check_sql(sql);
for issue in issues {
eprintln!("[pgorm warn] {issue}");
}
Ok(())
}
CheckMode::Strict => {
let issues = self.registry.check_sql(sql);
let errors: Vec<_> = issues
.iter()
.filter(|i| i.level == crate::SchemaIssueLevel::Error)
.collect();
if errors.is_empty() {
Ok(())
} else {
let messages: Vec<String> = errors.iter().map(|i| i.message.clone()).collect();
Err(OrmError::validation(format!(
"SQL check failed: {}",
messages.join("; ")
)))
}
}
}
}
}
#[cfg(feature = "check")]
impl<C: GenericClient> GenericClient for CheckedClient<C> {
async fn query(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
self.check_sql(sql)?;
self.client.query(sql, params).await
}
async fn query_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Vec<Row>> {
self.check_sql(sql)?;
self.client.query_tagged(tag, sql, params).await
}
async fn query_one(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
self.check_sql(sql)?;
self.client.query_one(sql, params).await
}
async fn query_one_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Row> {
self.check_sql(sql)?;
self.client.query_one_tagged(tag, sql, params).await
}
async fn query_opt(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<Option<Row>> {
self.check_sql(sql)?;
self.client.query_opt(sql, params).await
}
async fn query_opt_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<Option<Row>> {
self.check_sql(sql)?;
self.client.query_opt_tagged(tag, sql, params).await
}
async fn execute(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
self.check_sql(sql)?;
self.client.execute(sql, params).await
}
async fn execute_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<u64> {
self.check_sql(sql)?;
self.client.execute_tagged(tag, sql, params).await
}
fn cancel_token(&self) -> Option<tokio_postgres::CancelToken> {
self.client.cancel_token()
}
}
#[cfg(feature = "check")]
impl<C: GenericClient + StreamingClient> StreamingClient for CheckedClient<C> {
async fn query_stream(&self, sql: &str, params: &[&(dyn ToSql + Sync)]) -> OrmResult<RowStream> {
self.check_sql(sql)?;
self.client.query_stream(sql, params).await
}
async fn query_stream_tagged(
&self,
tag: &str,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> OrmResult<RowStream> {
self.check_sql(sql)?;
self.client.query_stream_tagged(tag, sql, params).await
}
}
#[cfg(test)]
#[cfg(feature = "check")]
mod tests {
use super::*;
#[test]
fn test_check_mode_default() {
assert_eq!(CheckMode::default(), CheckMode::WarnOnly);
}
}