Skip to main content

good_ormning/pg/
mod.rs

1use {
2    good_ormning_core::{
3        pg::{
4            graph::utils::PgMigrateCtx,
5            query::utils::{
6                PgFieldInfo,
7                PgTableInfo,
8            },
9            schema::{
10                field::FieldRef,
11                table::TableRef,
12            },
13        },
14        utils::Errs,
15    },
16    convert_case::{
17        Casing,
18        Case,
19    },
20    quote::{
21        format_ident,
22        quote,
23    },
24    std::{
25        collections::HashMap,
26        env,
27        fs,
28        path::Path,
29    },
30};
31pub use {
32    good_ormning_core::pg::*,
33    good_ormning_macros::{
34        good_query_many_pg as good_query_many,
35        good_query_one_pg as good_query_one,
36        good_query_opt_pg as good_query_opt,
37        good_query_pg as good_query,
38    },
39};
40
41pub struct GenerateArgs {
42    /// If you have multiple databases, use this to disambiguate them. You'll also need
43    /// to use it in `good_module!` and `good_query!`.
44    pub db_name: Option<String>,
45    /// A list of database version ids and schema versions. The ids must be consecutive
46    /// but can start from any number. Once a version has been applied to a production
47    /// database it shouldn't be modified again (modifications should be done in a new
48    /// version).
49    ///
50    /// These will be turned into migrations as part of the `migrate` function.
51    pub versions: Vec<(usize, Version)>,
52    /// A list of queries to generate type-safe functions for. Use instead of or to
53    /// supplement `good_query!` if you want to programmatically generate queries.
54    pub queries: Vec<Query>,
55}
56
57impl Default for GenerateArgs {
58    fn default() -> Self {
59        Self {
60            db_name: None,
61            versions: vec![],
62            queries: vec![],
63        }
64    }
65}
66
67/// Generate Rust code for migrations and queries, also saves schema type
68/// information for use by queries.
69///
70/// # Returns
71///
72/// * Error - a list of validation or generation errors that occurred
73pub fn generate(args: GenerateArgs) -> Result<(), Vec<String>> {
74    let db_name = args.db_name.as_deref().unwrap_or(good_ormning_core::utils::DEFAULT_DB_NAME);
75    let out_dir = env::var("OUT_DIR").map_err(|e| vec![format!("OUT_DIR not set: {:?}", e)])?;
76    let out_dir = Path::new(&out_dir);
77    let output = out_dir.join(good_ormning_core::utils::rs_file_name(db_name));
78    let json_dir = out_dir.join("good_ormning");
79    if let Err(e) = fs::create_dir_all(&json_dir) {
80        return Err(vec![format!("Error creating directory {:?}: {:?}", json_dir, e)]);
81    }
82    let json_path = json_dir.join(good_ormning_core::utils::json_file_name(db_name));
83
84    // Serialize versions for proc macro
85    {
86        let mut versions_map: HashMap<usize, Version> = if json_path.exists() {
87            serde_json::from_str(&fs::read_to_string(&json_path).unwrap()).unwrap_or_default()
88        } else {
89            HashMap::new()
90        };
91        for (version_i, version) in args.versions.iter() {
92            let entry = versions_map.entry(*version_i).or_insert_with(|| Version::default());
93            for (k, v) in &version.tables {
94                entry.tables.insert(k.clone(), v.clone());
95            }
96            for (k, v) in &version.custom_types {
97                entry.custom_types.insert(k.clone(), v.clone());
98            }
99        }
100        let _ = fs::write(json_path, serde_json::to_string(&versions_map).unwrap());
101    }
102    let mut errs = Errs::new();
103    let mut migrations = vec![];
104    let mut prev_version: Option<Version> = None;
105    let mut prev_version_i: Option<i64> = None;
106    let mut field_lookup: HashMap<TableRef, PgTableInfo> = HashMap::new();
107    for (version_i, version) in &args.versions {
108        let path = rpds::vector![format!("Migration to {}", version_i)];
109        let mut migration = vec![];
110
111        // Prep for current version
112        field_lookup.clear();
113        for (table_id, table) in &version.tables {
114            let mut fields: HashMap<FieldRef, PgFieldInfo> = HashMap::new();
115            for (field_id, field) in &table.fields {
116                fields.insert(FieldRef {
117                    table_id: table_id.clone(),
118                    field_id: field_id.clone(),
119                }, PgFieldInfo {
120                    sql_name: field.id.clone(),
121                    type_: field.type_.type_.clone(),
122                });
123            }
124            field_lookup.insert(TableRef(table_id.clone()), PgTableInfo {
125                sql_name: table.id.clone(),
126                fields: fields,
127            });
128        }
129        let version_i = *version_i as i64;
130        if let Some(i) = prev_version_i {
131            if version_i != i as i64 + 1 {
132                errs.err(
133                    &path,
134                    format!(
135                        "Version numbers are not consecutive ({} to {}) - was an intermediate version deleted?",
136                        i,
137                        version_i
138                    ),
139                );
140            }
141        }
142
143        // Main migrations
144        {
145            let mut table_sql_names = HashMap::new();
146            for (table_id, table) in &version.tables {
147                table_sql_names.insert(table_id.clone(), table.id.clone());
148            }
149            let mut state = PgMigrateCtx::new(errs.clone(), table_sql_names, version.clone());
150            let current_nodes = version.to_migrate_nodes();
151            let prev_nodes = prev_version.take().map(|s| s.to_migrate_nodes());
152            good_ormning_core::graphmigrate::migrate(&mut state, prev_nodes, &current_nodes);
153            for statement in &state.statements {
154                migration.push(quote!{
155                    {
156                        let query = #statement;
157                        txn.execute(query, &[]).await.to_good_error_query(query)?;
158                    };
159                });
160            }
161            errs = state.errs.clone();
162        }
163
164        // Build migration
165        let pascal_db_name: String = db_name.to_case(Case::Pascal);
166        let enum_name = format_ident!("Db{}Versions", pascal_db_name);
167        let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i as usize);
168        let enum_variant = format_ident!("V{}", version_i as usize);
169        migrations.push(quote!{
170            if version < #version_i {
171                #(#migration) * {
172                    let query = "update __good_version set version = $1";
173                    good_ormning:: runtime:: pg:: PgConnection:: execute(
174                        &mut txn,
175                        query,
176                        &[& #version_i]
177                    ).await.to_good_error_query(query) ?;
178                }
179                if let Some(callback) = & callback {
180                    callback(#enum_name::#enum_variant(#newtype_name(&mut txn))).await ?;
181                }
182            }
183        });
184
185        // Next iter prep
186        prev_version = Some(version.clone());
187        prev_version_i = Some(version_i);
188    }
189
190    // Compile, output
191    let last_version_i = prev_version_i.unwrap() as i64;
192    let pascal_db_name: String = db_name.to_case(Case::Pascal);
193    let enum_name = format_ident!("Db{}Versions", pascal_db_name);
194    let mut enum_variants = vec![];
195    let mut db_types = vec![];
196    for (version_i, _) in &args.versions {
197        let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i);
198        let enum_variant = format_ident!("V{}", version_i);
199        enum_variants.push(quote!(#enum_variant(#newtype_name <'a >)));
200        db_types.push(quote!{
201            pub struct #newtype_name <'a >(pub &'a mut dyn good_ormning:: runtime:: pg:: PgConnection);
202        });
203    }
204    let latest_newtype_name = format_ident!("Db{}{}", pascal_db_name, last_version_i as usize);
205    let db_alias_name = format_ident!("Db{}", pascal_db_name);
206    let db_others =
207        good_ormning_core::pg::query::generate::generate_query_functions(
208            &mut errs,
209            field_lookup,
210            args.queries,
211            "",
212            quote!(#latest_newtype_name),
213        );
214    let tokens = quote!{
215        use good_ormning::runtime::GoodError;
216        use good_ormning::runtime::ToGoodError;
217        #(#db_types) * pub enum #enum_name <'a > {
218            #(#enum_variants,) *
219        }
220        pub use #latest_newtype_name as #db_alias_name;
221        async fn init_db(db: & mut impl good_ormning:: runtime:: pg:: PgConnection) -> Result <(),
222        GoodError > {
223            {
224                let query =
225                    "create table if not exists __good_version (rid int primary key, version bigint not null, lock int not null);";
226                good_ormning::runtime::pg::PgConnection::execute(db, query, &[]).await.to_good_error_query(query)?;
227            }
228            {
229                let query =
230                    "insert into __good_version (rid, version, lock) values (0, -1, 0) on conflict do nothing;";
231                good_ormning::runtime::pg::PgConnection::execute(db, query, &[]).await.to_good_error_query(query)?;
232            }
233            Ok(())
234        }
235        #[
236            doc =
237                "(Initialize and) migrate the database to the latest schema version. Optionally takes a callback which is run after each version, so custom post-schema change code can be run. Use `good_query!` macros with the version parameter to do migrations."
238        ] pub async fn migrate(
239            db: & mut tokio_postgres:: Client,
240            callback: Option <&(
241                dyn for <'b > Fn(
242                    #enum_name <'b >
243                ) -> std:: pin:: Pin < Box < dyn std:: future:: Future < Output = Result <(),
244                GoodError >> + Send + 'b >> + Send + Sync
245            ) >
246        ) -> Result <(),
247        GoodError > {
248            init_db(db).await?;
249            loop {
250                let mut txn = db.transaction().await.to_good_error(|| "Failed to start transaction".to_string())?;
251                let migrated = {
252                    let query = "update __good_version set lock = 1 where rid = 0 and lock = 0 returning version";
253                    let res =
254                        good_ormning::runtime::pg::PgConnection::query(&mut txn, query, &[])
255                            .await
256                            .to_good_error_query(query)?;
257                    let version = match res.first() {
258                        Some(r) => {
259                            let ver: i64 = r.get(0usize);
260                            Some(ver)
261                        },
262                        None => {
263                            None
264                        },
265                    };
266                    if let Some(version) = version {
267                        if version > #last_version_i {
268                            return Err(
269                                GoodError(
270                                    format!(
271                                        "The latest known version is {}, but the schema is at unknown version {}",
272                                        #last_version_i,
273                                        version
274                                    ),
275                                ),
276                            );
277                        }
278                        #(#migrations) * {
279                            let query = "update __good_version set lock = 0";
280                            good_ormning::runtime::pg::PgConnection::execute(&mut txn, query, &[])
281                                .await
282                                .to_good_error_query(query)?;
283                        }
284                        true
285                    }
286                    else {
287                        false
288                    }
289                };
290                if migrated {
291                    txn.commit().await.to_good_error(|| "Failed to commit transaction".to_string())?;
292                    return Ok(());
293                }
294                else {
295                    txn.rollback().await.to_good_error(|| "Failed to rollback transaction".to_string())?;
296                    tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
297                }
298            }
299        }
300        pub async fn get_schema_version(
301            db: & mut impl good_ormning:: runtime:: pg:: PgConnection
302        ) -> Result < Option < i64 >,
303        GoodError > {
304            init_db(db).await?;
305            let query = "select version from __good_version where rid = 0";
306            let res = db.query(query, &[]).await.to_good_error_query(query)?;
307            if let Some(r) = res.first() {
308                let x: i64 = r.get(0usize);
309                if x == -1 {
310                    return Ok(None);
311                } else {
312                    return Ok(Some(x));
313                }
314            }
315            Ok(None)
316        }
317        #(#db_others) *
318    };
319    match genemichaels_lib::format_str(&tokens.to_string(), &genemichaels_lib::FormatConfig::default()) {
320        Ok(src) => {
321            match fs::write(&output, src.rendered.as_bytes()) {
322                Ok(_) => { },
323                Err(e) => errs.err(
324                    &rpds::vector![],
325                    format!("Failed to write generated code to {:?}: {:?}", output, e),
326                ),
327            };
328        },
329        Err(e) => {
330            errs.err(&rpds::vector![], format!("Error formatting generated code: {:?}\n{}", e, tokens));
331        },
332    };
333    errs.raise()?;
334    Ok(())
335}
336
337#[cfg(test)]
338mod test {
339    use {
340        super::{
341            generate,
342            GenerateArgs,
343            query::expr::SerialExpr,
344            schema::field::{
345                field_auto,
346                field_i32,
347                field_str,
348            },
349            Version,
350        },
351    };
352
353    #[test]
354    fn test_add_field_serial_bad() {
355        assert!(generate(GenerateArgs {
356            db_name: None,
357            versions: vec![
358                // Versions (previous)
359                (0usize, {
360                    let v = Version::new();
361                    v.table("zMOY9YMCK").field("z437INV6D", field_str().build());
362                    v.build()
363                }),
364                (1usize, {
365                    let v = Version::new();
366                    let bananna = v.table("zMOY9YMCK");
367                    bananna.field("z437INV6D", field_str().build());
368                    bananna.field("zPREUVAOD", field_auto().migrate_fill(SerialExpr::LitAuto(0)).build(),);
369                    v.build()
370                }),
371            ],
372            ..Default::default()
373        }).is_err());
374    }
375
376    #[test]
377    #[should_panic]
378    fn test_add_field_dup_bad() {
379        generate(GenerateArgs {
380            db_name: None,
381            versions: vec![
382                // Versions (previous)
383                (0usize, {
384                    let v = Version::new();
385                    v.table("zPAO2PJU4").field("z437INV6D", field_str().build());
386                    v.build()
387                }),
388                (1usize, {
389                    let v = Version::new();
390                    let bananna = v.table("zQZQ8E2WD");
391                    bananna.field("z437INV6D", field_str().build());
392                    bananna.field("z437INV6D", field_i32().build());
393                    v.build()
394                }),
395            ],
396            ..Default::default()
397        }).unwrap();
398    }
399
400    #[test]
401    #[should_panic]
402    fn test_add_table_dup_bad() {
403        generate(GenerateArgs {
404            db_name: None,
405            versions: vec![
406                // Versions (previous)
407                (0usize, {
408                    let v = Version::new();
409                    v.table("zSNS34DYI").field("z437INV6D", field_str().build());
410                    v.build()
411                }),
412                (1usize, {
413                    let v = Version::new();
414                    v.table("zSNS34DYI").field("z437INV6D", field_str().build());
415                    v.table("zSNS34DYI").field("z437INV6D", field_str().build());
416                    v.build()
417                }),
418            ],
419            ..Default::default()
420        }).unwrap();
421    }
422
423    #[test]
424    fn test_res_count_none_bad() {
425        let v = Version::new();
426        let bananna = v.table("z5S18LWQE");
427        bananna.field("z437INV6D", field_str().build());
428        assert!(generate(GenerateArgs {
429            db_name: None,
430            versions: vec![(0usize, v.build())],
431            ..Default::default()
432        }).is_err());
433    }
434
435    #[test]
436    fn test_select_nothing_bad() {
437        let v = Version::new();
438        v.table("zOOR88EQ9").field("z437INV6D", field_str().build());
439        assert!(generate(GenerateArgs {
440            db_name: None,
441            versions: vec![(0usize, v.build())],
442            ..Default::default()
443        }).is_err());
444    }
445
446    #[test]
447    fn test_returning_none_bad() {
448        let v = Version::new();
449        let bananna = v.table("zZPD1I2EF");
450        bananna.field("z437INV6D", field_str().build());
451        assert!(generate(GenerateArgs {
452            db_name: None,
453            versions: vec![(0usize, v.build())],
454            ..Default::default()
455        }).is_err());
456    }
457}