use std::fmt;
use crate::config::DatabaseType;
use crate::database::{Database, require_db};
use crate::error::{Error, Result};
use crate::internal::ConnectionTrait;
use crate::tide_info;
pub use async_trait::async_trait;
#[async_trait]
pub trait Seed: Send + Sync {
fn name(&self) -> &str;
async fn run(&self, db: &Database) -> Result<()>;
async fn rollback(&self, _db: &Database) -> Result<()> {
Ok(())
}
fn priority(&self) -> u32 {
100
}
fn depends_on(&self) -> Vec<&str> {
Vec::new()
}
}
pub struct Seeder {
seeds: Vec<Box<dyn Seed>>,
}
impl Seeder {
pub fn new() -> Self {
Self { seeds: Vec::new() }
}
#[allow(clippy::should_implement_trait)]
pub fn add<S: Seed + 'static>(mut self, seed: S) -> Self {
self.seeds.push(Box::new(seed));
self
}
#[doc(hidden)]
pub fn add_boxed(mut self, seed: Box<dyn Seed>) -> Self {
self.seeds.push(seed);
self
}
pub async fn run(&self) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let executed = self.get_executed_seeds().await?;
let mut result = SeedResult::new();
let database = require_db()?;
let sorted_seeds = self.sort_seeds_by_priority_and_deps()?;
for seed in sorted_seeds {
let name = seed.name();
if executed.contains(&name.to_string()) {
result.skipped.push(SeedInfo {
name: name.to_string(),
});
continue;
}
for dep in seed.depends_on() {
if !executed.contains(&dep.to_string())
&& !result.executed.iter().any(|s| s.name == dep)
{
return Err(Error::configuration(format!(
"Seed '{}' depends on '{}' which has not been executed",
name, dep
)));
}
}
log_seed_start(name);
seed.run(&database).await?;
self.record_seed(name).await?;
result.executed.push(SeedInfo {
name: name.to_string(),
});
log_seed_complete(name);
}
Ok(result)
}
pub async fn run_seed(&self, seed_name: &str) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let database = require_db()?;
let mut result = SeedResult::new();
for seed in &self.seeds {
if seed.name() == seed_name {
log_seed_start(seed_name);
seed.run(&database).await?;
let executed = self.get_executed_seeds().await?;
if !executed.contains(&seed_name.to_string()) {
self.record_seed(seed_name).await?;
}
result.executed.push(SeedInfo {
name: seed_name.to_string(),
});
log_seed_complete(seed_name);
return Ok(result);
}
}
Err(Error::not_found(format!("Seed '{}' not found", seed_name)))
}
pub async fn rollback(&self) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let executed = self.get_executed_seeds().await?;
let mut result = SeedResult::new();
if executed.is_empty() {
return Ok(result);
}
let last_name = match executed.last() {
Some(n) => n,
None => return Ok(result),
};
let database = require_db()?;
for seed in &self.seeds {
if seed.name() == last_name {
log_seed_rollback(last_name);
seed.rollback(&database).await?;
self.remove_seed_record(last_name).await?;
result.rolled_back.push(SeedInfo {
name: seed.name().to_string(),
});
break;
}
}
Ok(result)
}
pub async fn rollback_seed(&self, seed_name: &str) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let database = require_db()?;
let mut result = SeedResult::new();
for seed in &self.seeds {
if seed.name() == seed_name {
log_seed_rollback(seed_name);
seed.rollback(&database).await?;
self.remove_seed_record(seed_name).await?;
result.rolled_back.push(SeedInfo {
name: seed_name.to_string(),
});
return Ok(result);
}
}
Err(Error::not_found(format!("Seed '{}' not found", seed_name)))
}
pub async fn rollback_steps(&self, steps: usize) -> Result<SeedResult> {
let mut result = SeedResult::new();
for _ in 0..steps {
let step_result = self.rollback().await?;
if step_result.rolled_back.is_empty() {
break;
}
result.rolled_back.extend(step_result.rolled_back);
}
Ok(result)
}
pub async fn reset(&self) -> Result<SeedResult> {
let executed = self.get_executed_seeds().await?;
self.rollback_steps(executed.len()).await
}
pub async fn refresh(&self) -> Result<SeedResult> {
let reset_result = self.reset().await?;
let run_result = self.run().await?;
Ok(SeedResult {
executed: run_result.executed,
skipped: run_result.skipped,
rolled_back: reset_result.rolled_back,
})
}
pub async fn status(&self) -> Result<Vec<SeedStatus>> {
self.ensure_seeds_table().await?;
let executed = self.get_executed_seeds().await?;
let mut status = Vec::new();
let sorted_seeds = self.sort_seeds_by_priority_and_deps()?;
for seed in sorted_seeds {
let is_executed = executed.contains(&seed.name().to_string());
status.push(SeedStatus {
name: seed.name().to_string(),
executed: is_executed,
priority: seed.priority(),
});
}
Ok(status)
}
fn sort_seeds_by_priority_and_deps(&self) -> Result<Vec<&dyn Seed>> {
use std::collections::{HashMap, HashSet, VecDeque};
let seeds: Vec<_> = self.seeds.iter().collect();
let name_to_idx: HashMap<String, usize> = seeds
.iter()
.enumerate()
.map(|(i, s)| (s.name().to_string(), i))
.collect();
let n = seeds.len();
let mut in_degree = vec![0usize; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for (i, seed) in seeds.iter().enumerate() {
for dep in seed.depends_on() {
if let Some(&dep_idx) = name_to_idx.get(dep) {
adj[dep_idx].push(i);
in_degree[i] += 1;
}
}
}
let mut queue: VecDeque<usize> = VecDeque::new();
let mut roots: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
roots.sort_by_key(|&i| seeds[i].priority());
for r in roots {
queue.push_back(r);
}
let mut sorted_indices: Vec<usize> = Vec::with_capacity(n);
let mut visited = HashSet::new();
while let Some(idx) = queue.pop_front() {
if !visited.insert(idx) {
continue;
}
sorted_indices.push(idx);
let mut next: Vec<usize> = Vec::new();
for &neighbor in &adj[idx] {
in_degree[neighbor] -= 1;
if in_degree[neighbor] == 0 {
next.push(neighbor);
}
}
next.sort_by_key(|&i| seeds[i].priority());
for n in next {
queue.push_back(n);
}
}
if sorted_indices.len() < n {
let mut remaining: Vec<usize> = (0..n).filter(|i| !visited.contains(i)).collect();
remaining.sort_by_key(|&i| seeds[i].priority());
let cycle_names = remaining
.into_iter()
.map(|i| seeds[i].name().to_string())
.collect::<Vec<_>>()
.join(", ");
return Err(Error::configuration(format!(
"Circular seed dependency detected involving: {}",
cycle_names
)));
}
Ok(sorted_indices
.into_iter()
.map(|i| seeds[i].as_ref())
.collect())
}
async fn ensure_seeds_table(&self) -> Result<()> {
let database = require_db()?;
let db_type = detect_database_type(&database);
let sql = match db_type {
DatabaseType::Postgres => {
r#"
CREATE TABLE IF NOT EXISTS "_seeds" (
"id" SERIAL PRIMARY KEY,
"name" VARCHAR(255) NOT NULL UNIQUE,
"executed_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
DatabaseType::MySQL | DatabaseType::MariaDB => {
r#"
CREATE TABLE IF NOT EXISTS `_seeds` (
`id` INT AUTO_INCREMENT PRIMARY KEY,
`name` VARCHAR(255) NOT NULL UNIQUE,
`executed_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
DatabaseType::SQLite => {
r#"
CREATE TABLE IF NOT EXISTS "_seeds" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"name" TEXT NOT NULL UNIQUE,
"executed_at" TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
};
database
.__internal_connection()?
.execute_unprepared(sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
async fn get_executed_seeds(&self) -> Result<Vec<String>> {
let database = require_db()?;
use crate::internal::Statement;
let backend = database.__internal_connection()?.get_database_backend();
let q = |id: &str| quote_identifier(id, backend);
let sql = format!(
"SELECT {} FROM {} ORDER BY {} ASC",
q("name"),
q("_seeds"),
q("executed_at")
);
let stmt = Statement::from_string(backend, sql);
let results = database
.__internal_connection()?
.query_all_raw(stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
let mut names = Vec::new();
for row in results {
let name: String = row
.try_get("", "name")
.map_err(|e| Error::query(e.to_string()))?;
names.push(name);
}
Ok(names)
}
async fn record_seed(&self, name: &str) -> Result<()> {
let database = require_db()?;
let backend = database.__internal_connection()?.get_database_backend();
let q = |id: &str| quote_identifier(id, backend);
let sql = format!(
"INSERT INTO {} ({}) VALUES ('{}')",
q("_seeds"),
q("name"),
name.replace('\'', "''")
);
database
.__internal_connection()?
.execute_unprepared(&sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
async fn remove_seed_record(&self, name: &str) -> Result<()> {
let database = require_db()?;
let backend = database.__internal_connection()?.get_database_backend();
let q = |id: &str| quote_identifier(id, backend);
let sql = format!(
"DELETE FROM {} WHERE {} = '{}'",
q("_seeds"),
q("name"),
name.replace('\'', "''")
);
database
.__internal_connection()?
.execute_unprepared(&sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
}
impl Default for Seeder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SeedResult {
pub executed: Vec<SeedInfo>,
pub skipped: Vec<SeedInfo>,
pub rolled_back: Vec<SeedInfo>,
}
impl SeedResult {
fn new() -> Self {
Self {
executed: Vec::new(),
skipped: Vec::new(),
rolled_back: Vec::new(),
}
}
pub fn has_executed(&self) -> bool {
!self.executed.is_empty()
}
pub fn has_rolled_back(&self) -> bool {
!self.rolled_back.is_empty()
}
pub fn total(&self) -> usize {
self.executed.len() + self.skipped.len() + self.rolled_back.len()
}
}
impl fmt::Display for SeedResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.executed.is_empty() {
writeln!(f, "Executed seeds:")?;
for s in &self.executed {
writeln!(f, " ✓ {}", s.name)?;
}
}
if !self.skipped.is_empty() {
writeln!(f, "Skipped seeds (already executed):")?;
for s in &self.skipped {
writeln!(f, " - {}", s.name)?;
}
}
if !self.rolled_back.is_empty() {
writeln!(f, "Rolled back seeds:")?;
for s in &self.rolled_back {
writeln!(f, " ↩ {}", s.name)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SeedInfo {
pub name: String,
}
#[derive(Debug, Clone)]
pub struct SeedStatus {
pub name: String,
pub executed: bool,
pub priority: u32,
}
impl fmt::Display for SeedStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let status = if self.executed { "✓" } else { "○" };
write!(
f,
"[{}] {} (priority: {})",
status, self.name, self.priority
)
}
}
fn detect_database_type(database: &Database) -> DatabaseType {
database.backend()
}
fn quote_identifier(name: &str, backend: crate::internal::DbBackend) -> String {
use crate::internal::DbBackend;
match backend {
DbBackend::MySql => format!("`{}`", name),
_ => format!(r#""{}""#, name), }
}
fn log_seed_start(name: &str) {
if std::env::var("TIDE_LOG_QUERIES").is_ok() || std::env::var("TIDE_LOG_SEEDS").is_ok() {
tide_info!("Seed running: {}", name);
}
}
fn log_seed_complete(name: &str) {
if std::env::var("TIDE_LOG_QUERIES").is_ok() || std::env::var("TIDE_LOG_SEEDS").is_ok() {
tide_info!("Seed completed: {}", name);
}
}
fn log_seed_rollback(name: &str) {
if std::env::var("TIDE_LOG_QUERIES").is_ok() || std::env::var("TIDE_LOG_SEEDS").is_ok() {
tide_info!("Seed rolling back: {}", name);
}
}
#[cfg(test)]
#[path = "testing/seeding_tests.rs"]
mod tests;