use crate::common::{build_test_db_url_with_db, TestDatabaseManager};
use anyhow::Result;
use codex_memory::database::{create_pool, run_migrations};
use codex_memory::Storage;
use sqlx::Row;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn test_invalid_connection_string() -> Result<()> {
let invalid_urls = vec![
"invalid://url",
"postgresql://nonexistent:pass@localhost:5432/db",
"postgresql://user:pass@invalid_host:5432/db",
"postgresql://user:pass@localhost:99999/db", "", "not a url at all",
];
for invalid_url in invalid_urls {
let result = create_pool(invalid_url).await;
assert!(
result.is_err(),
"Should fail for invalid URL: {}",
invalid_url
);
}
Ok(())
}
#[tokio::test]
async fn test_connection_timeout() -> Result<()> {
let timeout_url = "postgresql://user:pass@192.0.2.1:5432/db";
let result = timeout(Duration::from_secs(5), create_pool(timeout_url)).await;
if let Ok(pool_result) = result {
assert!(
pool_result.is_err(),
"Should fail to connect to non-routable IP"
)
}
Ok(())
}
#[tokio::test]
async fn test_authentication_failure() -> Result<()> {
let wrong_creds_url = build_test_db_url_with_db("/codex_db")
.replace("codex_user:", "wrong_user:")
.replace("MZSfXiLr5uR3QYbRwv2vTzi22SvFkj4a", "wrong_pass");
let result = create_pool(&wrong_creds_url).await;
assert!(result.is_err(), "Should fail with wrong credentials");
let error_msg = result.unwrap_err().to_string().to_lowercase();
assert!(
error_msg.contains("authentication")
|| error_msg.contains("password")
|| error_msg.contains("role")
|| error_msg.contains("does not exist"),
"Error should be authentication-related: {}",
error_msg
);
Ok(())
}
#[tokio::test]
async fn test_connection_pool_exhaustion() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let mut connections = Vec::new();
for i in 0..5 {
match pool.acquire().await {
Ok(conn) => {
connections.push(conn);
println!("Acquired connection {}", i + 1);
}
Err(e) => {
println!("Failed to acquire connection {}: {}", i + 1, e);
break;
}
}
}
let extra_connection = timeout(Duration::from_secs(2), pool.acquire()).await;
match extra_connection {
Ok(Ok(_)) => {
println!("Extra connection acquired - pool management working");
}
Ok(Err(e)) => {
println!("Pool exhausted as expected: {}", e);
}
Err(_) => {
println!("Connection acquisition timed out as expected");
}
}
drop(connections);
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_database_does_not_exist() -> Result<()> {
let nonexistent_db_url = build_test_db_url_with_db("/nonexistent_database_12345");
let result = create_pool(&nonexistent_db_url).await;
assert!(result.is_err(), "Should fail when database doesn't exist");
let error_msg = result.unwrap_err().to_string().to_lowercase();
assert!(
error_msg.contains("database")
&& (error_msg.contains("does not exist") || error_msg.contains("not exist")),
"Error should indicate database doesn't exist: {}",
error_msg
);
Ok(())
}
#[tokio::test]
async fn test_migration_on_corrupted_schema() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
sqlx::query("DROP TABLE IF EXISTS memories CASCADE")
.execute(&pool)
.await?;
let result = run_migrations(&pool).await;
assert!(result.is_ok(), "Migrations should handle missing tables");
let table_exists: bool = sqlx::query(
"SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'memories'
) as exists",
)
.fetch_one(&pool)
.await?
.get("exists");
assert!(
table_exists,
"Memories table should be recreated after migration"
);
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_concurrent_database_operations() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let storage = Arc::new(Storage::new(pool));
let mut handles = vec![];
for i in 0..20 {
let storage_clone = storage.clone();
let handle = tokio::spawn(async move {
let content = format!("Concurrent content #{}", i);
storage_clone
.store(
&content,
"Test context".to_string(),
"Test summary".to_string(),
None,
)
.await
});
handles.push(handle);
}
let mut successes = 0;
let mut failures = 0;
for handle in handles {
match handle.await {
Ok(Ok(_)) => successes += 1,
Ok(Err(e)) => {
println!("Storage operation failed: {}", e);
failures += 1;
}
Err(e) => {
println!("Task failed: {}", e);
failures += 1;
}
}
}
println!(
"Concurrent operations: {} succeeded, {} failed",
successes, failures
);
assert!(successes > 15, "Most concurrent operations should succeed");
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_database_transaction_timeout() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let mut tx = pool.begin().await?;
sqlx::query("INSERT INTO memories (content, content_hash, context, summary) VALUES ('test', 'testhash', 'test context', 'test summary')")
.execute(&mut *tx)
.await?;
let result = timeout(
Duration::from_secs(2),
sqlx::query("SELECT COUNT(*) as count FROM memories WHERE content_hash = 'testhash'")
.fetch_one(&pool),
)
.await;
match result {
Ok(row_result) => {
match row_result {
Ok(row) => {
let count: i64 = row.get("count");
assert_eq!(count, 0, "Uncommitted data should not be visible");
}
Err(e) => println!("Query failed as expected: {}", e),
}
}
Err(_) => println!("Query timed out as expected due to lock"),
}
tx.rollback().await?;
manager.cleanup().await?;
Ok(())
}