#[allow(unused_imports)]
use sqlx_oldapi::any::{AnyConnectOptions, AnyPoolOptions};
#[allow(unused_imports)]
use sqlx_oldapi::Executor;
#[allow(unused_imports)]
use std::sync::atomic::AtomicI32;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::Duration;
#[sqlx_macros::test]
async fn pool_should_invoke_after_connect() -> anyhow::Result<()> {
let counter = Arc::new(AtomicUsize::new(0));
let pool = AnyPoolOptions::new()
.after_connect({
let counter = counter.clone();
move |_conn, _meta| {
let counter = counter.clone();
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
})
}
})
.connect(&dotenvy::var("DATABASE_URL")?)
.await?;
let _ = pool.acquire().await?;
let _ = pool.acquire().await?;
let _ = pool.acquire().await?;
let _ = pool.acquire().await?;
assert!(counter.load(Ordering::SeqCst) >= 1);
Ok(())
}
#[sqlx_macros::test]
async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {
let pool = AnyPoolOptions::new()
.max_connections(2)
.acquire_timeout(Duration::from_secs(3))
.connect(&dotenvy::var("DATABASE_URL")?)
.await?;
let query = "blah blah";
let mut tx = pool.begin().await?;
let res = sqlx_oldapi::query(query).execute(&mut tx).await;
assert!(res.is_err());
drop(tx);
let mut tx = pool.begin().await?;
let res = sqlx_oldapi::query(query).execute(&mut tx).await;
assert!(res.is_err());
drop(tx);
let mut tx = pool.begin().await?;
let res = sqlx_oldapi::query(query).execute(&mut tx).await;
assert!(res.is_err());
drop(tx);
Ok(())
}
#[cfg(feature = "runtime-tokio-rustls")]
#[sqlx_macros::test]
async fn big_pool() -> anyhow::Result<()> {
use sqlx_oldapi::Row;
let database_url = dotenvy::var("DATABASE_URL")?;
let pool = Arc::new(
AnyPoolOptions::new()
.max_connections(2)
.acquire_timeout(Duration::from_secs(3))
.connect(&database_url)
.await?,
);
let mut handles = Vec::new();
for _ in 0..1000 {
let p = pool.clone();
handles.push(tokio::spawn(async move {
let row = sqlx_oldapi::query("SELECT 1").fetch_one(&*p).await?;
let val: i32 = row.get(0);
assert_eq!(val, 1);
Ok::<_, sqlx_oldapi::Error>(())
}));
}
for h in handles {
h.await??;
}
Ok(())
}
#[sqlx_macros::test]
#[cfg(feature = "macros")]
async fn test_pool_callbacks() -> anyhow::Result<()> {
#[derive(sqlx_oldapi::FromRow, Debug, PartialEq, Eq)]
struct ConnStats {
id: i32,
before_acquire_calls: i32,
after_release_calls: i32,
}
sqlx_test::setup_if_needed();
let conn_options: AnyConnectOptions = std::env::var("DATABASE_URL")?.parse()?;
#[cfg(feature = "mssql")]
if conn_options.kind() == sqlx_oldapi::any::AnyKind::Mssql {
return Ok(());
}
let current_id = AtomicI32::new(0);
let pool = AnyPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(5))
.after_connect(move |conn, meta| {
assert_eq!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);
let id = current_id.fetch_add(1, Ordering::AcqRel);
Box::pin(async move {
let create_statement = r#"
CREATE TEMPORARY TABLE conn_stats(
id int primary key,
before_acquire_calls int default 0,
after_release_calls int default 0
)
"#;
let insert_statement = format!("INSERT INTO conn_stats(id) VALUES ({})", id);
conn.execute(create_statement).await?;
conn.execute(&insert_statement[..]).await?;
Ok(())
})
})
.before_acquire(|conn, meta| {
assert_ne!(meta.age, Duration::ZERO);
assert_ne!(meta.idle_for, Duration::ZERO);
Box::pin(async move {
sqlx_oldapi::query(
r#"
UPDATE conn_stats
SET before_acquire_calls = before_acquire_calls + 1
"#,
)
.execute(&mut *conn)
.await?;
let stats: ConnStats = sqlx_oldapi::query_as("SELECT * FROM conn_stats")
.fetch_one(conn)
.await?;
Ok((stats.id & 1) == 1 || stats.before_acquire_calls < 3)
})
})
.after_release(|conn, meta| {
assert_ne!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);
Box::pin(async move {
sqlx_oldapi::query(
r#"
UPDATE conn_stats
SET after_release_calls = after_release_calls + 1
"#,
)
.execute(&mut *conn)
.await?;
let stats: ConnStats = sqlx_oldapi::query_as("SELECT * FROM conn_stats")
.fetch_one(conn)
.await?;
Ok((stats.id & 1) == 0 || stats.after_release_calls < 4)
})
})
.connect_lazy_with(conn_options);
let pattern = [
(0, 0, 0),
(0, 1, 1),
(0, 2, 2),
(1, 0, 0),
(1, 1, 1),
(1, 2, 2),
(1, 3, 3),
(2, 0, 0),
(2, 1, 1),
(2, 2, 2),
(3, 0, 0),
];
for (id, before_acquire_calls, after_release_calls) in pattern {
let conn_stats: ConnStats = sqlx_oldapi::query_as("SELECT * FROM conn_stats")
.fetch_one(&pool)
.await?;
assert_eq!(
conn_stats,
ConnStats {
id,
before_acquire_calls,
after_release_calls
}
);
}
pool.close().await;
Ok(())
}