Skip to main content

lift_migration/schema/
blueprint.rs

1use crate::{
2    BlueprintExecutor, MigrationError,
3    schema::render::{
4        default_index_name, infer_referenced_table, render_column, render_constraint,
5        render_foreign_key,
6    },
7};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SchemaDialect {
11    Sqlite,
12    Postgres,
13    MariaDb,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ColumnType {
18    Integer,
19    BigInt,
20    Bool,
21    Char(u32),
22    Varchar(u32),
23    Text,
24    Date,
25    Time,
26    DateTime,
27    Timestamp,
28    Decimal(u32, u32),
29    Float,
30    Double,
31    Json,
32    Uuid,
33    Custom(String),
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub(crate) enum TableAlterOperation {
38    DropColumn(String),
39    AddColumn(ColumnDef),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub(crate) enum ConstraintDef {
44    Primary {
45        columns: Vec<String>,
46    },
47    Unique {
48        name: Option<String>,
49        columns: Vec<String>,
50    },
51    Check {
52        name: Option<String>,
53        expression: String,
54    },
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub(crate) struct ColumnDef {
59    pub(crate) name: String,
60    pub(crate) ty: ColumnType,
61    pub(crate) nullable: bool,
62    pub(crate) primary_key: bool,
63    pub(crate) auto_increment: bool,
64    pub(crate) unique: bool,
65    pub(crate) default_raw: Option<String>,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct DefaultValue {
70    pub(crate) sql: String,
71}
72
73impl DefaultValue {
74    pub fn raw(sql: impl Into<String>) -> Self {
75        Self { sql: sql.into() }
76    }
77}
78
79pub fn current_timestamp() -> DefaultValue {
80    DefaultValue::raw("current_timestamp")
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub(crate) struct ForeignKeyDef {
85    pub(crate) column: String,
86    pub(crate) references_table: String,
87    pub(crate) references_column: String,
88    pub(crate) on_delete: Option<ForeignKeyAction>,
89    pub(crate) on_update: Option<ForeignKeyAction>,
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum ForeignKeyAction {
94    Cascade,
95    Restrict,
96    SetNull,
97    NoAction,
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct IndexBlueprint {
102    name: String,
103    table: String,
104    columns: Vec<String>,
105    unique: bool,
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub struct TableBlueprint {
110    name: String,
111    columns: Vec<ColumnDef>,
112    foreign_keys: Vec<ForeignKeyDef>,
113    constraints: Vec<ConstraintDef>,
114    indexes: Vec<IndexBlueprint>,
115}
116
117#[derive(Debug, Clone, PartialEq, Eq)]
118pub struct AlterTableBlueprint {
119    name: String,
120    operations: Vec<TableAlterOperation>,
121}
122
123pub trait IntoSchemaColumns {
124    fn into_schema_columns(self) -> Vec<String>;
125}
126
127impl IntoSchemaColumns for &str {
128    fn into_schema_columns(self) -> Vec<String> {
129        vec![self.to_owned()]
130    }
131}
132
133impl IntoSchemaColumns for String {
134    fn into_schema_columns(self) -> Vec<String> {
135        vec![self]
136    }
137}
138
139impl<const N: usize> IntoSchemaColumns for [&str; N] {
140    fn into_schema_columns(self) -> Vec<String> {
141        self.into_iter().map(str::to_owned).collect()
142    }
143}
144
145impl<const N: usize> IntoSchemaColumns for [String; N] {
146    fn into_schema_columns(self) -> Vec<String> {
147        self.into_iter().collect()
148    }
149}
150
151impl IntoSchemaColumns for Vec<&str> {
152    fn into_schema_columns(self) -> Vec<String> {
153        self.into_iter().map(str::to_owned).collect()
154    }
155}
156
157impl IntoSchemaColumns for Vec<String> {
158    fn into_schema_columns(self) -> Vec<String> {
159        self
160    }
161}
162
163macro_rules! impl_into_schema_columns_tuple {
164    ($(($type_name:ident, $value_name:ident)),+ $(,)?) => {
165        impl<$($type_name),+> IntoSchemaColumns for ($($type_name,)+)
166        where
167            $($type_name: AsRef<str>,)+
168        {
169            fn into_schema_columns(self) -> Vec<String> {
170                let ($($value_name,)+) = self;
171                vec![$($value_name.as_ref().to_owned(),)+]
172            }
173        }
174    };
175}
176
177impl_into_schema_columns_tuple!((A, a), (B, b));
178impl_into_schema_columns_tuple!((A, a), (B, b), (C, c));
179impl_into_schema_columns_tuple!((A, a), (B, b), (C, c), (D, d));
180
181impl TableBlueprint {
182    pub fn new(name: impl Into<String>) -> Self {
183        Self {
184            name: name.into(),
185            columns: Vec::new(),
186            foreign_keys: Vec::new(),
187            constraints: Vec::new(),
188            indexes: Vec::new(),
189        }
190    }
191
192    pub fn id(&mut self) {
193        self.big_increments("id");
194    }
195
196    pub fn increments(&mut self, name: &str) {
197        self.columns.push(ColumnDef {
198            name: name.to_owned(),
199            ty: ColumnType::Integer,
200            nullable: false,
201            primary_key: true,
202            auto_increment: true,
203            unique: false,
204            default_raw: None,
205        });
206    }
207
208    pub fn big_increments(&mut self, name: &str) {
209        self.columns.push(ColumnDef {
210            name: name.to_owned(),
211            ty: ColumnType::BigInt,
212            nullable: false,
213            primary_key: true,
214            auto_increment: true,
215            unique: false,
216            default_raw: None,
217        });
218    }
219
220    pub fn string(&mut self, name: &str) -> ColumnBuilder<'_> {
221        self.push_column(name, ColumnType::Varchar(255))
222    }
223
224    pub fn char(&mut self, name: &str, len: u32) -> ColumnBuilder<'_> {
225        self.push_column(name, ColumnType::Char(len))
226    }
227
228    pub fn varchar(&mut self, name: &str, len: u32) -> ColumnBuilder<'_> {
229        self.push_column(name, ColumnType::Varchar(len))
230    }
231
232    pub fn text(&mut self, name: &str) -> ColumnBuilder<'_> {
233        self.push_column(name, ColumnType::Text)
234    }
235
236    pub fn integer(&mut self, name: &str) -> ColumnBuilder<'_> {
237        self.push_column(name, ColumnType::Integer)
238    }
239
240    pub fn bigint(&mut self, name: &str) -> ColumnBuilder<'_> {
241        self.push_column(name, ColumnType::BigInt)
242    }
243
244    pub fn boolean(&mut self, name: &str) -> ColumnBuilder<'_> {
245        self.push_column(name, ColumnType::Bool)
246    }
247
248    pub fn date(&mut self, name: &str) -> ColumnBuilder<'_> {
249        self.push_column(name, ColumnType::Date)
250    }
251
252    pub fn time(&mut self, name: &str) -> ColumnBuilder<'_> {
253        self.push_column(name, ColumnType::Time)
254    }
255
256    pub fn datetime(&mut self, name: &str) -> ColumnBuilder<'_> {
257        self.push_column(name, ColumnType::DateTime)
258    }
259
260    pub fn timestamp(&mut self, name: &str) -> ColumnBuilder<'_> {
261        self.push_column(name, ColumnType::Timestamp)
262    }
263
264    pub fn decimal(&mut self, name: &str, precision: u32, scale: u32) -> ColumnBuilder<'_> {
265        self.push_column(name, ColumnType::Decimal(precision, scale))
266    }
267
268    pub fn float(&mut self, name: &str) -> ColumnBuilder<'_> {
269        self.push_column(name, ColumnType::Float)
270    }
271
272    pub fn double(&mut self, name: &str) -> ColumnBuilder<'_> {
273        self.push_column(name, ColumnType::Double)
274    }
275
276    pub fn json(&mut self, name: &str) -> ColumnBuilder<'_> {
277        self.push_column(name, ColumnType::Json)
278    }
279
280    pub fn uuid(&mut self, name: &str) -> ColumnBuilder<'_> {
281        self.push_column(name, ColumnType::Uuid)
282    }
283
284    pub(crate) fn custom(&mut self, name: &str, ty: ColumnType) -> ColumnBuilder<'_> {
285        self.push_column(name, ty)
286    }
287
288    pub fn timestamps(&mut self) {
289        self.timestamp("created_at").default(current_timestamp());
290        self.timestamp("updated_at").default(current_timestamp());
291    }
292
293    pub fn remember_token(&mut self) -> ColumnBuilder<'_> {
294        self.push_column("remember_token", ColumnType::Varchar(100))
295            .nullable()
296    }
297
298    pub fn foreign_id(&mut self, column: &str) -> ForeignKeyBuilder<'_> {
299        self.foreign(column, ColumnType::BigInt)
300    }
301
302    pub fn foreign(&mut self, column: &str, ty: ColumnType) -> ForeignKeyBuilder<'_> {
303        self.columns.push(ColumnDef {
304            name: column.to_owned(),
305            ty,
306            nullable: false,
307            primary_key: false,
308            auto_increment: false,
309            unique: false,
310            default_raw: None,
311        });
312        let index = self.columns.len() - 1;
313        ForeignKeyBuilder {
314            table: self,
315            index,
316            foreign_key: None,
317        }
318    }
319
320    pub fn unique<I>(&mut self, columns: I)
321    where
322        I: IntoSchemaColumns,
323    {
324        self.constraints.push(ConstraintDef::Unique {
325            name: None,
326            columns: columns.into_schema_columns(),
327        });
328    }
329
330    pub fn primary<I>(&mut self, columns: I)
331    where
332        I: IntoSchemaColumns,
333    {
334        self.constraints.push(ConstraintDef::Primary {
335            columns: columns.into_schema_columns(),
336        });
337    }
338
339    pub fn unique_named<I>(&mut self, name: &str, columns: I)
340    where
341        I: IntoSchemaColumns,
342    {
343        self.constraints.push(ConstraintDef::Unique {
344            name: Some(name.to_owned()),
345            columns: columns.into_schema_columns(),
346        });
347    }
348
349    pub fn check(&mut self, expression: &str) {
350        self.constraints.push(ConstraintDef::Check {
351            name: None,
352            expression: expression.to_owned(),
353        });
354    }
355
356    pub fn constraint(&mut self, name: &str) -> ConstraintBuilder<'_> {
357        ConstraintBuilder {
358            table: self,
359            name: name.to_owned(),
360        }
361    }
362
363    pub fn check_named(&mut self, name: &str, expression: &str) {
364        self.constraints.push(ConstraintDef::Check {
365            name: Some(name.to_owned()),
366            expression: expression.to_owned(),
367        });
368    }
369
370    pub fn index<I>(&mut self, name: &str, columns: I)
371    where
372        I: IntoSchemaColumns,
373    {
374        self.indexes
375            .push(IndexBlueprint::new(name, self.name.as_str(), columns));
376    }
377
378    pub fn unique_index<I>(&mut self, name: &str, columns: I)
379    where
380        I: IntoSchemaColumns,
381    {
382        self.indexes.push(IndexBlueprint::new_unique(
383            name,
384            self.name.as_str(),
385            columns,
386        ));
387    }
388
389    fn push_column(&mut self, name: &str, ty: ColumnType) -> ColumnBuilder<'_> {
390        self.columns.push(ColumnDef {
391            name: name.to_owned(),
392            ty,
393            nullable: false,
394            primary_key: false,
395            auto_increment: false,
396            unique: false,
397            default_raw: None,
398        });
399        let index = self.columns.len() - 1;
400        ColumnBuilder { table: self, index }
401    }
402
403    pub fn create_sql(&self, dialect: SchemaDialect) -> String {
404        self.create_statements(dialect)
405            .into_iter()
406            .next()
407            .expect("create table statements are never empty")
408    }
409
410    pub fn create_statements(&self, dialect: SchemaDialect) -> Vec<String> {
411        let name = dialect.quote_ident(&self.name);
412        let mut defs = self
413            .columns
414            .iter()
415            .map(|column| render_column(dialect, column))
416            .collect::<Vec<_>>();
417
418        defs.extend(
419            self.foreign_keys
420                .iter()
421                .map(|foreign| render_foreign_key(dialect, foreign)),
422        );
423        defs.extend(
424            self.constraints
425                .iter()
426                .map(|constraint| render_constraint(dialect, constraint)),
427        );
428
429        let mut statements = vec![format!("create table {name} ({});", defs.join(", "))];
430        statements.extend(self.indexes.iter().map(|index| index.create_sql(dialect)));
431        statements
432    }
433
434    pub fn drop_sql(&self, dialect: SchemaDialect) -> String {
435        format!("drop table if exists {};", dialect.quote_ident(&self.name))
436    }
437
438    pub async fn create<C>(self, ctx: &mut C) -> Result<(), MigrationError>
439    where
440        C: BlueprintExecutor,
441    {
442        for sql in self.create_statements(ctx.dialect()) {
443            ctx.execute_raw_blueprint(&sql).await?;
444        }
445        Ok(())
446    }
447}
448
449impl IndexBlueprint {
450    pub fn named(name: &str) -> Self {
451        Self {
452            name: name.to_owned(),
453            table: String::new(),
454            columns: Vec::new(),
455            unique: false,
456        }
457    }
458
459    pub fn new<I>(name: &str, table: &str, columns: I) -> Self
460    where
461        I: IntoSchemaColumns,
462    {
463        Self {
464            name: name.to_owned(),
465            table: table.to_owned(),
466            columns: columns.into_schema_columns(),
467            unique: false,
468        }
469    }
470
471    pub fn new_unique<I>(name: &str, table: &str, columns: I) -> Self
472    where
473        I: IntoSchemaColumns,
474    {
475        Self {
476            name: name.to_owned(),
477            table: table.to_owned(),
478            columns: columns.into_schema_columns(),
479            unique: true,
480        }
481    }
482
483    pub fn create_sql(&self, dialect: SchemaDialect) -> String {
484        let unique = if self.unique { "unique " } else { "" };
485        let name = dialect.quote_ident(&self.name);
486        let table = dialect.quote_ident(&self.table);
487        let columns = self
488            .columns
489            .iter()
490            .map(|column| dialect.quote_ident(column))
491            .collect::<Vec<_>>()
492            .join(", ");
493        format!("create {unique}index {name} on {table} ({columns});")
494    }
495
496    pub fn drop_sql(&self, dialect: SchemaDialect) -> String {
497        format!("drop index if exists {};", dialect.quote_ident(&self.name))
498    }
499}
500
501impl AlterTableBlueprint {
502    pub fn new(name: impl Into<String>) -> Self {
503        Self {
504            name: name.into(),
505            operations: Vec::new(),
506        }
507    }
508
509    pub fn drop_column(&mut self, name: &str) {
510        self.operations
511            .push(TableAlterOperation::DropColumn(name.to_owned()));
512    }
513
514    pub fn string(&mut self, name: &str) -> AlterColumnBuilder<'_> {
515        self.push_column(name, ColumnType::Varchar(255))
516    }
517
518    pub fn text(&mut self, name: &str) -> AlterColumnBuilder<'_> {
519        self.push_column(name, ColumnType::Text)
520    }
521
522    pub fn varchar(&mut self, name: &str, len: u32) -> AlterColumnBuilder<'_> {
523        self.push_column(name, ColumnType::Varchar(len))
524    }
525
526    pub fn integer(&mut self, name: &str) -> AlterColumnBuilder<'_> {
527        self.push_column(name, ColumnType::Integer)
528    }
529
530    pub fn bigint(&mut self, name: &str) -> AlterColumnBuilder<'_> {
531        self.push_column(name, ColumnType::BigInt)
532    }
533
534    pub fn boolean(&mut self, name: &str) -> AlterColumnBuilder<'_> {
535        self.push_column(name, ColumnType::Bool)
536    }
537
538    pub fn date(&mut self, name: &str) -> AlterColumnBuilder<'_> {
539        self.push_column(name, ColumnType::Date)
540    }
541
542    pub fn time(&mut self, name: &str) -> AlterColumnBuilder<'_> {
543        self.push_column(name, ColumnType::Time)
544    }
545
546    pub fn datetime(&mut self, name: &str) -> AlterColumnBuilder<'_> {
547        self.push_column(name, ColumnType::DateTime)
548    }
549
550    pub fn timestamp(&mut self, name: &str) -> AlterColumnBuilder<'_> {
551        self.push_column(name, ColumnType::Timestamp)
552    }
553
554    pub fn uuid(&mut self, name: &str) -> AlterColumnBuilder<'_> {
555        self.push_column(name, ColumnType::Uuid)
556    }
557
558    pub(crate) fn custom(&mut self, name: &str, ty: ColumnType) -> AlterColumnBuilder<'_> {
559        self.push_column(name, ty)
560    }
561
562    pub fn decimal(&mut self, name: &str, precision: u32, scale: u32) -> AlterColumnBuilder<'_> {
563        self.push_column(name, ColumnType::Decimal(precision, scale))
564    }
565
566    pub fn float(&mut self, name: &str) -> AlterColumnBuilder<'_> {
567        self.push_column(name, ColumnType::Float)
568    }
569
570    pub fn double(&mut self, name: &str) -> AlterColumnBuilder<'_> {
571        self.push_column(name, ColumnType::Double)
572    }
573
574    pub fn json(&mut self, name: &str) -> AlterColumnBuilder<'_> {
575        self.push_column(name, ColumnType::Json)
576    }
577
578    pub fn drop_columns<const N: usize>(&mut self, names: [&str; N]) {
579        for name in names {
580            self.drop_column(name);
581        }
582    }
583
584    pub fn drop_timestamps(&mut self) {
585        self.drop_columns(["created_at", "updated_at"]);
586    }
587
588    fn push_column(&mut self, name: &str, ty: ColumnType) -> AlterColumnBuilder<'_> {
589        self.operations
590            .push(TableAlterOperation::AddColumn(ColumnDef {
591                name: name.to_owned(),
592                ty,
593                nullable: false,
594                primary_key: false,
595                auto_increment: false,
596                unique: false,
597                default_raw: None,
598            }));
599        let index = self.operations.len() - 1;
600        AlterColumnBuilder { table: self, index }
601    }
602
603    pub(crate) fn sql_statements(&self, dialect: SchemaDialect) -> Vec<String> {
604        if self.operations.is_empty() {
605            return Vec::new();
606        }
607
608        let table_name = dialect.quote_ident(&self.name);
609
610        match dialect {
611            SchemaDialect::Sqlite => self
612                .operations
613                .iter()
614                .map(|operation| match operation {
615                    TableAlterOperation::DropColumn(name) => format!(
616                        "alter table {table_name} drop column {};",
617                        dialect.quote_ident(name)
618                    ),
619                    TableAlterOperation::AddColumn(column) => format!(
620                        "alter table {table_name} add column {};",
621                        render_column(dialect, column)
622                    ),
623                })
624                .collect(),
625            SchemaDialect::Postgres | SchemaDialect::MariaDb => {
626                let actions = self
627                    .operations
628                    .iter()
629                    .map(|operation| match operation {
630                        TableAlterOperation::DropColumn(name) => {
631                            format!("drop column {}", dialect.quote_ident(name))
632                        }
633                        TableAlterOperation::AddColumn(column) => {
634                            format!("add column {}", render_column(dialect, column))
635                        }
636                    })
637                    .collect::<Vec<_>>();
638
639                vec![format!("alter table {table_name} {};", actions.join(", "))]
640            }
641        }
642    }
643}
644
645pub struct ColumnBuilder<'a> {
646    table: &'a mut TableBlueprint,
647    index: usize,
648}
649
650pub struct AlterColumnBuilder<'a> {
651    table: &'a mut AlterTableBlueprint,
652    index: usize,
653}
654
655pub struct ConstraintBuilder<'a> {
656    table: &'a mut TableBlueprint,
657    name: String,
658}
659
660impl<'a> ColumnBuilder<'a> {
661    pub fn index(self) -> Self {
662        let table_name = self.table.name.clone();
663        let column_name = self.table.columns[self.index].name.clone();
664        let name = default_index_name(&table_name, &column_name);
665        self.table.indexes.push(IndexBlueprint::new(
666            &name,
667            &table_name,
668            [column_name.as_str()],
669        ));
670        self
671    }
672
673    pub fn default(self, value: DefaultValue) -> Self {
674        self.table.columns[self.index].default_raw = Some(value.sql);
675        self
676    }
677
678    pub fn nullable(self) -> Self {
679        self.table.columns[self.index].nullable = true;
680        self
681    }
682
683    pub fn unique(self) -> Self {
684        self.table.columns[self.index].unique = true;
685        self
686    }
687
688    pub fn default_raw(self, value: &str) -> Self {
689        self.table.columns[self.index].default_raw = Some(value.to_owned());
690        self
691    }
692}
693
694impl<'a> AlterColumnBuilder<'a> {
695    pub fn default(self, value: DefaultValue) -> Self {
696        if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
697            column.default_raw = Some(value.sql);
698        }
699        self
700    }
701
702    pub fn nullable(self) -> Self {
703        if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
704            column.nullable = true;
705        }
706        self
707    }
708
709    pub fn unique(self) -> Self {
710        if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
711            column.unique = true;
712        }
713        self
714    }
715
716    pub fn default_raw(self, value: &str) -> Self {
717        if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
718            column.default_raw = Some(value.to_owned());
719        }
720        self
721    }
722}
723
724impl<'a> ConstraintBuilder<'a> {
725    pub fn unique<I>(self, columns: I)
726    where
727        I: IntoSchemaColumns,
728    {
729        self.table.constraints.push(ConstraintDef::Unique {
730            name: Some(self.name),
731            columns: columns.into_schema_columns(),
732        });
733    }
734
735    pub fn check(self, expression: &str) {
736        self.table.constraints.push(ConstraintDef::Check {
737            name: Some(self.name),
738            expression: expression.to_owned(),
739        });
740    }
741}
742
743pub struct ForeignKeyBuilder<'a> {
744    table: &'a mut TableBlueprint,
745    index: usize,
746    foreign_key: Option<usize>,
747}
748
749impl<'a> ForeignKeyBuilder<'a> {
750    pub fn index(self) -> Self {
751        let table_name = self.table.name.clone();
752        let column_name = self.table.columns[self.index].name.clone();
753        let name = default_index_name(&table_name, &column_name);
754        self.table.indexes.push(IndexBlueprint::new(
755            &name,
756            &table_name,
757            [column_name.as_str()],
758        ));
759        self
760    }
761
762    pub fn default(self, value: DefaultValue) -> Self {
763        self.table.columns[self.index].default_raw = Some(value.sql);
764        self
765    }
766
767    pub fn nullable(self) -> Self {
768        self.table.columns[self.index].nullable = true;
769        self
770    }
771
772    pub fn unique(self) -> Self {
773        self.table.columns[self.index].unique = true;
774        self
775    }
776
777    pub fn default_raw(self, value: &str) -> Self {
778        self.table.columns[self.index].default_raw = Some(value.to_owned());
779        self
780    }
781
782    pub fn constrained(self) -> Self {
783        let column = self.table.columns[self.index].name.clone();
784        let referenced_table = infer_referenced_table(&column);
785        self.references(&referenced_table)
786    }
787
788    pub fn references(self, table: &str) -> Self {
789        self.references_column(table, "id")
790    }
791
792    pub fn references_column(mut self, table: &str, column: &str) -> Self {
793        let foreign_key = ForeignKeyDef {
794            column: self.table.columns[self.index].name.clone(),
795            references_table: table.to_owned(),
796            references_column: column.to_owned(),
797            on_delete: None,
798            on_update: None,
799        };
800
801        match self.foreign_key {
802            Some(index) => self.table.foreign_keys[index] = foreign_key,
803            None => {
804                self.table.foreign_keys.push(foreign_key);
805                self.foreign_key = Some(self.table.foreign_keys.len() - 1);
806            }
807        }
808
809        self
810    }
811
812    pub fn cascade_on_delete(self) -> Self {
813        self.with_on_delete(ForeignKeyAction::Cascade)
814    }
815
816    pub fn restrict_on_delete(self) -> Self {
817        self.with_on_delete(ForeignKeyAction::Restrict)
818    }
819
820    pub fn null_on_delete(self) -> Self {
821        self.with_on_delete(ForeignKeyAction::SetNull)
822    }
823
824    pub fn no_action_on_delete(self) -> Self {
825        self.with_on_delete(ForeignKeyAction::NoAction)
826    }
827
828    pub fn cascade_on_update(self) -> Self {
829        self.with_on_update(ForeignKeyAction::Cascade)
830    }
831
832    pub fn restrict_on_update(self) -> Self {
833        self.with_on_update(ForeignKeyAction::Restrict)
834    }
835
836    pub fn null_on_update(self) -> Self {
837        self.with_on_update(ForeignKeyAction::SetNull)
838    }
839
840    pub fn no_action_on_update(self) -> Self {
841        self.with_on_update(ForeignKeyAction::NoAction)
842    }
843
844    fn with_on_delete(mut self, action: ForeignKeyAction) -> Self {
845        let index = self.ensure_foreign_key();
846        self.table.foreign_keys[index].on_delete = Some(action);
847        self
848    }
849
850    fn with_on_update(mut self, action: ForeignKeyAction) -> Self {
851        let index = self.ensure_foreign_key();
852        self.table.foreign_keys[index].on_update = Some(action);
853        self
854    }
855
856    fn ensure_foreign_key(&mut self) -> usize {
857        if let Some(index) = self.foreign_key {
858            index
859        } else {
860            self.table.foreign_keys.push(ForeignKeyDef {
861                column: self.table.columns[self.index].name.clone(),
862                references_table: String::new(),
863                references_column: "id".to_owned(),
864                on_delete: None,
865                on_update: None,
866            });
867            let index = self.table.foreign_keys.len() - 1;
868            self.foreign_key = Some(index);
869            index
870        }
871    }
872}