use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use crate::arena::Arena;
use crate::codec::Encode;
use crate::conn::Connection;
use crate::types::{Config, PgDataRow, QueryResult, SimpleRow};
use crate::DriverError;
#[cfg(feature = "async")]
use crate::async_conn::AsyncConnection;
pub(crate) enum PoolSlot {
Sync(Connection),
#[cfg(feature = "async")]
Async(AsyncConnection),
}
#[cfg(feature = "detect-n-plus-one")]
pub(crate) struct NPlusOneDetector {
last_query_hash: u64,
repeat_count: u16,
threshold: u16,
}
#[cfg(feature = "detect-n-plus-one")]
impl NPlusOneDetector {
pub(crate) fn new(threshold: u16) -> Self {
Self {
last_query_hash: 0,
repeat_count: 0,
threshold,
}
}
#[inline]
pub(crate) fn track(&mut self, sql_hash: u64) {
if sql_hash == self.last_query_hash {
self.repeat_count = self.repeat_count.saturating_add(1);
} else {
self.emit_warning();
self.last_query_hash = sql_hash;
self.repeat_count = 1;
}
}
pub(crate) fn check_final(&self) -> Option<(u64, u16)> {
if self.repeat_count > self.threshold && self.last_query_hash != 0 {
Some((self.last_query_hash, self.repeat_count))
} else {
None
}
}
#[cold]
#[inline(never)]
fn emit_warning(&self) {
if let Some((hash, count)) = self.check_final() {
log::warn!(
"[bsql] potential N+1 detected: sql_hash={:#018x} repeated {} times (threshold: {})",
hash,
count,
self.threshold,
);
}
}
#[cold]
#[inline(never)]
pub(crate) fn emit_final_warning(&self) {
self.emit_warning();
}
}
pub struct Pool {
inner: Arc<PoolInner>,
}
struct PoolInner {
stack: std::sync::Mutex<Vec<PoolSlot>>,
max_size: usize,
open_count: AtomicUsize,
config: Arc<Config>,
closed: AtomicBool,
release_pair: (std::sync::Mutex<()>, std::sync::Condvar),
max_lifetime: Option<Duration>,
acquire_timeout: Option<Duration>,
min_idle: usize,
warmup_sqls: std::sync::Mutex<Arc<Vec<Box<str>>>>,
max_stmt_cache_size: usize,
stale_timeout: Duration,
#[cfg(feature = "detect-n-plus-one")]
n_plus_one_threshold: u16,
}
impl Pool {
pub fn connect(url: &str) -> Result<Self, DriverError> {
PoolBuilder::new().url(url).build()
}
pub fn builder() -> PoolBuilder {
PoolBuilder::new()
}
#[inline]
pub fn acquire(&self) -> Result<PoolGuard, DriverError> {
if self.inner.closed.load(Ordering::Acquire) {
return Err(DriverError::Pool("pool is closed".into()));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
loop {
let current = self.inner.open_count.load(Ordering::Acquire);
if current >= self.inner.max_size {
if let Some(timeout) = self.inner.acquire_timeout {
let (lock, cvar) = &self.inner.release_pair;
let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
let (_guard, result) = cvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|e| e.into_inner());
if result.timed_out() {
return Err(DriverError::Pool(
"pool exhausted: acquire timeout expired".into(),
));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
continue;
}
return Err(DriverError::Pool(
"pool exhausted: all connections in use".into(),
));
}
if self
.inner
.open_count
.compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
}
let conn_result = Connection::connect_arc(self.inner.config.clone());
match conn_result {
Ok(mut conn) => {
conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
self.warmup_conn(&mut conn);
Ok(PoolGuard {
conn: Some(PoolSlot::Sync(conn)),
pool: self.inner.clone(),
discard: false,
#[cfg(feature = "detect-n-plus-one")]
detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
})
}
Err(e) => {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
Err(e)
}
}
}
#[inline]
fn try_pop_idle(&self) -> Result<Option<PoolGuard>, DriverError> {
loop {
let (mut slot, needs_health_check) = {
let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
loop {
let Some(slot) = stack.pop() else {
return Ok(None);
};
let (created_at, idle_dur) = match &slot {
PoolSlot::Sync(conn) => (conn.created_at(), conn.idle_duration()),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => (conn.created_at(), conn.idle_duration()),
};
if let Some(max_lifetime) = self.inner.max_lifetime {
if created_at.elapsed() >= max_lifetime {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
continue;
}
}
if idle_dur >= self.inner.stale_timeout {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
continue;
}
break (slot, idle_dur > Duration::from_secs(5));
}
};
if needs_health_check {
let alive = match &mut slot {
PoolSlot::Sync(conn) => conn.simple_query("").is_ok(),
#[cfg(feature = "async")]
PoolSlot::Async(_) => true, };
if !alive {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
continue; }
}
return Ok(Some(PoolGuard {
conn: Some(slot),
pool: self.inner.clone(),
discard: false,
#[cfg(feature = "detect-n-plus-one")]
detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
}));
}
}
pub fn is_uds(&self) -> bool {
#[cfg(unix)]
{
self.inner.config.host_is_uds()
}
#[cfg(not(unix))]
{
false
}
}
pub fn begin(&self) -> Result<Transaction, DriverError> {
let mut guard = self.acquire()?;
guard.simple_query("BEGIN")?;
Ok(Transaction {
guard,
committed: false,
deferred_buf: Vec::new(),
deferred_count: 0,
})
}
pub fn open_count(&self) -> usize {
self.inner.open_count.load(Ordering::Relaxed)
}
pub fn max_size(&self) -> usize {
self.inner.max_size
}
pub fn status(&self) -> PoolStatus {
let idle = self
.inner
.stack
.lock()
.unwrap_or_else(|e| e.into_inner())
.len();
let open = self.inner.open_count.load(Ordering::Relaxed);
let active = open.saturating_sub(idle);
PoolStatus {
idle,
active,
open,
max_size: self.inner.max_size,
}
}
fn warmup_conn(&self, conn: &mut Connection) {
let sqls = self
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
if sqls.is_empty() {
return;
}
let batch: Vec<(&str, u64)> = sqls
.iter()
.map(|sql| (sql.as_ref(), crate::types::hash_sql(sql)))
.collect();
let _ = conn.prepare_batch(&batch);
}
pub fn set_warmup_sqls(&self, sqls: &[&str]) {
let boxed: Arc<Vec<Box<str>>> =
Arc::new(sqls.iter().map(|s| (*s).into()).collect::<Vec<_>>());
*self
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner()) = boxed;
}
pub fn close(&self) {
self.inner.closed.store(true, Ordering::Release);
let slots: Vec<PoolSlot> = {
let mut stack = self.inner.stack.lock().unwrap_or_else(|e| e.into_inner());
std::mem::take(&mut *stack)
};
for slot in slots {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
match slot {
PoolSlot::Sync(conn) => {
let _ = conn.close();
}
#[cfg(feature = "async")]
PoolSlot::Async(_conn) => {
}
}
}
let (_, cvar) = &self.inner.release_pair;
cvar.notify_all();
}
pub fn is_closed(&self) -> bool {
self.inner.closed.load(Ordering::Acquire)
}
#[cfg(feature = "async")]
pub async fn acquire_async(&self) -> Result<PoolGuard, DriverError> {
if self.inner.closed.load(Ordering::Acquire) {
return Err(DriverError::Pool("pool is closed".into()));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
loop {
let current = self.inner.open_count.load(Ordering::Acquire);
if current >= self.inner.max_size {
if let Some(timeout) = self.inner.acquire_timeout {
let (lock, cvar) = &self.inner.release_pair;
let guard = lock.lock().unwrap_or_else(|e| e.into_inner());
let (_guard, result) = cvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|e| e.into_inner());
if result.timed_out() {
return Err(DriverError::Pool(
"pool exhausted: acquire timeout expired".into(),
));
}
if let Some(guard) = self.try_pop_idle()? {
return Ok(guard);
}
continue;
}
return Err(DriverError::Pool(
"pool exhausted: all connections in use".into(),
));
}
if self
.inner
.open_count
.compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
break;
}
}
if self.inner.config.host_is_uds() {
let conn_result = Connection::connect_arc(self.inner.config.clone());
match conn_result {
Ok(mut conn) => {
conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
self.warmup_conn(&mut conn);
Ok(PoolGuard {
conn: Some(PoolSlot::Sync(conn)),
pool: self.inner.clone(),
discard: false,
#[cfg(feature = "detect-n-plus-one")]
detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
})
}
Err(e) => {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
Err(e)
}
}
} else {
let conn_result = AsyncConnection::connect_arc(self.inner.config.clone()).await;
match conn_result {
Ok(mut conn) => {
conn.set_max_stmt_cache_size(self.inner.max_stmt_cache_size);
Ok(PoolGuard {
conn: Some(PoolSlot::Async(conn)),
pool: self.inner.clone(),
discard: false,
#[cfg(feature = "detect-n-plus-one")]
detector: NPlusOneDetector::new(self.inner.n_plus_one_threshold),
})
}
Err(e) => {
self.inner.open_count.fetch_sub(1, Ordering::AcqRel);
Err(e)
}
}
}
}
}
impl Clone for Pool {
fn clone(&self) -> Self {
Pool {
inner: self.inner.clone(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStatus {
pub idle: usize,
pub active: usize,
pub open: usize,
pub max_size: usize,
}
pub struct PoolBuilder {
url: Option<String>,
max_size: usize,
max_lifetime: Option<Duration>,
acquire_timeout: Option<Duration>,
min_idle: usize,
max_stmt_cache_size: usize,
stale_timeout: Duration,
#[cfg(feature = "detect-n-plus-one")]
n_plus_one_threshold: Option<u16>,
}
impl PoolBuilder {
fn new() -> Self {
Self {
url: None,
max_size: 10,
max_lifetime: Some(Duration::from_secs(30 * 60)), acquire_timeout: Some(Duration::from_secs(5)), min_idle: 0, max_stmt_cache_size: 256, stale_timeout: Duration::from_secs(30), #[cfg(feature = "detect-n-plus-one")]
n_plus_one_threshold: None,
}
}
pub fn url(mut self, url: &str) -> Self {
self.url = Some(url.to_owned());
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn max_lifetime(mut self, lifetime: Option<Duration>) -> Self {
self.max_lifetime = lifetime;
self
}
pub fn acquire_timeout(mut self, timeout: Option<Duration>) -> Self {
self.acquire_timeout = timeout;
self
}
pub fn min_idle(mut self, count: usize) -> Self {
self.min_idle = count;
self
}
pub fn max_stmt_cache_size(mut self, size: usize) -> Self {
self.max_stmt_cache_size = size;
self
}
pub fn stale_timeout(mut self, timeout: Duration) -> Self {
self.stale_timeout = timeout;
self
}
#[cfg(feature = "detect-n-plus-one")]
pub fn n_plus_one_threshold(mut self, n: u16) -> Self {
self.n_plus_one_threshold = Some(n);
self
}
pub fn build(self) -> Result<Pool, DriverError> {
let url = self
.url
.ok_or_else(|| DriverError::Pool("pool builder requires a URL".into()))?;
let config = Arc::new(Config::from_url(&url)?);
let pool = Pool {
inner: Arc::new(PoolInner {
stack: std::sync::Mutex::new(Vec::with_capacity(self.max_size)),
max_size: self.max_size,
open_count: AtomicUsize::new(0),
config,
closed: AtomicBool::new(false),
release_pair: (std::sync::Mutex::new(()), std::sync::Condvar::new()),
max_lifetime: self.max_lifetime,
acquire_timeout: self.acquire_timeout,
min_idle: self.min_idle,
warmup_sqls: std::sync::Mutex::new(Arc::new(Vec::new())),
max_stmt_cache_size: self.max_stmt_cache_size,
stale_timeout: self.stale_timeout,
#[cfg(feature = "detect-n-plus-one")]
n_plus_one_threshold: self.n_plus_one_threshold.unwrap_or(10),
}),
};
if self.min_idle > 0 {
let inner = pool.inner.clone();
std::thread::spawn(move || {
maintain_min_idle(inner);
});
}
Ok(pool)
}
}
fn maintain_min_idle(inner: Arc<PoolInner>) {
loop {
if inner.closed.load(Ordering::Acquire) {
return;
}
let idle_count = inner.stack.lock().unwrap_or_else(|e| e.into_inner()).len();
let needed = inner.min_idle.saturating_sub(idle_count);
for _ in 0..needed {
if inner.closed.load(Ordering::Acquire) {
return;
}
let current = inner.open_count.load(Ordering::Acquire);
if current >= inner.max_size {
break;
}
if inner
.open_count
.compare_exchange(current, current + 1, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
continue;
}
match Connection::connect_arc(inner.config.clone()) {
Ok(conn) => {
let mut stack = inner.stack.lock().unwrap_or_else(|e| e.into_inner());
stack.push(PoolSlot::Sync(conn));
let (_, cvar) = &inner.release_pair;
cvar.notify_one();
}
Err(_) => {
inner.open_count.fetch_sub(1, Ordering::AcqRel);
}
}
}
std::thread::sleep(Duration::from_secs(1));
}
}
pub struct PoolGuard {
conn: Option<PoolSlot>,
pool: Arc<PoolInner>,
discard: bool,
#[cfg(feature = "detect-n-plus-one")]
detector: NPlusOneDetector,
}
impl PoolGuard {
#[inline]
fn sync_conn(&self) -> Result<&Connection, DriverError> {
match self.conn.as_ref() {
Some(PoolSlot::Sync(conn)) => Ok(conn),
#[cfg(feature = "async")]
Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
"expected sync connection, got async; use async methods".into(),
)),
None => Err(DriverError::Pool("connection already taken".into())),
}
}
#[inline]
fn sync_conn_mut(&mut self) -> Result<&mut Connection, DriverError> {
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => Ok(conn),
#[cfg(feature = "async")]
Some(PoolSlot::Async(_)) => Err(DriverError::Pool(
"expected sync connection, got async; use async methods".into(),
)),
None => Err(DriverError::Pool("connection already taken".into())),
}
}
pub fn mark_discard(&mut self) {
self.discard = true;
}
pub fn cancel(&self) -> Result<(), DriverError> {
self.sync_conn()?.cancel()
}
pub fn pid(&self) -> i32 {
match self.conn.as_ref().expect("connection returned to pool") {
PoolSlot::Sync(conn) => conn.pid(),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => conn.pid(),
}
}
pub fn is_idle(&self) -> bool {
match self.conn.as_ref().expect("connection returned to pool") {
PoolSlot::Sync(conn) => conn.is_idle(),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => conn.is_idle(),
}
}
pub fn is_in_transaction(&self) -> bool {
match self.conn.as_ref().expect("connection returned to pool") {
PoolSlot::Sync(conn) => conn.is_in_transaction(),
#[cfg(feature = "async")]
PoolSlot::Async(conn) => conn.is_in_transaction(),
}
}
#[inline]
pub fn query(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<QueryResult, DriverError> {
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
self.sync_conn_mut()?.query(sql, sql_hash, params)
}
#[inline]
pub fn execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
self.sync_conn_mut()?.execute(sql, sql_hash, params)
}
pub fn execute_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> Result<Vec<u64>, DriverError> {
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
self.sync_conn_mut()?
.execute_pipeline(sql, sql_hash, param_sets)
}
pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
self.sync_conn_mut()?.simple_query(sql)
}
pub fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
self.sync_conn_mut()?.simple_query_rows(sql)
}
pub fn for_each<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
{
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
self.sync_conn_mut()?.for_each(sql, sql_hash, params, f)
}
pub fn for_each_raw<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
self.sync_conn_mut()?.for_each_raw(sql, sql_hash, params, f)
}
pub fn query_streaming_start(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
chunk_size: i32,
) -> Result<(std::sync::Arc<[crate::types::ColumnDesc]>, bool), DriverError> {
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
self.sync_conn_mut()?
.query_streaming_start(sql, sql_hash, params, chunk_size)
}
pub fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
self.sync_conn_mut()?.streaming_send_execute(chunk_size)
}
pub fn streaming_next_chunk(
&mut self,
arena: &mut Arena,
all_col_offsets: &mut Vec<(usize, i32)>,
) -> Result<bool, DriverError> {
self.sync_conn_mut()?
.streaming_next_chunk(arena, all_col_offsets)
}
pub fn copy_in<'a, I>(
&mut self,
table: &str,
columns: &[&str],
rows: I,
) -> Result<u64, DriverError>
where
I: IntoIterator<Item = &'a str>,
{
self.sync_conn_mut()?.copy_in(table, columns, rows)
}
pub fn copy_out<W: std::io::Write>(
&mut self,
query: &str,
writer: &mut W,
) -> Result<u64, DriverError> {
self.sync_conn_mut()?.copy_out(query, writer)
}
pub fn is_sync(&self) -> bool {
matches!(self.conn.as_ref(), Some(PoolSlot::Sync(_)))
}
#[cfg(feature = "async")]
pub fn is_async(&self) -> bool {
matches!(self.conn.as_ref(), Some(PoolSlot::Async(_)))
}
#[cfg(feature = "async")]
pub async fn query_async(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<QueryResult, DriverError> {
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => conn.query(sql, sql_hash, params),
Some(PoolSlot::Async(conn)) => conn.query(sql, sql_hash, params).await,
None => Err(DriverError::Pool("connection already taken".into())),
}
}
#[cfg(feature = "async")]
pub async fn execute_async(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
#[cfg(feature = "detect-n-plus-one")]
self.detector.track(sql_hash);
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => conn.execute(sql, sql_hash, params),
Some(PoolSlot::Async(conn)) => conn.execute(sql, sql_hash, params).await,
None => Err(DriverError::Pool("connection already taken".into())),
}
}
#[cfg(feature = "async")]
pub async fn simple_query_async(&mut self, sql: &str) -> Result<(), DriverError> {
match self.conn.as_mut() {
Some(PoolSlot::Sync(conn)) => conn.simple_query(sql),
Some(PoolSlot::Async(conn)) => conn.simple_query(sql).await,
None => Err(DriverError::Pool("connection already taken".into())),
}
}
pub(crate) fn ensure_stmt_prepared(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<[u8; 18], DriverError> {
self.sync_conn_mut()?
.ensure_stmt_prepared(sql, sql_hash, params)
}
pub(crate) fn write_deferred_bind_execute(
&self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
buf: &mut Vec<u8>,
) -> Result<(), DriverError> {
let conn = self.sync_conn()?;
conn.write_deferred_bind_execute(sql, sql_hash, params, buf)
}
pub(crate) fn flush_deferred_pipeline(
&mut self,
buf: &mut Vec<u8>,
count: usize,
) -> Result<Vec<u64>, DriverError> {
self.sync_conn_mut()?.flush_deferred_pipeline(buf, count)
}
}
impl Drop for PoolGuard {
fn drop(&mut self) {
#[cfg(feature = "detect-n-plus-one")]
self.detector.emit_final_warning();
if let Some(slot) = self.conn.take() {
let should_discard = self.discard
|| self.pool.closed.load(Ordering::Acquire)
|| match &slot {
PoolSlot::Sync(conn) => {
conn.is_in_failed_transaction()
|| conn.is_in_transaction()
|| conn.is_streaming()
}
#[cfg(feature = "async")]
PoolSlot::Async(conn) => {
conn.is_in_failed_transaction() || conn.is_in_transaction()
}
};
if should_discard {
self.pool.open_count.fetch_sub(1, Ordering::AcqRel);
return;
}
let mut slot = slot;
match &mut slot {
PoolSlot::Sync(conn) => {
if conn.query_counter() & 63 == 0 {
conn.touch();
}
}
#[cfg(feature = "async")]
PoolSlot::Async(conn) => {
if conn.query_counter() & 63 == 0 {
conn.touch();
}
}
}
{
let mut stack = self.pool.stack.lock().unwrap_or_else(|e| e.into_inner());
stack.push(slot);
}
if self.pool.open_count.load(Ordering::Relaxed) >= self.pool.max_size {
let (_, cvar) = &self.pool.release_pair;
cvar.notify_one();
}
}
}
}
pub struct Transaction {
guard: PoolGuard,
committed: bool,
deferred_buf: Vec<u8>,
deferred_count: usize,
}
impl Transaction {
pub fn commit(mut self) -> Result<(), DriverError> {
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.simple_query("COMMIT")?;
self.committed = true;
Ok(())
}
pub fn rollback(mut self) -> Result<(), DriverError> {
self.deferred_buf.clear();
self.deferred_count = 0;
self.guard.simple_query("ROLLBACK")?;
self.committed = true; Ok(())
}
pub fn query(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<QueryResult, DriverError> {
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.query(sql, sql_hash, params)
}
pub fn execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<u64, DriverError> {
self.guard.execute(sql, sql_hash, params)
}
pub fn execute_pipeline(
&mut self,
sql: &str,
sql_hash: u64,
param_sets: &[&[&(dyn Encode + Sync)]],
) -> Result<Vec<u64>, DriverError> {
self.guard.execute_pipeline(sql, sql_hash, param_sets)
}
pub fn for_each<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(crate::types::PgDataRow<'_>) -> Result<(), DriverError>,
{
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.for_each(sql, sql_hash, params, f)
}
pub fn for_each_raw<F>(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
f: F,
) -> Result<(), DriverError>
where
F: FnMut(&[u8]) -> Result<(), DriverError>,
{
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.for_each_raw(sql, sql_hash, params, f)
}
pub fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
if self.deferred_count > 0 {
self.flush_deferred()?;
}
self.guard.simple_query(sql)
}
pub fn defer_execute(
&mut self,
sql: &str,
sql_hash: u64,
params: &[&(dyn Encode + Sync)],
) -> Result<(), DriverError> {
if params.len() > i16::MAX as usize {
return Err(DriverError::Protocol(format!(
"parameter count {} exceeds maximum {}",
params.len(),
i16::MAX
)));
}
self.guard.ensure_stmt_prepared(sql, sql_hash, params)?;
self.guard
.write_deferred_bind_execute(sql, sql_hash, params, &mut self.deferred_buf)?;
self.deferred_count += 1;
Ok(())
}
pub fn flush_deferred(&mut self) -> Result<Vec<u64>, DriverError> {
let count = self.deferred_count;
self.deferred_count = 0;
self.guard
.flush_deferred_pipeline(&mut self.deferred_buf, count)
}
pub fn deferred_count(&self) -> usize {
self.deferred_count
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.committed {
if let Some(_slot) = self.guard.conn.take() {
self.guard.pool.open_count.fetch_sub(1, Ordering::AcqRel);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_builder_requires_url() {
let result = PoolBuilder::new().build();
assert!(result.is_err());
}
#[test]
fn pool_builder_validates_url() {
let result = PoolBuilder::new().url("not_a_url").build();
assert!(result.is_err());
}
#[test]
fn pool_builder_accepts_valid_url() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
assert_eq!(pool.max_size(), 5);
assert_eq!(pool.open_count(), 0);
}
#[test]
fn pool_connect_validates_url() {
let result = Pool::connect("not_a_url");
assert!(result.is_err());
}
#[test]
fn pool_max_size_zero() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(0)
.build()
.unwrap();
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn pool_clone_shares_state() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
let pool2 = pool.clone();
assert_eq!(pool.max_size(), pool2.max_size());
}
#[test]
fn pool_builder_max_lifetime() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_lifetime(Some(Duration::from_secs(60)))
.build()
.unwrap();
assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(60)));
}
#[test]
fn pool_builder_max_lifetime_none() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_lifetime(None)
.build()
.unwrap();
assert_eq!(pool.inner.max_lifetime, None);
}
#[test]
fn pool_builder_acquire_timeout_none() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.acquire_timeout(None)
.build()
.unwrap();
assert_eq!(pool.inner.acquire_timeout, None);
}
#[test]
fn pool_builder_acquire_timeout_custom() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.acquire_timeout(Some(Duration::from_secs(10)))
.build()
.unwrap();
assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(10)));
}
#[test]
fn pool_builder_min_idle() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.min_idle(2)
.build()
.unwrap();
assert_eq!(pool.inner.min_idle, 2);
}
#[test]
fn pool_close_marks_closed() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
assert!(!pool.is_closed());
pool.close();
assert!(pool.is_closed());
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => assert!(msg.contains("closed")),
Err(e) => panic!("expected Pool(closed) error, got: {e:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn pool_status_initial() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(10)
.build()
.unwrap();
let status = pool.status();
assert_eq!(status.idle, 0);
assert_eq!(status.active, 0);
assert_eq!(status.open, 0);
assert_eq!(status.max_size, 10);
}
#[test]
fn pool_builder_defaults() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.build()
.unwrap();
assert_eq!(pool.max_size(), 10);
assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(30 * 60)));
assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
assert_eq!(pool.inner.min_idle, 0);
}
#[test]
fn pool_open_count_initial() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
assert_eq!(pool.open_count(), 0);
}
#[test]
fn pool_builder_max_stmt_cache_size_default() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.build()
.unwrap();
assert_eq!(pool.inner.max_stmt_cache_size, 256);
}
#[test]
fn pool_builder_max_stmt_cache_size_custom() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_stmt_cache_size(512)
.build()
.unwrap();
assert_eq!(pool.inner.max_stmt_cache_size, 512);
}
#[test]
fn pool_is_uds_false_for_tcp() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
assert!(!pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_is_uds_true_for_unix_socket() {
let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_is_uds_true_for_var_run_socket() {
let pool = Pool::connect("postgres://user@localhost/db?host=/var/run/postgresql").unwrap();
assert!(pool.is_uds());
}
#[test]
fn pool_is_uds_false_for_ip_address() {
let pool = Pool::connect("postgres://user:pass@127.0.0.1/db").unwrap();
assert!(!pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_slot_sync_created_for_uds_config() {
let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(config.host_is_uds());
}
#[test]
fn pool_slot_tcp_config() {
let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert!(!config.host_is_uds());
}
#[test]
fn pool_is_uds_false_for_hostname() {
let pool = Pool::connect("postgres://user:pass@db.example.com/db").unwrap();
assert!(!pool.is_uds());
}
#[cfg(unix)]
#[test]
fn pool_is_uds_true_for_tmp() {
let pool = Pool::connect("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(pool.is_uds());
}
#[test]
fn pool_close_then_acquire_fails() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(5)
.build()
.unwrap();
pool.close();
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => {
assert!(msg.contains("closed"), "should say closed: {msg}")
}
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn pool_is_closed_before_and_after() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
assert!(!pool.is_closed());
pool.close();
assert!(pool.is_closed());
}
#[test]
fn pool_exhausted_no_timeout() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(0)
.acquire_timeout(None) .build()
.unwrap();
let result = pool.acquire();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => {
assert!(msg.contains("exhausted"), "should say exhausted: {msg}")
}
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn pool_builder_no_url_error() {
let result = PoolBuilder::new().max_size(5).build();
assert!(result.is_err());
match result {
Err(DriverError::Pool(msg)) => {
assert!(msg.contains("URL"), "should mention URL: {msg}")
}
Err(e) => panic!("expected Pool error, got: {e:?}"),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn pool_builder_invalid_url_error() {
let result = PoolBuilder::new().url("ftp://something").build();
assert!(result.is_err());
}
#[test]
fn pool_builder_stmt_cache_size_zero() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_stmt_cache_size(0)
.build()
.unwrap();
assert_eq!(pool.inner.max_stmt_cache_size, 0);
}
#[test]
fn pool_builder_stale_timeout_default() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.build()
.unwrap();
assert_eq!(pool.inner.stale_timeout, Duration::from_secs(30));
}
#[test]
fn pool_builder_stale_timeout_custom() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.stale_timeout(Duration::from_secs(60))
.build()
.unwrap();
assert_eq!(pool.inner.stale_timeout, Duration::from_secs(60));
}
#[test]
fn pool_builder_stale_timeout_zero() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.stale_timeout(Duration::from_secs(0))
.build()
.unwrap();
assert_eq!(pool.inner.stale_timeout, Duration::from_secs(0));
}
#[test]
fn pool_status_reflects_max_size() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(20)
.build()
.unwrap();
let status = pool.status();
assert_eq!(status.max_size, 20);
assert_eq!(status.idle, 0);
assert_eq!(status.active, 0);
assert_eq!(status.open, 0);
}
#[test]
fn pool_clone_shares_config() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(7)
.build()
.unwrap();
let p2 = pool.clone();
assert_eq!(pool.max_size(), 7);
assert_eq!(p2.max_size(), 7);
assert_eq!(pool.open_count(), p2.open_count());
}
#[test]
fn pool_set_warmup_sqls_empty() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
pool.set_warmup_sqls(&[]);
let sqls = pool
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
assert!(sqls.is_empty());
}
#[test]
fn pool_set_warmup_sqls_multiple() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
pool.set_warmup_sqls(&["SELECT 1", "SELECT 2", "SELECT 3"]);
let sqls = pool
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
assert_eq!(sqls.len(), 3);
assert_eq!(&*sqls[0], "SELECT 1");
assert_eq!(&*sqls[1], "SELECT 2");
assert_eq!(&*sqls[2], "SELECT 3");
}
#[test]
fn pool_set_warmup_sqls_overwrite() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
pool.set_warmup_sqls(&["SELECT 1"]);
pool.set_warmup_sqls(&["SELECT 99"]);
let sqls = pool
.inner
.warmup_sqls
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
assert_eq!(sqls.len(), 1);
assert_eq!(&*sqls[0], "SELECT 99");
}
#[test]
fn pool_status_debug() {
let pool = Pool::connect("postgres://user:pass@localhost/db").unwrap();
let status = pool.status();
let dbg = format!("{status:?}");
assert!(dbg.contains("PoolStatus"));
assert!(dbg.contains("idle"));
assert!(dbg.contains("active"));
assert!(dbg.contains("open"));
assert!(dbg.contains("max_size"));
}
#[test]
fn config_host_is_uds_returns_true_for_slash() {
let config = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
assert!(config.host_is_uds());
}
#[test]
fn config_host_is_uds_returns_false_for_tcp() {
let config = Config::from_url("postgres://user:pass@localhost/db").unwrap();
assert!(!config.host_is_uds());
}
#[test]
fn config_host_is_uds_returns_false_for_ip() {
let config = Config::from_url("postgres://user:pass@192.168.1.1/db").unwrap();
assert!(!config.host_is_uds());
}
#[test]
fn pool_builder_full_chain() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(3)
.max_lifetime(Some(Duration::from_secs(600)))
.acquire_timeout(Some(Duration::from_secs(5)))
.min_idle(1)
.max_stmt_cache_size(128)
.build()
.unwrap();
assert_eq!(pool.max_size(), 3);
assert_eq!(pool.inner.max_lifetime, Some(Duration::from_secs(600)));
assert_eq!(pool.inner.acquire_timeout, Some(Duration::from_secs(5)));
assert_eq!(pool.inner.min_idle, 1);
assert_eq!(pool.inner.max_stmt_cache_size, 128);
}
#[test]
fn pool_max_size_zero_rejects_all_acquires() {
let pool = PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.max_size(0)
.build()
.unwrap();
let result = pool.acquire();
assert!(result.is_err());
match &result {
Err(DriverError::Pool(msg)) => assert!(msg.contains("exhausted")),
_ => panic!("expected pool exhausted error"),
}
}
#[test]
fn url_parse_unknown_sslmode_returns_error() {
let result = Config::from_url("postgres://u:p@h/d?sslmode=bogus");
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("unknown sslmode"));
}
#[test]
fn url_parse_invalid_port_returns_error() {
let result = Config::from_url("postgres://u:p@h:abc/d");
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("invalid port"));
}
#[test]
fn url_parse_missing_at_sign_returns_error() {
let result = Config::from_url("postgres://u:plocalhost/d");
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("missing @"));
}
#[test]
fn url_parse_empty_host_returns_error() {
let result = Config::from_url("postgres://u:p@/d");
assert!(result.is_err());
}
#[test]
fn url_parse_empty_user_returns_error() {
let result = Config::from_url("postgres://:p@h/d");
assert!(result.is_err());
}
#[test]
fn url_parse_statement_timeout_invalid_uses_default() {
let config = Config::from_url("postgres://u:p@h/d?statement_timeout=notnum").unwrap();
assert_eq!(config.statement_timeout_secs, 30);
}
#[test]
fn url_parse_malformed_percent_encoding() {
let result = Config::from_url("postgres://u%:p@h/d");
assert!(result.is_err());
}
#[test]
fn url_parse_invalid_hex_in_percent_encoding() {
let result = Config::from_url("postgres://u%ZZ:p@h/d");
assert!(result.is_err());
}
}
#[cfg(all(test, feature = "detect-n-plus-one"))]
mod n_plus_one_tests {
use super::NPlusOneDetector;
#[test]
fn below_threshold_no_warning() {
let mut d = NPlusOneDetector::new(10);
for _ in 0..10 {
d.track(42);
}
assert!(d.check_final().is_none());
}
#[test]
fn above_threshold_warns() {
let mut d = NPlusOneDetector::new(10);
for _ in 0..11 {
d.track(42);
}
let w = d.check_final().unwrap();
assert_eq!(w, (42, 11));
}
#[test]
fn exact_threshold_no_warning() {
let mut d = NPlusOneDetector::new(5);
for _ in 0..5 {
d.track(99);
}
assert!(d.check_final().is_none(), "> not >=");
}
#[test]
fn threshold_plus_one_warns() {
let mut d = NPlusOneDetector::new(5);
for _ in 0..6 {
d.track(99);
}
assert_eq!(d.check_final(), Some((99, 6)));
}
#[test]
fn alternating_hashes_no_warning() {
let mut d = NPlusOneDetector::new(2);
for i in 0..100 {
d.track(if i % 2 == 0 { 1 } else { 2 });
}
assert!(d.check_final().is_none());
}
#[test]
fn single_query_no_warning() {
let mut d = NPlusOneDetector::new(10);
d.track(42);
assert!(d.check_final().is_none());
}
#[test]
fn no_queries_no_warning() {
let d = NPlusOneDetector::new(10);
assert!(d.check_final().is_none());
}
#[test]
fn threshold_zero_warns_on_second() {
let mut d = NPlusOneDetector::new(0);
d.track(42);
assert_eq!(d.check_final(), Some((42, 1)));
}
#[test]
fn threshold_max_never_warns() {
let mut d = NPlusOneDetector::new(u16::MAX);
for _ in 0..1000 {
d.track(42);
}
assert!(d.check_final().is_none());
}
#[test]
fn saturating_add_no_overflow() {
let mut d = NPlusOneDetector::new(10);
d.last_query_hash = 42;
d.repeat_count = u16::MAX - 1;
d.track(42); d.track(42); assert_eq!(d.repeat_count, u16::MAX);
}
#[test]
fn different_hash_resets() {
let mut d = NPlusOneDetector::new(100);
for _ in 0..50 {
d.track(1);
}
d.track(2); assert_eq!(d.repeat_count, 1);
assert_eq!(d.last_query_hash, 2);
}
#[test]
fn multiple_n_plus_one_sequences() {
let mut d = NPlusOneDetector::new(3);
for _ in 0..5 {
d.track(1);
}
for _ in 0..4 {
d.track(2);
}
assert_eq!(d.check_final(), Some((2, 4)));
}
#[test]
fn warning_emitted_on_hash_switch() {
let mut d = NPlusOneDetector::new(2);
d.track(10);
d.track(10);
d.track(10); d.track(20);
assert_eq!(d.last_query_hash, 20);
assert_eq!(d.repeat_count, 1);
}
#[test]
fn hash_zero_treated_normally() {
let mut d = NPlusOneDetector::new(2);
d.track(0);
d.track(0);
d.track(0);
assert!(d.check_final().is_none());
}
#[test]
fn long_sequence_correct_count() {
let mut d = NPlusOneDetector::new(10);
for _ in 0..500 {
d.track(42);
}
assert_eq!(d.check_final(), Some((42, 500)));
}
#[test]
fn two_queries_below_threshold() {
let mut d = NPlusOneDetector::new(10);
d.track(1);
d.track(1);
assert!(d.check_final().is_none());
}
#[test]
fn interleaved_then_burst() {
let mut d = NPlusOneDetector::new(3);
d.track(1);
d.track(2);
d.track(1);
d.track(2);
for _ in 0..5 {
d.track(5);
}
assert_eq!(d.check_final(), Some((5, 5)));
}
#[test]
fn pool_builder_n_plus_one_threshold_default() {
let pool = super::PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.build()
.unwrap();
assert_eq!(pool.inner.n_plus_one_threshold, 10);
}
#[test]
fn pool_builder_n_plus_one_threshold_custom() {
let pool = super::PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.n_plus_one_threshold(5)
.build()
.unwrap();
assert_eq!(pool.inner.n_plus_one_threshold, 5);
}
#[test]
fn pool_builder_n_plus_one_threshold_zero() {
let pool = super::PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.n_plus_one_threshold(0)
.build()
.unwrap();
assert_eq!(pool.inner.n_plus_one_threshold, 0);
}
#[test]
fn pool_builder_n_plus_one_threshold_max() {
let pool = super::PoolBuilder::new()
.url("postgres://user:pass@localhost/db")
.n_plus_one_threshold(u16::MAX)
.build()
.unwrap();
assert_eq!(pool.inner.n_plus_one_threshold, u16::MAX);
}
#[test]
fn one_then_different_no_warning() {
let mut d = NPlusOneDetector::new(10);
d.track(1);
d.track(2);
assert!(d.check_final().is_none());
}
#[test]
fn nonzero_hash_after_zero_init() {
let mut d = NPlusOneDetector::new(0);
d.track(42);
let w = d.check_final().unwrap();
assert_eq!(w, (42, 1));
}
#[test]
fn independent_detectors_dont_interfere() {
let mut d1 = NPlusOneDetector::new(5);
let mut d2 = NPlusOneDetector::new(5);
for _ in 0..10 {
d1.track(42);
}
d2.track(1);
d2.track(2);
d2.track(3);
assert!(d1.check_final().is_some());
assert!(d2.check_final().is_none());
}
#[test]
fn rapid_hash_changes_dont_false_positive() {
let mut d = NPlusOneDetector::new(2);
for i in 0u64..1000 {
d.track(i);
}
assert!(d.check_final().is_none());
}
#[test]
fn detector_reset_state_after_warning() {
let mut d = NPlusOneDetector::new(2);
d.track(1);
d.track(1);
d.track(1); d.track(2); d.track(2); assert!(d.check_final().is_none()); }
#[test]
fn detector_with_realistic_orm_pattern() {
let mut d = NPlusOneDetector::new(5);
d.track(100); for _ in 0..20 {
d.track(200); }
assert_eq!(d.check_final(), Some((200, 20)));
}
#[test]
fn detector_with_legitimate_batch_pattern() {
let mut d = NPlusOneDetector::new(10);
for _ in 0..15 {
d.track(300); }
assert!(d.check_final().is_some());
}
#[test]
fn detector_exactly_at_boundaries() {
for threshold in [0u16, 1, 2, 5, 10, 100] {
let mut d = NPlusOneDetector::new(threshold);
for _ in 0..=threshold {
d.track(42);
}
assert!(
d.check_final().is_some(),
"threshold={threshold} should warn at count={}",
threshold + 1
);
}
}
#[test]
fn detector_with_deterministic_random_sequences() {
let mut d = NPlusOneDetector::new(5);
let hashes: Vec<u64> = (0..100).map(|i| ((i * 7 + 3) % 4) as u64).collect();
for &h in &hashes {
d.track(h);
}
let _ = d.check_final();
}
mod proptest_fuzz {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn detector_never_panics(
hashes in proptest::collection::vec(0u64..100, 0..500),
threshold in 0u16..100,
) {
let mut d = NPlusOneDetector::new(threshold);
for h in &hashes {
d.track(*h);
}
let _ = d.check_final();
}
#[test]
fn sequential_repeats_always_detected(
hash in 1u64..u64::MAX,
count in 2u16..1000,
threshold in 0u16..100,
) {
let mut d = NPlusOneDetector::new(threshold);
for _ in 0..count {
d.track(hash);
}
if count > threshold {
assert!(d.check_final().is_some(),
"count={count} > threshold={threshold} should trigger");
}
}
}
}
}