Skip to main content

bsql_driver_postgres/
pool.rs

1//! Connection pool — LIFO ordering, Condvar-based waiting.
2//!
3//! The pool maintains a stack of idle connections. `acquire()` pops the top
4//! (most recently used = warmest caches). On drop, the guard pushes the
5//! connection back. If the pool is exhausted, callers wait on a `Condvar`
6//! up to `acquire_timeout` (default: 5 seconds). Set `acquire_timeout` to
7//! `None` for fail-fast behavior (immediate error when exhausted).
8
9use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::arena::Arena;
14use crate::codec::Encode;
15use crate::conn::Connection;
16use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
17use crate::DriverError;
18
19#[cfg(feature = "async")]
20use crate::async_conn::AsyncConnection;
21
22// --- PoolSlot ---
23
24/// A connection slot — either sync (UDS/TCP) or async (TCP only).
25///
26/// The pool auto-detects: UDS hosts get sync `Connection`, TCP hosts get
27/// `AsyncConnection` (when the `async` feature is enabled). When `async`
28/// is disabled, all connections are sync.
29pub(crate) enum PoolSlot {
30    /// Sync connection (UDS or TCP without async feature).
31    Sync(Connection),
32    /// Async TCP connection (requires async feature + tokio runtime).
33    #[cfg(feature = "async")]
34    Async(AsyncConnection),
35}
36
37// --- N+1 Detection ---
38
39/// Tracks sequential repeats of the same `sql_hash` on a single connection
40/// checkout. When the same hash fires more than `threshold` times in a row,
41/// a warning is emitted. Fully `cfg`-gated — zero cost when disabled.
42#[cfg(feature = "detect-n-plus-one")]
43pub(crate) struct NPlusOneDetector {
44    last_query_hash: u64,
45    repeat_count: u16,
46    threshold: u16,
47}
48
49#[cfg(feature = "detect-n-plus-one")]
50impl NPlusOneDetector {
51    /// Create a new detector with the given warning threshold.
52    pub(crate) fn new(threshold: u16) -> Self {
53        Self {
54            last_query_hash: 0,
55            repeat_count: 0,
56            threshold,
57        }
58    }
59
60    /// Track a query execution. Call this at the start of every query method.
61    #[inline]
62    pub(crate) fn track(&mut self, sql_hash: u64) {
63        if sql_hash == self.last_query_hash {
64            self.repeat_count = self.repeat_count.saturating_add(1);
65        } else {
66            // Check previous run before resetting
67            self.emit_warning();
68            self.last_query_hash = sql_hash;
69            self.repeat_count = 1;
70        }
71    }
72
73    /// Check the final sequence on drop / connection return.
74    /// Returns `Some((hash, count))` if a warning should be emitted.
75    pub(crate) fn check_final(&self) -> Option<(u64, u16)> {
76        if self.repeat_count > self.threshold && self.last_query_hash != 0 {
77            Some((self.last_query_hash, self.repeat_count))
78        } else {
79            None
80        }
81    }
82
83    /// Emit a log warning if the current run exceeds the threshold.
84    #[cold]
85    #[inline(never)]
86    fn emit_warning(&self) {
87        if let Some((hash, count)) = self.check_final() {
88            log::warn!(
89                "[bsql] potential N+1 detected: sql_hash={:#018x} repeated {} times (threshold: {})",
90                hash,
91                count,
92                self.threshold,
93            );
94        }
95    }
96
97    /// Emit the final warning (called on drop).
98    #[cold]
99    #[inline(never)]
100    pub(crate) fn emit_final_warning(&self) {
101        self.emit_warning();
102    }
103}
104
105// --- Pool ---
106
107/// A connection pool with LIFO ordering and fail-fast semantics.
108///
109/// # Example
110///
111/// ```no_run
112/// # fn example() -> Result<(), bsql_driver_postgres::DriverError> {
113/// let pool = bsql_driver_postgres::Pool::connect("postgres://user:pass@localhost/db")?;
114/// let mut conn = pool.acquire()?;
115/// conn.simple_query("SELECT 1")?;
116/// // conn is returned to pool on drop
117/// # Ok(())
118/// # }
119/// ```
120pub struct Pool {
121    inner: Arc<PoolInner>,
122}
123
124struct PoolInner {
125    /// Idle connections. Uses std::sync::Mutex because the critical section is
126    /// trivial (push/pop — no I/O). This lets PoolGuard::Drop return connections
127    /// synchronously.
128    stack: std::sync::Mutex<Vec<PoolSlot>>,
129    max_size: usize,
130    open_count: AtomicUsize,
131    config: Arc<Config>,
132    /// When true, no new acquires are accepted.
133    closed: AtomicBool,
134    /// Condvar pair for release notification. Waiters block on the Condvar
135    /// when the pool is exhausted and `acquire_timeout` is set.
136    release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
137    /// Maximum lifetime of a connection. Connections older than this
138    /// are discarded when popped from the pool. Default: 30 minutes.
139    max_lifetime: Option<Duration>,
140    /// Maximum time to wait for a connection. Default: None (fail-fast).
141    acquire_timeout: Option<Duration>,
142    /// Minimum number of idle connections to maintain. Default: 0.
143    min_idle: usize,
144    /// SQL statements to PREPARE on new connections (warmup).
145    warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
146    /// Maximum number of cached prepared statements per connection.
147    max_stmt_cache_size: usize,
148    /// Maximum idle duration before a connection is considered stale and discarded.
149    /// Connections idle longer than this are dropped on acquire. Default: 30 seconds.
150    stale_timeout: Duration,
151    /// Threshold for N+1 detection. When the same sql_hash fires more than
152    /// this many times sequentially on a single checkout, a warning is logged.
153    #[cfg(feature = "detect-n-plus-one")]
154    n_plus_one_threshold: u16,
155}
156
157impl Pool {
158    /// Create a pool from a connection URL with default settings (max_size = 10).
159    ///
160    /// Validates the URL but does not open any connections yet (lazy initialization).
161    pub fn connect(url: &str) -> Result<Self, DriverError> {
162        PoolBuilder::new().url(url).build()
163    }
164
165    /// Create a pool builder for custom configuration.
166    pub fn builder() -> PoolBuilder {
167        PoolBuilder::new()
168    }
169
170    /// Acquire a connection from the pool.
171    ///
172    /// Returns immediately with the most recently used idle connection (LIFO).
173    /// If no idle connections are available and the pool is below max_size, a new
174    /// connection is created. If the pool is at max_size and no `acquire_timeout`
175    /// is set, returns `DriverError::Pool` immediately. If `acquire_timeout` is
176    /// set, blocks until a connection is returned or the timeout expires.
177    #[inline]
178    pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
179        if self.inner.closed.load(Ordering::Acquire) {
180            return Err(DriverError::Pool("pool is closed".into()));
181        }
182
183        // Try to pop an idle connection (fast path).
184        if let Some(guard) = self.try_pop_idle()? {
185            return Ok(guard);
186        }
187
188        // No idle connections — try to claim a slot with a proper CAS loop.
189        loop {
190            let current = self.inner.open_count.load(Ordering::Acquire);
191            if current >= self.inner.max_size {
192                if let Some(timeout) = self.inner.acquire_timeout {
193                    let (lock, cvar) = &self.inner.release_pair;
194                    let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
195                    let (_guard, result) = cvar
196                        .wait_timeout(guard, timeout)
197                        .unwrap_or_else(|e| e.into_inner());
198                    if result.timed_out() {
199                        return Err(DriverError::Pool(
200                            "pool exhausted: acquire timeout expired".into(),
201                        ));
202                    }
203                    // A connection was returned — try again
204                    if let Some(guard) = self.try_pop_idle()? {
205                        return Ok(guard);
206                    }
207                    // Popped nothing — retry CAS
208                    continue;
209                }
210                return Err(DriverError::Pool(
211                    "pool exhausted: all connections in use".into(),
212                ));
213            }
214            if self
215                .inner
216                .open_count
217                .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
218                .is_ok()
219            {
220                break;
221            }
222            // CAS failed — another thread incremented. Retry.
223        }
224
225        // Open a new connection
226        let conn_result = Connection::connect_arc(self.inner.config.clone());
227        match conn_result {
228            Ok(mut conn) => {
229                // Configure statement cache size
230                conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
231                // Warmup: pre-PREPARE frequently used statements
232                self.warmup_conn(&mut conn);
233
234                Ok(PoolGuard {
235                    conn: Some(PoolSlot::Sync(conn)),
236                    pool: self.inner.clone(),
237                    discard: false,
238                    #[cfg(feature = "detect-n-plus-one")]
239                    detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
240                })
241            }
242            Err(e) => {
243                // Give back the slot
244                self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
245                Err(e)
246            }
247        }
248    }
249
250    /// Try to pop a valid idle connection from the stack.
251    ///
252    /// Performs lifetime and stale checks. For connections idle > 5 seconds
253    /// (but within the stale timeout), sends an empty query as a health check
254    /// to verify the connection is still alive before returning it.
255    #[inline]
256    fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
257        // Pop a candidate slot under the lock, performing only non-I/O checks
258        // (lifetime, stale timeout). The health check (network round-trip) happens
259        // AFTER the lock is released so other threads aren't blocked.
260        loop {
261            let (mut slot, needs_health_check) = {
262                let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
263                loop {
264                    let Some(slot) = stack.pop() else {
265                        return Ok(None);
266                    };
267                    let (created_at, idle_dur) = match &slot {
268                        PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
269                        #[cfg(feature = "async")]
270                        PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
271                    };
272                    if let Some(max_lifetime) = self.inner.max_lifetime {
273                        if created_at.elapsed() >= max_lifetime {
274                            self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
275                            continue;
276                        }
277                    }
278                    if idle_dur >= self.inner.stale_timeout {
279                        // Stale connection — drop it, free the slot
280                        self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
281                        continue;
282                    }
283                    break (slot, idle_dur > Duration::from_secs(5));
284                }
285            };
286            // Lock is now released — health check happens outside the critical section.
287            // Sends an empty query — PG returns EmptyQueryResponse + ReadyForQuery.
288            // Fast: one round-trip, ~15us on UDS. Skip for hot connections.
289            if needs_health_check {
290                let alive = match &mut slot {
291                    PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
292                    #[cfg(feature = "async")]
293                    PoolSlot::Async(_) => true, // async connections are checked at I/O time
294                };
295                if !alive {
296                    self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
297                    continue; // retry — re-acquire lock and pop next slot
298                }
299            }
300            return Ok(Some(PoolGuard {
301                conn: Some(slot),
302                pool: self.inner.clone(),
303                discard: false,
304                #[cfg(feature = "detect-n-plus-one")]
305                detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
306            }));
307        }
308    }
309
310    /// Whether this pool uses UDS connections.
311    ///
312    /// Returns `true` when the pool URL points to a Unix domain socket.
313    /// On non-Unix platforms, always returns `false`.
314    pub fn is_uds(&self) -> bool {
315        #[cfg(unix)]
316        {
317            self.inner.config.host_is_uds()
318        }
319        #[cfg(not(unix))]
320        {
321            false
322        }
323    }
324
325    /// Begin a transaction. Acquires a connection and sends BEGIN.
326    pub fn begin(&self) -> Result<Transaction, DriverError> {
327        let mut guard = self.acquire()?;
328        guard.simple_query("BEGIN")?;
329        Ok(Transaction {
330            guard,
331            committed: false,
332            deferred_buf: Vec::new(),
333            deferred_count: 0,
334        })
335    }
336
337    /// Current number of open connections (idle + in-use).
338    pub fn open_count(&self) -> usize {
339        self.inner.open_count.load(Ordering::Relaxed)
340    }
341
342    /// Maximum pool size.
343    pub fn max_size(&self) -> usize {
344        self.inner.max_size
345    }
346
347    /// Pool status metrics.
348    pub fn status(&self) -> PoolStatus {
349        let idle = self
350            .inner
351            .stack
352            .lock()
353            .unwrap_or_else(|e| e.into_inner())
354            .len();
355        let open = self.inner.open_count.load(Ordering::Relaxed);
356        let active = open.saturating_sub(idle);
357        PoolStatus {
358            idle,
359            active,
360            open,
361            max_size: self.inner.max_size,
362        }
363    }
364
365    /// Pre-PREPARE warmup statements on a new connection.
366    ///
367    /// Uses `prepare_batch()` to pipeline N × (Parse+Describe) + 1 × Sync
368    /// in a single round-trip, instead of N separate round-trips.
369    ///
370    /// Best-effort: errors are silently ignored.
371    /// The connection remains usable even if warmup fails.
372    fn warmup_conn(&self, conn: &mut Connection) {
373        let sqls = self
374            .inner
375            .warmup_sqls
376            .lock()
377            .unwrap_or_else(|e| e.into_inner())
378            .clone();
379
380        if sqls.is_empty() {
381            return;
382        }
383
384        let batch: Vec<(&str, u64)> = sqls
385            .iter()
386            .map(|sql| (sql.as_ref(), crate::types::hash_sql(sql)))
387            .collect();
388
389        let _ = conn.prepare_batch(&batch);
390    }
391
392    /// Set the SQL statements to pre-PREPARE on new connections.
393    ///
394    /// Each SQL string is PREPAREd (Parse+Describe+Sync) on new connections
395    /// before they are returned from `acquire()`. This eliminates the first-use
396    /// Parse overhead for frequently executed queries.
397    ///
398    /// Warmup errors are silently ignored — a bad warmup SQL must not prevent
399    /// the connection from being usable.
400    ///
401    /// # Example
402    ///
403    /// ```no_run
404    /// # fn example() -> Result<(), bsql_driver_postgres::DriverError> {
405    /// let pool = bsql_driver_postgres::Pool::connect("postgres://user:pass@localhost/db")?;
406    /// pool.set_warmup_sqls(&[
407    ///     "SELECT id, name FROM users WHERE id = $1::int4",
408    ///     "SELECT id, title FROM tickets WHERE status = ANY($1::text[])",
409    /// ]);
410    /// # Ok(())
411    /// # }
412    /// ```
413    pub fn set_warmup_sqls(&self, sqls: &[&str]) {
414        let boxed: Arc<Vec<Box<str>>> =
415            Arc::new(sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>());
416        *self
417            .inner
418            .warmup_sqls
419            .lock()
420            .unwrap_or_else(|e| e.into_inner()) = boxed;
421    }
422
423    /// Close the pool. No new acquires are accepted. All idle connections
424    /// are sent Terminate and dropped.
425    pub fn close(&self) {
426        self.inner.closed.store(true, Ordering::Release);
427        // Drain and close all idle connections
428        let slots: Vec<PoolSlot> = {
429            let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
430            std::mem::take(&mut *stack)
431        };
432        for slot in slots {
433            self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
434            match slot {
435                PoolSlot::Sync(conn) => {
436                    let _ = conn.close();
437                }
438                #[cfg(feature = "async")]
439                PoolSlot::Async(_conn) => {
440                    // AsyncConnection::close() is async — we can't await in sync close().
441                    // Drop will close the TCP socket, PG auto-cleans up.
442                }
443            }
444        }
445        // Notify any waiters so they get the "pool is closed" error
446        let (_, cvar) = &self.inner.release_pair;
447        cvar.notify_all();
448    }
449
450    /// Whether the pool has been closed.
451    pub fn is_closed(&self) -> bool {
452        self.inner.closed.load(Ordering::Acquire)
453    }
454
455    /// Acquire a connection from the pool (async).
456    ///
457    /// Auto-detects transport: UDS hosts get a sync `Connection`, TCP hosts
458    /// get an `AsyncConnection`. If the `async` feature is disabled, always
459    /// creates sync connections.
460    ///
461    /// Returns immediately with the most recently used idle connection (LIFO).
462    /// If no idle connections are available and the pool is below max_size, a new
463    /// connection is created.
464    #[cfg(feature = "async")]
465    pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
466        if self.inner.closed.load(Ordering::Acquire) {
467            return Err(DriverError::Pool("pool is closed".into()));
468        }
469
470        // Try to pop an idle connection (fast path).
471        if let Some(guard) = self.try_pop_idle()? {
472            return Ok(guard);
473        }
474
475        // No idle connections — try to claim a slot with a proper CAS loop.
476        loop {
477            let current = self.inner.open_count.load(Ordering::Acquire);
478            if current >= self.inner.max_size {
479                if let Some(timeout) = self.inner.acquire_timeout {
480                    let (lock, cvar) = &self.inner.release_pair;
481                    let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
482                    let (_guard, result) = cvar
483                        .wait_timeout(guard, timeout)
484                        .unwrap_or_else(|e| e.into_inner());
485                    if result.timed_out() {
486                        return Err(DriverError::Pool(
487                            "pool exhausted: acquire timeout expired".into(),
488                        ));
489                    }
490                    if let Some(guard) = self.try_pop_idle()? {
491                        return Ok(guard);
492                    }
493                    continue;
494                }
495                return Err(DriverError::Pool(
496                    "pool exhausted: all connections in use".into(),
497                ));
498            }
499            if self
500                .inner
501                .open_count
502                .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
503                .is_ok()
504            {
505                break;
506            }
507        }
508
509        // Open a new connection — auto-detect UDS vs TCP
510        if self.inner.config.host_is_uds() {
511            // UDS — use sync Connection
512            let conn_result = Connection::connect_arc(self.inner.config.clone());
513            match conn_result {
514                Ok(mut conn) => {
515                    conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
516                    self.warmup_conn(&mut conn);
517                    Ok(PoolGuard {
518                        conn: Some(PoolSlot::Sync(conn)),
519                        pool: self.inner.clone(),
520                        discard: false,
521                        #[cfg(feature = "detect-n-plus-one")]
522                        detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
523                    })
524                }
525                Err(e) => {
526                    self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
527                    Err(e)
528                }
529            }
530        } else {
531            // TCP — use AsyncConnection
532            let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
533            match conn_result {
534                Ok(mut conn) => {
535                    conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
536                    Ok(PoolGuard {
537                        conn: Some(PoolSlot::Async(conn)),
538                        pool: self.inner.clone(),
539                        discard: false,
540                        #[cfg(feature = "detect-n-plus-one")]
541                        detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
542                    })
543                }
544                Err(e) => {
545                    self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
546                    Err(e)
547                }
548            }
549        }
550    }
551}
552
553impl Clone for Pool {
554    fn clone(&self) -> Self {
555        Pool {
556            inner: self.inner.clone(),
557        }
558    }
559}
560
561// --- PoolStatus ---
562
563/// Pool status metrics.
564#[derive(Debug, Clone, Copy)]
565pub struct PoolStatus {
566    /// Number of idle connections in the pool.
567    pub idle: usize,
568    /// Number of connections currently in use.
569    pub active: usize,
570    /// Total open connections (idle + active).
571    pub open: usize,
572    /// Maximum pool size.
573    pub max_size: usize,
574}
575
576// --- PoolBuilder ---
577
578/// Builder for configuring a connection pool.
579pub struct PoolBuilder {
580    url: Option<String>,
581    max_size: usize,
582    /// Maximum lifetime of a connection.
583    max_lifetime: Option<Duration>,
584    /// Maximum time to wait for a connection when pool is exhausted.
585    acquire_timeout: Option<Duration>,
586    /// Minimum number of idle connections to maintain.
587    min_idle: usize,
588    /// Maximum number of cached prepared statements per connection.
589    max_stmt_cache_size: usize,
590    /// Maximum idle duration before a connection is considered stale.
591    stale_timeout: Duration,
592    /// Threshold for N+1 detection warnings.
593    #[cfg(feature = "detect-n-plus-one")]
594    n_plus_one_threshold: Option<u16>,
595}
596
597impl PoolBuilder {
598    fn new() -> Self {
599        Self {
600            url: None,
601            max_size: 10,
602            max_lifetime: Some(Duration::from_secs(30 * 60)), // 30 min default
603            acquire_timeout: Some(Duration::from_secs(5)), // 5s default (matches common pool defaults)
604            min_idle: 0,                                   // no minimum by default
605            max_stmt_cache_size: 256,                      // LRU eviction at 256 stmts
606            stale_timeout: Duration::from_secs(30),        // 30s default
607            #[cfg(feature = "detect-n-plus-one")]
608            n_plus_one_threshold: None,
609        }
610    }
611
612    /// Set the connection URL.
613    pub fn url(mut self, url: &str) -> Self {
614        self.url = Some(url.to_owned());
615        self
616    }
617
618    /// Set the maximum pool size. Default: 10.
619    ///
620    /// A max_size of 0 means the pool will reject all acquire() calls immediately.
621    pub fn max_size(mut self, size: usize) -> Self {
622        self.max_size = size;
623        self
624    }
625
626    /// Set the maximum lifetime of a connection. Default: 30 minutes.
627    /// Set to None for unlimited lifetime.
628    pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
629        self.max_lifetime = lifetime;
630        self
631    }
632
633    /// Set the acquire timeout. Default: 5 seconds.
634    /// Set to None for fail-fast behavior when the pool is exhausted.
635    pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
636        self.acquire_timeout = timeout;
637        self
638    }
639
640    /// Set the minimum number of idle connections. Default: 0.
641    /// When > 0, a background thread maintains this many idle connections.
642    pub fn min_idle(mut self, count: usize) -> Self {
643        self.min_idle = count;
644        self
645    }
646
647    /// Set the maximum number of cached prepared statements per connection.
648    /// Default: 256. When the cache exceeds this size, the least recently
649    /// used statement is evicted (Close sent to PG to free server memory).
650    pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
651        self.max_stmt_cache_size = size;
652        self
653    }
654
655    /// Set the maximum idle duration before a connection is considered stale.
656    /// Default: 30 seconds. Connections idle longer than this are dropped on
657    /// acquire instead of being reused.
658    pub fn stale_timeout(mut self, timeout: Duration) -> Self {
659        self.stale_timeout = timeout;
660        self
661    }
662
663    /// Set the threshold for N+1 detection warnings.
664    ///
665    /// When the same `sql_hash` fires more than this many times sequentially
666    /// on a single connection checkout, a warning is logged. Default: 10.
667    #[cfg(feature = "detect-n-plus-one")]
668    pub fn n_plus_one_threshold(mut self, n: u16) -> Self {
669        self.n_plus_one_threshold = Some(n);
670        self
671    }
672
673    /// Build the pool. Validates the URL but does not open connections.
674    pub fn build(self) -> Result<Pool, DriverError> {
675        let url = self
676            .url
677            .ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
678
679        let config = Arc::new(Config::from_url(&url)?);
680
681        let pool = Pool {
682            inner: Arc::new(PoolInner {
683                stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
684                max_size: self.max_size,
685                open_count: AtomicUsize::new(0),
686                config,
687                closed: AtomicBool::new(false),
688                release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
689                max_lifetime: self.max_lifetime,
690                acquire_timeout: self.acquire_timeout,
691                min_idle: self.min_idle,
692                warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
693                max_stmt_cache_size: self.max_stmt_cache_size,
694                stale_timeout: self.stale_timeout,
695                #[cfg(feature = "detect-n-plus-one")]
696                n_plus_one_threshold: self.n_plus_one_threshold.unwrap_or(10),
697            }),
698        };
699
700        if self.min_idle > 0 {
701            let inner = pool.inner.clone();
702            std::thread::spawn(move || {
703                maintain_min_idle(inner);
704            });
705        }
706
707        Ok(pool)
708    }
709}
710
711/// Background thread that maintains min_idle connections.
712fn maintain_min_idle(inner: Arc<PoolInner>) {
713    loop {
714        if inner.closed.load(Ordering::Acquire) {
715            return;
716        }
717
718        let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
719        let needed = inner.min_idle.saturating_sub(idle_count);
720
721        for _ in 0..needed {
722            if inner.closed.load(Ordering::Acquire) {
723                return;
724            }
725            let current = inner.open_count.load(Ordering::Acquire);
726            if current >= inner.max_size {
727                break;
728            }
729            if inner
730                .open_count
731                .compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
732                .is_err()
733            {
734                continue;
735            }
736
737            match Connection::connect_arc(inner.config.clone()) {
738                Ok(conn) => {
739                    let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
740                    stack.push(PoolSlot::Sync(conn));
741                    let (_, cvar) = &inner.release_pair;
742                    cvar.notify_one();
743                }
744                Err(_) => {
745                    inner.open_count.fetch_sub(1, Ordering::AcqRel);
746                }
747            }
748        }
749
750        // Check every 1 second. Shorter interval ensures the thread exits promptly
751        // when pool.closed is set (worst-case 1s delay instead of 5s).
752        std::thread::sleep(Duration::from_secs(1));
753    }
754}
755
756// --- PoolGuard ---
757
758/// A borrowed connection from the pool. Returns to the pool on drop.
759///
760/// If the connection is in a failed transaction state, broken, or marked for
761/// discard, it is dropped (decrements open_count) instead of returned.
762pub struct PoolGuard {
763    conn: Option<PoolSlot>,
764    pool: Arc<PoolInner>,
765    /// When true, the connection is dropped instead of returned to the pool.
766    discard: bool,
767    /// Tracks sequential repeats of the same sql_hash for N+1 detection.
768    #[cfg(feature = "detect-n-plus-one")]
769    detector: NPlusOneDetector,
770}
771
772impl PoolGuard {
773    /// Get a reference to the inner sync connection. Panics if the slot
774    /// holds an async connection.
775    #[inline]
776    fn sync_conn(&self) -> Result<&Connection, DriverError> {
777        match self.conn.as_ref() {
778            Some(PoolSlot::Sync(conn)) => Ok(conn),
779            #[cfg(feature = "async")]
780            Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
781                "expected sync connection, got async; use async methods".into(),
782            )),
783            None => Err(DriverError::Pool("connection already taken".into())),
784        }
785    }
786
787    /// Get a mutable reference to the inner sync connection.
788    #[inline]
789    fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
790        match self.conn.as_mut() {
791            Some(PoolSlot::Sync(conn)) => Ok(conn),
792            #[cfg(feature = "async")]
793            Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
794                "expected sync connection, got async; use async methods".into(),
795            )),
796            None => Err(DriverError::Pool("connection already taken".into())),
797        }
798    }
799
800    /// Mark this connection for discard — it will NOT be returned to the pool
801    /// on drop. The open_count is decremented and the TCP connection is closed.
802    pub fn mark_discard(&mut self) {
803        self.discard = true;
804    }
805
806    /// Cancel the currently running query on the underlying connection.
807    ///
808    /// Opens a new TCP connection and sends a CancelRequest to PG.
809    /// The cancel connection is closed immediately after.
810    pub fn cancel(&self) -> Result<(), DriverError> {
811        self.sync_conn()?.cancel()
812    }
813
814    // --- Introspection dispatch methods ---
815
816    /// Get the backend process ID for this connection.
817    pub fn pid(&self) -> i32 {
818        match self.conn.as_ref().expect("connection taken") {
819            PoolSlot::Sync(conn) => conn.pid(),
820            #[cfg(feature = "async")]
821            PoolSlot::Async(conn) => conn.pid(),
822        }
823    }
824
825    /// Whether the connection is idle (not in a transaction).
826    pub fn is_idle(&self) -> bool {
827        match self.conn.as_ref().expect("connection taken") {
828            PoolSlot::Sync(conn) => conn.is_idle(),
829            #[cfg(feature = "async")]
830            PoolSlot::Async(conn) => conn.is_idle(),
831        }
832    }
833
834    /// Whether the connection is inside a transaction.
835    pub fn is_in_transaction(&self) -> bool {
836        match self.conn.as_ref().expect("connection taken") {
837            PoolSlot::Sync(conn) => conn.is_in_transaction(),
838            #[cfg(feature = "async")]
839            PoolSlot::Async(conn) => conn.is_in_transaction(),
840        }
841    }
842
843    // --- Sync query dispatch methods ---
844
845    /// Execute a prepared query and return rows.
846    #[inline]
847    pub fn query(
848        &mut self,
849        sql: &str,
850        sql_hash: u64,
851        params: &[&(dyn Encode + Sync)],
852    ) -> Result<QueryResult, DriverError> {
853        #[cfg(feature = "detect-n-plus-one")]
854        self.detector.track(sql_hash);
855        self.sync_conn_mut()?.query(sql, sql_hash, params)
856    }
857
858    /// Execute a query without result rows (INSERT/UPDATE/DELETE).
859    #[inline]
860    pub fn execute(
861        &mut self,
862        sql: &str,
863        sql_hash: u64,
864        params: &[&(dyn Encode + Sync)],
865    ) -> Result<u64, DriverError> {
866        #[cfg(feature = "detect-n-plus-one")]
867        self.detector.track(sql_hash);
868        self.sync_conn_mut()?.execute(sql, sql_hash, params)
869    }
870
871    /// Execute the same statement N times with different params in one pipeline.
872    ///
873    /// Sends all N Bind+Execute messages + one Sync. One round-trip for N operations.
874    /// Returns the affected row count for each parameter set.
875    pub fn execute_pipeline(
876        &mut self,
877        sql: &str,
878        sql_hash: u64,
879        param_sets: &[&[&(dyn Encode + Sync)]],
880    ) -> Result<Vec<u64>, DriverError> {
881        #[cfg(feature = "detect-n-plus-one")]
882        self.detector.track(sql_hash);
883        self.sync_conn_mut()?
884            .execute_pipeline(sql, sql_hash, param_sets)
885    }
886
887    /// Execute a simple (unprepared) query.
888    pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
889        self.sync_conn_mut()?.simple_query(sql)
890    }
891
892    /// Execute a simple query and return rows as text.
893    ///
894    /// Uses PostgreSQL's simple query protocol — all values are strings.
895    pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
896        self.sync_conn_mut()?.simple_query_rows(sql)
897    }
898
899    /// Process each row via a closure with zero-copy `PgDataRow`.
900    pub fn for_each<F>(
901        &mut self,
902        sql: &str,
903        sql_hash: u64,
904        params: &[&(dyn Encode + Sync)],
905        f: F,
906    ) -> Result<(), DriverError>
907    where
908        F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
909    {
910        #[cfg(feature = "detect-n-plus-one")]
911        self.detector.track(sql_hash);
912        self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
913    }
914
915    /// Process each DataRow as raw bytes — fastest path.
916    pub fn for_each_raw<F>(
917        &mut self,
918        sql: &str,
919        sql_hash: u64,
920        params: &[&(dyn Encode + Sync)],
921        f: F,
922    ) -> Result<(), DriverError>
923    where
924        F: FnMut(&[u8]) -> Result<(), DriverError>,
925    {
926        #[cfg(feature = "detect-n-plus-one")]
927        self.detector.track(sql_hash);
928        self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
929    }
930
931    // --- Streaming ---
932
933    /// Start a streaming query.
934    pub fn query_streaming_start(
935        &mut self,
936        sql: &str,
937        sql_hash: u64,
938        params: &[&(dyn Encode + Sync)],
939        chunk_size: i32,
940    ) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
941        #[cfg(feature = "detect-n-plus-one")]
942        self.detector.track(sql_hash);
943        self.sync_conn_mut()?
944            .query_streaming_start(sql, sql_hash, params, chunk_size)
945    }
946
947    /// Send Execute+Flush for a streaming query (2nd+ chunks).
948    pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
949        self.sync_conn_mut()?.streaming_send_execute(chunk_size)
950    }
951
952    /// Read the next chunk of rows from an in-progress streaming query.
953    pub fn streaming_next_chunk(
954        &mut self,
955        arena: &mut Arena,
956        all_col_offsets: &mut Vec<(usize, i32)>,
957    ) -> Result<bool, DriverError> {
958        self.sync_conn_mut()?
959            .streaming_next_chunk(arena, all_col_offsets)
960    }
961
962    // --- COPY protocol ---
963
964    /// Bulk copy data INTO a table from an iterator of text rows.
965    ///
966    /// Each row is a tab-separated string (TSV format). Returns the row count.
967    pub fn copy_in<'a, I>(
968        &mut self,
969        table: &str,
970        columns: &[&str],
971        rows: I,
972    ) -> Result<u64, DriverError>
973    where
974        I: IntoIterator<Item = &'a str>,
975    {
976        self.sync_conn_mut()?.copy_in(table, columns, rows)
977    }
978
979    /// Bulk copy data OUT of a table/query to a writer.
980    ///
981    /// Writes TSV-formatted rows. Returns the row count.
982    pub fn copy_out<W: std::io::Write>(
983        &mut self,
984        query: &str,
985        writer: &mut W,
986    ) -> Result<u64, DriverError> {
987        self.sync_conn_mut()?.copy_out(query, writer)
988    }
989
990    /// Whether this guard holds a sync connection.
991    pub fn is_sync(&self) -> bool {
992        matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
993    }
994
995    /// Whether this guard holds an async connection.
996    #[cfg(feature = "async")]
997    pub fn is_async(&self) -> bool {
998        matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
999    }
1000
1001    // --- Async query dispatch methods ---
1002
1003    /// Execute a prepared query and return rows (async).
1004    ///
1005    /// Auto-dispatches: sync connections use blocking I/O, async connections
1006    /// use tokio I/O. Returns an error if the guard holds a sync connection
1007    /// and this method is called.
1008    #[cfg(feature = "async")]
1009    pub async fn query_async(
1010        &mut self,
1011        sql: &str,
1012        sql_hash: u64,
1013        params: &[&(dyn Encode + Sync)],
1014    ) -> Result<QueryResult, DriverError> {
1015        #[cfg(feature = "detect-n-plus-one")]
1016        self.detector.track(sql_hash);
1017        match self.conn.as_mut() {
1018            Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
1019            Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
1020            None => Err(DriverError::Pool("connection already taken".into())),
1021        }
1022    }
1023
1024    /// Execute without result rows (async).
1025    #[cfg(feature = "async")]
1026    pub async fn execute_async(
1027        &mut self,
1028        sql: &str,
1029        sql_hash: u64,
1030        params: &[&(dyn Encode + Sync)],
1031    ) -> Result<u64, DriverError> {
1032        #[cfg(feature = "detect-n-plus-one")]
1033        self.detector.track(sql_hash);
1034        match self.conn.as_mut() {
1035            Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
1036            Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
1037            None => Err(DriverError::Pool("connection already taken".into())),
1038        }
1039    }
1040
1041    /// Execute a simple query (async).
1042    #[cfg(feature = "async")]
1043    pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
1044        match self.conn.as_mut() {
1045            Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
1046            Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
1047            None => Err(DriverError::Pool("connection already taken".into())),
1048        }
1049    }
1050
1051    // --- Deferred pipeline support ---
1052
1053    /// Ensure a statement is prepared and cached.
1054    pub(crate) fn ensure_stmt_prepared(
1055        &mut self,
1056        sql: &str,
1057        sql_hash: u64,
1058        params: &[&(dyn Encode + Sync)],
1059    ) -> Result<[u8; 18], DriverError> {
1060        self.sync_conn_mut()?
1061            .ensure_stmt_prepared(sql, sql_hash, params)
1062    }
1063
1064    /// Write Bind+Execute bytes for a prepared statement into an external buffer.
1065    pub(crate) fn write_deferred_bind_execute(
1066        &self,
1067        sql: &str,
1068        sql_hash: u64,
1069        params: &[&(dyn Encode + Sync)],
1070        buf: &mut Vec<u8>,
1071    ) {
1072        let conn = self
1073            .sync_conn()
1074            .expect("sync_conn failed in write_deferred");
1075        conn.write_deferred_bind_execute(sql, sql_hash, params, buf);
1076    }
1077
1078    /// Flush a buffer of deferred Bind+Execute messages as a single pipeline.
1079    pub(crate) fn flush_deferred_pipeline(
1080        &mut self,
1081        buf: &mut Vec<u8>,
1082        count: usize,
1083    ) -> Result<Vec<u64>, DriverError> {
1084        self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
1085    }
1086}
1087
1088impl Drop for PoolGuard {
1089    fn drop(&mut self) {
1090        #[cfg(feature = "detect-n-plus-one")]
1091        self.detector.emit_final_warning();
1092
1093        if let Some(slot) = self.conn.take() {
1094            // Check discard conditions based on slot type.
1095            let should_discard = self.discard
1096                || self.pool.closed.load(Ordering::Acquire)
1097                || match &slot {
1098                    PoolSlot::Sync(conn) => {
1099                        conn.is_in_failed_transaction()
1100                            || conn.is_in_transaction()
1101                            || conn.is_streaming()
1102                    }
1103                    #[cfg(feature = "async")]
1104                    PoolSlot::Async(conn) => {
1105                        conn.is_in_failed_transaction() || conn.is_in_transaction()
1106                    }
1107                };
1108
1109            if should_discard {
1110                self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1111                return;
1112            }
1113
1114            // Stamp last-used time for idle connection tracking.
1115            // Amortized: only call Instant::now() every 64 returns.
1116            let mut slot = slot;
1117            match &mut slot {
1118                PoolSlot::Sync(conn) => {
1119                    if conn.query_counter() & 63 == 0 {
1120                        conn.touch();
1121                    }
1122                }
1123                #[cfg(feature = "async")]
1124                PoolSlot::Async(conn) => {
1125                    if conn.query_counter() & 63 == 0 {
1126                        conn.touch();
1127                    }
1128                }
1129            }
1130
1131            // Return to pool
1132            {
1133                let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
1134                stack.push(slot);
1135            }
1136
1137            // Notify waiters only if pool was exhausted (someone might be waiting).
1138            if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
1139                let (_, cvar) = &self.pool.release_pair;
1140                cvar.notify_one();
1141            }
1142        }
1143    }
1144}
1145
1146// --- Transaction ---
1147
1148/// A database transaction. Sends ROLLBACK on drop if not committed.
1149///
1150/// # Example
1151///
1152/// ```no_run
1153/// # fn example() -> Result<(), bsql_driver_postgres::DriverError> {
1154/// # let pool = bsql_driver_postgres::Pool::connect("postgres://user:pass@localhost/db")?;
1155/// let mut tx = pool.begin()?;
1156/// tx.simple_query("INSERT INTO t VALUES (1)")?;
1157/// tx.commit()?;
1158/// # Ok(())
1159/// # }
1160/// ```
1161pub struct Transaction {
1162    guard: PoolGuard,
1163    committed: bool,
1164    /// Accumulated Bind+Execute message bytes for deferred operations.
1165    deferred_buf: Vec<u8>,
1166    /// Number of deferred operations buffered.
1167    deferred_count: usize,
1168}
1169
1170impl Transaction {
1171    /// Commit the transaction.
1172    ///
1173    /// Automatically flushes any deferred operations before committing.
1174    pub fn commit(mut self) -> Result<(), DriverError> {
1175        if self.deferred_count > 0 {
1176            self.flush_deferred()?;
1177        }
1178        self.guard.simple_query("COMMIT")?;
1179        self.committed = true;
1180        Ok(())
1181    }
1182
1183    /// Rollback the transaction explicitly.
1184    ///
1185    /// Discards any deferred operations without sending them.
1186    pub fn rollback(mut self) -> Result<(), DriverError> {
1187        self.deferred_buf.clear();
1188        self.deferred_count = 0;
1189        self.guard.simple_query("ROLLBACK")?;
1190        self.committed = true; // prevent double rollback in drop
1191        Ok(())
1192    }
1193
1194    /// Execute a prepared query within the transaction.
1195    ///
1196    /// Automatically flushes any deferred operations before executing the query,
1197    /// ensuring read-your-writes consistency.
1198    pub fn query(
1199        &mut self,
1200        sql: &str,
1201        sql_hash: u64,
1202        params: &[&(dyn Encode + Sync)],
1203    ) -> Result<QueryResult, DriverError> {
1204        if self.deferred_count > 0 {
1205            self.flush_deferred()?;
1206        }
1207        self.guard.query(sql, sql_hash, params)
1208    }
1209
1210    /// Execute without result rows within the transaction.
1211    pub fn execute(
1212        &mut self,
1213        sql: &str,
1214        sql_hash: u64,
1215        params: &[&(dyn Encode + Sync)],
1216    ) -> Result<u64, DriverError> {
1217        self.guard.execute(sql, sql_hash, params)
1218    }
1219
1220    /// Execute the same statement N times with different params in one pipeline.
1221    pub fn execute_pipeline(
1222        &mut self,
1223        sql: &str,
1224        sql_hash: u64,
1225        param_sets: &[&[&(dyn Encode + Sync)]],
1226    ) -> Result<Vec<u64>, DriverError> {
1227        self.guard.execute_pipeline(sql, sql_hash, param_sets)
1228    }
1229
1230    /// Process each row directly from the wire buffer within a transaction.
1231    ///
1232    /// Automatically flushes any deferred operations first.
1233    pub fn for_each<F>(
1234        &mut self,
1235        sql: &str,
1236        sql_hash: u64,
1237        params: &[&(dyn Encode + Sync)],
1238        f: F,
1239    ) -> Result<(), DriverError>
1240    where
1241        F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
1242    {
1243        if self.deferred_count > 0 {
1244            self.flush_deferred()?;
1245        }
1246        self.guard.for_each(sql, sql_hash, params, f)
1247    }
1248
1249    /// Process each DataRow as raw bytes within a transaction.
1250    ///
1251    /// Automatically flushes any deferred operations first.
1252    pub fn for_each_raw<F>(
1253        &mut self,
1254        sql: &str,
1255        sql_hash: u64,
1256        params: &[&(dyn Encode + Sync)],
1257        f: F,
1258    ) -> Result<(), DriverError>
1259    where
1260        F: FnMut(&[u8]) -> Result<(), DriverError>,
1261    {
1262        if self.deferred_count > 0 {
1263            self.flush_deferred()?;
1264        }
1265        self.guard.for_each_raw(sql, sql_hash, params, f)
1266    }
1267
1268    /// Simple query within the transaction.
1269    ///
1270    /// Automatically flushes any deferred operations first.
1271    pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
1272        if self.deferred_count > 0 {
1273            self.flush_deferred()?;
1274        }
1275        self.guard.simple_query(sql)
1276    }
1277
1278    // --- Deferred pipeline API ---
1279
1280    /// Buffer an execute for deferred pipeline flush.
1281    ///
1282    /// The operation is not sent to the server immediately. Instead, the
1283    /// Bind+Execute message bytes are buffered internally. The buffered
1284    /// operations are sent as a single pipeline on [`commit()`](Self::commit)
1285    /// or [`flush_deferred()`](Self::flush_deferred).
1286    ///
1287    /// # Example
1288    ///
1289    /// ```no_run
1290    /// # fn example() -> Result<(), bsql_driver_postgres::DriverError> {
1291    /// # let pool = bsql_driver_postgres::Pool::connect("postgres://u:p@localhost/db")?;
1292    /// let mut tx = pool.begin()?;
1293    /// let sql = "INSERT INTO t (v) VALUES ($1)";
1294    /// let hash = bsql_driver_postgres::hash_sql(sql);
1295    ///
1296    /// // These are buffered, not sent:
1297    /// tx.defer_execute(sql, hash, &[&1i32])?;
1298    /// tx.defer_execute(sql, hash, &[&2i32])?;
1299    /// tx.defer_execute(sql, hash, &[&3i32])?;
1300    ///
1301    /// // commit() flushes all 3 as one pipeline + COMMIT = 2 round-trips total
1302    /// tx.commit()?;
1303    /// # Ok(())
1304    /// # }
1305    /// ```
1306    pub fn defer_execute(
1307        &mut self,
1308        sql: &str,
1309        sql_hash: u64,
1310        params: &[&(dyn Encode + Sync)],
1311    ) -> Result<(), DriverError> {
1312        if params.len() > i16::MAX as usize {
1313            return Err(DriverError::Protocol(format!(
1314                "parameter count {} exceeds maximum {}",
1315                params.len(),
1316                i16::MAX
1317            )));
1318        }
1319
1320        // Ensure statement is prepared (may require one round-trip on first call)
1321        self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
1322
1323        // Buffer the Bind+Execute bytes — no I/O
1324        self.guard
1325            .write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf);
1326        self.deferred_count += 1;
1327        Ok(())
1328    }
1329
1330    /// Flush all deferred operations as a single pipeline.
1331    ///
1332    /// Sends all buffered Bind+Execute messages + one Sync in a single TCP write.
1333    /// Returns the affected row count for each deferred operation.
1334    pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
1335        let count = self.deferred_count;
1336        self.deferred_count = 0;
1337        self.guard
1338            .flush_deferred_pipeline(&mut self.deferred_buf, count)
1339    }
1340
1341    /// Number of operations currently buffered for deferred execution.
1342    pub fn deferred_count(&self) -> usize {
1343        self.deferred_count
1344    }
1345}
1346
1347impl Drop for Transaction {
1348    fn drop(&mut self) {
1349        if !self.committed {
1350            // Connection is in an uncommitted transaction — discard it from the pool.
1351            // Take the connection out of the guard and drop it, decrementing open_count.
1352            if let Some(_slot) = self.guard.conn.take() {
1353                self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
1354                // Connection dropped — PG server will auto-rollback when it sees disconnect
1355            }
1356        }
1357    }
1358}
1359
1360#[cfg(test)]
1361mod tests {
1362    use super::*;
1363
1364    #[test]
1365    fn pool_builder_requires_url() {
1366        let result = PoolBuilder::new().build();
1367        assert!(result.is_err());
1368    }
1369
1370    #[test]
1371    fn pool_builder_validates_url() {
1372        let result = PoolBuilder::new().url("not_a_url").build();
1373        assert!(result.is_err());
1374    }
1375
1376    #[test]
1377    fn pool_builder_accepts_valid_url() {
1378        let pool = PoolBuilder::new()
1379            .url("postgres://user:pass@localhost/db")
1380            .max_size(5)
1381            .build()
1382            .unwrap();
1383        assert_eq!(pool.max_size(), 5);
1384        assert_eq!(pool.open_count(), 0);
1385    }
1386
1387    #[test]
1388    fn pool_connect_validates_url() {
1389        let result = Pool::connect("not_a_url");
1390        assert!(result.is_err());
1391    }
1392
1393    #[test]
1394    fn pool_max_size_zero() {
1395        let pool = PoolBuilder::new()
1396            .url("postgres://user:pass@localhost/db")
1397            .max_size(0)
1398            .build()
1399            .unwrap();
1400
1401        let result = pool.acquire();
1402        assert!(result.is_err());
1403        match result {
1404            Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1405            Err(e) => panic!("expected Pool error, got: {e:?}"),
1406            Ok(_) => panic!("expected error, got Ok"),
1407        }
1408    }
1409
1410    #[test]
1411    fn pool_clone_shares_state() {
1412        let pool = PoolBuilder::new()
1413            .url("postgres://user:pass@localhost/db")
1414            .max_size(5)
1415            .build()
1416            .unwrap();
1417
1418        let pool2 = pool.clone();
1419        assert_eq!(pool.max_size(), pool2.max_size());
1420    }
1421
1422    // --- Audit gap tests ---
1423
1424    // #60: max_lifetime is configurable
1425    #[test]
1426    fn pool_builder_max_lifetime() {
1427        let pool = PoolBuilder::new()
1428            .url("postgres://user:pass@localhost/db")
1429            .max_lifetime(Some(Duration::from_secs(60)))
1430            .build()
1431            .unwrap();
1432        assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
1433    }
1434
1435    // #60: max_lifetime None
1436    #[test]
1437    fn pool_builder_max_lifetime_none() {
1438        let pool = PoolBuilder::new()
1439            .url("postgres://user:pass@localhost/db")
1440            .max_lifetime(None)
1441            .build()
1442            .unwrap();
1443        assert_eq!(pool.inner.max_lifetime, None);
1444    }
1445
1446    // #62: acquire_timeout set to None (fail-fast)
1447    #[test]
1448    fn pool_builder_acquire_timeout_none() {
1449        let pool = PoolBuilder::new()
1450            .url("postgres://user:pass@localhost/db")
1451            .acquire_timeout(None)
1452            .build()
1453            .unwrap();
1454        assert_eq!(pool.inner.acquire_timeout, None);
1455    }
1456
1457    // #62: acquire_timeout custom value
1458    #[test]
1459    fn pool_builder_acquire_timeout_custom() {
1460        let pool = PoolBuilder::new()
1461            .url("postgres://user:pass@localhost/db")
1462            .acquire_timeout(Some(Duration::from_secs(10)))
1463            .build()
1464            .unwrap();
1465        assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
1466    }
1467
1468    // #63: min_idle setting
1469    #[test]
1470    fn pool_builder_min_idle() {
1471        let pool = PoolBuilder::new()
1472            .url("postgres://user:pass@localhost/db")
1473            .min_idle(2)
1474            .build()
1475            .unwrap();
1476        assert_eq!(pool.inner.min_idle, 2);
1477    }
1478
1479    // #64: Pool close marks pool as closed
1480    #[test]
1481    fn pool_close_marks_closed() {
1482        let pool = PoolBuilder::new()
1483            .url("postgres://user:pass@localhost/db")
1484            .max_size(5)
1485            .build()
1486            .unwrap();
1487
1488        assert!(!pool.is_closed());
1489        pool.close();
1490        assert!(pool.is_closed());
1491
1492        // New acquires should fail
1493        let result = pool.acquire();
1494        assert!(result.is_err());
1495        match result {
1496            Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
1497            Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
1498            Ok(_) => panic!("expected error, got Ok"),
1499        }
1500    }
1501
1502    // #67: PoolStatus idle/active counts
1503    #[test]
1504    fn pool_status_initial() {
1505        let pool = PoolBuilder::new()
1506            .url("postgres://user:pass@localhost/db")
1507            .max_size(10)
1508            .build()
1509            .unwrap();
1510
1511        let status = pool.status();
1512        assert_eq!(status.idle, 0);
1513        assert_eq!(status.active, 0);
1514        assert_eq!(status.open, 0);
1515        assert_eq!(status.max_size, 10);
1516    }
1517
1518    // Default pool builder values
1519    #[test]
1520    fn pool_builder_defaults() {
1521        let pool = PoolBuilder::new()
1522            .url("postgres://user:pass@localhost/db")
1523            .build()
1524            .unwrap();
1525
1526        assert_eq!(pool.max_size(), 10);
1527        assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
1528        assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1529        assert_eq!(pool.inner.min_idle, 0);
1530    }
1531
1532    // Pool open_count starts at 0
1533    #[test]
1534    fn pool_open_count_initial() {
1535        let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1536        assert_eq!(pool.open_count(), 0);
1537    }
1538
1539    // --- Task 7: max_stmt_cache_size ---
1540
1541    #[test]
1542    fn pool_builder_max_stmt_cache_size_default() {
1543        let pool = PoolBuilder::new()
1544            .url("postgres://user:pass@localhost/db")
1545            .build()
1546            .unwrap();
1547        assert_eq!(pool.inner.max_stmt_cache_size, 256);
1548    }
1549
1550    #[test]
1551    fn pool_builder_max_stmt_cache_size_custom() {
1552        let pool = PoolBuilder::new()
1553            .url("postgres://user:pass@localhost/db")
1554            .max_stmt_cache_size(512)
1555            .build()
1556            .unwrap();
1557        assert_eq!(pool.inner.max_stmt_cache_size, 512);
1558    }
1559
1560    // --- Auto-UDS detection tests ---
1561
1562    #[test]
1563    fn pool_is_uds_false_for_tcp() {
1564        let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1565        assert!(!pool.is_uds());
1566    }
1567
1568    #[cfg(unix)]
1569    #[test]
1570    fn pool_is_uds_true_for_unix_socket() {
1571        let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1572        assert!(pool.is_uds());
1573    }
1574
1575    #[cfg(unix)]
1576    #[test]
1577    fn pool_is_uds_true_for_var_run_socket() {
1578        let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
1579        assert!(pool.is_uds());
1580    }
1581
1582    #[test]
1583    fn pool_is_uds_false_for_ip_address() {
1584        let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
1585        assert!(!pool.is_uds());
1586    }
1587
1588    #[cfg(unix)]
1589    #[test]
1590    fn pool_slot_sync_created_for_uds_config() {
1591        let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1592        assert!(config.host_is_uds());
1593    }
1594
1595    #[test]
1596    fn pool_slot_tcp_config() {
1597        let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1598        assert!(!config.host_is_uds());
1599    }
1600
1601    // ===============================================================
1602    // Pool::is_uds — extended tests
1603    // ===============================================================
1604
1605    #[test]
1606    fn pool_is_uds_false_for_hostname() {
1607        let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
1608        assert!(!pool.is_uds());
1609    }
1610
1611    #[cfg(unix)]
1612    #[test]
1613    fn pool_is_uds_true_for_tmp() {
1614        let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
1615        assert!(pool.is_uds());
1616    }
1617
1618    // ===============================================================
1619    // Pool close semantics
1620    // ===============================================================
1621
1622    #[test]
1623    fn pool_close_then_acquire_fails() {
1624        let pool = PoolBuilder::new()
1625            .url("postgres://user:pass@localhost/db")
1626            .max_size(5)
1627            .build()
1628            .unwrap();
1629        pool.close();
1630        let result = pool.acquire();
1631        assert!(result.is_err());
1632        match result {
1633            Err(DriverError::Pool(msg)) => {
1634                assert!(msg.contains("closed"), "should say closed: {msg}")
1635            }
1636            Err(e) => panic!("expected Pool error, got: {e:?}"),
1637            Ok(_) => panic!("expected error"),
1638        }
1639    }
1640
1641    #[test]
1642    fn pool_is_closed_before_and_after() {
1643        let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1644        assert!(!pool.is_closed());
1645        pool.close();
1646        assert!(pool.is_closed());
1647    }
1648
1649    // ===============================================================
1650    // Pool exhaustion (fail-fast without timeout)
1651    // ===============================================================
1652
1653    #[test]
1654    fn pool_exhausted_no_timeout() {
1655        let pool = PoolBuilder::new()
1656            .url("postgres://user:pass@localhost/db")
1657            .max_size(0)
1658            .acquire_timeout(None) // fail-fast
1659            .build()
1660            .unwrap();
1661        let result = pool.acquire();
1662        assert!(result.is_err());
1663        match result {
1664            Err(DriverError::Pool(msg)) => {
1665                assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
1666            }
1667            Err(e) => panic!("expected Pool error, got: {e:?}"),
1668            Ok(_) => panic!("expected error"),
1669        }
1670    }
1671
1672    // ===============================================================
1673    // PoolBuilder validation
1674    // ===============================================================
1675
1676    #[test]
1677    fn pool_builder_no_url_error() {
1678        let result = PoolBuilder::new().max_size(5).build();
1679        assert!(result.is_err());
1680        match result {
1681            Err(DriverError::Pool(msg)) => {
1682                assert!(msg.contains("URL"), "should mention URL: {msg}")
1683            }
1684            Err(e) => panic!("expected Pool error, got: {e:?}"),
1685            Ok(_) => panic!("expected error"),
1686        }
1687    }
1688
1689    #[test]
1690    fn pool_builder_invalid_url_error() {
1691        let result = PoolBuilder::new().url("ftp://something").build();
1692        assert!(result.is_err());
1693    }
1694
1695    #[test]
1696    fn pool_builder_stmt_cache_size_zero() {
1697        let pool = PoolBuilder::new()
1698            .url("postgres://user:pass@localhost/db")
1699            .max_stmt_cache_size(0)
1700            .build()
1701            .unwrap();
1702        assert_eq!(pool.inner.max_stmt_cache_size, 0);
1703    }
1704
1705    // --- Gap: stale_timeout builder config ---
1706
1707    #[test]
1708    fn pool_builder_stale_timeout_default() {
1709        let pool = PoolBuilder::new()
1710            .url("postgres://user:pass@localhost/db")
1711            .build()
1712            .unwrap();
1713        assert_eq!(pool.inner.stale_timeout, Duration::from_secs(30));
1714    }
1715
1716    #[test]
1717    fn pool_builder_stale_timeout_custom() {
1718        let pool = PoolBuilder::new()
1719            .url("postgres://user:pass@localhost/db")
1720            .stale_timeout(Duration::from_secs(60))
1721            .build()
1722            .unwrap();
1723        assert_eq!(pool.inner.stale_timeout, Duration::from_secs(60));
1724    }
1725
1726    #[test]
1727    fn pool_builder_stale_timeout_zero() {
1728        let pool = PoolBuilder::new()
1729            .url("postgres://user:pass@localhost/db")
1730            .stale_timeout(Duration::from_secs(0))
1731            .build()
1732            .unwrap();
1733        assert_eq!(pool.inner.stale_timeout, Duration::from_secs(0));
1734    }
1735
1736    // ===============================================================
1737    // PoolStatus
1738    // ===============================================================
1739
1740    #[test]
1741    fn pool_status_reflects_max_size() {
1742        let pool = PoolBuilder::new()
1743            .url("postgres://user:pass@localhost/db")
1744            .max_size(20)
1745            .build()
1746            .unwrap();
1747        let status = pool.status();
1748        assert_eq!(status.max_size, 20);
1749        assert_eq!(status.idle, 0);
1750        assert_eq!(status.active, 0);
1751        assert_eq!(status.open, 0);
1752    }
1753
1754    // ===============================================================
1755    // Pool clone
1756    // ===============================================================
1757
1758    #[test]
1759    fn pool_clone_shares_config() {
1760        let pool = PoolBuilder::new()
1761            .url("postgres://user:pass@localhost/db")
1762            .max_size(7)
1763            .build()
1764            .unwrap();
1765        let p2 = pool.clone();
1766        assert_eq!(pool.max_size(), 7);
1767        assert_eq!(p2.max_size(), 7);
1768        assert_eq!(pool.open_count(), p2.open_count());
1769    }
1770
1771    // ===============================================================
1772    // set_warmup_sqls
1773    // ===============================================================
1774
1775    #[test]
1776    fn pool_set_warmup_sqls_empty() {
1777        let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1778        pool.set_warmup_sqls(&[]);
1779        let sqls = pool
1780            .inner
1781            .warmup_sqls
1782            .lock()
1783            .unwrap_or_else(|e| e.into_inner())
1784            .clone();
1785        assert!(sqls.is_empty());
1786    }
1787
1788    #[test]
1789    fn pool_set_warmup_sqls_multiple() {
1790        let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1791        pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
1792        let sqls = pool
1793            .inner
1794            .warmup_sqls
1795            .lock()
1796            .unwrap_or_else(|e| e.into_inner())
1797            .clone();
1798        assert_eq!(sqls.len(), 3);
1799        assert_eq!(&*sqls[0], "SELECT 1");
1800        assert_eq!(&*sqls[1], "SELECT 2");
1801        assert_eq!(&*sqls[2], "SELECT 3");
1802    }
1803
1804    #[test]
1805    fn pool_set_warmup_sqls_overwrite() {
1806        let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1807        pool.set_warmup_sqls(&["SELECT 1"]);
1808        pool.set_warmup_sqls(&["SELECT 99"]);
1809        let sqls = pool
1810            .inner
1811            .warmup_sqls
1812            .lock()
1813            .unwrap_or_else(|e| e.into_inner())
1814            .clone();
1815        assert_eq!(sqls.len(), 1);
1816        assert_eq!(&*sqls[0], "SELECT 99");
1817    }
1818
1819    // ===============================================================
1820    // PoolStatus Debug
1821    // ===============================================================
1822
1823    #[test]
1824    fn pool_status_debug() {
1825        let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
1826        let status = pool.status();
1827        let dbg = format!("{status:?}");
1828        assert!(dbg.contains("PoolStatus"));
1829        assert!(dbg.contains("idle"));
1830        assert!(dbg.contains("active"));
1831        assert!(dbg.contains("open"));
1832        assert!(dbg.contains("max_size"));
1833    }
1834
1835    // ===============================================================
1836    // Config host_is_uds via pool (structural tests)
1837    // ===============================================================
1838
1839    #[test]
1840    fn config_host_is_uds_returns_true_for_slash() {
1841        let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
1842        assert!(config.host_is_uds());
1843    }
1844
1845    #[test]
1846    fn config_host_is_uds_returns_false_for_tcp() {
1847        let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
1848        assert!(!config.host_is_uds());
1849    }
1850
1851    #[test]
1852    fn config_host_is_uds_returns_false_for_ip() {
1853        let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
1854        assert!(!config.host_is_uds());
1855    }
1856
1857    // ===============================================================
1858    // PoolBuilder chaining
1859    // ===============================================================
1860
1861    #[test]
1862    fn pool_builder_full_chain() {
1863        let pool = PoolBuilder::new()
1864            .url("postgres://user:pass@localhost/db")
1865            .max_size(3)
1866            .max_lifetime(Some(Duration::from_secs(600)))
1867            .acquire_timeout(Some(Duration::from_secs(5)))
1868            .min_idle(1)
1869            .max_stmt_cache_size(128)
1870            .build()
1871            .unwrap();
1872        assert_eq!(pool.max_size(), 3);
1873        assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
1874        assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
1875        assert_eq!(pool.inner.min_idle, 1);
1876        assert_eq!(pool.inner.max_stmt_cache_size, 128);
1877    }
1878
1879    // --- Audit: PoolGuard drop discards connections in bad state ---
1880
1881    #[test]
1882    fn pool_max_size_zero_rejects_all_acquires() {
1883        let pool = PoolBuilder::new()
1884            .url("postgres://user:pass@localhost/db")
1885            .max_size(0)
1886            .build()
1887            .unwrap();
1888        let result = pool.acquire();
1889        assert!(result.is_err());
1890        match &result {
1891            Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
1892            _ => panic!("expected pool exhausted error"),
1893        }
1894    }
1895
1896    // --- Audit: URL parsing edge cases ---
1897
1898    #[test]
1899    fn url_parse_unknown_sslmode_returns_error() {
1900        let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
1901        assert!(result.is_err());
1902        let msg = format!("{}", result.unwrap_err());
1903        assert!(msg.contains("unknown sslmode"));
1904    }
1905
1906    #[test]
1907    fn url_parse_invalid_port_returns_error() {
1908        let result = Config::from_url("postgres://u:p@h:abc/d");
1909        assert!(result.is_err());
1910        let msg = format!("{}", result.unwrap_err());
1911        assert!(msg.contains("invalid port"));
1912    }
1913
1914    #[test]
1915    fn url_parse_missing_at_sign_returns_error() {
1916        let result = Config::from_url("postgres://u:plocalhost/d");
1917        assert!(result.is_err());
1918        let msg = format!("{}", result.unwrap_err());
1919        assert!(msg.contains("missing @"));
1920    }
1921
1922    #[test]
1923    fn url_parse_empty_host_returns_error() {
1924        let result = Config::from_url("postgres://u:p@/d");
1925        assert!(result.is_err());
1926    }
1927
1928    #[test]
1929    fn url_parse_empty_user_returns_error() {
1930        let result = Config::from_url("postgres://:p@h/d");
1931        assert!(result.is_err());
1932    }
1933
1934    #[test]
1935    fn url_parse_statement_timeout_invalid_uses_default() {
1936        let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
1937        assert_eq!(config.statement_timeout_secs, 30);
1938    }
1939
1940    #[test]
1941    fn url_parse_malformed_percent_encoding() {
1942        let result = Config::from_url("postgres://u%:p@h/d");
1943        assert!(result.is_err());
1944    }
1945
1946    #[test]
1947    fn url_parse_invalid_hex_in_percent_encoding() {
1948        let result = Config::from_url("postgres://u%ZZ:p@h/d");
1949        assert!(result.is_err());
1950    }
1951}
1952
1953// --- N+1 detector tests ---
1954
1955#[cfg(all(test, feature = "detect-n-plus-one"))]
1956mod n_plus_one_tests {
1957    use super::NPlusOneDetector;
1958
1959    #[test]
1960    fn below_threshold_no_warning() {
1961        let mut d = NPlusOneDetector::new(10);
1962        for _ in 0..10 {
1963            d.track(42);
1964        }
1965        assert!(d.check_final().is_none());
1966    }
1967
1968    #[test]
1969    fn above_threshold_warns() {
1970        let mut d = NPlusOneDetector::new(10);
1971        for _ in 0..11 {
1972            d.track(42);
1973        }
1974        let w = d.check_final().unwrap();
1975        assert_eq!(w, (42, 11));
1976    }
1977
1978    #[test]
1979    fn exact_threshold_no_warning() {
1980        let mut d = NPlusOneDetector::new(5);
1981        for _ in 0..5 {
1982            d.track(99);
1983        }
1984        assert!(d.check_final().is_none(), "> not >=");
1985    }
1986
1987    #[test]
1988    fn threshold_plus_one_warns() {
1989        let mut d = NPlusOneDetector::new(5);
1990        for _ in 0..6 {
1991            d.track(99);
1992        }
1993        assert_eq!(d.check_final(), Some((99, 6)));
1994    }
1995
1996    #[test]
1997    fn alternating_hashes_no_warning() {
1998        let mut d = NPlusOneDetector::new(2);
1999        for i in 0..100 {
2000            d.track(if i % 2 == 0 { 1 } else { 2 });
2001        }
2002        assert!(d.check_final().is_none());
2003    }
2004
2005    #[test]
2006    fn single_query_no_warning() {
2007        let mut d = NPlusOneDetector::new(10);
2008        d.track(42);
2009        assert!(d.check_final().is_none());
2010    }
2011
2012    #[test]
2013    fn no_queries_no_warning() {
2014        let d = NPlusOneDetector::new(10);
2015        assert!(d.check_final().is_none());
2016    }
2017
2018    #[test]
2019    fn threshold_zero_warns_on_second() {
2020        let mut d = NPlusOneDetector::new(0);
2021        d.track(42);
2022        // count=1, threshold=0 -> 1 > 0 -> warn
2023        assert_eq!(d.check_final(), Some((42, 1)));
2024    }
2025
2026    #[test]
2027    fn threshold_max_never_warns() {
2028        let mut d = NPlusOneDetector::new(u16::MAX);
2029        for _ in 0..1000 {
2030            d.track(42);
2031        }
2032        assert!(d.check_final().is_none());
2033    }
2034
2035    #[test]
2036    fn saturating_add_no_overflow() {
2037        let mut d = NPlusOneDetector::new(10);
2038        d.last_query_hash = 42;
2039        d.repeat_count = u16::MAX - 1;
2040        d.track(42); // saturating_add -> MAX
2041        d.track(42); // saturating_add -> still MAX
2042        assert_eq!(d.repeat_count, u16::MAX);
2043    }
2044
2045    #[test]
2046    fn different_hash_resets() {
2047        let mut d = NPlusOneDetector::new(100);
2048        for _ in 0..50 {
2049            d.track(1);
2050        }
2051        d.track(2); // resets
2052        assert_eq!(d.repeat_count, 1);
2053        assert_eq!(d.last_query_hash, 2);
2054    }
2055
2056    #[test]
2057    fn multiple_n_plus_one_sequences() {
2058        let mut d = NPlusOneDetector::new(3);
2059        // First sequence: hash=1, 5 times (>3 -> warning on switch)
2060        for _ in 0..5 {
2061            d.track(1);
2062        }
2063        // Switch triggers warning for hash=1
2064        // Second sequence: hash=2, 4 times (>3 -> check_final catches it)
2065        for _ in 0..4 {
2066            d.track(2);
2067        }
2068        // check_final sees hash=2, count=4 > 3
2069        assert_eq!(d.check_final(), Some((2, 4)));
2070    }
2071
2072    #[test]
2073    fn warning_emitted_on_hash_switch() {
2074        let mut d = NPlusOneDetector::new(2);
2075        d.track(10);
2076        d.track(10);
2077        d.track(10); // count=3 > 2
2078                     // Switch hash — this internally calls emit_warning for hash=10
2079        d.track(20);
2080        // Now tracking hash=20, count=1
2081        assert_eq!(d.last_query_hash, 20);
2082        assert_eq!(d.repeat_count, 1);
2083    }
2084
2085    #[test]
2086    fn hash_zero_treated_normally() {
2087        let mut d = NPlusOneDetector::new(2);
2088        d.track(0);
2089        d.track(0);
2090        d.track(0);
2091        // hash=0 but check_final requires hash != 0 — no warning
2092        assert!(d.check_final().is_none());
2093    }
2094
2095    #[test]
2096    fn long_sequence_correct_count() {
2097        let mut d = NPlusOneDetector::new(10);
2098        for _ in 0..500 {
2099            d.track(42);
2100        }
2101        assert_eq!(d.check_final(), Some((42, 500)));
2102    }
2103
2104    #[test]
2105    fn two_queries_below_threshold() {
2106        let mut d = NPlusOneDetector::new(10);
2107        d.track(1);
2108        d.track(1);
2109        assert!(d.check_final().is_none());
2110    }
2111
2112    #[test]
2113    fn interleaved_then_burst() {
2114        let mut d = NPlusOneDetector::new(3);
2115        // Interleaved: no trigger
2116        d.track(1);
2117        d.track(2);
2118        d.track(1);
2119        d.track(2);
2120        // Burst: hash=5, 5 times
2121        for _ in 0..5 {
2122            d.track(5);
2123        }
2124        assert_eq!(d.check_final(), Some((5, 5)));
2125    }
2126
2127    // --- Builder threshold wiring ---
2128
2129    #[test]
2130    fn pool_builder_n_plus_one_threshold_default() {
2131        let pool = super::PoolBuilder::new()
2132            .url("postgres://user:pass@localhost/db")
2133            .build()
2134            .unwrap();
2135        assert_eq!(pool.inner.n_plus_one_threshold, 10);
2136    }
2137
2138    #[test]
2139    fn pool_builder_n_plus_one_threshold_custom() {
2140        let pool = super::PoolBuilder::new()
2141            .url("postgres://user:pass@localhost/db")
2142            .n_plus_one_threshold(5)
2143            .build()
2144            .unwrap();
2145        assert_eq!(pool.inner.n_plus_one_threshold, 5);
2146    }
2147
2148    #[test]
2149    fn pool_builder_n_plus_one_threshold_zero() {
2150        let pool = super::PoolBuilder::new()
2151            .url("postgres://user:pass@localhost/db")
2152            .n_plus_one_threshold(0)
2153            .build()
2154            .unwrap();
2155        assert_eq!(pool.inner.n_plus_one_threshold, 0);
2156    }
2157
2158    #[test]
2159    fn pool_builder_n_plus_one_threshold_max() {
2160        let pool = super::PoolBuilder::new()
2161            .url("postgres://user:pass@localhost/db")
2162            .n_plus_one_threshold(u16::MAX)
2163            .build()
2164            .unwrap();
2165        assert_eq!(pool.inner.n_plus_one_threshold, u16::MAX);
2166    }
2167
2168    #[test]
2169    fn one_then_different_no_warning() {
2170        let mut d = NPlusOneDetector::new(10);
2171        d.track(1);
2172        d.track(2);
2173        // hash=1 had count=1 (below 10), hash=2 has count=1 (below 10)
2174        assert!(d.check_final().is_none());
2175    }
2176
2177    #[test]
2178    fn nonzero_hash_after_zero_init() {
2179        // First call with nonzero hash: else branch (0 != hash),
2180        // emit_warning for old (hash=0, count=0) - nothing.
2181        // Set last=hash, count=1.
2182        let mut d = NPlusOneDetector::new(0);
2183        d.track(42);
2184        let w = d.check_final().unwrap();
2185        assert_eq!(w, (42, 1));
2186    }
2187
2188    #[test]
2189    fn independent_detectors_dont_interfere() {
2190        // Each PoolGuard has its own detector -- verify independence
2191        let mut d1 = NPlusOneDetector::new(5);
2192        let mut d2 = NPlusOneDetector::new(5);
2193
2194        // d1 gets N+1 pattern
2195        for _ in 0..10 {
2196            d1.track(42);
2197        }
2198        // d2 gets different pattern
2199        d2.track(1);
2200        d2.track(2);
2201        d2.track(3);
2202
2203        // d1 should warn, d2 should not
2204        assert!(d1.check_final().is_some());
2205        assert!(d2.check_final().is_none());
2206    }
2207
2208    #[test]
2209    fn rapid_hash_changes_dont_false_positive() {
2210        // Rapid switching between many different hashes should never trigger
2211        let mut d = NPlusOneDetector::new(2);
2212        for i in 0u64..1000 {
2213            d.track(i);
2214        }
2215        // Final hash (999) was only tracked once
2216        assert!(d.check_final().is_none());
2217    }
2218
2219    #[test]
2220    fn detector_reset_state_after_warning() {
2221        // After a sequence triggers, the next sequence starts fresh
2222        let mut d = NPlusOneDetector::new(2);
2223        d.track(1);
2224        d.track(1);
2225        d.track(1); // count=3 > 2, would warn on switch
2226        d.track(2); // switch triggers warning for hash=1, resets to hash=2, count=1
2227        d.track(2); // count=2, not > 2
2228        assert!(d.check_final().is_none()); // hash=2, count=2, not > threshold=2
2229    }
2230
2231    #[test]
2232    fn detector_with_realistic_orm_pattern() {
2233        // Simulate: fetch users, then for each user fetch orders (N+1)
2234        let mut d = NPlusOneDetector::new(5);
2235        d.track(100); // SELECT * FROM users
2236                      // N+1 pattern: same query per user
2237        for _ in 0..20 {
2238            d.track(200); // SELECT * FROM orders WHERE user_id = ?
2239        }
2240        // Should detect the orders query
2241        assert_eq!(d.check_final(), Some((200, 20)));
2242    }
2243
2244    #[test]
2245    fn detector_with_legitimate_batch_pattern() {
2246        // Legitimate: different params but same prepared statement hash
2247        // This IS an N+1 and SHOULD be detected
2248        let mut d = NPlusOneDetector::new(10);
2249        for _ in 0..15 {
2250            d.track(300); // same sql_hash, different params (detector doesn't see params)
2251        }
2252        assert!(d.check_final().is_some());
2253    }
2254
2255    #[test]
2256    fn detector_exactly_at_boundaries() {
2257        for threshold in [0u16, 1, 2, 5, 10, 100] {
2258            let mut d = NPlusOneDetector::new(threshold);
2259            for _ in 0..=threshold {
2260                d.track(42);
2261            }
2262            // count == threshold + 1, should warn (> not >=)
2263            assert!(
2264                d.check_final().is_some(),
2265                "threshold={threshold} should warn at count={}",
2266                threshold + 1
2267            );
2268        }
2269    }
2270
2271    #[test]
2272    fn detector_with_deterministic_random_sequences() {
2273        // Deterministic "random" hash sequences
2274        let mut d = NPlusOneDetector::new(5);
2275        let hashes: Vec<u64> = (0..100).map(|i| ((i * 7 + 3) % 4) as u64).collect();
2276        for &h in &hashes {
2277            d.track(h);
2278        }
2279        // Should not panic, result depends on sequence
2280        let _ = d.check_final();
2281    }
2282
2283    mod proptest_fuzz {
2284        use super::*;
2285        use proptest::prelude::*;
2286
2287        proptest! {
2288            #[test]
2289            fn detector_never_panics(
2290                hashes in proptest::collection::vec(0u64..100, 0..500),
2291                threshold in 0u16..100,
2292            ) {
2293                let mut d = NPlusOneDetector::new(threshold);
2294                for h in &hashes {
2295                    d.track(*h);
2296                }
2297                let _ = d.check_final();
2298            }
2299
2300            #[test]
2301            fn sequential_repeats_always_detected(
2302                hash in 1u64..u64::MAX,
2303                count in 2u16..1000,
2304                threshold in 0u16..100,
2305            ) {
2306                let mut d = NPlusOneDetector::new(threshold);
2307                for _ in 0..count {
2308                    d.track(hash);
2309                }
2310                if count > threshold {
2311                    assert!(d.check_final().is_some(),
2312                        "count={count} > threshold={threshold} should trigger");
2313                }
2314            }
2315        }
2316    }
2317}