use crate::TestError;
use std::path::PathBuf;
use switchy_database::{Database, DatabaseError};
use switchy_schema::MigrationError;
#[cfg(feature = "snapshots")]
use insta::Settings;
#[cfg(feature = "snapshots")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "snapshots")]
use std::collections::BTreeMap;
#[cfg(feature = "snapshots")]
use std::sync::Arc;
#[cfg(feature = "snapshots")]
use std::{future::Future, pin::Pin};
#[cfg(feature = "snapshots")]
use switchy_database::schema::{ColumnInfo as DbColumnInfo, TableInfo};
#[cfg(feature = "snapshots")]
use switchy_database::{DatabaseValue, Row};
#[cfg(feature = "snapshots")]
use switchy_schema::discovery::directory::DirectoryMigrationSource;
#[cfg(feature = "snapshots")]
use switchy_schema::migration::{Migration, MigrationSource};
#[cfg(feature = "snapshots")]
use switchy_schema::runner::MigrationRunner;
#[cfg(feature = "snapshots")]
type SetupFn = Box<
dyn for<'a> Fn(
&'a dyn Database,
) -> Pin<
Box<dyn Future<Output = std::result::Result<(), DatabaseError>> + Send + 'a>,
> + Send
+ Sync,
>;
#[cfg(feature = "snapshots")]
type VerificationFn = Box<
dyn for<'a> Fn(
&'a dyn Database,
) -> Pin<
Box<dyn Future<Output = std::result::Result<(), DatabaseError>> + Send + 'a>,
> + Send
+ Sync,
>;
#[cfg(feature = "snapshots")]
struct VecMigrationSource<'a> {
migrations: Vec<Arc<dyn Migration<'a> + 'a>>,
}
#[cfg(feature = "snapshots")]
impl<'a> VecMigrationSource<'a> {
#[must_use]
fn new(migrations: Vec<Arc<dyn Migration<'a> + 'a>>) -> Self {
Self { migrations }
}
}
#[cfg(feature = "snapshots")]
#[async_trait::async_trait]
impl<'a> MigrationSource<'a> for VecMigrationSource<'a> {
async fn migrations(&self) -> switchy_schema::Result<Vec<Arc<dyn Migration<'a> + 'a>>> {
Ok(self.migrations.clone())
}
}
#[derive(Debug, thiserror::Error)]
pub enum SnapshotError {
#[error("Database error: {0}")]
Database(#[from] DatabaseError),
#[error("Migration error: {0}")]
Migration(#[from] MigrationError),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Snapshot validation failed: {0}")]
Validation(String),
#[error("Test error: {0}")]
Test(#[from] TestError),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, SnapshotError>;
#[cfg(feature = "snapshots")]
#[derive(Debug, Serialize, Deserialize)]
struct MigrationSnapshot {
test_name: String,
migration_sequence: Vec<String>,
schema: Option<DatabaseSchema>,
data_samples: Option<std::collections::BTreeMap<String, Vec<serde_json::Value>>>,
}
#[cfg(feature = "snapshots")]
#[derive(Debug, Serialize, Deserialize)]
struct DatabaseSchema {
tables: BTreeMap<String, TableSchema>,
}
#[cfg(feature = "snapshots")]
#[derive(Debug, Serialize, Deserialize)]
struct TableSchema {
columns: Vec<ColumnInfo>,
indexes: Vec<String>,
}
#[cfg(feature = "snapshots")]
#[derive(Debug, Serialize, Deserialize)]
struct ColumnInfo {
name: String,
data_type: String,
nullable: bool,
default_value: Option<String>,
primary_key: bool,
}
pub struct SnapshotTester {
}
#[allow(clippy::struct_excessive_bools)]
pub struct MigrationSnapshotTest {
test_name: String,
migrations_dir: Option<PathBuf>,
assert_schema: bool,
assert_sequence: bool,
expected_tables: Vec<String>, redact_timestamps: bool,
redact_auto_ids: bool,
redact_paths: bool,
assert_data: bool,
data_samples: std::collections::BTreeMap<String, usize>,
setup_fn: Option<SetupFn>,
verification_fn: Option<VerificationFn>,
db: Option<Box<dyn Database>>,
migrations_table_name: Option<String>,
}
impl MigrationSnapshotTest {
#[must_use]
pub fn new(test_name: &str) -> Self {
Self {
test_name: test_name.to_string(),
migrations_dir: None,
assert_schema: true,
assert_sequence: true,
expected_tables: Vec::new(), redact_timestamps: true,
redact_auto_ids: true,
redact_paths: true,
assert_data: false,
data_samples: std::collections::BTreeMap::new(),
setup_fn: None,
verification_fn: None,
db: None,
migrations_table_name: None,
}
}
#[must_use]
pub fn migrations_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.migrations_dir = Some(path.into());
self
}
#[must_use]
pub const fn assert_schema(mut self, enabled: bool) -> Self {
self.assert_schema = enabled;
self
}
#[must_use]
pub const fn assert_sequence(mut self, enabled: bool) -> Self {
self.assert_sequence = enabled;
self
}
#[must_use]
pub fn expected_tables(mut self, tables: Vec<String>) -> Self {
self.expected_tables = tables;
self
}
#[must_use]
pub const fn redact_timestamps(mut self, enabled: bool) -> Self {
self.redact_timestamps = enabled;
self
}
#[must_use]
pub const fn redact_auto_ids(mut self, enabled: bool) -> Self {
self.redact_auto_ids = enabled;
self
}
#[must_use]
pub const fn redact_paths(mut self, enabled: bool) -> Self {
self.redact_paths = enabled;
self
}
#[must_use]
pub const fn assert_data(mut self, enabled: bool) -> Self {
self.assert_data = enabled;
self
}
#[must_use]
pub fn with_data_samples(mut self, table: &str, count: usize) -> Self {
self.data_samples.insert(table.to_string(), count);
self
}
#[must_use]
#[cfg(feature = "snapshots")]
pub fn with_setup<F>(mut self, f: F) -> Self
where
F: for<'a> Fn(
&'a dyn Database,
) -> Pin<
Box<dyn Future<Output = std::result::Result<(), DatabaseError>> + Send + 'a>,
> + Send
+ Sync
+ 'static,
{
self.setup_fn = Some(Box::new(f));
self
}
#[must_use]
#[cfg(feature = "snapshots")]
pub fn with_verification<F>(mut self, f: F) -> Self
where
F: for<'a> Fn(
&'a dyn Database,
) -> Pin<
Box<dyn Future<Output = std::result::Result<(), DatabaseError>> + Send + 'a>,
> + Send
+ Sync
+ 'static,
{
self.verification_fn = Some(Box::new(f));
self
}
#[must_use]
pub fn with_database(mut self, db: Box<dyn Database>) -> Self {
self.db = Some(db);
self
}
#[must_use]
pub fn with_migrations_table(mut self, table_name: impl Into<String>) -> Self {
self.migrations_table_name = Some(table_name.into());
self
}
#[must_use]
pub const fn auto_discover_tables(self) -> Self {
self
}
#[must_use]
#[cfg(feature = "snapshots")]
pub fn with_test_builder(self, _builder: crate::MigrationTestBuilder<'_>) -> Self {
self
}
#[cfg(feature = "snapshots")]
async fn capture_schema(&self, db: &dyn Database) -> Result<DatabaseSchema> {
let mut schema = DatabaseSchema {
tables: BTreeMap::new(),
};
for table_name in &self.expected_tables {
if let Some(table_info) = db.get_table_info(table_name).await? {
let columns = table_info
.columns
.into_values()
.map(|col| ColumnInfo {
name: col.name,
data_type: format!("{:?}", col.data_type), nullable: col.nullable,
default_value: col.default_value.map(|v| format!("{v:?}")),
primary_key: col.is_primary_key,
})
.collect();
let indexes = table_info
.indexes
.into_values()
.map(|idx| idx.name)
.collect();
schema
.tables
.insert(table_name.clone(), TableSchema { columns, indexes });
}
}
Ok(schema)
}
#[cfg(feature = "snapshots")]
#[allow(unused, clippy::unused_async)] async fn discover_tables_from_migrations(&self) -> Result<Vec<String>> {
Ok(vec![])
}
#[cfg(feature = "snapshots")]
async fn capture_data_samples(
&self,
db: &dyn Database,
) -> Result<std::collections::BTreeMap<String, Vec<serde_json::Value>>> {
let mut samples = std::collections::BTreeMap::new();
for (table, &count) in &self.data_samples {
let query = db.select(table).limit(count);
let rows = query.execute(db).await?;
let sample_data: Vec<serde_json::Value> = rows
.into_iter()
.map(row_to_json) .collect();
samples.insert(table.clone(), sample_data);
}
Ok(samples)
}
#[cfg(feature = "snapshots")]
async fn create_test_database(&self) -> Result<Box<dyn Database>> {
log::debug!("Creating test database");
let db = crate::create_empty_in_memory()
.await
.map_err(TestError::from)?;
Ok(db)
}
#[cfg(feature = "snapshots")]
async fn load_migrations(&self) -> Result<Vec<Arc<dyn Migration<'static> + 'static>>> {
if let Some(ref migrations_dir) = self.migrations_dir {
if migrations_dir.exists() {
log::debug!(
"Loading migrations from directory: {}",
migrations_dir.display()
);
let source = DirectoryMigrationSource::from_path(migrations_dir.clone());
let migrations = source.migrations().await?;
log::debug!("Loaded {} migrations from directory", migrations.len());
return Ok(migrations);
}
log::debug!(
"Migrations directory does not exist: {}",
migrations_dir.display()
);
} else {
log::debug!("No migrations directory configured");
}
Ok(vec![])
}
#[cfg(feature = "snapshots")]
async fn get_migration_sequence(&self, db: &dyn Database) -> Result<Vec<String>> {
use switchy_schema::{migration::MigrationStatus, version::VersionTracker};
let tracker = self
.migrations_table_name
.as_ref()
.map_or_else(VersionTracker::new, |table_name| {
VersionTracker::with_table_name(table_name.clone())
});
let ids = tracker
.get_applied_migration_ids(db, MigrationStatus::Completed)
.await
.map_err(SnapshotError::Migration)?;
log::debug!("Found {} applied migrations in database", ids.len());
Ok(ids)
}
#[cfg(feature = "snapshots")]
#[allow(clippy::cognitive_complexity)]
pub async fn run(mut self) -> Result<()> {
let db = if let Some(db) = self.db.take() {
db
} else {
self.create_test_database().await?
};
let db = &*db;
let migrations_to_apply = self.load_migrations().await?;
if let Some(setup_fn) = &self.setup_fn {
log::debug!("run: executing setup function");
setup_fn(db).await?;
} else {
log::debug!("run: no setup function provided");
}
if migrations_to_apply.is_empty() {
log::debug!("run: no new migrations to apply");
} else {
log::debug!("run: executing {} migrations", migrations_to_apply.len());
let source = VecMigrationSource::new(migrations_to_apply.clone());
let runner = MigrationRunner::new(Box::new(source));
runner.run(db).await?;
}
if let Some(verification_fn) = &self.verification_fn {
log::debug!("run: executing verification function");
verification_fn(db).await?;
} else {
log::debug!("run: no verification function provided");
}
let schema = if self.assert_schema {
log::debug!("run: capturing schema");
Some(self.capture_schema(db).await?)
} else {
log::debug!("run: no schema capture");
None
};
let sequence = if self.assert_sequence {
log::debug!("run: capturing migration sequence");
self.get_migration_sequence(db).await?
} else {
log::debug!("run: no migration sequence capture");
vec![]
};
log::debug!("run: migration sequence: {sequence:?}");
let data_samples = if self.assert_data {
log::debug!("run: capturing data samples");
Some(self.capture_data_samples(db).await?)
} else {
log::debug!("run: no data samples capture");
None
};
let snapshot = MigrationSnapshot {
test_name: self.test_name.clone(),
migration_sequence: sequence,
schema,
data_samples,
};
log::debug!("run: snapshot={snapshot:?}");
let mut settings = Settings::clone_current();
if self.redact_timestamps {
settings.add_filter(r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}", "[TIMESTAMP]");
settings.add_filter(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", "[TIMESTAMP]");
settings.add_filter(r"\d{4}-\d{2}-\d{2}", "[DATE]");
}
if self.redact_auto_ids {
settings.add_filter(r#""id": \d+"#, r#""id": "[ID]""#);
settings.add_filter(r#""user_id": \d+"#, r#""user_id": "[USER_ID]""#);
settings.add_filter(r#""post_id": \d+"#, r#""post_id": "[POST_ID]""#);
settings.add_filter(r#""(\w+_id)": \d+"#, r#""$1": "[FK_ID]""#);
}
if self.redact_paths {
settings.add_filter(r"/[\w/.-]+", "[PATH]");
settings.add_filter(r"[A-Z]:\\[\w\\.-]+", "[PATH]");
}
settings.bind(|| {
insta::assert_json_snapshot!(self.test_name, snapshot);
});
Ok(())
}
#[cfg(not(feature = "snapshots"))]
pub fn run(self) -> Result<()> {
println!("Test: {}", self.test_name);
println!("Migrations: {}", self.migrations_dir.display());
println!(
"Schema: {}, Sequence: {}",
self.assert_schema, self.assert_sequence
);
Ok(())
}
}
#[cfg(feature = "snapshots")]
#[allow(unused)]
fn table_info_to_schema(info: TableInfo) -> TableSchema {
TableSchema {
columns: info
.columns
.into_values()
.map(db_column_info_to_column_info)
.collect(),
indexes: info.indexes.into_values().map(|idx| idx.name).collect(),
}
}
#[cfg(feature = "snapshots")]
fn db_column_info_to_column_info(col: DbColumnInfo) -> ColumnInfo {
ColumnInfo {
name: col.name,
data_type: format!("{:?}", col.data_type), nullable: col.nullable,
default_value: col.default_value.map(|v| format!("{v:?}")),
primary_key: col.is_primary_key,
}
}
#[cfg(feature = "snapshots")]
#[allow(unused)]
fn row_to_json(row: Row) -> serde_json::Value {
let map: serde_json::Map<String, serde_json::Value> = row
.columns
.into_iter()
.map(|(k, v)| (k, database_value_to_json(v)))
.collect();
serde_json::Value::Object(map)
}
#[cfg(feature = "snapshots")]
#[allow(unused)]
fn database_value_to_json(value: DatabaseValue) -> serde_json::Value {
match value {
DatabaseValue::String(s) | DatabaseValue::StringOpt(Some(s)) => {
serde_json::Value::String(s)
}
DatabaseValue::Bool(b) | DatabaseValue::BoolOpt(Some(b)) => serde_json::Value::Bool(b),
DatabaseValue::Int8(i) | DatabaseValue::Int8Opt(Some(i)) => {
serde_json::Value::Number(i.into())
}
DatabaseValue::UInt8(i) | DatabaseValue::UInt8Opt(Some(i)) => {
serde_json::Value::Number(i.into())
}
DatabaseValue::Int16(i) | DatabaseValue::Int16Opt(Some(i)) => {
serde_json::Value::Number(i.into())
}
DatabaseValue::UInt16(i) | DatabaseValue::UInt16Opt(Some(i)) => {
serde_json::Value::Number(i.into())
}
DatabaseValue::Int32(i) | DatabaseValue::Int32Opt(Some(i)) => {
serde_json::Value::Number(i.into())
}
DatabaseValue::UInt32(i) | DatabaseValue::UInt32Opt(Some(i)) => {
serde_json::Value::Number(i.into())
}
DatabaseValue::Int64(i) | DatabaseValue::Int64Opt(Some(i)) => {
serde_json::Value::Number(i.into())
}
DatabaseValue::UInt64(u) | DatabaseValue::UInt64Opt(Some(u)) => {
serde_json::Value::Number(u.into())
}
DatabaseValue::Real64(f) | DatabaseValue::Real64Opt(Some(f)) => {
serde_json::Number::from_f64(f)
.map_or(serde_json::Value::Null, serde_json::Value::Number)
}
DatabaseValue::Real32(f) | DatabaseValue::Real32Opt(Some(f)) => {
serde_json::Number::from_f64(f64::from(f))
.map_or(serde_json::Value::Null, serde_json::Value::Number)
}
DatabaseValue::Null
| DatabaseValue::StringOpt(None)
| DatabaseValue::BoolOpt(None)
| DatabaseValue::Int8Opt(None)
| DatabaseValue::UInt8Opt(None)
| DatabaseValue::Int16Opt(None)
| DatabaseValue::UInt16Opt(None)
| DatabaseValue::Int32Opt(None)
| DatabaseValue::UInt32Opt(None)
| DatabaseValue::Int64Opt(None)
| DatabaseValue::UInt64Opt(None)
| DatabaseValue::Real64Opt(None)
| DatabaseValue::Real32Opt(None) => serde_json::Value::Null,
DatabaseValue::DateTime(dt) => serde_json::Value::String(dt.to_string()),
DatabaseValue::NowPlus(interval) => {
serde_json::Value::String(format!("NOW + {interval:?}"))
}
DatabaseValue::Now => serde_json::Value::String("NOW".to_string()),
#[cfg(feature = "decimal")]
DatabaseValue::Decimal(d) | DatabaseValue::DecimalOpt(Some(d)) => {
serde_json::Value::String(d.to_string())
}
#[cfg(feature = "decimal")]
DatabaseValue::DecimalOpt(None) => serde_json::Value::Null,
#[cfg(feature = "uuid")]
DatabaseValue::Uuid(d) | DatabaseValue::UuidOpt(Some(d)) => {
serde_json::Value::String(d.to_string())
}
#[cfg(feature = "uuid")]
DatabaseValue::UuidOpt(None) => serde_json::Value::Null,
}
}