Skip to main content

good_ormning/sqlite/
mod.rs

1use {
2    good_ormning_core::{
3        sqlite::{
4            graph::utils::SqliteMigrateCtx,
5            query::utils::{
6                SqliteFieldInfo,
7                SqliteTableInfo,
8            },
9            schema::{
10                field::FieldRef,
11                table::TableRef,
12            },
13        },
14        utils::Errs,
15    },
16    convert_case::{
17        Casing,
18        Case,
19    },
20    quote::{
21        format_ident,
22        quote,
23    },
24    std::{
25        collections::HashMap,
26        env,
27        fs,
28        path::Path,
29    },
30};
31pub use {
32    good_ormning_core::sqlite::*,
33    good_ormning_macros::{
34        good_query_many_sqlite as good_query_many,
35        good_query_one_sqlite as good_query_one,
36        good_query_opt_sqlite as good_query_opt,
37        good_query_sqlite as good_query,
38    },
39};
40
41pub struct GenerateArgs {
42    /// If you have multiple databases, use this to disambiguate them. You'll also need
43    /// to use it in `good_module!` and `good_query!`.
44    pub db_name: Option<String>,
45    /// A list of database version ids and schema versions. The ids must be consecutive
46    /// but can start from any number. Once a version has been applied to a production
47    /// database it shouldn't be modified again (modifications should be done in a new
48    /// version).
49    ///
50    /// These will be turned into migrations as part of the `migrate` function.
51    pub versions: Vec<(usize, Version)>,
52    /// A list of queries to generate type-safe functions for.
53    pub queries: Vec<Query>,
54}
55
56impl Default for GenerateArgs {
57    fn default() -> Self {
58        Self {
59            db_name: None,
60            versions: vec![],
61            queries: vec![],
62        }
63    }
64}
65
66/// Generate Rust code for migrations and queries. Also saves schema type info for
67/// proc_macros to refer to.
68///
69/// # Returns
70///
71/// * Error - a list of validation or generation errors that occurred
72pub fn generate(args: GenerateArgs) -> Result<(), Vec<String>> {
73    let db_name = args.db_name.as_deref().unwrap_or(good_ormning_core::utils::DEFAULT_DB_NAME);
74    let out_dir = env::var("OUT_DIR").map_err(|e| vec![format!("OUT_DIR not set: {:?}", e)])?;
75    let out_dir = Path::new(&out_dir);
76    let output = out_dir.join(good_ormning_core::utils::rs_file_name(db_name));
77    let json_dir = out_dir.join("good_ormning");
78    if let Err(e) = fs::create_dir_all(&json_dir) {
79        return Err(vec![format!("Error creating directory {:?}: {:?}", json_dir, e)]);
80    }
81    let json_path = json_dir.join(good_ormning_core::utils::json_file_name(db_name));
82
83    // Serialize versions for proc macro
84    {
85        let mut versions_map: HashMap<usize, Version> = if json_path.exists() {
86            serde_json::from_str(&fs::read_to_string(&json_path).unwrap()).unwrap_or_default()
87        } else {
88            HashMap::new()
89        };
90        for (version_i, version) in args.versions.iter() {
91            let entry = versions_map.entry(*version_i).or_insert_with(|| Version::default());
92            for (k, v) in &version.tables {
93                entry.tables.insert(k.clone(), v.clone());
94            }
95            for (k, v) in &version.custom_types {
96                entry.custom_types.insert(k.clone(), v.clone());
97            }
98        }
99        let _ = fs::write(json_path, serde_json::to_string(&versions_map).unwrap());
100    }
101    let mut errs = Errs::new();
102    let mut migrations = vec![];
103    let mut prev_version: Option<Version> = None;
104    let mut prev_version_i: Option<i64> = None;
105    let mut field_lookup: HashMap<TableRef, SqliteTableInfo> = HashMap::new();
106    for (version_i, version) in &args.versions {
107        let path = rpds::vector![format!("Migration to {}", version_i)];
108        let mut migration = vec![];
109
110        // Prep for current version
111        field_lookup.clear();
112        for (table_id, table) in &version.tables {
113            let mut fields: HashMap<FieldRef, SqliteFieldInfo> = HashMap::new();
114            for (field_id, field) in &table.fields {
115                fields.insert(FieldRef {
116                    table_id: table_id.clone(),
117                    field_id: field_id.clone(),
118                }, SqliteFieldInfo {
119                    sql_name: field.id.clone(),
120                    type_: field.type_.type_.clone(),
121                });
122            }
123            field_lookup.insert(TableRef(table_id.clone()), SqliteTableInfo {
124                sql_name: table.id.clone(),
125                fields: fields,
126            });
127        }
128        let version_i = *version_i as i64;
129        if let Some(i) = prev_version_i {
130            if version_i != i as i64 + 1 {
131                errs.err(
132                    &path,
133                    format!(
134                        "Version numbers are not consecutive ({} to {}) - was an intermediate version deleted?",
135                        i,
136                        version_i
137                    ),
138                );
139            }
140        }
141
142        // Main migrations
143        {
144            let mut table_sql_names = HashMap::new();
145            for (table_id, table) in &version.tables {
146                table_sql_names.insert(table_id.clone(), table.id.clone());
147            }
148            let mut state = SqliteMigrateCtx::new(errs.clone(), table_sql_names, version.clone());
149            let current_nodes = version.to_migrate_nodes();
150            let prev_nodes = prev_version.take().map(|s| s.to_migrate_nodes());
151            good_ormning_core::graphmigrate::migrate(&mut state, prev_nodes, &current_nodes);
152            for statement in &state.statements {
153                migration.push(quote!{
154                    {
155                        let query = #statement;
156                        db.execute(query, ()).to_good_error_query(query)?;
157                    };
158                });
159            }
160            errs = state.errs.clone();
161        }
162
163        // Build migration
164        let pascal_db_name: String = db_name.to_case(Case::Pascal);
165        let enum_name = format_ident!("Db{}Versions", pascal_db_name);
166        let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i as usize);
167        let enum_variant = format_ident!("V{}", version_i as usize);
168        migrations.push(quote!{
169            if version < #version_i {
170                #(#migration) * {
171                    let query = "update __good_version set version = ?";
172                    db.execute(query, (#version_i,)).to_good_error_query(query) ?;
173                }
174                if let Some(callback) = & callback {
175                    callback(#enum_name::#enum_variant(#newtype_name(db))) ?;
176                }
177            }
178        });
179
180        // Next iter prep
181        prev_version = Some(version.clone());
182        prev_version_i = Some(version_i);
183    }
184
185    // Compile, output
186    let last_version_i = prev_version_i.unwrap() as i64;
187    let pascal_db_name: String = db_name.to_case(Case::Pascal);
188    let enum_name = format_ident!("Db{}Versions", pascal_db_name);
189    let mut enum_variants = vec![];
190    let mut db_types = vec![];
191    for (version_i, _) in &args.versions {
192        let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i);
193        let enum_variant = format_ident!("V{}", version_i);
194        enum_variants.push(quote!(#enum_variant(#newtype_name <'a, C >)));
195        db_types.push(quote!{
196            pub struct #newtype_name <'a,
197            C: good_ormning:: runtime:: sqlite:: SqliteConnection >(pub &'a mut C);
198        });
199    }
200    let latest_newtype_name = format_ident!("Db{}{}", pascal_db_name, last_version_i as usize);
201    let db_alias_name = format_ident!("Db{}", pascal_db_name);
202    let db_others =
203        good_ormning_core::sqlite::query::generate::generate_query_functions(
204            &mut errs,
205            field_lookup,
206            args.queries,
207            "",
208            quote!(#latest_newtype_name <'_, C >),
209        );
210    let tokens = quote!{
211        use good_ormning::runtime::GoodError;
212        use good_ormning::runtime::ToGoodError;
213        #(#db_types) * pub enum #enum_name <'a,
214        C: good_ormning:: runtime:: sqlite:: SqliteConnection > {
215            #(#enum_variants,) *
216        }
217        pub use #latest_newtype_name as #db_alias_name;
218        fn init_db(db: & mut impl good_ormning:: runtime:: sqlite:: SqliteConnection) -> Result <(),
219        GoodError > {
220            db.load_array_module().to_good_error(|| "Error loading array extension for array values".to_string())?;
221            {
222                let query =
223                    "create table if not exists __good_version (rid int primary key, version bigint not null, lock int not null);";
224                db.execute(query, ()).to_good_error_query(query)?;
225            }
226            {
227                let query =
228                    "insert into __good_version (rid, version, lock) values (0, -1, 0) on conflict do nothing;";
229                db.execute(query, ()).to_good_error_query(query)?;
230            }
231            Ok(())
232        }
233        pub fn migrate < C: good_ormning:: runtime:: sqlite:: SqliteConnection >(
234            db: & mut C,
235            callback: Option <&(dyn Fn(#enum_name <'_, C >) -> Result <(), GoodError >) >
236        ) -> Result <(),
237        GoodError > {
238            init_db(db)?;
239            loop {
240                let query = "update __good_version set lock = 1 where rid = 0 and lock = 0 returning version";
241                let version = match db.query(query, (), |r| {
242                    let ver: i64 = r.get("version")?;
243                    Ok(ver)
244                }).to_good_error_query(query)?.pop() {
245                    Some(v) => v,
246                    None => {
247                        std::thread::sleep(std::time::Duration::from_millis(100));
248                        continue;
249                    },
250                };
251                if version > #last_version_i {
252                    return Err(
253                        GoodError(
254                            format!(
255                                "The latest known version is {}, but the schema is at unknown version {}",
256                                #last_version_i,
257                                version
258                            ),
259                        ),
260                    );
261                }
262                #(#migrations) * {
263                    let query = "update __good_version set lock = 0";
264                    db.execute(query, ()).to_good_error_query(query)?;
265                }
266                return Ok(());
267            }
268        }
269        pub fn get_schema_version(
270            db: & mut impl good_ormning:: runtime:: sqlite:: SqliteConnection
271        ) -> Result < Option < i64 >,
272        GoodError > {
273            init_db(db)?;
274            let query = "select version from __good_version where rid = 0";
275            let mut res = db.query(query, (), |r| -> rusqlite::Result<i64> {
276                let x: i64 = r.get(0usize)?;
277                Ok(x)
278            }).to_good_error_query(query)?;
279            if let Some(v) = res.pop() {
280                if v == -1 {
281                    Ok(None)
282                } else {
283                    Ok(Some(v))
284                }
285            } else {
286                Ok(None)
287            }
288        }
289        #(#db_others) *
290    };
291    match genemichaels_lib::format_str(&tokens.to_string(), &genemichaels_lib::FormatConfig::default()) {
292        Ok(src) => {
293            match fs::write(&output, src.rendered.as_bytes()) {
294                Ok(_) => { },
295                Err(e) => errs.err(
296                    &rpds::vector![],
297                    format!("Failed to write generated code to {:?}: {:?}", output, e),
298                ),
299            };
300        },
301        Err(e) => {
302            errs.err(&rpds::vector![], format!("Error formatting generated code: {:?}\n{}", e, tokens));
303        },
304    };
305    errs.raise()?;
306    Ok(())
307}
308
309#[cfg(test)]
310mod test {
311    use {
312        super::{
313            generate,
314            GenerateArgs,
315            query::expr::SerialExpr,
316            schema::field::{
317                field_auto,
318                field_i32,
319                field_str,
320            },
321            Version,
322        },
323    };
324
325    #[test]
326    fn test_add_field_serial_bad() {
327        assert!(generate(GenerateArgs {
328            db_name: None,
329            versions: vec![
330                // Versions (previous)
331                (0usize, {
332                    let v = Version::new();
333                    v.table("bananna").field("hizat", field_str().build());
334                    v.build()
335                }),
336                (1usize, {
337                    let v = Version::new();
338                    let bananna = v.table("bananna");
339                    bananna.field("hizat", field_str().build());
340                    bananna.field("zomzom", field_auto().migrate_fill(SerialExpr::LitAuto(0)).build(),);
341                    v.build()
342                }),
343            ],
344            ..Default::default()
345        }).is_err());
346    }
347
348    #[test]
349    #[should_panic]
350    fn test_add_field_dup_bad() {
351        generate(GenerateArgs {
352            db_name: None,
353            versions: vec![
354                // Versions (previous)
355                (0usize, {
356                    let v = Version::new();
357                    v.table("bananna").field("hizat", field_str().build());
358                    v.build()
359                }),
360                (1usize, {
361                    let v = Version::new();
362                    let bananna = v.table("bananna");
363                    bananna.field("hizat", field_str().build());
364                    bananna.field("zomzom", field_i32().build());
365                    v.build()
366                }),
367            ],
368            ..Default::default()
369        }).unwrap();
370    }
371
372    #[test]
373    #[should_panic]
374    fn test_add_table_dup_bad() {
375        generate(GenerateArgs {
376            db_name: None,
377            versions: vec![
378                // Versions (previous)
379                (0usize, {
380                    let v = Version::new();
381                    v.table("bananna").field("hizat", field_str().build());
382                    v.build()
383                }),
384                (1usize, {
385                    let v = Version::new();
386                    v.table("bananna").field("hizat", field_str().build());
387                    v.table("bananna").field("hizat", field_str().build());
388                    v.build()
389                }),
390            ],
391            ..Default::default()
392        }).unwrap();
393    }
394
395    #[test]
396    fn test_res_count_none_bad() {
397        let v = Version::new();
398        let bananna = v.table("bananna");
399        bananna.field("hizat", field_str().build());
400        assert!(generate(GenerateArgs {
401            db_name: None,
402            versions: vec![(0usize, v.build())],
403            ..Default::default()
404        }).is_err());
405    }
406
407    #[test]
408    fn test_select_nothing_bad() {
409        let v = Version::new();
410        v.table("bananna").field("hizat", field_str().build());
411        assert!(generate(GenerateArgs {
412            db_name: None,
413            versions: vec![(0usize, v.build())],
414            ..Default::default()
415        }).is_err());
416    }
417
418    #[test]
419    fn test_returning_none_bad() {
420        let v = Version::new();
421        let bananna = v.table("bananna");
422        bananna.field("hizat", field_str().build());
423        assert!(generate(GenerateArgs {
424            db_name: None,
425            versions: vec![(0usize, v.build())],
426            ..Default::default()
427        }).is_err());
428    }
429}