pgorm 0.2.0

A model-definition-first, AI-friendly PostgreSQL ORM for Rust
Documentation
use super::*;
use crate::client::GenericClient;
use crate::error::{OrmError, OrmResult};
use std::sync::Arc;
use std::time::Duration;
use tokio_postgres::Row;
use tokio_postgres::types::ToSql;

#[test]
fn test_query_type_detection() {
    assert_eq!(
        QueryType::from_sql("SELECT * FROM users"),
        QueryType::Select
    );
    assert_eq!(
        QueryType::from_sql("  select * FROM users"),
        QueryType::Select
    );
    assert_eq!(
        QueryType::from_sql("WITH cte AS (SELECT 1) SELECT * FROM cte"),
        QueryType::Select
    );
    assert_eq!(
        QueryType::from_sql("INSERT INTO users (name) VALUES ($1)"),
        QueryType::Insert
    );
    assert_eq!(
        QueryType::from_sql("UPDATE users SET name = $1"),
        QueryType::Update
    );
    assert_eq!(
        QueryType::from_sql("DELETE FROM users WHERE id = $1"),
        QueryType::Delete
    );
    assert_eq!(
        QueryType::from_sql("CREATE TABLE users (id INT)"),
        QueryType::Other
    );
}

#[test]
fn test_logging_monitor_truncation() {
    let monitor = LoggingMonitor::new().max_sql_length(10);
    assert_eq!(monitor.truncate_sql("SELECT * FROM users"), "SELECT * F...");
    assert_eq!(monitor.truncate_sql("SELECT 1"), "SELECT 1");
}

#[test]
fn test_stats_monitor() {
    let monitor = StatsMonitor::new();
    let ctx = QueryContext::new("SELECT * FROM users", 0);

    monitor.on_query_complete(&ctx, Duration::from_millis(10), &QueryResult::Rows(5));
    monitor.on_query_complete(&ctx, Duration::from_millis(20), &QueryResult::Rows(3));

    let stats = monitor.stats();
    assert_eq!(stats.total_queries, 2);
    assert_eq!(stats.select_count, 2);
    assert_eq!(stats.total_duration, Duration::from_millis(30));
    assert_eq!(stats.stmt_cache_hits, 0);
    assert_eq!(stats.stmt_cache_misses, 0);
    assert_eq!(stats.stmt_prepare_count, 0);
    assert_eq!(stats.stmt_prepare_duration, Duration::ZERO);
}

#[test]
fn test_composite_hook_modify() {
    struct AddCommentHook;
    impl QueryHook for AddCommentHook {
        fn before_query(&self, ctx: &QueryContext) -> HookAction {
            HookAction::ModifySql {
                exec_sql: format!("/* instrumented */ {}", ctx.exec_sql),
                canonical_sql: None,
            }
        }
    }

    let hook = CompositeHook::new().add(AddCommentHook);
    let ctx = QueryContext::new("SELECT 1", 0);

    match hook.before_query(&ctx) {
        HookAction::ModifySql {
            exec_sql,
            canonical_sql,
        } => {
            assert_eq!(exec_sql, "/* instrumented */ SELECT 1");
            assert!(canonical_sql.is_none());
        }
        _ => panic!("Expected ModifySql"),
    }
}

#[test]
fn test_composite_hook_abort() {
    struct BlockDeleteHook;
    impl QueryHook for BlockDeleteHook {
        fn before_query(&self, ctx: &QueryContext) -> HookAction {
            if ctx.query_type == QueryType::Delete {
                HookAction::Abort("DELETE not allowed".to_string())
            } else {
                HookAction::Continue
            }
        }
    }

    let hook = CompositeHook::new().add(BlockDeleteHook);
    let ctx = QueryContext::new("DELETE FROM users", 0);

    match hook.before_query(&ctx) {
        HookAction::Abort(reason) => assert_eq!(reason, "DELETE not allowed"),
        _ => panic!("Expected Abort"),
    }
}

#[tokio::test]
async fn tagged_queries_propagate_to_monitor() {
    #[derive(Default)]
    struct TagCapture(std::sync::Mutex<Option<String>>);

    impl QueryMonitor for TagCapture {
        fn on_query_complete(&self, ctx: &QueryContext, _: Duration, _: &QueryResult) {
            *self.0.lock().unwrap() = ctx.tag.clone();
        }
    }

    struct DummyClient;
    impl GenericClient for DummyClient {
        async fn query(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
            Ok(vec![])
        }
        async fn query_one(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
            Err(OrmError::not_found("no rows"))
        }
        async fn query_opt(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Option<Row>> {
            Ok(None)
        }
        async fn execute(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
            Ok(0)
        }
    }

    let capture = Arc::new(TagCapture::default());
    let client = InstrumentedClient::new(DummyClient)
        .with_config(MonitorConfig::new().enable_monitoring())
        .with_monitor_arc(capture.clone());

    client
        .query_tagged("test-tag", "SELECT 1", &[])
        .await
        .unwrap();

    assert_eq!(capture.0.lock().unwrap().as_deref(), Some("test-tag"));
}

#[tokio::test]
async fn timeout_returns_error_and_attempts_cancellation() {
    struct HangingClient;
    impl GenericClient for HangingClient {
        async fn query(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Vec<Row>> {
            tokio::time::sleep(Duration::from_secs(60)).await;
            Ok(vec![])
        }
        async fn query_one(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Row> {
            Err(OrmError::not_found("unused"))
        }
        async fn query_opt(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<Option<Row>> {
            Ok(None)
        }
        async fn execute(&self, _: &str, _: &[&(dyn ToSql + Sync)]) -> OrmResult<u64> {
            Ok(0)
        }
    }

    let client = InstrumentedClient::new(HangingClient).with_config(
        MonitorConfig::new()
            .with_query_timeout(Duration::from_millis(10))
            .enable_monitoring(),
    );

    let err = client.query("SELECT pg_sleep(60)", &[]).await.unwrap_err();
    assert!(matches!(err, OrmError::Timeout(_)));
}