1use std::fs;
4use std::path::{Path, PathBuf};
5
6use ruest_db_runtime::RuestDb;
7use sqlx::Executor;
8use thiserror::Error;
9
10pub const MIGRATIONS_DIR: &str = "ruestdb/migrations";
11pub const SCHEMA_FILE: &str = "schema.ruest";
12
13#[derive(Debug, Error)]
14pub enum MigrateError {
15 #[error("io error: {0}")]
16 Io(#[from] std::io::Error),
17
18 #[error("parse error: {0}")]
19 Parse(#[from] ruest_db_parser::ParseError),
20
21 #[error("database error: {0}")]
22 Db(#[from] sqlx::Error),
23
24 #[error("{0}")]
25 Message(String),
26}
27
28pub fn db_init(project_root: &Path) -> Result<(), MigrateError> {
30 let schema_path = project_root.join(SCHEMA_FILE);
31 if !schema_path.exists() {
32 fs::write(&schema_path, DEFAULT_SCHEMA)?;
33 println!("Created {}", schema_path.display());
34 }
35
36 let migrations = project_root.join(MIGRATIONS_DIR);
37 fs::create_dir_all(&migrations)?;
38 println!("Created {}", migrations.display());
39 Ok(())
40}
41
42pub fn generate_client(project_root: &Path) -> Result<(), MigrateError> {
44 let schema_src = fs::read_to_string(project_root.join(SCHEMA_FILE))?;
45 let schema = ruest_db_parser::parse_schema(&schema_src)?;
46 let generated = ruest_db_codegen::generate_client(&schema);
47
48 let out = project_root.join("generated/ruestdb");
49 fs::create_dir_all(&out)?;
50 fs::write(out.join("mod.rs"), generated.root)?;
51
52 for (name, src) in generated.modules {
53 fs::write(out.join(format!("{name}.rs")), src)?;
54 }
55
56 println!("Generated RuestDB client in {}", out.display());
57 Ok(())
58}
59
60pub fn create_migration(project_root: &Path, name: &str) -> Result<PathBuf, MigrateError> {
62 let schema_src = fs::read_to_string(project_root.join(SCHEMA_FILE))?;
63 let schema = ruest_db_parser::parse_schema(&schema_src)?;
64 let sql = ruest_db_codegen::generate_migration_sql(&schema);
65
66 let stamp = chrono_lite_timestamp();
67 let dir = project_root.join(MIGRATIONS_DIR).join(format!("{stamp}_{name}"));
68 fs::create_dir_all(&dir)?;
69 let file = dir.join("migration.sql");
70 fs::write(&file, sql)?;
71 println!("Created migration {}", dir.display());
72 Ok(dir)
73}
74
75pub async fn migrate_apply(project_root: &Path) -> Result<(), MigrateError> {
77 let db = RuestDb::connect_from_env()
78 .await
79 .map_err(|e| MigrateError::Message(e.to_string()))?;
80
81 ensure_migrations_table(db.pool()).await?;
82
83 let applied = applied_migrations(db.pool()).await?;
84 let mut pending = list_migrations(project_root)?;
85 pending.sort();
86
87 for dir in pending {
88 let name = dir
89 .file_name()
90 .and_then(|n| n.to_str())
91 .ok_or_else(|| MigrateError::Message("invalid migration dir".into()))?;
92 if applied.iter().any(|a| a == name) {
93 continue;
94 }
95 let sql_path = dir.join("migration.sql");
96 let sql = fs::read_to_string(&sql_path)?;
97 tracing::info!(migration = name, "applying");
98 db.pool().execute(sql.as_str()).await?;
99 sqlx::query("INSERT INTO _ruestdb_migrations (name) VALUES ($1)")
100 .bind(name)
101 .execute(db.pool())
102 .await?;
103 println!("Applied {name}");
104 }
105
106 Ok(())
107}
108
109pub async fn migrate_reset(project_root: &Path) -> Result<(), MigrateError> {
111 let db = RuestDb::connect_from_env()
112 .await
113 .map_err(|e| MigrateError::Message(e.to_string()))?;
114
115 let schema_src = fs::read_to_string(project_root.join(SCHEMA_FILE))?;
116 let schema = ruest_db_parser::parse_schema(&schema_src)?;
117
118 for model in schema.models.iter().rev() {
119 let table = ruest_db_codegen::table_name(&model.name);
120 let sql = format!("DROP TABLE IF EXISTS \"{table}\" CASCADE");
121 db.pool().execute(sql.as_str()).await.ok();
122 }
123
124 sqlx::query("DROP TABLE IF EXISTS _ruestdb_migrations CASCADE")
125 .execute(db.pool())
126 .await?;
127
128 create_migration(project_root, "init")?;
129 migrate_apply(project_root).await
130}
131
132async fn ensure_migrations_table(pool: &sqlx::PgPool) -> Result<(), sqlx::Error> {
133 sqlx::query(
134 r#"
135 CREATE TABLE IF NOT EXISTS _ruestdb_migrations (
136 name TEXT PRIMARY KEY,
137 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
138 )
139 "#,
140 )
141 .execute(pool)
142 .await?;
143 Ok(())
144}
145
146async fn applied_migrations(pool: &sqlx::PgPool) -> Result<Vec<String>, sqlx::Error> {
147 let rows = sqlx::query_scalar::<_, String>("SELECT name FROM _ruestdb_migrations ORDER BY name")
148 .fetch_all(pool)
149 .await?;
150 Ok(rows)
151}
152
153fn list_migrations(project_root: &Path) -> Result<Vec<PathBuf>, MigrateError> {
154 let dir = project_root.join(MIGRATIONS_DIR);
155 if !dir.exists() {
156 return Ok(Vec::new());
157 }
158 let mut out = Vec::new();
159 for entry in fs::read_dir(dir)? {
160 let entry = entry?;
161 if entry.file_type()?.is_dir() {
162 out.push(entry.path());
163 }
164 }
165 Ok(out)
166}
167
168fn chrono_lite_timestamp() -> String {
169 use std::time::{SystemTime, UNIX_EPOCH};
170 let secs = SystemTime::now()
171 .duration_since(UNIX_EPOCH)
172 .unwrap()
173 .as_secs();
174 format!("{secs}")
175}
176
177const DEFAULT_SCHEMA: &str = r#"// RuestDB schema — https://github.com/hardhacklife/ruest
178model User {
179 id String @id @default(uuid())
180 email String @unique
181 name String
182}
183"#;