#![doc = include_str!("../README.md")]
use std::future::Future;
use std::pin::Pin;
use thiserror::Error;
use turso::{Connection, Error};
pub type MigrationFuture<'a> = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
#[derive(Error, Debug)]
pub enum MigrationError {
#[error("Turso error: {0}")]
Turso(#[from] Error),
#[error("Database has user_version {0}, but it must be between 0 and {1}")]
InvalidUserVersion(i32, usize),
}
#[derive(Clone, Copy)]
pub enum Migration<'a> {
Sql {
name: &'a str,
sql: &'a str,
},
Fn {
name: &'a str,
f: fn(&Connection) -> MigrationFuture,
},
}
impl<'a> Migration<'a> {
pub const fn up(name: &'a str, sql: &'a str) -> Self {
Self::Sql { name, sql }
}
pub const fn up_fn(name: &'a str, f: fn(&Connection) -> MigrationFuture) -> Self {
Self::Fn { name, f }
}
}
#[macro_export]
macro_rules! up_file {
($path:literal) => {
$crate::Migration::up($path, include_str!($path))
};
}
#[macro_export]
macro_rules! up_fn {
($name:expr, $func:path) => {{
fn wrapper(conn: &turso::Connection) -> $crate::MigrationFuture<'_> {
Box::pin($func(conn))
}
$crate::Migration::up_fn($name, wrapper)
}};
}
pub struct Migrations<'a> {
migrations: &'a [Migration<'a>],
}
impl<'a> Migrations<'a> {
pub const fn new(migrations: &'a [Migration<'a>]) -> Self {
Self { migrations }
}
pub async fn to_latest(&self, conn: &mut Connection) -> Result<usize, (usize, MigrationError)> {
let current_version = match get_user_version(conn).await {
Ok(v) => v,
Err(e) => return Err((0, e.into())),
};
let target_version = self.migrations.len() as i32;
if current_version == target_version {
return Ok(0);
}
if current_version < 0 || current_version > target_version {
return Err((
0,
MigrationError::InvalidUserVersion(current_version, self.migrations.len()),
));
}
let mut applied_count = 0;
for (i, migration) in self
.migrations
.iter()
.enumerate()
.skip(current_version as usize)
{
let version = (i + 1) as i32;
let tx = match conn.transaction().await {
Ok(tx) => tx,
Err(e) => return Err((applied_count, e.into())),
};
let result = match migration {
Migration::Sql { name: _name, sql } => tx.execute_batch(sql).await.map(|_| ()),
Migration::Fn { name: _name, f } => f(&tx).await,
};
#[cfg(feature = "tracing")]
let migration_name_log = match migration {
Migration::Sql { name, .. } => *name,
Migration::Fn { name, .. } => *name,
};
if let Err(e) = result {
#[cfg(feature = "tracing")]
tracing::error!(error = ?e, migration_name = %migration_name_log, "Migration failed");
return Err((applied_count, e.into()));
}
if let Err(e) = set_user_version(&tx, version).await {
#[cfg(feature = "tracing")]
tracing::error!(error = ?e, migration_name = %migration_name_log, "Failed to update user_version");
return Err((applied_count, e.into()));
}
if let Err(e) = tx.commit().await {
#[cfg(feature = "tracing")]
tracing::error!(error = ?e, migration_name = %migration_name_log, "Failed to commit transaction");
return Err((applied_count, e.into()));
}
#[cfg(feature = "tracing")]
tracing::debug!(migration_name = %migration_name_log, "Migration applied successfully");
applied_count += 1;
}
Ok(applied_count)
}
pub async fn run_all_in_memory(&self) -> Result<Connection, MigrationError> {
let db = turso::Builder::new_local(":memory:").build().await?;
let mut conn = db.connect()?;
if let Err((_, e)) = self.to_latest(&mut conn).await {
return Err(e);
}
Ok(conn)
}
}
async fn get_user_version(conn: &Connection) -> Result<i32, Error> {
let version = std::sync::atomic::AtomicI32::new(0);
conn.pragma_query("user_version", |row| {
let v = row.get::<i32>(0).unwrap();
version.store(v, std::sync::atomic::Ordering::SeqCst);
Ok(())
})
.await?;
Ok(version.load(std::sync::atomic::Ordering::SeqCst))
}
async fn set_user_version(conn: &Connection, version: i32) -> Result<(), Error> {
conn.pragma_update("user_version", version).await?;
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
use turso::Builder;
const MIGRATIONS: Migrations = Migrations::new(&[
Migration::up("001", "CREATE TABLE friend(name TEXT NOT NULL);"),
up_fn!("002", my_complex_migration),
]);
async fn my_complex_migration(conn: &Connection) -> turso::Result<()> {
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", ())
.await?;
conn.execute("INSERT INTO users (name) VALUES ('Alice')", ())
.await?;
Ok(())
}
async fn get_in_memory_conn() -> Connection {
let db = Builder::new_local(":memory:").build().await.unwrap();
db.connect().unwrap()
}
#[tokio::test]
async fn test_migrations_success() {
let mut conn = get_in_memory_conn().await;
let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
assert_eq!(applied_count, 2);
let version = get_user_version(&conn).await.unwrap();
assert_eq!(version, 2);
conn.execute("INSERT INTO friend (name) VALUES ('test')", ())
.await
.unwrap();
conn.execute("INSERT INTO users (name) VALUES ('test')", ())
.await
.unwrap();
}
#[tokio::test]
async fn test_idempotency() {
let mut conn = get_in_memory_conn().await;
let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
assert_eq!(applied_count, 2);
let applied_count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
assert_eq!(applied_count, 0);
}
#[tokio::test]
async fn test_partial_migration() {
let mut conn = get_in_memory_conn().await;
conn.pragma_update("user_version", 1).await.unwrap();
let count = MIGRATIONS.to_latest(&mut conn).await.unwrap();
assert_eq!(count, 1);
let version = get_user_version(&conn).await.unwrap();
assert_eq!(version, 2);
}
#[tokio::test]
async fn test_validate_helper() {
let conn = MIGRATIONS.run_all_in_memory().await;
assert!(conn.is_ok());
let conn = conn.unwrap();
let version = get_user_version(&conn).await.unwrap();
assert_eq!(version, 2);
}
#[tokio::test]
async fn test_negative_user_version() {
let mut conn = get_in_memory_conn().await;
conn.pragma_update("user_version", -5).await.unwrap();
let result = MIGRATIONS.to_latest(&mut conn).await;
match result {
Err((0, MigrationError::InvalidUserVersion(v, 2))) => assert_eq!(v, -5),
_ => panic!("Expected InvalidUserVersion error, got {:?}", result),
}
}
#[tokio::test]
async fn test_user_version_too_high() {
let mut conn = get_in_memory_conn().await;
conn.pragma_update("user_version", 10).await.unwrap();
let result = MIGRATIONS.to_latest(&mut conn).await;
match result {
Err((0, MigrationError::InvalidUserVersion(v, max))) => {
assert_eq!(v, 10);
assert_eq!(max, 2);
}
_ => panic!("Expected InvalidUserVersion error, got {:?}", result),
}
}
#[tokio::test]
async fn test_failing_migration() {
let mut conn = get_in_memory_conn().await;
const BROKEN_MIGRATIONS: Migrations = Migrations::new(&[
Migration::up("001", "CREATE TABLE ok (id int)"),
Migration::up("002", "SELECT * FROM non_existent_table"), ]);
let result = BROKEN_MIGRATIONS.to_latest(&mut conn).await;
match result {
Err((1, MigrationError::Turso(_))) => {} _ => panic!("Expected Turso error, got {:?}", result),
}
let version = get_user_version(&conn).await.unwrap();
assert_eq!(version, 1);
}
#[tokio::test]
async fn test_up_file_macro() {
let mut conn = get_in_memory_conn().await;
const FILE_MIGRATIONS: Migrations =
Migrations::new(&[up_file!("../tests/migration-files/001_test.sql")]);
let result = FILE_MIGRATIONS.to_latest(&mut conn).await;
assert!(result.is_ok(), "File migration failed: {:?}", result.err());
let version = get_user_version(&conn).await.unwrap();
assert_eq!(version, 1);
conn.execute("INSERT INTO file_test (id) VALUES (1)", ())
.await
.unwrap();
conn.execute("INSERT INTO file_test_2 (id) VALUES (1)", ())
.await
.unwrap();
}
#[tokio::test]
async fn test_dynamic_migration_name() {
let mut conn = get_in_memory_conn().await;
let dynamic_name = String::from("003_dynamic");
let dynamic_sql = String::from("CREATE TABLE dynamic(id int)");
let migrations = [Migration::up(&dynamic_name, &dynamic_sql)];
let migrations = Migrations::new(&migrations);
let applied = migrations.to_latest(&mut conn).await.unwrap();
assert_eq!(applied, 1);
let version = get_user_version(&conn).await.unwrap();
assert_eq!(version, 1);
}
}