use sqlx::postgres::PgPoolOptions;
use sqlx_pool_router::{PoolProvider, TestDbPools};
struct UserRepository<P: PoolProvider> {
pools: P,
}
impl<P: PoolProvider> UserRepository<P> {
async fn create_user(&self, name: &str) -> Result<i64, sqlx::Error> {
sqlx::query_scalar("INSERT INTO users (name) VALUES ($1) RETURNING id")
.bind(name)
.fetch_one(self.pools.write())
.await
}
async fn get_user(&self, id: i64) -> Result<String, sqlx::Error> {
sqlx::query_scalar("SELECT name FROM users WHERE id = $1")
.bind(id)
.fetch_one(self.pools.read())
.await
}
async fn count_users(&self) -> Result<i64, sqlx::Error> {
sqlx::query_scalar("SELECT COUNT(*) FROM users")
.fetch_one(self.pools.read())
.await
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let database_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgresql://postgres:password@localhost/test".to_string());
println!("๐งช TestDbPools Example");
println!("====================");
println!();
println!("Connecting to: {}", database_url);
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await?;
println!("Creating TestDbPools...");
let pools = TestDbPools::new(pool).await?;
println!("โ TestDbPools created (read pool is read-only)");
println!();
println!("๐ Setting up test table...");
sqlx::query("DROP TABLE IF EXISTS users")
.execute(pools.write())
.await?;
sqlx::query("CREATE TABLE users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)")
.execute(pools.write())
.await?;
println!("โ Table created");
println!();
let repo = UserRepository {
pools: pools.clone(),
};
println!("Test 1: Writing through .write() pool");
let user_id = repo.create_user("Alice").await?;
println!(" โ Created user with ID: {}", user_id);
println!();
println!("Test 2: Reading through .read() pool");
let name = repo.get_user(user_id).await?;
println!(" โ Read user name: {}", name);
assert_eq!(name, "Alice");
println!();
println!("Test 3: Aggregate queries through .read() pool");
let count = repo.count_users().await?;
println!(" โ User count: {}", count);
assert_eq!(count, 1);
println!();
println!("Test 4: Writing through .read() pool (should fail)");
let result = sqlx::query("INSERT INTO users (name) VALUES ($1)")
.bind("Bob")
.execute(pools.read())
.await;
match result {
Ok(_) => {
println!(" โ UNEXPECTED: Write succeeded on read pool!");
println!(" This should have failed - the read pool should be read-only");
}
Err(e) => {
println!(" โ Write correctly rejected on read pool");
println!(" Error: {}", e);
assert!(
e.to_string().contains("read-only") || e.to_string().contains("cannot execute")
);
}
}
println!();
println!("Test 5: DDL through .read() pool (should fail)");
let result = sqlx::query("CREATE TEMP TABLE temp_test (id INT)")
.execute(pools.read())
.await;
match result {
Ok(_) => {
println!(" โ UNEXPECTED: DDL succeeded on read pool!");
}
Err(e) => {
println!(" โ DDL correctly rejected on read pool");
println!(" Error: {}", e);
}
}
println!();
println!("๐งน Cleaning up...");
sqlx::query("DROP TABLE users")
.execute(pools.write())
.await?;
println!(" โ Table dropped");
println!();
println!("โ
All tests passed!");
println!();
println!("๐ก Key Takeaways:");
println!(" - TestDbPools enforces read/write separation in tests");
println!(" - Write operations through .read() fail immediately");
println!(" - Catches routing bugs before they reach production");
println!(" - Works seamlessly with #[sqlx::test] macro");
println!();
println!("๐ Use in your tests:");
println!(" #[sqlx::test]");
println!(" async fn test_my_feature(pool: PgPool) {{");
println!(" let pools = TestDbPools::new(pool).await.unwrap();");
println!(" let repo = MyRepository {{ pools }};");
println!(" // Test will fail if repo routes incorrectly!");
println!(" }}");
Ok(())
}