ic_dbms_canister/dbms/integrity/
insert.rs

1use ic_dbms_api::prelude::{
2    ColumnDef, Database as _, Filter, ForeignFetcher, ForeignKeyDef, IcDbmsError, IcDbmsResult,
3    Query, QueryError, TableSchema, Value,
4};
5
6use crate::dbms::IcDbmsDatabase;
7
8/// Integrity validator for insert operations.
9pub struct InsertIntegrityValidator<'a, T>
10where
11    T: TableSchema,
12{
13    database: &'a IcDbmsDatabase,
14    _marker: std::marker::PhantomData<T>,
15}
16
17impl<'a, T> InsertIntegrityValidator<'a, T>
18where
19    T: TableSchema,
20{
21    /// Creates a new insert integrity validator.
22    pub fn new(dbms: &'a IcDbmsDatabase) -> Self {
23        Self {
24            database: dbms,
25            _marker: std::marker::PhantomData,
26        }
27    }
28}
29
30impl<T> InsertIntegrityValidator<'_, T>
31where
32    T: TableSchema,
33{
34    /// Verify whether the given insert record is valid.
35    ///
36    /// An insert is valid when:
37    /// - No primary key conflicts with existing records.
38    /// - All foreign keys reference existing records.
39    /// - All non-nullable columns are provided.
40    pub fn validate(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
41        self.check_primary_key_conflict(record_values)?;
42        self.check_foreign_keys(record_values)?;
43        self.check_non_nullable_fields(record_values)?;
44
45        Ok(())
46    }
47
48    /// Checks for primary key conflicts.
49    fn check_primary_key_conflict(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
50        let pk_name = T::primary_key();
51        let pk = record_values
52            .iter()
53            .find(|(col_def, _)| col_def.name == pk_name)
54            .map(|(_, value)| value.clone())
55            .ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
56                pk_name.to_string(),
57            )))?;
58
59        // select
60        let query: Query<T> = Query::builder()
61            .field(pk_name)
62            .and_where(Filter::Eq(pk_name.to_string(), pk))
63            .build();
64
65        let res = self.database.select(query)?;
66        if res.is_empty() {
67            Ok(())
68        } else {
69            Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
70        }
71    }
72
73    /// Checks whether all the foreign keys reference existing records.
74    fn check_foreign_keys(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
75        record_values
76            .iter()
77            .filter_map(|(col, value)| col.foreign_key.as_ref().map(|fk| (fk, value)))
78            .try_for_each(|(col, value)| self.check_foreign_key_existence(col, value))
79    }
80
81    /// Checks whether a foreign key references an existing record.
82    fn check_foreign_key_existence(
83        &self,
84        foreign_key: &ForeignKeyDef,
85        value: &Value,
86    ) -> IcDbmsResult<()> {
87        let res = T::foreign_fetcher().fetch(
88            self.database,
89            foreign_key.foreign_table,
90            foreign_key.local_column,
91            value.clone(),
92        )?;
93        if res.is_empty() {
94            Err(IcDbmsError::Query(
95                QueryError::ForeignKeyConstraintViolation {
96                    field: foreign_key.local_column.to_string(),
97                    referencing_table: foreign_key.foreign_table.to_string(),
98                },
99            ))
100        } else {
101            Ok(())
102        }
103    }
104
105    /// Check whether all non-nullable fields are provided.
106    fn check_non_nullable_fields(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
107        for column in T::columns().iter().filter(|col| !col.nullable) {
108            if !record_values
109                .iter()
110                .any(|(col_def, _)| col_def.name == column.name)
111            {
112                return Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
113                    column.name.to_string(),
114                )));
115            }
116        }
117
118        Ok(())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124
125    use ic_dbms_api::prelude::DateTime;
126
127    use super::*;
128    use crate::tests::{Message, Post, TestDatabaseSchema, User, load_fixtures};
129
130    #[test]
131    fn test_should_not_pass_check_for_pk_conflict() {
132        load_fixtures();
133        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
134
135        let values = User::columns()
136            .iter()
137            .cloned()
138            .zip(vec![
139                Value::Uint32(1.into()),
140                Value::Text("Alice".to_string().into()),
141            ])
142            .collect::<Vec<(ColumnDef, Value)>>();
143
144        let validator = InsertIntegrityValidator::<User>::new(&dbms);
145        let result = validator.validate(&values);
146        assert!(matches!(
147            result,
148            Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
149        ));
150    }
151    #[test]
152    fn test_should_pass_check_for_pk_conflict() {
153        load_fixtures();
154        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
155
156        // no conflict case
157        let values = User::columns()
158            .iter()
159            .cloned()
160            .zip(vec![
161                Value::Uint32(1000.into()),
162                Value::Text("Alice".to_string().into()),
163            ])
164            .collect::<Vec<(ColumnDef, Value)>>();
165
166        let validator = InsertIntegrityValidator::<User>::new(&dbms);
167        let result = validator.validate(&values);
168        assert!(result.is_ok());
169    }
170
171    #[test]
172    fn test_should_not_pass_check_for_fk_conflict() {
173        load_fixtures();
174        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
175
176        let values = Post::columns()
177            .iter()
178            .cloned()
179            .zip(vec![
180                Value::Uint32(1.into()),
181                Value::Text("Title".to_string().into()),
182                Value::Text("Content".to_string().into()),
183                Value::Uint32(9999.into()), // non-existing user_id
184            ])
185            .collect::<Vec<(ColumnDef, Value)>>();
186
187        let validator = InsertIntegrityValidator::<Post>::new(&dbms);
188        let result = validator.check_foreign_keys(&values);
189        println!("{:?}", result);
190        assert!(matches!(
191            result,
192            Err(IcDbmsError::Query(QueryError::BrokenForeignKeyReference {
193                table,
194                ..
195            })) if table == User::table_name()
196        ));
197    }
198
199    #[test]
200    fn test_should_pass_check_for_fk_conflict() {
201        load_fixtures();
202        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
203
204        let values = Post::columns()
205            .iter()
206            .cloned()
207            .zip(vec![
208                Value::Uint32(1.into()),
209                Value::Text("Title".to_string().into()),
210                Value::Text("Content".to_string().into()),
211                Value::Uint32(1.into()), // existing user_id
212            ])
213            .collect::<Vec<(ColumnDef, Value)>>();
214
215        let validator = InsertIntegrityValidator::<Post>::new(&dbms);
216        let result = validator.check_foreign_keys(&values);
217        assert!(result.is_ok());
218    }
219
220    #[test]
221    fn test_should_not_pass_non_nullable_field_check() {
222        load_fixtures();
223        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
224
225        let values = Post::columns()
226            .iter()
227            .cloned()
228            .filter(|col| col.name != "title") // omit non-nullable field
229            .zip(vec![
230                Value::Uint32(1.into()),
231                // Missing title
232                Value::Text("Content".to_string().into()),
233                Value::Uint32(1.into()),
234            ])
235            .collect::<Vec<(ColumnDef, Value)>>();
236
237        let validator = InsertIntegrityValidator::<Post>::new(&dbms);
238        let result = validator.check_non_nullable_fields(&values);
239        assert!(matches!(
240            result,
241            Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
242                field_name
243            ))) if field_name == "title"
244        ));
245    }
246
247    #[test]
248    fn test_should_pass_non_nullable_field_check() {
249        load_fixtures();
250        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
251
252        let values = Message::columns()
253            .iter()
254            .filter(|col| !col.nullable)
255            .cloned()
256            .zip(vec![
257                Value::Uint32(100.into()),
258                Value::Text("Hello".to_string().into()),
259                Value::Uint32(1.into()),
260                Value::Uint32(2.into()),
261            ])
262            .collect::<Vec<(ColumnDef, Value)>>();
263
264        let validator = InsertIntegrityValidator::<Message>::new(&dbms);
265        let result = validator.check_non_nullable_fields(&values);
266        assert!(result.is_ok());
267
268        // should pass with nullable set
269
270        let values = Message::columns()
271            .iter()
272            .cloned()
273            .zip(vec![
274                Value::Uint32(100.into()),
275                Value::Text("Hello".to_string().into()),
276                Value::Uint32(1.into()),
277                Value::Uint32(2.into()),
278                Value::DateTime(DateTime {
279                    year: 2024,
280                    month: 6,
281                    day: 1,
282                    hour: 12,
283                    minute: 0,
284                    second: 0,
285                    microsecond: 0,
286                    timezone_offset_minutes: 0,
287                }),
288            ])
289            .collect::<Vec<(ColumnDef, Value)>>();
290
291        let validator = InsertIntegrityValidator::<Message>::new(&dbms);
292        let result = validator.check_non_nullable_fields(&values);
293        assert!(result.is_ok());
294    }
295}