use crate::error::BoxDynError;
use crate::migrate::{migration, Migration, MigrationType};
use crate::sql_str::{AssertSqlSafe, SqlSafeStr};
use futures_core::future::BoxFuture;
use std::borrow::Cow;
use std::collections::BTreeSet;
use std::fmt::Debug;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
pub trait MigrationSource<'s>: Debug {
fn resolve(self) -> BoxFuture<'s, Result<Vec<Migration>, BoxDynError>>;
}
impl<'s> MigrationSource<'s> for &'s Path {
fn resolve(self) -> BoxFuture<'s, Result<Vec<Migration>, BoxDynError>> {
self.to_owned().resolve()
}
}
impl MigrationSource<'static> for PathBuf {
fn resolve(self) -> BoxFuture<'static, Result<Vec<Migration>, BoxDynError>> {
Box::pin(async move {
crate::rt::spawn_blocking(move || {
let migrations_with_paths = resolve_blocking(&self)?;
Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect())
})
.await
})
}
}
#[derive(Debug)]
pub struct ResolveWith<S>(pub S, pub ResolveConfig);
impl<'s, S: Debug + Into<PathBuf> + Send + 's> MigrationSource<'s> for ResolveWith<S> {
fn resolve(self) -> BoxFuture<'s, Result<Vec<Migration>, BoxDynError>> {
Box::pin(async move {
let path = self.0.into();
let config = self.1;
let migrations_with_paths =
crate::rt::spawn_blocking(move || resolve_blocking_with_config(&path, &config))
.await?;
Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect())
})
}
}
#[derive(thiserror::Error, Debug)]
#[error("{message}")]
pub struct ResolveError {
message: String,
#[source]
source: Option<io::Error>,
}
#[derive(Debug, Default)]
pub struct ResolveConfig {
ignored_chars: BTreeSet<char>,
}
impl ResolveConfig {
pub fn new() -> Self {
ResolveConfig {
ignored_chars: BTreeSet::new(),
}
}
pub fn ignore_char(&mut self, c: char) -> &mut Self {
self.ignored_chars.insert(c);
self
}
pub fn ignore_chars(&mut self, chars: impl IntoIterator<Item = char>) -> &mut Self {
self.ignored_chars.extend(chars);
self
}
pub fn ignored_chars(&self) -> impl Iterator<Item = char> + '_ {
self.ignored_chars.iter().copied()
}
}
#[doc(hidden)]
pub fn resolve_blocking(path: &Path) -> Result<Vec<(Migration, PathBuf)>, ResolveError> {
resolve_blocking_with_config(path, &ResolveConfig::new())
}
#[doc(hidden)]
pub fn resolve_blocking_with_config(
path: &Path,
config: &ResolveConfig,
) -> Result<Vec<(Migration, PathBuf)>, ResolveError> {
let path = path.canonicalize().map_err(|e| ResolveError {
message: format!("error canonicalizing path {}", path.display()),
source: Some(e),
})?;
let s = fs::read_dir(&path).map_err(|e| ResolveError {
message: format!("error reading migration directory {}", path.display()),
source: Some(e),
})?;
let mut migrations = Vec::new();
for res in s {
let entry = res.map_err(|e| ResolveError {
message: format!(
"error reading contents of migration directory {}",
path.display()
),
source: Some(e),
})?;
let entry_path = entry.path();
let metadata = fs::metadata(&entry_path).map_err(|e| ResolveError {
message: format!(
"error getting metadata of migration path {}",
entry_path.display()
),
source: Some(e),
})?;
if !metadata.is_file() {
continue;
}
let file_name = entry.file_name();
let file_name = file_name.to_string_lossy();
let parts = file_name.splitn(2, '_').collect::<Vec<_>>();
if parts.len() != 2 || !parts[1].ends_with(".sql") {
continue;
}
let version: i64 = parts[0].parse()
.map_err(|_e| ResolveError {
message: format!("error parsing migration filename {file_name:?}; expected integer version prefix (e.g. `01_foo.sql`)"),
source: None,
})?;
let migration_type = MigrationType::from_filename(parts[1]);
let description = parts[1]
.trim_end_matches(migration_type.suffix())
.replace('_', " ")
.to_owned();
let sql = fs::read_to_string(&entry_path).map_err(|e| ResolveError {
message: format!(
"error reading contents of migration {}: {e}",
entry_path.display()
),
source: Some(e),
})?;
let no_tx = sql.starts_with("-- no-transaction");
let checksum = checksum_with(&sql, &config.ignored_chars);
migrations.push((
Migration::with_checksum(
version,
Cow::Owned(description),
migration_type,
AssertSqlSafe(sql).into_sql_str(),
checksum.into(),
no_tx,
),
entry_path,
));
}
migrations.sort_by_key(|(m, _)| m.version);
Ok(migrations)
}
fn checksum_with(sql: &str, ignored_chars: &BTreeSet<char>) -> Vec<u8> {
if ignored_chars.is_empty() {
return migration::checksum(sql);
}
migration::checksum_fragments(sql.split(|c| ignored_chars.contains(&c)))
}
#[test]
fn checksum_with_ignored_chars() {
let ignored_chars = [
' ', '\t', '\r', '\n',
'\u{FEFF}',
];
let sql = "\
\u{FEFF}create table comment (\r\n\
\tcomment_id uuid primary key default gen_random_uuid(),\r\n\
\tpost_id uuid not null references post(post_id),\r\n\
\tuser_id uuid not null references \"user\"(user_id),\r\n\
\tcontent text not null,\r\n\
\tcreated_at timestamptz not null default now()\r\n\
);\r\n\
\r\n\
create index on comment(post_id, created_at);\r\n\
";
let stripped_sql = sql.replace(&ignored_chars[..], "");
let ignored_chars = BTreeSet::from(ignored_chars);
let digest_ignored = checksum_with(sql, &ignored_chars);
let digest_stripped = migration::checksum(&stripped_sql);
assert_eq!(digest_ignored, digest_stripped);
}