good_ormning/pg/
mod.rs

1use proc_macro2::{
2    TokenStream,
3    Ident,
4};
5use quote::{
6    quote,
7    format_ident,
8    ToTokens,
9};
10use std::{
11    collections::{
12        HashMap,
13        HashSet,
14        BTreeMap,
15    },
16    path::Path,
17    fs,
18    rc::Rc,
19};
20use crate::{
21    pg::{
22        types::{
23            Type,
24            to_rust_types,
25        },
26        query::expr::ExprValName,
27        graph::utils::PgMigrateCtx,
28    },
29    utils::{
30        Errs,
31        sanitize_ident,
32    },
33};
34use self::{
35    query::{
36        utils::{
37            PgQueryCtx,
38            QueryBody,
39        },
40        insert::{
41            Insert,
42            InsertConflict,
43        },
44        expr::Expr,
45        select::{
46            Returning,
47            Select,
48            NamedSelectSource,
49            JoinSource,
50            Join,
51            Order,
52        },
53        update::Update,
54        delete::Delete,
55    },
56    schema::{
57        field::{
58            Field,
59            Field_,
60            SchemaFieldId,
61            FieldType,
62        },
63        table::{
64            Table,
65            Table_,
66            SchemaTableId,
67        },
68        constraint::{
69            ConstraintType,
70            Constraint_,
71            Constraint,
72            SchemaConstraintId,
73        },
74        index::{
75            Index_,
76            Index,
77            SchemaIndexId,
78        },
79    },
80    graph::{
81        table::NodeTable_,
82        GraphId,
83        utils::MigrateNode,
84        Node,
85        field::NodeField_,
86        constraint::NodeConstraint_,
87        index::NodeIndex_,
88    },
89};
90
91pub mod types;
92pub mod query;
93pub mod schema;
94pub mod graph;
95
96/// The number of results this query returns. This determines if the return type is
97/// void, `Option`, the value directly, or a `Vec`. It must be a valid value per
98/// the query body (e.g. select can't have `None` res count).
99#[derive(Debug, Clone)]
100pub enum QueryResCount {
101    None,
102    MaybeOne,
103    One,
104    Many,
105}
106
107/// See Insert for field descriptions. Call `build()` to get a finished query
108/// object.
109pub struct InsertBuilder {
110    pub q: Insert,
111}
112
113impl InsertBuilder {
114    pub fn on_conflict_do_update(mut self, f: &[&Field], v: Vec<(Field, Expr)>) -> Self {
115        self.q.on_conflict = Some(InsertConflict::DoUpdate {
116            conflict: f.iter().map(|f| (*f).clone()).collect(),
117            set: v,
118        });
119        self
120    }
121
122    pub fn on_conflict_do_nothing(mut self) -> Self {
123        self.q.on_conflict = Some(InsertConflict::DoNothing);
124        self
125    }
126
127    pub fn return_(mut self, v: Expr) -> Self {
128        self.q.returning.push(Returning {
129            e: v,
130            rename: None,
131        });
132        self
133    }
134
135    pub fn return_named(mut self, name: impl ToString, v: Expr) -> Self {
136        self.q.returning.push(Returning {
137            e: v,
138            rename: Some(name.to_string()),
139        });
140        self
141    }
142
143    pub fn return_field(mut self, f: &Field) -> Self {
144        self.q.returning.push(Returning {
145            e: Expr::Field(f.clone()),
146            rename: None,
147        });
148        self
149    }
150
151    pub fn return_fields(mut self, f: &[&Field]) -> Self {
152        for f in f {
153            self.q.returning.push(Returning {
154                e: Expr::Field((*f).clone()),
155                rename: None,
156            });
157        }
158        self
159    }
160
161    pub fn returns_from_iter(mut self, f: impl Iterator<Item = Returning>) -> Self {
162        self.q.returning.extend(f);
163        self
164    }
165
166    /// Produce a migration for use in version pre/post-migration.
167    pub fn build_migration(self) -> Insert {
168        self.q
169    }
170
171    /// Produce a query object.
172    ///
173    /// # Arguments
174    ///
175    /// * `name` - This is used as the name of the rust function.
176    pub fn build_query(self, name: impl ToString, res_count: QueryResCount) -> Query {
177        Query {
178            name: name.to_string(),
179            body: Box::new(self.q),
180            res_count: res_count,
181            res_name: None,
182        }
183    }
184
185    /// Same as `build_query`, but specify a name for the result structure. Only valid
186    /// if result is a record (not a single value).
187    pub fn build_query_named_res(self, name: impl ToString, res_count: QueryResCount, res_name: impl ToString) -> Query {
188        Query {
189            name: name.to_string(),
190            body: Box::new(self.q),
191            res_count: res_count,
192            res_name: Some(res_name.to_string()),
193        }
194    }
195}
196
197/// See Select for field descriptions. Call `build()` to get a finished query
198/// object.
199pub struct SelectBuilder {
200    pub q: Select,
201}
202
203impl SelectBuilder {
204    pub fn return_(mut self, v: Expr) -> Self {
205        self.q.returning.push(Returning {
206            e: v,
207            rename: None,
208        });
209        self
210    }
211
212    pub fn return_named(mut self, name: impl ToString, v: Expr) -> Self {
213        self.q.returning.push(Returning {
214            e: v,
215            rename: Some(name.to_string()),
216        });
217        self
218    }
219
220    pub fn return_field(mut self, f: &Field) -> Self {
221        self.q.returning.push(Returning {
222            e: Expr::Field(f.clone()),
223            rename: None,
224        });
225        self
226    }
227
228    pub fn return_fields(mut self, f: &[&Field]) -> Self {
229        for f in f {
230            self.q.returning.push(Returning {
231                e: Expr::Field((*f).clone()),
232                rename: None,
233            });
234        }
235        self
236    }
237
238    pub fn returns_from_iter(mut self, f: impl Iterator<Item = Returning>) -> Self {
239        self.q.returning.extend(f);
240        self
241    }
242
243    pub fn join(mut self, join: Join) -> Self {
244        self.q.join.push(join);
245        self
246    }
247
248    pub fn where_(mut self, predicate: Expr) -> Self {
249        self.q.where_ = Some(predicate);
250        self
251    }
252
253    pub fn group(mut self, clauses: Vec<Expr>) -> Self {
254        self.q.group = clauses;
255        self
256    }
257
258    pub fn order(mut self, expr: Expr, order: Order) -> Self {
259        self.q.order.push((expr, order));
260        self
261    }
262
263    pub fn order_from_iter(mut self, clauses: impl Iterator<Item = (Expr, Order)>) -> Self {
264        self.q.order.extend(clauses);
265        self
266    }
267
268    /// Sets `LIMIT`. `v` must evaluate to a number.
269    pub fn limit(mut self, v: Expr) -> Self {
270        self.q.limit = Some(v);
271        self
272    }
273
274    /// Produce a migration for use in version pre/post-migration.
275    pub fn build_migration(self) -> Select {
276        self.q
277    }
278
279    /// Produce a query object.
280    ///
281    /// # Arguments
282    ///
283    /// * `name` - This is used as the name of the rust function.
284    pub fn build_query(self, name: impl ToString, res_count: QueryResCount) -> Query {
285        Query {
286            name: name.to_string(),
287            body: Box::new(self.q),
288            res_count: res_count,
289            res_name: None,
290        }
291    }
292
293    // Same as `build_query`, but specify a name for the result structure. Only valid
294    // if result is a record (not a single value).
295    pub fn build_query_named_res(self, name: impl ToString, res_count: QueryResCount, res_name: impl ToString) -> Query {
296        Query {
297            name: name.to_string(),
298            body: Box::new(self.q),
299            res_count: res_count,
300            res_name: Some(res_name.to_string()),
301        }
302    }
303}
304
305/// See Update for field descriptions. Call `build()` to get a finished query
306/// object.
307pub struct UpdateBuilder {
308    pub q: Update,
309}
310
311impl UpdateBuilder {
312    pub fn where_(mut self, v: Expr) -> Self {
313        self.q.where_ = Some(v);
314        self
315    }
316
317    pub fn return_(mut self, v: Expr) -> Self {
318        self.q.returning.push(Returning {
319            e: v,
320            rename: None,
321        });
322        self
323    }
324
325    pub fn return_named(mut self, name: impl ToString, v: Expr) -> Self {
326        self.q.returning.push(Returning {
327            e: v,
328            rename: Some(name.to_string()),
329        });
330        self
331    }
332
333    pub fn return_field(mut self, f: &Field) -> Self {
334        self.q.returning.push(Returning {
335            e: Expr::Field(f.clone()),
336            rename: None,
337        });
338        self
339    }
340
341    pub fn return_fields(mut self, f: &[&Field]) -> Self {
342        for f in f {
343            self.q.returning.push(Returning {
344                e: Expr::Field((*f).clone()),
345                rename: None,
346            });
347        }
348        self
349    }
350
351    pub fn returns_from_iter(mut self, f: impl Iterator<Item = Returning>) -> Self {
352        self.q.returning.extend(f);
353        self
354    }
355
356    // Produce a migration for use in version pre/post-migration.
357    pub fn build_migration(self) -> Update {
358        self.q
359    }
360
361    // Produce a query object.
362    //
363    // # Arguments
364    //
365    // * `name` - This is used as the name of the rust function.
366    pub fn build_query(self, name: impl ToString, res_count: QueryResCount) -> Query {
367        Query {
368            name: name.to_string(),
369            body: Box::new(self.q),
370            res_count: res_count,
371            res_name: None,
372        }
373    }
374
375    // Same as `build_query`, but specify a name for the result structure. Only valid
376    // if result is a record (not a single value).
377    pub fn build_query_named_res(self, name: impl ToString, res_count: QueryResCount, res_name: impl ToString) -> Query {
378        Query {
379            name: name.to_string(),
380            body: Box::new(self.q),
381            res_count: res_count,
382            res_name: Some(res_name.to_string()),
383        }
384    }
385}
386
387/// See Delete for field descriptions. Call `build()` to get a finished query
388/// object.
389pub struct DeleteBuilder {
390    pub q: Delete,
391}
392
393impl DeleteBuilder {
394    pub fn where_(mut self, v: Expr) -> Self {
395        self.q.where_ = Some(v);
396        self
397    }
398
399    pub fn return_(mut self, v: Expr) -> Self {
400        self.q.returning.push(Returning {
401            e: v,
402            rename: None,
403        });
404        self
405    }
406
407    pub fn return_named(mut self, name: impl ToString, v: Expr) -> Self {
408        self.q.returning.push(Returning {
409            e: v,
410            rename: Some(name.to_string()),
411        });
412        self
413    }
414
415    pub fn return_field(mut self, f: &Field) -> Self {
416        self.q.returning.push(Returning {
417            e: Expr::Field(f.clone()),
418            rename: None,
419        });
420        self
421    }
422
423    pub fn return_fields(mut self, f: &[&Field]) -> Self {
424        for f in f {
425            self.q.returning.push(Returning {
426                e: Expr::Field((*f).clone()),
427                rename: None,
428            });
429        }
430        self
431    }
432
433    pub fn returns_from_iter(mut self, f: impl Iterator<Item = Returning>) -> Self {
434        self.q.returning.extend(f);
435        self
436    }
437
438    // Produce a migration for use in version pre/post-migration.
439    pub fn build_migration(self) -> Delete {
440        self.q
441    }
442
443    // Produce a query object.
444    //
445    // # Arguments
446    //
447    // * `name` - This is used as the name of the rust function.
448    pub fn build_query(self, name: impl ToString, res_count: QueryResCount) -> Query {
449        Query {
450            name: name.to_string(),
451            body: Box::new(self.q),
452            res_count: res_count,
453            res_name: None,
454        }
455    }
456
457    // Same as `build_query`, but specify a name for the result structure. Only valid
458    // if result is a record (not a single value).
459    pub fn build_query_named_res(self, name: impl ToString, res_count: QueryResCount, res_name: impl ToString) -> Query {
460        Query {
461            name: name.to_string(),
462            body: Box::new(self.q),
463            res_count: res_count,
464            res_name: Some(res_name.to_string()),
465        }
466    }
467}
468
469/// This represents an SQL query. A function will be generated which accepts a db
470/// connection and query parameters, and returns the query results. Call the
471/// `new_*` functions to get a builder.
472pub struct Query {
473    pub name: String,
474    pub body: Box<dyn QueryBody>,
475    pub res_count: QueryResCount,
476    pub res_name: Option<String>,
477}
478
479/// Get a builder for an INSERT query.
480///
481/// # Arguments
482///
483/// * `values` - The fields to insert and their corresponding values
484pub fn new_insert(table: &Table, values: Vec<(Field, Expr)>) -> InsertBuilder {
485    let mut unique = HashSet::new();
486    for v in &values {
487        if !unique.insert(&v.0) {
488            panic!("Duplicate field {} in insert", v.0);
489        }
490    }
491    InsertBuilder { q: Insert {
492        table: table.clone(),
493        values: values,
494        on_conflict: None,
495        returning: vec![],
496    } }
497}
498
499/// Get a builder for a SELECT query.
500pub fn new_select(table: &Table) -> SelectBuilder {
501    SelectBuilder { q: Select {
502        table: NamedSelectSource {
503            source: JoinSource::Table(table.clone()),
504            alias: None,
505        },
506        returning: vec![],
507        join: vec![],
508        where_: None,
509        group: vec![],
510        order: vec![],
511        limit: None,
512    } }
513}
514
515/// Get a builder for a SELECT query. This allows advanced sources (like selecting
516/// from a synthetic table).
517pub fn new_select_from(source: NamedSelectSource) -> SelectBuilder {
518    SelectBuilder { q: Select {
519        table: source,
520        returning: vec![],
521        join: vec![],
522        where_: None,
523        group: vec![],
524        order: vec![],
525        limit: None,
526    } }
527}
528
529/// Get a builder for an UPDATE query.
530///
531/// # Arguments
532///
533/// * `values` - The fields to update and their corresponding values
534pub fn new_update(table: &Table, values: Vec<(Field, Expr)>) -> UpdateBuilder {
535    let mut unique = HashSet::new();
536    for v in &values {
537        if !unique.insert(&v.0) {
538            panic!("Duplicate field {} in update", v.0);
539        }
540    }
541    UpdateBuilder { q: Update {
542        table: table.clone(),
543        values: values,
544        where_: None,
545        returning: vec![],
546    } }
547}
548
549/// Get a builder for a DELETE query.
550///
551/// # Arguments
552///
553/// * `name` - This becomes the name of the generated rust function.
554pub fn new_delete(table: &Table) -> DeleteBuilder {
555    DeleteBuilder { q: Delete {
556        table: table.clone(),
557        returning: vec![],
558        where_: None,
559    } }
560}
561
562/// The version represents the state of a schema at a point in time.
563#[derive(Default)]
564pub struct Version {
565    schema: BTreeMap<GraphId, MigrateNode>,
566    pre_migration: Vec<Box<dyn QueryBody>>,
567    post_migration: Vec<Box<dyn QueryBody>>,
568}
569
570impl Version {
571    /// Define a table in this version
572    pub fn table(&mut self, schema_id: &str, id: &str) -> Table {
573        let out = Table(Rc::new(Table_ {
574            schema_id: SchemaTableId(schema_id.into()),
575            id: id.into(),
576        }));
577        if self.schema.insert(GraphId::Table(out.schema_id.clone()), MigrateNode::new(vec![], Node::table(NodeTable_ {
578            def: out.clone(),
579            fields: vec![],
580        }))).is_some() {
581            panic!("Table with schema id {} already exists", out.schema_id);
582        };
583        out
584    }
585
586    /// Add a query to execute before before migrating to this schema (applied
587    /// immediately before migration).  Note that these may not run on new databases or
588    /// if you later delete early migrations, so these should only modify existing data
589    /// and not create new data (singleton rows, etc).  If you need those, do it with a
590    /// normal query executed manually against the latest version.
591    pub fn pre_migration(&mut self, q: impl QueryBody + 'static) {
592        self.pre_migration.push(Box::new(q));
593    }
594
595    /// Add a query to execute after migrating to this schema version (applied
596    /// immediately after migration). See other warnings from `pre_migration`.
597    pub fn post_migration(&mut self, q: impl QueryBody + 'static) {
598        self.post_migration.push(Box::new(q));
599    }
600}
601
602impl Table {
603    /// Define a field
604    pub fn field(&self, v: &mut Version, schema_id: impl ToString, id: impl ToString, type_: FieldType) -> Field {
605        let out = Field(Rc::new(Field_ {
606            table: self.clone(),
607            schema_id: SchemaFieldId(schema_id.to_string()),
608            id: id.to_string(),
609            type_: type_,
610        }));
611        if v
612            .schema
613            .insert(
614                GraphId::Field(self.schema_id.clone(), out.schema_id.clone()),
615                MigrateNode::new(
616                    vec![GraphId::Table(self.schema_id.clone())],
617                    Node::field(NodeField_ { def: out.clone() }),
618                ),
619            )
620            .is_some() {
621            panic!("Field with schema id {}.{} already exists", self.schema_id, out.schema_id);
622        };
623        out
624    }
625
626    /// Define a constraint
627    pub fn constraint(&self, v: &mut Version, schema_id: impl ToString, id: impl ToString, type_: ConstraintType) {
628        let out = Constraint(Rc::new(Constraint_ {
629            table: self.clone(),
630            schema_id: SchemaConstraintId(schema_id.to_string()),
631            id: id.to_string(),
632            type_: type_,
633        }));
634        let mut deps = vec![GraphId::Table(self.schema_id.clone())];
635        match &out.type_ {
636            ConstraintType::PrimaryKey(x) => {
637                for f in &x.fields {
638                    if &f.table != self {
639                        panic!(
640                            "Field {} in primary key constraint {} is in table {}, but constraint is in table {}",
641                            f,
642                            out.id,
643                            f.table,
644                            self
645                        );
646                    }
647                    deps.push(GraphId::Field(self.schema_id.clone(), f.schema_id.clone()));
648                }
649            },
650            ConstraintType::ForeignKey(x) => {
651                let mut last_foreign_table: Option<Field> = None;
652                for f in &x.fields {
653                    if &f.0.table != self {
654                        panic!(
655                            "Local field {} in foreign key constraint {} is in table {}, but constraint is in table {}",
656                            f.0,
657                            out.id,
658                            f.0.table,
659                            self
660                        );
661                    }
662                    deps.push(GraphId::Field(f.0.table.schema_id.clone(), f.0.schema_id.clone()));
663                    if let Some(t) = last_foreign_table.take() {
664                        if t.table != f.1.table {
665                            panic!(
666                                "Foreign field {} in foreign key constraint {} is in table {}, but constraint is in table {}",
667                                f.1,
668                                out.id,
669                                f.1.table,
670                                self
671                            );
672                        }
673                    }
674                    last_foreign_table = Some(f.1.clone());
675                    deps.push(GraphId::Field(f.1.table.schema_id.clone(), f.1.schema_id.clone()));
676                }
677            },
678        }
679        if v
680            .schema
681            .insert(
682                GraphId::Constraint(self.schema_id.clone(), out.schema_id.clone()),
683                MigrateNode::new(deps, Node::table_constraint(NodeConstraint_ { def: out.clone() })),
684            )
685            .is_some() {
686            panic!("Constraint with schema id {}.{} aleady exists", self.schema_id, out.schema_id)
687        };
688    }
689
690    /// Define an index
691    pub fn index(&self, schema_id: impl ToString, id: impl ToString, fields: &[&Field]) -> IndexBuilder {
692        IndexBuilder {
693            table: self.clone(),
694            schema_id: schema_id.to_string(),
695            id: id.to_string(),
696            fields: fields.iter().map(|e| (*e).clone()).collect(),
697            unique: false,
698        }
699    }
700}
701
702pub struct IndexBuilder {
703    table: Table,
704    schema_id: String,
705    id: String,
706    fields: Vec<Field>,
707    unique: bool,
708}
709
710impl IndexBuilder {
711    pub fn unique(mut self) -> Self {
712        self.unique = true;
713        self
714    }
715
716    pub fn build(self, v: &mut Version) -> Index {
717        let mut deps = vec![GraphId::Table(self.table.schema_id.clone())];
718        for field in &self.fields {
719            deps.push(GraphId::Field(field.table.schema_id.clone(), field.schema_id.clone()));
720        }
721        let out = Index(Rc::new(Index_ {
722            table: self.table,
723            schema_id: SchemaIndexId(self.schema_id),
724            id: self.id,
725            fields: self.fields,
726            unique: self.unique,
727        }));
728        if v
729            .schema
730            .insert(
731                GraphId::Index(out.table.schema_id.clone(), out.schema_id.clone()),
732                MigrateNode::new(deps, Node::table_index(NodeIndex_ { def: out.clone() })),
733            )
734            .is_some() {
735            panic!("Index with schema id {}.{} already exists", out.table.schema_id, out.schema_id);
736        };
737        out
738    }
739}
740
741/// Generate Rust code for migrations and queries.
742///
743/// # Arguments
744///
745/// * `output` - the path to a single rust source file where the output will be written
746///
747/// * `versions` - a list of database version ids and schema versions. The ids must be
748///   consecutive but can start from any number. Once a version has been applied to a
749///   production database it shouldn't be modified again (modifications should be done
750///   in a new version).
751///
752///   These will be turned into migrations as part of the `migrate` function.
753///
754/// * `queries` - a list of queries against the schema in the latest version. These
755///   will be turned into functions.
756///
757/// # Returns
758///
759/// * Error - a list of validation or generation errors that occurred
760pub fn generate(output: &Path, versions: Vec<(usize, Version)>, queries: Vec<Query>) -> Result<(), Vec<String>> {
761    {
762        let mut prev_relations: HashMap<&String, String> = HashMap::new();
763        let mut prev_fields = HashMap::new();
764        let mut prev_constraints = HashMap::new();
765        for (v_i, v) in &versions {
766            let mut relations = HashMap::new();
767            let mut fields = HashMap::new();
768            let mut constraints = HashMap::new();
769            for n in v.schema.values() {
770                match &n.body {
771                    Node::Table(t) => {
772                        let id = &t.def.id;
773                        let comp_id = format!("table {}", t.def.schema_id);
774                        if relations.insert(id, comp_id.clone()).is_some() {
775                            panic!("Duplicate table id {} -- {}", t.def.id, t.def);
776                        }
777                        if let Some(schema_id) = prev_relations.get(id) {
778                            if schema_id != &comp_id {
779                                panic!(
780                                    "Table {} id in version {} swapped with another relation since previous version; unsupported",
781                                    t.def,
782                                    v_i
783                                );
784                            }
785                        }
786                    },
787                    Node::Field(f) => {
788                        let id = (&f.def.table.schema_id, &f.def.id);
789                        if fields.insert(id, f.def.schema_id.clone()).is_some() {
790                            panic!("Duplicate field id {} -- {}", f.def.id, f.def);
791                        }
792                        if let Some(schema_id) = prev_fields.get(&id) {
793                            if schema_id != &f.def.schema_id {
794                                panic!(
795                                    "Field {} id in version {} swapped with another field since previous version; unsupported",
796                                    f.def,
797                                    v_i
798                                );
799                            }
800                        }
801                    },
802                    Node::Constraint(c) => {
803                        let id = (&c.def.table.schema_id, &c.def.id);
804                        if constraints.insert(id, c.def.schema_id.clone()).is_some() {
805                            panic!("Duplicate constraint id {} -- {}", c.def.id, c.def);
806                        }
807                        if let Some(schema_id) = prev_constraints.get(&id) {
808                            if schema_id != &c.def.schema_id {
809                                panic!(
810                                    "Constraint {} id in version {} swapped with another constraint since previous version; unsupported",
811                                    c.def,
812                                    v_i
813                                );
814                            }
815                        }
816                    },
817                    Node::Index(i) => {
818                        let id = &i.def.id;
819                        let comp_id = format!("index {}", i.def.schema_id);
820                        if relations.insert(id, comp_id.clone()).is_some() {
821                            panic!("Duplicate index id {} -- {}", i.def.id, i.def);
822                        }
823                        if let Some(schema_id) = prev_relations.get(&id) {
824                            if schema_id != &comp_id {
825                                panic!(
826                                    "Index {} id in version {} swapped with another relation since previous version; unsupported",
827                                    i.def,
828                                    v_i
829                                );
830                            }
831                        }
832                    },
833                }
834            }
835            prev_relations = relations;
836            prev_fields = fields;
837            prev_constraints = constraints;
838        }
839    }
840    let mut errs = Errs::new();
841    let mut migrations = vec![];
842    let mut prev_version: Option<Version> = None;
843    let mut prev_version_i: Option<i64> = None;
844    let mut field_lookup = HashMap::new();
845    for (version_i, version) in versions {
846        let path = rpds::vector![format!("Migration to {}", version_i)];
847        let mut migration = vec![];
848
849        fn do_migration_query(
850            errs: &mut Errs,
851            path: &rpds::Vector<String>,
852            migration: &mut Vec<TokenStream>,
853            field_lookup: &HashMap<Table, HashMap<Field, Type>>,
854            q: &dyn QueryBody,
855        ) {
856            let mut qctx = PgQueryCtx::new(errs.clone(), &field_lookup);
857            let e_res = q.build(&mut qctx, path, QueryResCount::None);
858            if !qctx.rust_args.is_empty() {
859                qctx.errs.err(path, format!("Migration statements can't receive arguments"));
860            }
861            let statement = e_res.1.to_string();
862            let args = qctx.query_args;
863            migration.push(quote!{
864                {
865                    let query = #statement;
866                    txn.execute(query, &[#(& #args,) *]).await.to_good_error_query(query) ?;
867                };
868            });
869        }
870
871        // Do pre-migrations
872        for (i, q) in version.pre_migration.iter().enumerate() {
873            do_migration_query(
874                &mut errs,
875                &path.push_back(format!("Pre-migration statement {}", i)),
876                &mut migration,
877                &field_lookup,
878                q.as_ref(),
879            );
880        }
881
882        // Prep for current version
883        field_lookup.clear();
884        let version_i = version_i as i64;
885        if let Some(i) = prev_version_i {
886            if version_i != i as i64 + 1 {
887                errs.err(
888                    &path,
889                    format!(
890                        "Version numbers are not consecutive ({} to {}) - was an intermediate version deleted?",
891                        i,
892                        version_i
893                    ),
894                );
895            }
896        }
897
898        // Gather tables for lookup during query generation and check duplicates
899        for v in version.schema.values() {
900            match &v.body {
901                Node::Field(f) => {
902                    match field_lookup.entry(f.def.table.clone()) {
903                        std::collections::hash_map::Entry::Occupied(_) => { },
904                        std::collections::hash_map::Entry::Vacant(e) => {
905                            e.insert(HashMap::new());
906                        },
907                    };
908                    let table = field_lookup.get_mut(&f.def.table).unwrap();
909                    table.insert(f.def.clone(), f.def.type_.type_.clone());
910                },
911                _ => { },
912            };
913        }
914
915        // Main migrations
916        {
917            let mut state = PgMigrateCtx::new(errs.clone());
918            crate::graphmigrate::migrate(&mut state, prev_version.take().map(|s| s.schema), &version.schema);
919            for statement in &state.statements {
920                migration.push(quote!{
921                    {
922                        let query = #statement;
923                        txn.execute(query, &[]).await.to_good_error_query(query)?;
924                    };
925                });
926            }
927        }
928
929        // Post-migration
930        for (i, q) in version.post_migration.iter().enumerate() {
931            do_migration_query(
932                &mut errs,
933                &path.push_back(format!("Post-migration statement {}", i)),
934                &mut migration,
935                &field_lookup,
936                q.as_ref(),
937            );
938        }
939
940        // Build migration
941        migrations.push(quote!{
942            if version < #version_i {
943                #(#migration) *
944            }
945        });
946
947        // Next iter prep
948        prev_version = Some(version);
949        prev_version_i = Some(version_i);
950    }
951
952    // Generate queries
953    let mut db_others = Vec::new();
954    {
955        let mut res_type_idents: HashMap<String, Ident> = HashMap::new();
956        for q in queries {
957            let path = rpds::vector![format!("Query {}", q.name)];
958            let mut ctx = PgQueryCtx::new(errs.clone(), &field_lookup);
959            let res = QueryBody::build(q.body.as_ref(), &mut ctx, &path, q.res_count.clone());
960            let ident = format_ident!("{}", q.name);
961            let q_text = res.1.to_string();
962            let args = ctx.rust_args.split_off(0);
963            let args_forward = ctx.query_args.split_off(0);
964            drop(ctx);
965            let (res_ident, res_def, unforward_res) = {
966                fn convert_one_res(
967                    errs: &mut Errs,
968                    path: &rpds::Vector<String>,
969                    i: usize,
970                    k: &ExprValName,
971                    v: &Type,
972                ) -> Option<(Ident, TokenStream, TokenStream)> {
973                    if k.id.is_empty() {
974                        errs.err(
975                            path,
976                            format!("Result element {} has no name; name it using `rename` if this is intentional", i),
977                        );
978                        return None;
979                    }
980                    let rust_types = to_rust_types(&v.type_.type_);
981                    let custom_trait_ident = rust_types.custom_trait;
982                    let mut ident = rust_types.ret_type;
983                    if v.opt {
984                        ident = quote!(Option < #ident >);
985                    }
986                    let mut unforward = quote!{
987                        let x: #ident = r.get(#i);
988                    };
989                    if let Some(custom) = &v.type_.custom {
990                        ident = match syn::parse_str::<syn::Path>(&custom) {
991                            Ok(i) => i.to_token_stream(),
992                            Err(e) => {
993                                errs.err(
994                                    path,
995                                    format!(
996                                        "Couldn't parse provided custom type name [{}] as identifier path: {:?}",
997                                        custom,
998                                        e
999                                    ),
1000                                );
1001                                return None;
1002                            },
1003                        };
1004                        if v.opt {
1005                            unforward = quote!{
1006                                #unforward let x = if let Some(x) = x {
1007                                    Some(
1008                                        < #ident as #custom_trait_ident < #ident >>:: from_sql(
1009                                            x
1010                                        ).to_good_error(|| format!("Parsing result {}", #i)) ?
1011                                    )
1012                                }
1013                                else {
1014                                    None
1015                                };
1016                            };
1017                            ident = quote!(Option < #ident >);
1018                        } else {
1019                            unforward = quote!{
1020                                #unforward let x =< #ident as #custom_trait_ident < #ident >>:: from_sql(
1021                                    x
1022                                ).to_good_error(|| format!("Parsing result {}", #i)) ?;
1023                            };
1024                        }
1025                    }
1026                    return Some((format_ident!("{}", sanitize_ident(&k.id).1), ident, quote!({
1027                        #unforward x
1028                    })));
1029                }
1030
1031                if res.0.0.len() == 1 {
1032                    let e = &res.0.0[0];
1033                    let (_, type_ident, unforward) = match convert_one_res(&mut errs, &path, 0, &e.0, &e.1) {
1034                        None => {
1035                            continue;
1036                        },
1037                        Some(x) => x,
1038                    };
1039                    (type_ident, None, unforward)
1040                } else {
1041                    let mut fields = vec![];
1042                    let mut unforward_fields = vec![];
1043                    for (i, (k, v)) in res.0.0.into_iter().enumerate() {
1044                        let (k_ident, type_ident, unforward) = match convert_one_res(&mut errs, &path, i, &k, &v) {
1045                            Some(x) => x,
1046                            None => continue,
1047                        };
1048                        fields.push(quote!{
1049                            pub #k_ident: #type_ident
1050                        });
1051                        unforward_fields.push(quote!{
1052                            #k_ident: #unforward
1053                        });
1054                    }
1055                    let body = quote!({
1056                        #(#fields,) *
1057                    });
1058                    let res_type_count = res_type_idents.len();
1059                    let (res_ident, res_def) = match res_type_idents.entry(body.to_string()) {
1060                        std::collections::hash_map::Entry::Occupied(e) => {
1061                            (e.get().clone(), None)
1062                        },
1063                        std::collections::hash_map::Entry::Vacant(e) => {
1064                            let ident = if let Some(name) = q.res_name {
1065                                format_ident!("{}", name)
1066                            } else {
1067                                format_ident!("DbRes{}", res_type_count)
1068                            };
1069                            e.insert(ident.clone());
1070                            let res_def = quote!(pub struct #ident #body);
1071                            (ident, Some(res_def))
1072                        },
1073                    };
1074                    let unforward = quote!(#res_ident {
1075                        #(#unforward_fields,) *
1076                    });
1077                    (res_ident.to_token_stream(), res_def, unforward)
1078                }
1079            };
1080            let db_arg = quote!(db: &mut impl tokio_postgres::GenericClient);
1081            match q.res_count {
1082                QueryResCount::None => {
1083                    db_others.push(quote!{
1084                        pub async fn #ident(#db_arg, #(#args,) *) -> Result <(),
1085                        GoodError > {
1086                            let query = #q_text;
1087                            db.execute(query, &[#(& #args_forward,) *]).await.to_good_error_query(query) ?;
1088                            Ok(())
1089                        }
1090                    });
1091                },
1092                QueryResCount::MaybeOne => {
1093                    if let Some(res_def) = res_def {
1094                        db_others.push(res_def);
1095                    }
1096                    db_others.push(quote!{
1097                        pub async fn #ident(#db_arg, #(#args,) *) -> Result < Option < #res_ident >,
1098                        GoodError > {
1099                            let query = #q_text;
1100                            let r = db.query_opt(query, &[#(& #args_forward,) *]).await.to_good_error_query(query) ?;
1101                            if let Some(r) = r {
1102                                return Ok(Some(#unforward_res));
1103                            }
1104                            Ok(None)
1105                        }
1106                    });
1107                },
1108                QueryResCount::One => {
1109                    if let Some(res_def) = res_def {
1110                        db_others.push(res_def);
1111                    }
1112                    db_others.push(quote!{
1113                        pub async fn #ident(#db_arg, #(#args,) *) -> Result < #res_ident,
1114                        GoodError > {
1115                            let query = #q_text;
1116                            let r = db.query_one(query, &[#(& #args_forward,) *]).await.to_good_error_query(query) ?;
1117                            Ok(#unforward_res)
1118                        }
1119                    });
1120                },
1121                QueryResCount::Many => {
1122                    if let Some(res_def) = res_def {
1123                        db_others.push(res_def);
1124                    }
1125                    db_others.push(quote!{
1126                        pub async fn #ident(#db_arg, #(#args,) *) -> Result < Vec < #res_ident >,
1127                        GoodError > {
1128                            let mut out = vec![];
1129                            let query = #q_text;
1130                            for r in db.query(query, &[#(& #args_forward,) *]).await.to_good_error_query(query) ? {
1131                                out.push(#unforward_res);
1132                            }
1133                            Ok(out)
1134                        }
1135                    });
1136                },
1137            }
1138        }
1139    }
1140
1141    // Compile, output
1142    let last_version_i = prev_version_i.unwrap() as i64;
1143    let tokens = quote!{
1144        use good_ormning_runtime::GoodError;
1145        use good_ormning_runtime::ToGoodError;
1146        pub async fn migrate(db: &mut tokio_postgres::Client) -> Result <(),
1147        GoodError > {
1148            {
1149                let query =
1150                    "create table if not exists __good_version (rid int primary key, version bigint not null, lock int not null);";
1151                db.execute(query, &[]).await.to_good_error_query(query)?;
1152            }
1153            {
1154                let query =
1155                    "insert into __good_version (rid, version, lock) values (0, -1, 0) on conflict do nothing;";
1156                db.execute(query, &[]).await.to_good_error_query(query)?;
1157            }
1158            loop {
1159                let txn = db.transaction().await.to_good_error(|| "Failed to start transaction".to_string())?;
1160                match(|| {
1161                    async {
1162                        let query =
1163                            "update __good_version set lock = 1 where rid = 0 and lock = 0 returning version";
1164                        let version = match txn.query_opt(query, &[]).await.to_good_error_query(query)? {
1165                            Some(r) => {
1166                                let ver: i64 = r.get("version");
1167                                ver
1168                            },
1169                            None => {
1170                                return Ok(false);
1171                            },
1172                        };
1173                        if version > #last_version_i {
1174                            return Err(
1175                                GoodError(
1176                                    format!(
1177                                        "The latest known version is {}, but the schema is at unknown version {}",
1178                                        #last_version_i,
1179                                        version
1180                                    ),
1181                                ),
1182                            );
1183                        }
1184                        #(#migrations) * {
1185                            let query = "update __good_version set version = $1, lock = 0";
1186                            txn.execute(query, &[& #last_version_i]).await.to_good_error_query(query) ?;
1187                        }
1188                        let out: Result < bool,
1189                        GoodError >= Ok(true);
1190                        out
1191                    }
1192                })().await {
1193                    Err(e) => {
1194                        match txn.rollback().await {
1195                            Err(e1) => {
1196                                return Err(
1197                                    GoodError(
1198                                        format!(
1199                                            "{}\n\nRolling back the transaction due to the above also failed: {}",
1200                                            e,
1201                                            e1
1202                                        ),
1203                                    ),
1204                                );
1205                            },
1206                            Ok(_) => {
1207                                return Err(GoodError(e.to_string()));
1208                            },
1209                        };
1210                    }
1211                    Ok(migrated) => {
1212                        match txn.commit().await {
1213                            Err(e) => {
1214                                return Err(GoodError(format!("Error committing the migration transaction: {}", e)));
1215                            },
1216                            Ok(_) => {
1217                                if migrated {
1218                                    return Ok(())
1219                                } else {
1220                                    tokio::time::sleep(tokio::time::Duration::from_millis(5 * 1000)).await;
1221                                }
1222                            },
1223                        };
1224                    }
1225                }
1226            }
1227        }
1228        #(#db_others) *
1229    };
1230    if let Some(p) = output.parent() {
1231        if let Err(e) = fs::create_dir_all(&p) {
1232            errs.err(
1233                &rpds::vector![],
1234                format!("Error creating output parent directories {}: {:?}", p.to_string_lossy(), e),
1235            );
1236        }
1237    }
1238    match genemichaels_lib::format_str(&tokens.to_string(), &genemichaels_lib::FormatConfig::default()) {
1239        Ok(src) => {
1240            match fs::write(output, src.rendered.as_bytes()) {
1241                Ok(_) => { },
1242                Err(e) => errs.err(
1243                    &rpds::vector![],
1244                    format!("Failed to write generated code to {}: {:?}", output.to_string_lossy(), e),
1245                ),
1246            };
1247        },
1248        Err(e) => {
1249            errs.err(&rpds::vector![], format!("Error formatting generated code: {:?}\n{}", e, tokens));
1250        },
1251    };
1252    errs.raise()?;
1253    Ok(())
1254}
1255
1256#[cfg(test)]
1257mod test {
1258    use std::{
1259        path::PathBuf,
1260        str::FromStr,
1261    };
1262    use crate::pg::{
1263        new_select,
1264        QueryResCount,
1265        new_insert,
1266    };
1267    use super::{
1268        schema::field::{
1269            field_str,
1270            field_auto,
1271            field_i32,
1272        },
1273        generate,
1274        Version,
1275        query::expr::Expr,
1276    };
1277
1278    #[test]
1279    fn test_add_field_serial_bad() {
1280        assert!(generate(&PathBuf::from_str("/dev/null").unwrap(), vec![
1281            // Versions (previous)
1282            (0usize, {
1283                let mut v = Version::default();
1284                let bananna = v.table("zMOY9YMCK", "bananna");
1285                bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1286                v
1287            }),
1288            (1usize, {
1289                let mut v = Version::default();
1290                let bananna = v.table("zMOY9YMCK", "bananna");
1291                bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1292                bananna.field(&mut v, "zPREUVAOD", "zomzom", field_auto().migrate_fill(Expr::LitAuto(0)).build());
1293                v
1294            })
1295        ], vec![]).is_err());
1296    }
1297
1298    #[test]
1299    #[should_panic]
1300    fn test_add_field_dup_bad() {
1301        generate(&PathBuf::from_str("/dev/null").unwrap(), vec![
1302            // Versions (previous)
1303            (0usize, {
1304                let mut v = Version::default();
1305                let bananna = v.table("zPAO2PJU4", "bananna");
1306                bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1307                v
1308            }),
1309            (1usize, {
1310                let mut v = Version::default();
1311                let bananna = v.table("zQZQ8E2WD", "bananna");
1312                bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1313                bananna.field(&mut v, "z437INV6D", "zomzom", field_i32().build());
1314                v
1315            })
1316        ], vec![]).unwrap();
1317    }
1318
1319    #[test]
1320    #[should_panic]
1321    fn test_add_table_dup_bad() {
1322        generate(&PathBuf::from_str("/dev/null").unwrap(), vec![
1323            // Versions (previous)
1324            (0usize, {
1325                let mut v = Version::default();
1326                let bananna = v.table("zSNS34DYI", "bananna");
1327                bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1328                v
1329            }),
1330            (1usize, {
1331                let mut v = Version::default();
1332                let bananna = v.table("zSNS34DYI", "bananna");
1333                bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1334                let bananna = v.table("zSNS34DYI", "bananna");
1335                bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1336                v
1337            })
1338        ], vec![]).unwrap();
1339    }
1340
1341    #[test]
1342    fn test_res_count_none_bad() {
1343        let mut v = Version::default();
1344        let bananna = v.table("z5S18LWQE", "bananna");
1345        let hizat = bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1346        assert!(
1347            generate(
1348                &PathBuf::from_str("/dev/null").unwrap(),
1349                vec![(0usize, v)],
1350                vec![new_select(&bananna).return_field(&hizat).build_query("x", QueryResCount::None)],
1351            ).is_err()
1352        );
1353    }
1354
1355    #[test]
1356    fn test_select_nothing_bad() {
1357        let mut v = Version::default();
1358        let bananna = v.table("zOOR88EQ9", "bananna");
1359        bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1360        assert!(
1361            generate(
1362                &PathBuf::from_str("/dev/null").unwrap(),
1363                vec![(0usize, v)],
1364                vec![new_select(&bananna).build_query("x", QueryResCount::None)],
1365            ).is_err()
1366        );
1367    }
1368
1369    #[test]
1370    fn test_returning_none_bad() {
1371        let mut v = Version::default();
1372        let bananna = v.table("zZPD1I2EF", "bananna");
1373        let hizat = bananna.field(&mut v, "z437INV6D", "hizat", field_str().build());
1374        assert!(
1375            generate(
1376                &PathBuf::from_str("/dev/null").unwrap(),
1377                vec![(0usize, v)],
1378                vec![
1379                    new_insert(&bananna, vec![(hizat.clone(), Expr::LitString("hoy".into()))])
1380                        .return_field(&hizat)
1381                        .build_query("x", QueryResCount::None)
1382                ],
1383            ).is_err()
1384        );
1385    }
1386}