Skip to main content

ruest_db_migrate/
lib.rs

1//! Exécution des migrations RuestDB (`prisma/migrations` style → `ruestdb/migrations`).
2
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use ruest_db_runtime::RuestDb;
7use sqlx::Executor;
8use thiserror::Error;
9
10pub const MIGRATIONS_DIR: &str = "ruestdb/migrations";
11pub const SCHEMA_FILE: &str = "schema.ruest";
12
13#[derive(Debug, Error)]
14pub enum MigrateError {
15    #[error("io error: {0}")]
16    Io(#[from] std::io::Error),
17
18    #[error("parse error: {0}")]
19    Parse(#[from] ruest_db_parser::ParseError),
20
21    #[error("database error: {0}")]
22    Db(#[from] sqlx::Error),
23
24    #[error("{0}")]
25    Message(String),
26}
27
28/// Crée `schema.ruest` et le dossier migrations (projet neuf).
29pub fn db_init(project_root: &Path) -> Result<(), MigrateError> {
30    let schema_path = project_root.join(SCHEMA_FILE);
31    if !schema_path.exists() {
32        fs::write(&schema_path, DEFAULT_SCHEMA)?;
33        println!("Created {}", schema_path.display());
34    }
35
36    let migrations = project_root.join(MIGRATIONS_DIR);
37    fs::create_dir_all(&migrations)?;
38    println!("Created {}", migrations.display());
39    Ok(())
40}
41
42/// Génère `generated/ruestdb/` (client Rust type-safe).
43pub fn generate_client(project_root: &Path) -> Result<(), MigrateError> {
44    let schema_src = fs::read_to_string(project_root.join(SCHEMA_FILE))?;
45    let schema = ruest_db_parser::parse_schema(&schema_src)?;
46    let generated = ruest_db_codegen::generate_client(&schema);
47
48    let out = project_root.join("generated/ruestdb");
49    fs::create_dir_all(&out)?;
50    fs::write(out.join("mod.rs"), generated.root)?;
51
52    for (name, src) in generated.modules {
53        fs::write(out.join(format!("{name}.rs")), src)?;
54    }
55
56    println!("Generated RuestDB client in {}", out.display());
57    Ok(())
58}
59
60/// Génère une migration SQL depuis `schema.ruest`.
61pub fn create_migration(project_root: &Path, name: &str) -> Result<PathBuf, MigrateError> {
62    let schema_src = fs::read_to_string(project_root.join(SCHEMA_FILE))?;
63    let schema = ruest_db_parser::parse_schema(&schema_src)?;
64    let sql = ruest_db_codegen::generate_migration_sql(&schema);
65
66    let stamp = chrono_lite_timestamp();
67    let dir = project_root.join(MIGRATIONS_DIR).join(format!("{stamp}_{name}"));
68    fs::create_dir_all(&dir)?;
69    let file = dir.join("migration.sql");
70    fs::write(&file, sql)?;
71    println!("Created migration {}", dir.display());
72    Ok(dir)
73}
74
75/// Applique les migrations en attente (`ruest migrate dev` / `deploy`).
76pub async fn migrate_apply(project_root: &Path) -> Result<(), MigrateError> {
77    let db = RuestDb::connect_from_env()
78        .await
79        .map_err(|e| MigrateError::Message(e.to_string()))?;
80
81    ensure_migrations_table(db.pool()).await?;
82
83    let applied = applied_migrations(db.pool()).await?;
84    let mut pending = list_migrations(project_root)?;
85    pending.sort();
86
87    for dir in pending {
88        let name = dir
89            .file_name()
90            .and_then(|n| n.to_str())
91            .ok_or_else(|| MigrateError::Message("invalid migration dir".into()))?;
92        if applied.iter().any(|a| a == name) {
93            continue;
94        }
95        let sql_path = dir.join("migration.sql");
96        let sql = fs::read_to_string(&sql_path)?;
97        tracing::info!(migration = name, "applying");
98        db.pool().execute(sql.as_str()).await?;
99        sqlx::query("INSERT INTO _ruestdb_migrations (name) VALUES ($1)")
100            .bind(name)
101            .execute(db.pool())
102            .await?;
103        println!("Applied {name}");
104    }
105
106    Ok(())
107}
108
109/// Supprime les tables et réapplique (dangereux — dev uniquement).
110pub async fn migrate_reset(project_root: &Path) -> Result<(), MigrateError> {
111    let db = RuestDb::connect_from_env()
112        .await
113        .map_err(|e| MigrateError::Message(e.to_string()))?;
114
115    let schema_src = fs::read_to_string(project_root.join(SCHEMA_FILE))?;
116    let schema = ruest_db_parser::parse_schema(&schema_src)?;
117
118    for model in schema.models.iter().rev() {
119        let table = ruest_db_codegen::table_name(&model.name);
120        let sql = format!("DROP TABLE IF EXISTS \"{table}\" CASCADE");
121        db.pool().execute(sql.as_str()).await.ok();
122    }
123
124    sqlx::query("DROP TABLE IF EXISTS _ruestdb_migrations CASCADE")
125        .execute(db.pool())
126        .await?;
127
128    create_migration(project_root, "init")?;
129    migrate_apply(project_root).await
130}
131
132async fn ensure_migrations_table(pool: &sqlx::PgPool) -> Result<(), sqlx::Error> {
133    sqlx::query(
134        r#"
135        CREATE TABLE IF NOT EXISTS _ruestdb_migrations (
136            name TEXT PRIMARY KEY,
137            applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
138        )
139        "#,
140    )
141    .execute(pool)
142    .await?;
143    Ok(())
144}
145
146async fn applied_migrations(pool: &sqlx::PgPool) -> Result<Vec<String>, sqlx::Error> {
147    let rows = sqlx::query_scalar::<_, String>("SELECT name FROM _ruestdb_migrations ORDER BY name")
148        .fetch_all(pool)
149        .await?;
150    Ok(rows)
151}
152
153fn list_migrations(project_root: &Path) -> Result<Vec<PathBuf>, MigrateError> {
154    let dir = project_root.join(MIGRATIONS_DIR);
155    if !dir.exists() {
156        return Ok(Vec::new());
157    }
158    let mut out = Vec::new();
159    for entry in fs::read_dir(dir)? {
160        let entry = entry?;
161        if entry.file_type()?.is_dir() {
162            out.push(entry.path());
163        }
164    }
165    Ok(out)
166}
167
168fn chrono_lite_timestamp() -> String {
169    use std::time::{SystemTime, UNIX_EPOCH};
170    let secs = SystemTime::now()
171        .duration_since(UNIX_EPOCH)
172        .unwrap()
173        .as_secs();
174    format!("{secs}")
175}
176
177const DEFAULT_SCHEMA: &str = r#"// RuestDB schema — https://github.com/hardhacklife/ruest
178model User {
179  id    String @id @default(uuid())
180  email String @unique
181  name  String
182}
183"#;