ic_dbms_canister/dbms/integrity/
insert.rs

1use ic_dbms_api::prelude::{
2    ColumnDef, Database as _, Filter, ForeignFetcher, ForeignKeyDef, IcDbmsError, IcDbmsResult,
3    Query, QueryError, TableSchema, Value,
4};
5
6use crate::dbms::IcDbmsDatabase;
7
8/// Integrity validator for insert operations.
9pub struct InsertIntegrityValidator<'a, T>
10where
11    T: TableSchema,
12{
13    database: &'a IcDbmsDatabase,
14    _marker: std::marker::PhantomData<T>,
15}
16
17impl<'a, T> InsertIntegrityValidator<'a, T>
18where
19    T: TableSchema,
20{
21    /// Creates a new insert integrity validator.
22    pub fn new(dbms: &'a IcDbmsDatabase) -> Self {
23        Self {
24            database: dbms,
25            _marker: std::marker::PhantomData,
26        }
27    }
28}
29
30impl<T> InsertIntegrityValidator<'_, T>
31where
32    T: TableSchema,
33{
34    /// Verify whether the given insert record is valid.
35    ///
36    /// An insert is valid when:
37    /// - No primary key conflicts with existing records.
38    /// - All foreign keys reference existing records.
39    /// - All non-nullable columns are provided.
40    pub fn validate(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
41        // check validations for each column
42        for (col, value) in record_values {
43            self.check_column_validate(col, value)?;
44        }
45        self.check_primary_key_conflict(record_values)?;
46        self.check_foreign_keys(record_values)?;
47        self.check_non_nullable_fields(record_values)?;
48
49        Ok(())
50    }
51
52    /// Checks whether the given column value is valid according to its validator.
53    fn check_column_validate(&self, column: &ColumnDef, value: &Value) -> IcDbmsResult<()> {
54        let Some(validator) = T::validator(column.name) else {
55            return Ok(());
56        };
57
58        validator.validate(value)
59    }
60
61    /// Checks for primary key conflicts.
62    fn check_primary_key_conflict(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
63        let pk_name = T::primary_key();
64        let pk = record_values
65            .iter()
66            .find(|(col_def, _)| col_def.name == pk_name)
67            .map(|(_, value)| value.clone())
68            .ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
69                pk_name.to_string(),
70            )))?;
71
72        // select
73        let query: Query<T> = Query::builder()
74            .field(pk_name)
75            .and_where(Filter::Eq(pk_name.to_string(), pk))
76            .build();
77
78        let res = self.database.select(query)?;
79        if res.is_empty() {
80            Ok(())
81        } else {
82            Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
83        }
84    }
85
86    /// Checks whether all the foreign keys reference existing records.
87    fn check_foreign_keys(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
88        record_values
89            .iter()
90            .filter_map(|(col, value)| col.foreign_key.as_ref().map(|fk| (fk, value)))
91            .try_for_each(|(col, value)| self.check_foreign_key_existence(col, value))
92    }
93
94    /// Checks whether a foreign key references an existing record.
95    fn check_foreign_key_existence(
96        &self,
97        foreign_key: &ForeignKeyDef,
98        value: &Value,
99    ) -> IcDbmsResult<()> {
100        let res = T::foreign_fetcher().fetch(
101            self.database,
102            foreign_key.foreign_table,
103            foreign_key.local_column,
104            value.clone(),
105        )?;
106        if res.is_empty() {
107            Err(IcDbmsError::Query(
108                QueryError::ForeignKeyConstraintViolation {
109                    field: foreign_key.local_column.to_string(),
110                    referencing_table: foreign_key.foreign_table.to_string(),
111                },
112            ))
113        } else {
114            Ok(())
115        }
116    }
117
118    /// Check whether all non-nullable fields are provided.
119    fn check_non_nullable_fields(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
120        for column in T::columns().iter().filter(|col| !col.nullable) {
121            if !record_values
122                .iter()
123                .any(|(col_def, _)| col_def.name == column.name)
124            {
125                return Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
126                    column.name.to_string(),
127                )));
128            }
129        }
130
131        Ok(())
132    }
133}
134
135#[cfg(test)]
136mod tests {
137
138    use ic_dbms_api::prelude::DateTime;
139
140    use super::*;
141    use crate::tests::{Message, Post, TestDatabaseSchema, User, load_fixtures};
142
143    #[test]
144    fn test_should_not_pass_email_validation() {
145        load_fixtures();
146        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
147
148        let values = User::columns()
149            .iter()
150            .cloned()
151            .zip(vec![
152                Value::Uint32(10.into()),
153                Value::Text("Bob".to_string().into()),
154                Value::Text("invalid-email".to_string().into()),
155            ])
156            .collect::<Vec<(ColumnDef, Value)>>();
157
158        let validator = InsertIntegrityValidator::<User>::new(&dbms);
159        let result = validator.validate(&values);
160        assert!(matches!(result, Err(IcDbmsError::Validation(_))));
161    }
162
163    #[test]
164    fn test_should_not_pass_check_for_pk_conflict() {
165        load_fixtures();
166        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
167
168        let values = User::columns()
169            .iter()
170            .cloned()
171            .zip(vec![
172                Value::Uint32(1.into()),
173                Value::Text("Alice".to_string().into()),
174                Value::Text("alice@example.com".into()),
175            ])
176            .collect::<Vec<(ColumnDef, Value)>>();
177
178        let validator = InsertIntegrityValidator::<User>::new(&dbms);
179        let result = validator.validate(&values);
180        assert!(matches!(
181            result,
182            Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
183        ));
184    }
185    #[test]
186    fn test_should_pass_check_for_pk_conflict() {
187        load_fixtures();
188        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
189
190        // no conflict case
191        let values = User::columns()
192            .iter()
193            .cloned()
194            .zip(vec![
195                Value::Uint32(1000.into()),
196                Value::Text("Alice".to_string().into()),
197                Value::Text("alice@example.com".into()),
198            ])
199            .collect::<Vec<(ColumnDef, Value)>>();
200
201        let validator = InsertIntegrityValidator::<User>::new(&dbms);
202        let result = validator.validate(&values);
203        assert!(result.is_ok());
204    }
205
206    #[test]
207    fn test_should_not_pass_check_for_fk_conflict() {
208        load_fixtures();
209        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
210
211        let values = Post::columns()
212            .iter()
213            .cloned()
214            .zip(vec![
215                Value::Uint32(1.into()),
216                Value::Text("Title".to_string().into()),
217                Value::Text("Content".to_string().into()),
218                Value::Uint32(9999.into()), // non-existing user_id
219            ])
220            .collect::<Vec<(ColumnDef, Value)>>();
221
222        let validator = InsertIntegrityValidator::<Post>::new(&dbms);
223        let result = validator.check_foreign_keys(&values);
224        println!("{:?}", result);
225        assert!(matches!(
226            result,
227            Err(IcDbmsError::Query(QueryError::BrokenForeignKeyReference {
228                table,
229                ..
230            })) if table == User::table_name()
231        ));
232    }
233
234    #[test]
235    fn test_should_pass_check_for_fk_conflict() {
236        load_fixtures();
237        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
238
239        let values = Post::columns()
240            .iter()
241            .cloned()
242            .zip(vec![
243                Value::Uint32(1.into()),
244                Value::Text("Title".to_string().into()),
245                Value::Text("Content".to_string().into()),
246                Value::Uint32(1.into()), // existing user_id
247            ])
248            .collect::<Vec<(ColumnDef, Value)>>();
249
250        let validator = InsertIntegrityValidator::<Post>::new(&dbms);
251        let result = validator.check_foreign_keys(&values);
252        assert!(result.is_ok());
253    }
254
255    #[test]
256    fn test_should_not_pass_non_nullable_field_check() {
257        load_fixtures();
258        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
259
260        let values = Post::columns()
261            .iter()
262            .cloned()
263            .filter(|col| col.name != "title") // omit non-nullable field
264            .zip(vec![
265                Value::Uint32(1.into()),
266                // Missing title
267                Value::Text("Content".to_string().into()),
268                Value::Uint32(1.into()),
269            ])
270            .collect::<Vec<(ColumnDef, Value)>>();
271
272        let validator = InsertIntegrityValidator::<Post>::new(&dbms);
273        let result = validator.check_non_nullable_fields(&values);
274        assert!(matches!(
275            result,
276            Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
277                field_name
278            ))) if field_name == "title"
279        ));
280    }
281
282    #[test]
283    fn test_should_pass_non_nullable_field_check() {
284        load_fixtures();
285        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
286
287        let values = Message::columns()
288            .iter()
289            .filter(|col| !col.nullable)
290            .cloned()
291            .zip(vec![
292                Value::Uint32(100.into()),
293                Value::Text("Hello".to_string().into()),
294                Value::Uint32(1.into()),
295                Value::Uint32(2.into()),
296            ])
297            .collect::<Vec<(ColumnDef, Value)>>();
298
299        let validator = InsertIntegrityValidator::<Message>::new(&dbms);
300        let result = validator.check_non_nullable_fields(&values);
301        assert!(result.is_ok());
302
303        // should pass with nullable set
304
305        let values = Message::columns()
306            .iter()
307            .cloned()
308            .zip(vec![
309                Value::Uint32(100.into()),
310                Value::Text("Hello".to_string().into()),
311                Value::Uint32(1.into()),
312                Value::Uint32(2.into()),
313                Value::DateTime(DateTime {
314                    year: 2024,
315                    month: 6,
316                    day: 1,
317                    hour: 12,
318                    minute: 0,
319                    second: 0,
320                    microsecond: 0,
321                    timezone_offset_minutes: 0,
322                }),
323            ])
324            .collect::<Vec<(ColumnDef, Value)>>();
325
326        let validator = InsertIntegrityValidator::<Message>::new(&dbms);
327        let result = validator.check_non_nullable_fields(&values);
328        assert!(result.is_ok());
329    }
330}