use std::fmt;
use crate::config::DatabaseType;
use crate::database::{db, Database};
use crate::error::{Error, Result};
use crate::internal::ConnectionTrait;
pub use async_trait::async_trait;
#[async_trait]
pub trait Seed: Send + Sync {
fn name(&self) -> &str;
async fn run(&self, db: &Database) -> Result<()>;
async fn rollback(&self, _db: &Database) -> Result<()> {
Ok(())
}
fn priority(&self) -> u32 {
100
}
fn depends_on(&self) -> Vec<&str> {
Vec::new()
}
}
pub struct Seeder {
seeds: Vec<Box<dyn Seed>>,
}
impl Seeder {
pub fn new() -> Self {
Self { seeds: Vec::new() }
}
pub fn add<S: Seed + 'static>(mut self, seed: S) -> Self {
self.seeds.push(Box::new(seed));
self
}
#[doc(hidden)]
pub fn add_boxed(mut self, seed: Box<dyn Seed>) -> Self {
self.seeds.push(seed);
self
}
pub async fn run(&self) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let executed = self.get_executed_seeds().await?;
let mut result = SeedResult::new();
let database = db();
let sorted_seeds = self.sort_seeds_by_priority_and_deps();
for seed in sorted_seeds {
let name = seed.name();
if executed.contains(&name.to_string()) {
result.skipped.push(SeedInfo {
name: name.to_string(),
});
continue;
}
for dep in seed.depends_on() {
if !executed.contains(&dep.to_string())
&& !result.executed.iter().any(|s| s.name == dep)
{
return Err(Error::configuration(format!(
"Seed '{}' depends on '{}' which has not been executed",
name, dep
)));
}
}
log_seed_start(name);
seed.run(database).await?;
self.record_seed(name).await?;
result.executed.push(SeedInfo {
name: name.to_string(),
});
log_seed_complete(name);
}
Ok(result)
}
pub async fn run_seed(&self, seed_name: &str) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let database = db();
let mut result = SeedResult::new();
for seed in &self.seeds {
if seed.name() == seed_name {
log_seed_start(seed_name);
seed.run(database).await?;
let executed = self.get_executed_seeds().await?;
if !executed.contains(&seed_name.to_string()) {
self.record_seed(seed_name).await?;
}
result.executed.push(SeedInfo {
name: seed_name.to_string(),
});
log_seed_complete(seed_name);
return Ok(result);
}
}
Err(Error::not_found(format!("Seed '{}' not found", seed_name)))
}
pub async fn rollback(&self) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let executed = self.get_executed_seeds().await?;
let mut result = SeedResult::new();
if executed.is_empty() {
return Ok(result);
}
let last_name = executed.last().unwrap();
let database = db();
for seed in &self.seeds {
if seed.name() == last_name {
log_seed_rollback(last_name);
seed.rollback(database).await?;
self.remove_seed_record(last_name).await?;
result.rolled_back.push(SeedInfo {
name: seed.name().to_string(),
});
break;
}
}
Ok(result)
}
pub async fn rollback_seed(&self, seed_name: &str) -> Result<SeedResult> {
self.ensure_seeds_table().await?;
let database = db();
let mut result = SeedResult::new();
for seed in &self.seeds {
if seed.name() == seed_name {
log_seed_rollback(seed_name);
seed.rollback(database).await?;
self.remove_seed_record(seed_name).await?;
result.rolled_back.push(SeedInfo {
name: seed_name.to_string(),
});
return Ok(result);
}
}
Err(Error::not_found(format!("Seed '{}' not found", seed_name)))
}
pub async fn rollback_steps(&self, steps: usize) -> Result<SeedResult> {
let mut result = SeedResult::new();
for _ in 0..steps {
let step_result = self.rollback().await?;
if step_result.rolled_back.is_empty() {
break;
}
result.rolled_back.extend(step_result.rolled_back);
}
Ok(result)
}
pub async fn reset(&self) -> Result<SeedResult> {
let executed = self.get_executed_seeds().await?;
self.rollback_steps(executed.len()).await
}
pub async fn refresh(&self) -> Result<SeedResult> {
let reset_result = self.reset().await?;
let run_result = self.run().await?;
Ok(SeedResult {
executed: run_result.executed,
skipped: run_result.skipped,
rolled_back: reset_result.rolled_back,
})
}
pub async fn status(&self) -> Result<Vec<SeedStatus>> {
self.ensure_seeds_table().await?;
let executed = self.get_executed_seeds().await?;
let mut status = Vec::new();
let sorted_seeds = self.sort_seeds_by_priority_and_deps();
for seed in sorted_seeds {
let is_executed = executed.contains(&seed.name().to_string());
status.push(SeedStatus {
name: seed.name().to_string(),
executed: is_executed,
priority: seed.priority(),
});
}
Ok(status)
}
fn sort_seeds_by_priority_and_deps(&self) -> Vec<&Box<dyn Seed>> {
let mut seeds: Vec<_> = self.seeds.iter().collect();
seeds.sort_by_key(|s| s.priority());
seeds
}
async fn ensure_seeds_table(&self) -> Result<()> {
let database = db();
let db_type = detect_database_type(database);
let sql = match db_type {
DatabaseType::Postgres => {
r#"
CREATE TABLE IF NOT EXISTS "_seeds" (
"id" SERIAL PRIMARY KEY,
"name" VARCHAR(255) NOT NULL UNIQUE,
"executed_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
DatabaseType::MySQL => {
r#"
CREATE TABLE IF NOT EXISTS `_seeds` (
`id` INT AUTO_INCREMENT PRIMARY KEY,
`name` VARCHAR(255) NOT NULL UNIQUE,
`executed_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
DatabaseType::SQLite => {
r#"
CREATE TABLE IF NOT EXISTS "_seeds" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"name" TEXT NOT NULL UNIQUE,
"executed_at" TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
};
database
.__internal_connection()
.execute_unprepared(sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
async fn get_executed_seeds(&self) -> Result<Vec<String>> {
let database = db();
use crate::internal::Statement;
let backend = database.__internal_connection().get_database_backend();
let sql = r#"SELECT "name" FROM "_seeds" ORDER BY "executed_at" ASC"#;
let stmt = Statement::from_string(backend, sql.to_string());
let results = database
.__internal_connection()
.query_all_raw(stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
let mut names = Vec::new();
for row in results {
let name: String = row
.try_get("", "name")
.map_err(|e| Error::query(e.to_string()))?;
names.push(name);
}
Ok(names)
}
async fn record_seed(&self, name: &str) -> Result<()> {
let database = db();
let sql = format!(
r#"INSERT INTO "_seeds" ("name") VALUES ('{}')"#,
name.replace('\'', "''")
);
database
.__internal_connection()
.execute_unprepared(&sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
async fn remove_seed_record(&self, name: &str) -> Result<()> {
let database = db();
let sql = format!(
r#"DELETE FROM "_seeds" WHERE "name" = '{}'"#,
name.replace('\'', "''")
);
database
.__internal_connection()
.execute_unprepared(&sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
}
impl Default for Seeder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SeedResult {
pub executed: Vec<SeedInfo>,
pub skipped: Vec<SeedInfo>,
pub rolled_back: Vec<SeedInfo>,
}
impl SeedResult {
fn new() -> Self {
Self {
executed: Vec::new(),
skipped: Vec::new(),
rolled_back: Vec::new(),
}
}
pub fn has_executed(&self) -> bool {
!self.executed.is_empty()
}
pub fn has_rolled_back(&self) -> bool {
!self.rolled_back.is_empty()
}
pub fn total(&self) -> usize {
self.executed.len() + self.skipped.len() + self.rolled_back.len()
}
}
impl fmt::Display for SeedResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.executed.is_empty() {
writeln!(f, "Executed seeds:")?;
for s in &self.executed {
writeln!(f, " ✓ {}", s.name)?;
}
}
if !self.skipped.is_empty() {
writeln!(f, "Skipped seeds (already executed):")?;
for s in &self.skipped {
writeln!(f, " - {}", s.name)?;
}
}
if !self.rolled_back.is_empty() {
writeln!(f, "Rolled back seeds:")?;
for s in &self.rolled_back {
writeln!(f, " ↩ {}", s.name)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SeedInfo {
pub name: String,
}
#[derive(Debug, Clone)]
pub struct SeedStatus {
pub name: String,
pub executed: bool,
pub priority: u32,
}
impl fmt::Display for SeedStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let status = if self.executed { "✓" } else { "○" };
write!(f, "[{}] {} (priority: {})", status, self.name, self.priority)
}
}
fn detect_database_type(database: &Database) -> DatabaseType {
use crate::internal::DbBackend;
match database.__internal_connection().get_database_backend() {
DbBackend::Postgres => DatabaseType::Postgres,
DbBackend::MySql => DatabaseType::MySQL,
DbBackend::Sqlite => DatabaseType::SQLite,
_ => DatabaseType::Postgres, }
}
fn log_seed_start(name: &str) {
eprintln!("Running seed: {}", name);
}
fn log_seed_complete(name: &str) {
eprintln!("Completed seed: {}", name);
}
fn log_seed_rollback(name: &str) {
eprintln!("Rolling back seed: {}", name);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_seed_result_new() {
let result = SeedResult::new();
assert!(result.executed.is_empty());
assert!(result.skipped.is_empty());
assert!(result.rolled_back.is_empty());
assert!(!result.has_executed());
assert!(!result.has_rolled_back());
}
#[test]
fn test_seed_result_has_executed() {
let mut result = SeedResult::new();
result.executed.push(SeedInfo {
name: "test_seed".to_string(),
});
assert!(result.has_executed());
assert!(!result.has_rolled_back());
}
#[test]
fn test_seed_result_total() {
let mut result = SeedResult::new();
result.executed.push(SeedInfo {
name: "seed1".to_string(),
});
result.skipped.push(SeedInfo {
name: "seed2".to_string(),
});
result.rolled_back.push(SeedInfo {
name: "seed3".to_string(),
});
assert_eq!(result.total(), 3);
}
#[test]
fn test_seed_result_display() {
let mut result = SeedResult::new();
result.executed.push(SeedInfo {
name: "user_seeder".to_string(),
});
result.skipped.push(SeedInfo {
name: "category_seeder".to_string(),
});
let display = format!("{}", result);
assert!(display.contains("user_seeder"));
assert!(display.contains("category_seeder"));
assert!(display.contains("Executed seeds"));
assert!(display.contains("Skipped seeds"));
}
#[test]
fn test_seed_status_display() {
let status = SeedStatus {
name: "user_seeder".to_string(),
executed: true,
priority: 100,
};
let display = format!("{}", status);
assert!(display.contains("[✓]"));
assert!(display.contains("user_seeder"));
assert!(display.contains("priority: 100"));
}
#[test]
fn test_seed_status_not_executed() {
let status = SeedStatus {
name: "product_seeder".to_string(),
executed: false,
priority: 50,
};
let display = format!("{}", status);
assert!(display.contains("[○]"));
assert!(display.contains("product_seeder"));
}
#[test]
fn test_seeder_default() {
let seeder = Seeder::default();
assert!(seeder.seeds.is_empty());
}
#[test]
fn test_seed_info() {
let info = SeedInfo {
name: "test_seed".to_string(),
};
assert_eq!(info.name, "test_seed");
}
}