use crate::common::TestDatabaseManager;
use anyhow::Result;
use codex_memory::database::{create_pool, run_migrations};
use serial_test::serial;
use sqlx::Row;
use std::env;
#[tokio::test]
#[serial]
async fn test_connection_pool_creation() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let _ = manager.setup_test_database().await?;
let test_url = std::env::var("TEST_DATABASE_URL")
.expect("TEST_DATABASE_URL must be set for database tests");
let pool = create_pool(&test_url).await?;
let result = sqlx::query("SELECT 1 as test").fetch_one(&pool).await?;
assert_eq!(result.get::<i32, _>("test"), 1);
Ok(())
}
#[tokio::test]
#[serial]
async fn test_migrations_idempotent() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
for _ in 0..3 {
run_migrations(&pool).await?;
}
let table_exists = sqlx::query(
"SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'memories'
) as exists",
)
.fetch_one(&pool)
.await?;
assert!(table_exists.get::<bool, _>("exists"));
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
#[serial]
async fn test_table_structure() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let columns: Vec<(String, String, bool)> = sqlx::query(
"SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_name = 'memories'
ORDER BY ordinal_position",
)
.fetch_all(&pool)
.await?
.into_iter()
.map(|row| {
(
row.get("column_name"),
row.get("data_type"),
row.get::<String, _>("is_nullable") == "YES",
)
})
.collect();
let expected_columns = vec![
("id", "uuid", false),
("content", "text", false),
("content_hash", "character varying", false),
("tags", "ARRAY", true),
("context", "text", false),
("summary", "text", false),
("chunk_index", "integer", true),
("total_chunks", "integer", true),
("parent_id", "uuid", true),
("created_at", "timestamp with time zone", true),
("updated_at", "timestamp with time zone", true),
];
for (name, dtype, nullable) in expected_columns {
assert!(
columns
.iter()
.any(|(n, d, null)| n == name && d.contains(dtype) && *null == nullable),
"Column {} with type {} not found",
name,
dtype
);
}
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
#[serial]
async fn test_indexes_created() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let indexes: Vec<String> =
sqlx::query("SELECT indexname FROM pg_indexes WHERE tablename = 'memories'")
.fetch_all(&pool)
.await?
.into_iter()
.map(|row| row.get("indexname"))
.collect();
assert!(indexes.iter().any(|i| i.contains("content_hash")));
assert!(indexes.iter().any(|i| i.contains("tags")));
assert!(indexes.iter().any(|i| i.contains("created_at")));
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
#[serial]
async fn test_stats_view() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let stats = sqlx::query("SELECT * FROM memory_stats")
.fetch_one(&pool)
.await?;
assert_eq!(stats.get::<i64, _>("total_memories"), 0);
assert!(stats.try_get::<String, _>("table_size").is_ok());
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
#[serial]
async fn test_concurrent_connections() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let _ = manager.setup_test_database().await?;
let test_url = std::env::var("TEST_DATABASE_URL")
.expect("TEST_DATABASE_URL must be set for database tests");
let pool = create_pool(&test_url).await?;
let mut handles = vec![];
for i in 0..5 {
let pool = pool.clone();
let handle = tokio::spawn(async move {
let result = sqlx::query("SELECT $1::int as num")
.bind(i)
.fetch_one(&pool)
.await?;
Ok::<i32, sqlx::Error>(result.get("num"))
});
handles.push(handle);
}
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await??;
assert_eq!(result, i as i32);
}
Ok(())
}
#[tokio::test]
#[serial]
async fn test_transaction_rollback() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let mut tx = pool.begin().await?;
sqlx::query(
"INSERT INTO memories (id, content, content_hash, context, summary)
VALUES (gen_random_uuid(), 'test', 'testhash', 'test context', 'test summary')",
)
.execute(&mut *tx)
.await?;
let count_in_tx: i64 = sqlx::query("SELECT COUNT(*) as count FROM memories")
.fetch_one(&mut *tx)
.await?
.get("count");
assert_eq!(count_in_tx, 1);
tx.rollback().await?;
let count_after: i64 = sqlx::query("SELECT COUNT(*) as count FROM memories")
.fetch_one(&pool)
.await?
.get("count");
assert_eq!(count_after, 0);
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
#[serial]
async fn test_invalid_database_url_error_handling() -> Result<()> {
let invalid_urls = vec![
"invalid://url",
"postgresql://",
"postgresql:///nodatabase",
"",
];
for url in invalid_urls {
let result = create_pool(url).await;
assert!(result.is_err(), "Expected error for invalid URL: {}", url);
}
Ok(())
}
#[allow(dead_code)]
async fn test_setup_database_error_handling() -> Result<()> {
use codex_memory::database::setup_local_database;
env::remove_var("DATABASE_URL");
let result = setup_local_database().await;
assert!(
result.is_err(),
"Expected error when DATABASE_URL is not set"
);
match result.unwrap_err() {
codex_memory::error::Error::Config(msg) => {
assert!(msg.contains("DATABASE_URL not set"));
}
other => panic!("Expected Config error, got: {:?}", other),
}
Ok(())
}
#[tokio::test]
#[serial]
async fn test_connection_pool_error_recovery() -> Result<()> {
let bad_url = "postgresql://user:pass@192.0.2.1:5432/db";
let timeout_result =
tokio::time::timeout(std::time::Duration::from_secs(5), create_pool(bad_url)).await;
match timeout_result {
Ok(pool_result) => {
assert!(
pool_result.is_err(),
"Expected error connecting to non-existent server"
);
match pool_result.unwrap_err() {
codex_memory::error::Error::Database(sqlx_error) => {
println!("Properly caught database error: {:?}", sqlx_error);
}
other => {
println!("Got different error type (also acceptable): {:?}", other);
}
}
}
Err(_) => {
println!("Connection timed out as expected for non-routable IP");
}
}
Ok(())
}
#[tokio::test]
#[serial]
async fn test_database_setup_resilience() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let invalid_result = sqlx::query("INVALID SQL COMMAND").execute(&pool).await;
assert!(invalid_result.is_err(), "Expected error for invalid SQL");
let nonexistent_result = sqlx::query("SELECT * FROM nonexistent_table")
.execute(&pool)
.await;
assert!(
nonexistent_result.is_err(),
"Expected error for non-existent table"
);
let valid_result = sqlx::query("SELECT 1 as test").fetch_one(&pool).await?;
assert_eq!(valid_result.get::<i32, _>("test"), 1);
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
#[serial]
async fn test_no_unwrap_calls_in_database_module() -> Result<()> {
use std::fs;
let source_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src/database/core.rs");
let source_content =
fs::read_to_string(source_path).expect("Should be able to read database/core.rs source");
let unwrap_count = source_content.matches(".unwrap()").count();
assert_eq!(unwrap_count, 0,
"Found {} unwrap() calls in database/core.rs - all should be replaced with proper error handling",
unwrap_count);
let expect_lines: Vec<&str> = source_content
.lines()
.filter(|line| line.contains(".expect(") && !line.contains("// Test"))
.collect();
for line in expect_lines {
assert!(
line.contains("TEST_DATABASE_URL") || line.contains("documented invariant"),
"Found unsafe expect() call: {}",
line
);
}
println!("✅ Verified no unwrap() calls remain in database.rs production code");
Ok(())
}
#[tokio::test]
#[serial]
async fn test_database_setup_error_diagnostics() -> Result<()> {
let original_url = env::var("DATABASE_URL").ok();
env::set_var("DATABASE_URL", "postgresql://invalid::url::format");
let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
codex_memory::database::setup_local_database(),
)
.await;
match &result {
Ok(setup_result) => assert!(setup_result.is_err(), "Expected error for malformed URL"),
Err(_timeout) => {
println!("⚠️ Setup timed out on malformed URL (acceptable)");
}
}
if let Ok(Err(e)) = result {
let error_msg = format!("{:?}", e);
assert!(
error_msg.contains("Invalid URL") || error_msg.contains("Invalid DATABASE_URL"),
"Error should contain URL validation message: {}",
error_msg
);
}
env::set_var(
"DATABASE_URL",
"postgresql://user:pass@192.0.2.1:5432/testdb",
);
let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
codex_memory::database::setup_local_database(),
)
.await;
match &result {
Ok(setup_result) => assert!(setup_result.is_err(), "Expected error for unreachable host"),
Err(_timeout) => {
println!("⚠️ Setup timed out on unreachable host (acceptable)");
}
}
if let Some(url) = original_url {
env::set_var("DATABASE_URL", url);
} else {
env::remove_var("DATABASE_URL");
}
Ok(())
}
#[tokio::test]
#[serial]
async fn test_command_failure_error_propagation() -> Result<()> {
use std::process::Command;
let result = tokio::time::timeout(
std::time::Duration::from_secs(10),
tokio::task::spawn_blocking(|| {
Command::new("psql")
.args(["-h", "nonexistent.host", "-c", "SELECT 1"])
.output()
}),
)
.await;
match result {
Ok(Ok(Ok(output))) if !output.status.success() => {
let exit_code = output
.status
.code()
.map(|code| code.to_string())
.unwrap_or_else(|| "unknown".to_string());
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(
!stderr.trim().is_empty() || exit_code != "unknown",
"Command failure should provide either stderr or exit code information"
);
println!(
"✅ Command failure provides diagnostic info: exit_code={}, stderr_len={}",
exit_code,
stderr.len()
);
}
Ok(Ok(Ok(_))) => {
println!("⚠️ Command succeeded unexpectedly (network might be different)");
}
Ok(Ok(Err(e))) => {
println!("⚠️ Command failed to execute: {}", e);
}
Ok(Err(_)) => {
println!("⚠️ Task join error");
}
Err(_) => {
println!("⚠️ Command timed out (acceptable for unreachable host)");
}
}
Ok(())
}
#[tokio::test]
#[serial]
async fn test_production_safety_verification() -> Result<()> {
let source_path = concat!(env!("CARGO_MANIFEST_DIR"), "/src/database/core.rs");
let source_content =
std::fs::read_to_string(source_path).expect("Should be able to read database.rs source");
let unwrap_count = source_content.matches(".unwrap()").count();
assert_eq!(
unwrap_count, 0,
"All unwrap() calls should be eliminated from database.rs"
);
assert!(
source_content.contains("stderr = String::from_utf8_lossy"),
"Should include stderr handling in error messages"
);
assert!(
source_content.contains("exit_code = fallback_result.status.code()"),
"Should include proper exit code handling"
);
assert!(
source_content.contains("stderr.trim()"),
"Should include stderr in error messages for diagnostics"
);
let unwrap_or_count = source_content.matches("unwrap_or").count();
println!(
"✅ Found {} safe unwrap_or patterns (with fallbacks)",
unwrap_or_count
);
let panic_count = source_content.matches("panic!").count();
assert_eq!(
panic_count, 0,
"No panic! calls should exist in database.rs"
);
println!(
"✅ CODEX-RUST-001 safety verification passed - production unwrap() violations eliminated"
);
Ok(())
}