1use {
2 good_ormning_core::{
3 pg::{
4 graph::utils::PgMigrateCtx,
5 query::utils::{
6 PgFieldInfo,
7 PgTableInfo,
8 },
9 schema::{
10 field::FieldRef,
11 table::TableRef,
12 },
13 },
14 utils::Errs,
15 },
16 convert_case::{
17 Casing,
18 Case,
19 },
20 quote::{
21 format_ident,
22 quote,
23 },
24 std::{
25 collections::HashMap,
26 env,
27 fs,
28 path::Path,
29 },
30};
31pub use {
32 good_ormning_core::pg::*,
33 good_ormning_macros::{
34 good_query_many_pg as good_query_many,
35 good_query_one_pg as good_query_one,
36 good_query_opt_pg as good_query_opt,
37 good_query_pg as good_query,
38 },
39};
40
41pub struct GenerateArgs {
42 pub db_name: Option<String>,
45 pub versions: Vec<(usize, Version)>,
52 pub queries: Vec<Query>,
55}
56
57impl Default for GenerateArgs {
58 fn default() -> Self {
59 Self {
60 db_name: None,
61 versions: vec![],
62 queries: vec![],
63 }
64 }
65}
66
67pub fn generate(args: GenerateArgs) -> Result<(), Vec<String>> {
74 let db_name = args.db_name.as_deref().unwrap_or(good_ormning_core::utils::DEFAULT_DB_NAME);
75 let out_dir = env::var("OUT_DIR").map_err(|e| vec![format!("OUT_DIR not set: {:?}", e)])?;
76 let out_dir = Path::new(&out_dir);
77 let output = out_dir.join(good_ormning_core::utils::rs_file_name(db_name));
78 let json_dir = out_dir.join("good_ormning");
79 if let Err(e) = fs::create_dir_all(&json_dir) {
80 return Err(vec![format!("Error creating directory {:?}: {:?}", json_dir, e)]);
81 }
82 let json_path = json_dir.join(good_ormning_core::utils::json_file_name(db_name));
83
84 {
86 let mut versions_map: HashMap<usize, Version> = if json_path.exists() {
87 serde_json::from_str(&fs::read_to_string(&json_path).unwrap()).unwrap_or_default()
88 } else {
89 HashMap::new()
90 };
91 for (version_i, version) in args.versions.iter() {
92 let entry = versions_map.entry(*version_i).or_insert_with(|| Version::default());
93 for (k, v) in &version.tables {
94 entry.tables.insert(k.clone(), v.clone());
95 }
96 for (k, v) in &version.custom_types {
97 entry.custom_types.insert(k.clone(), v.clone());
98 }
99 }
100 let _ = fs::write(json_path, serde_json::to_string(&versions_map).unwrap());
101 }
102 let mut errs = Errs::new();
103 let mut migrations = vec![];
104 let mut prev_version: Option<Version> = None;
105 let mut prev_version_i: Option<i64> = None;
106 let mut field_lookup: HashMap<TableRef, PgTableInfo> = HashMap::new();
107 for (version_i, version) in &args.versions {
108 let path = rpds::vector![format!("Migration to {}", version_i)];
109 let mut migration = vec![];
110
111 field_lookup.clear();
113 for (table_id, table) in &version.tables {
114 let mut fields: HashMap<FieldRef, PgFieldInfo> = HashMap::new();
115 for (field_id, field) in &table.fields {
116 fields.insert(FieldRef {
117 table_id: table_id.clone(),
118 field_id: field_id.clone(),
119 }, PgFieldInfo {
120 sql_name: field.id.clone(),
121 type_: field.type_.type_.clone(),
122 });
123 }
124 field_lookup.insert(TableRef(table_id.clone()), PgTableInfo {
125 sql_name: table.id.clone(),
126 fields: fields,
127 });
128 }
129 let version_i = *version_i as i64;
130 if let Some(i) = prev_version_i {
131 if version_i != i as i64 + 1 {
132 errs.err(
133 &path,
134 format!(
135 "Version numbers are not consecutive ({} to {}) - was an intermediate version deleted?",
136 i,
137 version_i
138 ),
139 );
140 }
141 }
142
143 {
145 let mut table_sql_names = HashMap::new();
146 for (table_id, table) in &version.tables {
147 table_sql_names.insert(table_id.clone(), table.id.clone());
148 }
149 let mut state = PgMigrateCtx::new(errs.clone(), table_sql_names, version.clone());
150 let current_nodes = version.to_migrate_nodes();
151 let prev_nodes = prev_version.take().map(|s| s.to_migrate_nodes());
152 good_ormning_core::graphmigrate::migrate(&mut state, prev_nodes, ¤t_nodes);
153 for statement in &state.statements {
154 migration.push(quote!{
155 {
156 let query = #statement;
157 txn.execute(query, &[]).await.to_good_error_query(query)?;
158 };
159 });
160 }
161 errs = state.errs.clone();
162 }
163
164 let pascal_db_name: String = db_name.to_case(Case::Pascal);
166 let enum_name = format_ident!("Db{}Versions", pascal_db_name);
167 let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i as usize);
168 let enum_variant = format_ident!("V{}", version_i as usize);
169 migrations.push(quote!{
170 if version < #version_i {
171 #(#migration) * {
172 let query = "update __good_version set version = $1";
173 good_ormning:: runtime:: pg:: PgConnection:: execute(
174 &mut txn,
175 query,
176 &[& #version_i]
177 ).await.to_good_error_query(query) ?;
178 }
179 if let Some(callback) = & callback {
180 callback(#enum_name::#enum_variant(#newtype_name(&mut txn))).await ?;
181 }
182 }
183 });
184
185 prev_version = Some(version.clone());
187 prev_version_i = Some(version_i);
188 }
189
190 let last_version_i = prev_version_i.unwrap() as i64;
192 let pascal_db_name: String = db_name.to_case(Case::Pascal);
193 let enum_name = format_ident!("Db{}Versions", pascal_db_name);
194 let mut enum_variants = vec![];
195 let mut db_types = vec![];
196 for (version_i, _) in &args.versions {
197 let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i);
198 let enum_variant = format_ident!("V{}", version_i);
199 enum_variants.push(quote!(#enum_variant(#newtype_name <'a >)));
200 db_types.push(quote!{
201 pub struct #newtype_name <'a >(pub &'a mut dyn good_ormning:: runtime:: pg:: PgConnection);
202 });
203 }
204 let latest_newtype_name = format_ident!("Db{}{}", pascal_db_name, last_version_i as usize);
205 let db_alias_name = format_ident!("Db{}", pascal_db_name);
206 let db_others =
207 good_ormning_core::pg::query::generate::generate_query_functions(
208 &mut errs,
209 field_lookup,
210 args.queries,
211 "",
212 quote!(#latest_newtype_name),
213 );
214 let tokens = quote!{
215 use good_ormning::runtime::GoodError;
216 use good_ormning::runtime::ToGoodError;
217 #(#db_types) * pub enum #enum_name <'a > {
218 #(#enum_variants,) *
219 }
220 pub use #latest_newtype_name as #db_alias_name;
221 async fn init_db(db: & mut impl good_ormning:: runtime:: pg:: PgConnection) -> Result <(),
222 GoodError > {
223 {
224 let query =
225 "create table if not exists __good_version (rid int primary key, version bigint not null, lock int not null);";
226 good_ormning::runtime::pg::PgConnection::execute(db, query, &[]).await.to_good_error_query(query)?;
227 }
228 {
229 let query =
230 "insert into __good_version (rid, version, lock) values (0, -1, 0) on conflict do nothing;";
231 good_ormning::runtime::pg::PgConnection::execute(db, query, &[]).await.to_good_error_query(query)?;
232 }
233 Ok(())
234 }
235 #[
236 doc =
237 "(Initialize and) migrate the database to the latest schema version. Optionally takes a callback which is run after each version, so custom post-schema change code can be run. Use `good_query!` macros with the version parameter to do migrations."
238 ] pub async fn migrate(
239 db: & mut tokio_postgres:: Client,
240 callback: Option <&(
241 dyn for <'b > Fn(
242 #enum_name <'b >
243 ) -> std:: pin:: Pin < Box < dyn std:: future:: Future < Output = Result <(),
244 GoodError >> + Send + 'b >> + Send + Sync
245 ) >
246 ) -> Result <(),
247 GoodError > {
248 init_db(db).await?;
249 loop {
250 let mut txn = db.transaction().await.to_good_error(|| "Failed to start transaction".to_string())?;
251 let migrated = {
252 let query = "update __good_version set lock = 1 where rid = 0 and lock = 0 returning version";
253 let res =
254 good_ormning::runtime::pg::PgConnection::query(&mut txn, query, &[])
255 .await
256 .to_good_error_query(query)?;
257 let version = match res.first() {
258 Some(r) => {
259 let ver: i64 = r.get(0usize);
260 Some(ver)
261 },
262 None => {
263 None
264 },
265 };
266 if let Some(version) = version {
267 if version > #last_version_i {
268 return Err(
269 GoodError(
270 format!(
271 "The latest known version is {}, but the schema is at unknown version {}",
272 #last_version_i,
273 version
274 ),
275 ),
276 );
277 }
278 #(#migrations) * {
279 let query = "update __good_version set lock = 0";
280 good_ormning::runtime::pg::PgConnection::execute(&mut txn, query, &[])
281 .await
282 .to_good_error_query(query)?;
283 }
284 true
285 }
286 else {
287 false
288 }
289 };
290 if migrated {
291 txn.commit().await.to_good_error(|| "Failed to commit transaction".to_string())?;
292 return Ok(());
293 }
294 else {
295 txn.rollback().await.to_good_error(|| "Failed to rollback transaction".to_string())?;
296 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
297 }
298 }
299 }
300 pub async fn get_schema_version(
301 db: & mut impl good_ormning:: runtime:: pg:: PgConnection
302 ) -> Result < Option < i64 >,
303 GoodError > {
304 init_db(db).await?;
305 let query = "select version from __good_version where rid = 0";
306 let res = db.query(query, &[]).await.to_good_error_query(query)?;
307 if let Some(r) = res.first() {
308 let x: i64 = r.get(0usize);
309 if x == -1 {
310 return Ok(None);
311 } else {
312 return Ok(Some(x));
313 }
314 }
315 Ok(None)
316 }
317 #(#db_others) *
318 };
319 match genemichaels_lib::format_str(&tokens.to_string(), &genemichaels_lib::FormatConfig::default()) {
320 Ok(src) => {
321 match fs::write(&output, src.rendered.as_bytes()) {
322 Ok(_) => { },
323 Err(e) => errs.err(
324 &rpds::vector![],
325 format!("Failed to write generated code to {:?}: {:?}", output, e),
326 ),
327 };
328 },
329 Err(e) => {
330 errs.err(&rpds::vector![], format!("Error formatting generated code: {:?}\n{}", e, tokens));
331 },
332 };
333 errs.raise()?;
334 Ok(())
335}
336
337#[cfg(test)]
338mod test {
339 use {
340 super::{
341 generate,
342 GenerateArgs,
343 query::expr::SerialExpr,
344 schema::field::{
345 field_auto,
346 field_i32,
347 field_str,
348 },
349 Version,
350 },
351 };
352
353 #[test]
354 fn test_add_field_serial_bad() {
355 assert!(generate(GenerateArgs {
356 db_name: None,
357 versions: vec![
358 (0usize, {
360 let v = Version::new();
361 v.table("zMOY9YMCK").field("z437INV6D", field_str().build());
362 v.build()
363 }),
364 (1usize, {
365 let v = Version::new();
366 let bananna = v.table("zMOY9YMCK");
367 bananna.field("z437INV6D", field_str().build());
368 bananna.field("zPREUVAOD", field_auto().migrate_fill(SerialExpr::LitAuto(0)).build(),);
369 v.build()
370 }),
371 ],
372 ..Default::default()
373 }).is_err());
374 }
375
376 #[test]
377 #[should_panic]
378 fn test_add_field_dup_bad() {
379 generate(GenerateArgs {
380 db_name: None,
381 versions: vec![
382 (0usize, {
384 let v = Version::new();
385 v.table("zPAO2PJU4").field("z437INV6D", field_str().build());
386 v.build()
387 }),
388 (1usize, {
389 let v = Version::new();
390 let bananna = v.table("zQZQ8E2WD");
391 bananna.field("z437INV6D", field_str().build());
392 bananna.field("z437INV6D", field_i32().build());
393 v.build()
394 }),
395 ],
396 ..Default::default()
397 }).unwrap();
398 }
399
400 #[test]
401 #[should_panic]
402 fn test_add_table_dup_bad() {
403 generate(GenerateArgs {
404 db_name: None,
405 versions: vec![
406 (0usize, {
408 let v = Version::new();
409 v.table("zSNS34DYI").field("z437INV6D", field_str().build());
410 v.build()
411 }),
412 (1usize, {
413 let v = Version::new();
414 v.table("zSNS34DYI").field("z437INV6D", field_str().build());
415 v.table("zSNS34DYI").field("z437INV6D", field_str().build());
416 v.build()
417 }),
418 ],
419 ..Default::default()
420 }).unwrap();
421 }
422
423 #[test]
424 fn test_res_count_none_bad() {
425 let v = Version::new();
426 let bananna = v.table("z5S18LWQE");
427 bananna.field("z437INV6D", field_str().build());
428 assert!(generate(GenerateArgs {
429 db_name: None,
430 versions: vec![(0usize, v.build())],
431 ..Default::default()
432 }).is_err());
433 }
434
435 #[test]
436 fn test_select_nothing_bad() {
437 let v = Version::new();
438 v.table("zOOR88EQ9").field("z437INV6D", field_str().build());
439 assert!(generate(GenerateArgs {
440 db_name: None,
441 versions: vec![(0usize, v.build())],
442 ..Default::default()
443 }).is_err());
444 }
445
446 #[test]
447 fn test_returning_none_bad() {
448 let v = Version::new();
449 let bananna = v.table("zZPD1I2EF");
450 bananna.field("z437INV6D", field_str().build());
451 assert!(generate(GenerateArgs {
452 db_name: None,
453 versions: vec![(0usize, v.build())],
454 ..Default::default()
455 }).is_err());
456 }
457}