1use {
2 good_ormning_core::{
3 sqlite::{
4 graph::utils::SqliteMigrateCtx,
5 query::utils::{
6 SqliteFieldInfo,
7 SqliteTableInfo,
8 },
9 schema::{
10 field::FieldRef,
11 table::TableRef,
12 },
13 },
14 utils::Errs,
15 },
16 convert_case::{
17 Casing,
18 Case,
19 },
20 quote::{
21 format_ident,
22 quote,
23 },
24 std::{
25 collections::HashMap,
26 env,
27 fs,
28 path::Path,
29 },
30};
31pub use {
32 good_ormning_core::sqlite::*,
33 good_ormning_macros::{
34 good_query_many_sqlite as good_query_many,
35 good_query_one_sqlite as good_query_one,
36 good_query_opt_sqlite as good_query_opt,
37 good_query_sqlite as good_query,
38 },
39};
40
41pub struct GenerateArgs {
42 pub db_name: Option<String>,
45 pub versions: Vec<(usize, Version)>,
52 pub queries: Vec<Query>,
54}
55
56impl Default for GenerateArgs {
57 fn default() -> Self {
58 Self {
59 db_name: None,
60 versions: vec![],
61 queries: vec![],
62 }
63 }
64}
65
66pub fn generate(args: GenerateArgs) -> Result<(), Vec<String>> {
73 let db_name = args.db_name.as_deref().unwrap_or(good_ormning_core::utils::DEFAULT_DB_NAME);
74 let out_dir = env::var("OUT_DIR").map_err(|e| vec![format!("OUT_DIR not set: {:?}", e)])?;
75 let out_dir = Path::new(&out_dir);
76 let output = out_dir.join(good_ormning_core::utils::rs_file_name(db_name));
77 let json_dir = out_dir.join("good_ormning");
78 if let Err(e) = fs::create_dir_all(&json_dir) {
79 return Err(vec![format!("Error creating directory {:?}: {:?}", json_dir, e)]);
80 }
81 let json_path = json_dir.join(good_ormning_core::utils::json_file_name(db_name));
82
83 {
85 let mut versions_map: HashMap<usize, Version> = if json_path.exists() {
86 serde_json::from_str(&fs::read_to_string(&json_path).unwrap()).unwrap_or_default()
87 } else {
88 HashMap::new()
89 };
90 for (version_i, version) in args.versions.iter() {
91 let entry = versions_map.entry(*version_i).or_insert_with(|| Version::default());
92 for (k, v) in &version.tables {
93 entry.tables.insert(k.clone(), v.clone());
94 }
95 for (k, v) in &version.custom_types {
96 entry.custom_types.insert(k.clone(), v.clone());
97 }
98 }
99 let _ = fs::write(json_path, serde_json::to_string(&versions_map).unwrap());
100 }
101 let mut errs = Errs::new();
102 let mut migrations = vec![];
103 let mut prev_version: Option<Version> = None;
104 let mut prev_version_i: Option<i64> = None;
105 let mut field_lookup: HashMap<TableRef, SqliteTableInfo> = HashMap::new();
106 for (version_i, version) in &args.versions {
107 let path = rpds::vector![format!("Migration to {}", version_i)];
108 let mut migration = vec![];
109
110 field_lookup.clear();
112 for (table_id, table) in &version.tables {
113 let mut fields: HashMap<FieldRef, SqliteFieldInfo> = HashMap::new();
114 for (field_id, field) in &table.fields {
115 fields.insert(FieldRef {
116 table_id: table_id.clone(),
117 field_id: field_id.clone(),
118 }, SqliteFieldInfo {
119 sql_name: field.id.clone(),
120 type_: field.type_.type_.clone(),
121 });
122 }
123 field_lookup.insert(TableRef(table_id.clone()), SqliteTableInfo {
124 sql_name: table.id.clone(),
125 fields: fields,
126 });
127 }
128 let version_i = *version_i as i64;
129 if let Some(i) = prev_version_i {
130 if version_i != i as i64 + 1 {
131 errs.err(
132 &path,
133 format!(
134 "Version numbers are not consecutive ({} to {}) - was an intermediate version deleted?",
135 i,
136 version_i
137 ),
138 );
139 }
140 }
141
142 {
144 let mut table_sql_names = HashMap::new();
145 for (table_id, table) in &version.tables {
146 table_sql_names.insert(table_id.clone(), table.id.clone());
147 }
148 let mut state = SqliteMigrateCtx::new(errs.clone(), table_sql_names, version.clone());
149 let current_nodes = version.to_migrate_nodes();
150 let prev_nodes = prev_version.take().map(|s| s.to_migrate_nodes());
151 good_ormning_core::graphmigrate::migrate(&mut state, prev_nodes, ¤t_nodes);
152 for statement in &state.statements {
153 migration.push(quote!{
154 {
155 let query = #statement;
156 db.execute(query, ()).to_good_error_query(query)?;
157 };
158 });
159 }
160 errs = state.errs.clone();
161 }
162
163 let pascal_db_name: String = db_name.to_case(Case::Pascal);
165 let enum_name = format_ident!("Db{}Versions", pascal_db_name);
166 let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i as usize);
167 let enum_variant = format_ident!("V{}", version_i as usize);
168 migrations.push(quote!{
169 if version < #version_i {
170 #(#migration) * {
171 let query = "update __good_version set version = ?";
172 db.execute(query, (#version_i,)).to_good_error_query(query) ?;
173 }
174 if let Some(callback) = & callback {
175 callback(#enum_name::#enum_variant(#newtype_name(db))) ?;
176 }
177 }
178 });
179
180 prev_version = Some(version.clone());
182 prev_version_i = Some(version_i);
183 }
184
185 let last_version_i = prev_version_i.unwrap() as i64;
187 let pascal_db_name: String = db_name.to_case(Case::Pascal);
188 let enum_name = format_ident!("Db{}Versions", pascal_db_name);
189 let mut enum_variants = vec![];
190 let mut db_types = vec![];
191 for (version_i, _) in &args.versions {
192 let newtype_name = format_ident!("Db{}{}", pascal_db_name, version_i);
193 let enum_variant = format_ident!("V{}", version_i);
194 enum_variants.push(quote!(#enum_variant(#newtype_name <'a, C >)));
195 db_types.push(quote!{
196 pub struct #newtype_name <'a,
197 C: good_ormning:: runtime:: sqlite:: SqliteConnection >(pub &'a mut C);
198 });
199 }
200 let latest_newtype_name = format_ident!("Db{}{}", pascal_db_name, last_version_i as usize);
201 let db_alias_name = format_ident!("Db{}", pascal_db_name);
202 let db_others =
203 good_ormning_core::sqlite::query::generate::generate_query_functions(
204 &mut errs,
205 field_lookup,
206 args.queries,
207 "",
208 quote!(#latest_newtype_name <'_, C >),
209 );
210 let tokens = quote!{
211 use good_ormning::runtime::GoodError;
212 use good_ormning::runtime::ToGoodError;
213 #(#db_types) * pub enum #enum_name <'a,
214 C: good_ormning:: runtime:: sqlite:: SqliteConnection > {
215 #(#enum_variants,) *
216 }
217 pub use #latest_newtype_name as #db_alias_name;
218 fn init_db(db: & mut impl good_ormning:: runtime:: sqlite:: SqliteConnection) -> Result <(),
219 GoodError > {
220 db.load_array_module().to_good_error(|| "Error loading array extension for array values".to_string())?;
221 {
222 let query =
223 "create table if not exists __good_version (rid int primary key, version bigint not null, lock int not null);";
224 db.execute(query, ()).to_good_error_query(query)?;
225 }
226 {
227 let query =
228 "insert into __good_version (rid, version, lock) values (0, -1, 0) on conflict do nothing;";
229 db.execute(query, ()).to_good_error_query(query)?;
230 }
231 Ok(())
232 }
233 pub fn migrate < C: good_ormning:: runtime:: sqlite:: SqliteConnection >(
234 db: & mut C,
235 callback: Option <&(dyn Fn(#enum_name <'_, C >) -> Result <(), GoodError >) >
236 ) -> Result <(),
237 GoodError > {
238 init_db(db)?;
239 loop {
240 let query = "update __good_version set lock = 1 where rid = 0 and lock = 0 returning version";
241 let version = match db.query(query, (), |r| {
242 let ver: i64 = r.get("version")?;
243 Ok(ver)
244 }).to_good_error_query(query)?.pop() {
245 Some(v) => v,
246 None => {
247 std::thread::sleep(std::time::Duration::from_millis(100));
248 continue;
249 },
250 };
251 if version > #last_version_i {
252 return Err(
253 GoodError(
254 format!(
255 "The latest known version is {}, but the schema is at unknown version {}",
256 #last_version_i,
257 version
258 ),
259 ),
260 );
261 }
262 #(#migrations) * {
263 let query = "update __good_version set lock = 0";
264 db.execute(query, ()).to_good_error_query(query)?;
265 }
266 return Ok(());
267 }
268 }
269 pub fn get_schema_version(
270 db: & mut impl good_ormning:: runtime:: sqlite:: SqliteConnection
271 ) -> Result < Option < i64 >,
272 GoodError > {
273 init_db(db)?;
274 let query = "select version from __good_version where rid = 0";
275 let mut res = db.query(query, (), |r| -> rusqlite::Result<i64> {
276 let x: i64 = r.get(0usize)?;
277 Ok(x)
278 }).to_good_error_query(query)?;
279 if let Some(v) = res.pop() {
280 if v == -1 {
281 Ok(None)
282 } else {
283 Ok(Some(v))
284 }
285 } else {
286 Ok(None)
287 }
288 }
289 #(#db_others) *
290 };
291 match genemichaels_lib::format_str(&tokens.to_string(), &genemichaels_lib::FormatConfig::default()) {
292 Ok(src) => {
293 match fs::write(&output, src.rendered.as_bytes()) {
294 Ok(_) => { },
295 Err(e) => errs.err(
296 &rpds::vector![],
297 format!("Failed to write generated code to {:?}: {:?}", output, e),
298 ),
299 };
300 },
301 Err(e) => {
302 errs.err(&rpds::vector![], format!("Error formatting generated code: {:?}\n{}", e, tokens));
303 },
304 };
305 errs.raise()?;
306 Ok(())
307}
308
309#[cfg(test)]
310mod test {
311 use {
312 super::{
313 generate,
314 GenerateArgs,
315 query::expr::SerialExpr,
316 schema::field::{
317 field_auto,
318 field_i32,
319 field_str,
320 },
321 Version,
322 },
323 };
324
325 #[test]
326 fn test_add_field_serial_bad() {
327 assert!(generate(GenerateArgs {
328 db_name: None,
329 versions: vec![
330 (0usize, {
332 let v = Version::new();
333 v.table("bananna").field("hizat", field_str().build());
334 v.build()
335 }),
336 (1usize, {
337 let v = Version::new();
338 let bananna = v.table("bananna");
339 bananna.field("hizat", field_str().build());
340 bananna.field("zomzom", field_auto().migrate_fill(SerialExpr::LitAuto(0)).build(),);
341 v.build()
342 }),
343 ],
344 ..Default::default()
345 }).is_err());
346 }
347
348 #[test]
349 #[should_panic]
350 fn test_add_field_dup_bad() {
351 generate(GenerateArgs {
352 db_name: None,
353 versions: vec![
354 (0usize, {
356 let v = Version::new();
357 v.table("bananna").field("hizat", field_str().build());
358 v.build()
359 }),
360 (1usize, {
361 let v = Version::new();
362 let bananna = v.table("bananna");
363 bananna.field("hizat", field_str().build());
364 bananna.field("zomzom", field_i32().build());
365 v.build()
366 }),
367 ],
368 ..Default::default()
369 }).unwrap();
370 }
371
372 #[test]
373 #[should_panic]
374 fn test_add_table_dup_bad() {
375 generate(GenerateArgs {
376 db_name: None,
377 versions: vec![
378 (0usize, {
380 let v = Version::new();
381 v.table("bananna").field("hizat", field_str().build());
382 v.build()
383 }),
384 (1usize, {
385 let v = Version::new();
386 v.table("bananna").field("hizat", field_str().build());
387 v.table("bananna").field("hizat", field_str().build());
388 v.build()
389 }),
390 ],
391 ..Default::default()
392 }).unwrap();
393 }
394
395 #[test]
396 fn test_res_count_none_bad() {
397 let v = Version::new();
398 let bananna = v.table("bananna");
399 bananna.field("hizat", field_str().build());
400 assert!(generate(GenerateArgs {
401 db_name: None,
402 versions: vec![(0usize, v.build())],
403 ..Default::default()
404 }).is_err());
405 }
406
407 #[test]
408 fn test_select_nothing_bad() {
409 let v = Version::new();
410 v.table("bananna").field("hizat", field_str().build());
411 assert!(generate(GenerateArgs {
412 db_name: None,
413 versions: vec![(0usize, v.build())],
414 ..Default::default()
415 }).is_err());
416 }
417
418 #[test]
419 fn test_returning_none_bad() {
420 let v = Version::new();
421 let bananna = v.table("bananna");
422 bananna.field("hizat", field_str().build());
423 assert!(generate(GenerateArgs {
424 db_name: None,
425 versions: vec![(0usize, v.build())],
426 ..Default::default()
427 }).is_err());
428 }
429}