ic_dbms_canister/dbms/integrity/
insert.rs1use ic_dbms_api::prelude::{
2 ColumnDef, Database as _, Filter, ForeignFetcher, ForeignKeyDef, IcDbmsError, IcDbmsResult,
3 Query, QueryError, TableSchema, Value,
4};
5
6use crate::dbms::IcDbmsDatabase;
7
8pub struct InsertIntegrityValidator<'a, T>
10where
11 T: TableSchema,
12{
13 database: &'a IcDbmsDatabase,
14 _marker: std::marker::PhantomData<T>,
15}
16
17impl<'a, T> InsertIntegrityValidator<'a, T>
18where
19 T: TableSchema,
20{
21 pub fn new(dbms: &'a IcDbmsDatabase) -> Self {
23 Self {
24 database: dbms,
25 _marker: std::marker::PhantomData,
26 }
27 }
28}
29
30impl<T> InsertIntegrityValidator<'_, T>
31where
32 T: TableSchema,
33{
34 pub fn validate(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
41 for (col, value) in record_values {
43 self.check_column_validate(col, value)?;
44 }
45 self.check_primary_key_conflict(record_values)?;
46 self.check_foreign_keys(record_values)?;
47 self.check_non_nullable_fields(record_values)?;
48
49 Ok(())
50 }
51
52 fn check_column_validate(&self, column: &ColumnDef, value: &Value) -> IcDbmsResult<()> {
54 let Some(validator) = T::validator(column.name) else {
55 return Ok(());
56 };
57
58 validator.validate(value)
59 }
60
61 fn check_primary_key_conflict(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
63 let pk_name = T::primary_key();
64 let pk = record_values
65 .iter()
66 .find(|(col_def, _)| col_def.name == pk_name)
67 .map(|(_, value)| value.clone())
68 .ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
69 pk_name.to_string(),
70 )))?;
71
72 let query: Query<T> = Query::builder()
74 .field(pk_name)
75 .and_where(Filter::Eq(pk_name.to_string(), pk))
76 .build();
77
78 let res = self.database.select(query)?;
79 if res.is_empty() {
80 Ok(())
81 } else {
82 Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
83 }
84 }
85
86 fn check_foreign_keys(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
88 record_values
89 .iter()
90 .filter_map(|(col, value)| col.foreign_key.as_ref().map(|fk| (fk, value)))
91 .try_for_each(|(col, value)| self.check_foreign_key_existence(col, value))
92 }
93
94 fn check_foreign_key_existence(
96 &self,
97 foreign_key: &ForeignKeyDef,
98 value: &Value,
99 ) -> IcDbmsResult<()> {
100 let res = T::foreign_fetcher().fetch(
101 self.database,
102 foreign_key.foreign_table,
103 foreign_key.local_column,
104 value.clone(),
105 )?;
106 if res.is_empty() {
107 Err(IcDbmsError::Query(
108 QueryError::ForeignKeyConstraintViolation {
109 field: foreign_key.local_column.to_string(),
110 referencing_table: foreign_key.foreign_table.to_string(),
111 },
112 ))
113 } else {
114 Ok(())
115 }
116 }
117
118 fn check_non_nullable_fields(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
120 for column in T::columns().iter().filter(|col| !col.nullable) {
121 if !record_values
122 .iter()
123 .any(|(col_def, _)| col_def.name == column.name)
124 {
125 return Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
126 column.name.to_string(),
127 )));
128 }
129 }
130
131 Ok(())
132 }
133}
134
135#[cfg(test)]
136mod tests {
137
138 use ic_dbms_api::prelude::DateTime;
139
140 use super::*;
141 use crate::tests::{Message, Post, TestDatabaseSchema, User, load_fixtures};
142
143 #[test]
144 fn test_should_not_pass_email_validation() {
145 load_fixtures();
146 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
147
148 let values = User::columns()
149 .iter()
150 .cloned()
151 .zip(vec![
152 Value::Uint32(10.into()),
153 Value::Text("Bob".to_string().into()),
154 Value::Text("invalid-email".to_string().into()),
155 ])
156 .collect::<Vec<(ColumnDef, Value)>>();
157
158 let validator = InsertIntegrityValidator::<User>::new(&dbms);
159 let result = validator.validate(&values);
160 assert!(matches!(result, Err(IcDbmsError::Validation(_))));
161 }
162
163 #[test]
164 fn test_should_not_pass_check_for_pk_conflict() {
165 load_fixtures();
166 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
167
168 let values = User::columns()
169 .iter()
170 .cloned()
171 .zip(vec![
172 Value::Uint32(1.into()),
173 Value::Text("Alice".to_string().into()),
174 Value::Text("alice@example.com".into()),
175 ])
176 .collect::<Vec<(ColumnDef, Value)>>();
177
178 let validator = InsertIntegrityValidator::<User>::new(&dbms);
179 let result = validator.validate(&values);
180 assert!(matches!(
181 result,
182 Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
183 ));
184 }
185 #[test]
186 fn test_should_pass_check_for_pk_conflict() {
187 load_fixtures();
188 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
189
190 let values = User::columns()
192 .iter()
193 .cloned()
194 .zip(vec![
195 Value::Uint32(1000.into()),
196 Value::Text("Alice".to_string().into()),
197 Value::Text("alice@example.com".into()),
198 ])
199 .collect::<Vec<(ColumnDef, Value)>>();
200
201 let validator = InsertIntegrityValidator::<User>::new(&dbms);
202 let result = validator.validate(&values);
203 assert!(result.is_ok());
204 }
205
206 #[test]
207 fn test_should_not_pass_check_for_fk_conflict() {
208 load_fixtures();
209 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
210
211 let values = Post::columns()
212 .iter()
213 .cloned()
214 .zip(vec![
215 Value::Uint32(1.into()),
216 Value::Text("Title".to_string().into()),
217 Value::Text("Content".to_string().into()),
218 Value::Uint32(9999.into()), ])
220 .collect::<Vec<(ColumnDef, Value)>>();
221
222 let validator = InsertIntegrityValidator::<Post>::new(&dbms);
223 let result = validator.check_foreign_keys(&values);
224 println!("{:?}", result);
225 assert!(matches!(
226 result,
227 Err(IcDbmsError::Query(QueryError::BrokenForeignKeyReference {
228 table,
229 ..
230 })) if table == User::table_name()
231 ));
232 }
233
234 #[test]
235 fn test_should_pass_check_for_fk_conflict() {
236 load_fixtures();
237 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
238
239 let values = Post::columns()
240 .iter()
241 .cloned()
242 .zip(vec![
243 Value::Uint32(1.into()),
244 Value::Text("Title".to_string().into()),
245 Value::Text("Content".to_string().into()),
246 Value::Uint32(1.into()), ])
248 .collect::<Vec<(ColumnDef, Value)>>();
249
250 let validator = InsertIntegrityValidator::<Post>::new(&dbms);
251 let result = validator.check_foreign_keys(&values);
252 assert!(result.is_ok());
253 }
254
255 #[test]
256 fn test_should_not_pass_non_nullable_field_check() {
257 load_fixtures();
258 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
259
260 let values = Post::columns()
261 .iter()
262 .cloned()
263 .filter(|col| col.name != "title") .zip(vec![
265 Value::Uint32(1.into()),
266 Value::Text("Content".to_string().into()),
268 Value::Uint32(1.into()),
269 ])
270 .collect::<Vec<(ColumnDef, Value)>>();
271
272 let validator = InsertIntegrityValidator::<Post>::new(&dbms);
273 let result = validator.check_non_nullable_fields(&values);
274 assert!(matches!(
275 result,
276 Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
277 field_name
278 ))) if field_name == "title"
279 ));
280 }
281
282 #[test]
283 fn test_should_pass_non_nullable_field_check() {
284 load_fixtures();
285 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
286
287 let values = Message::columns()
288 .iter()
289 .filter(|col| !col.nullable)
290 .cloned()
291 .zip(vec![
292 Value::Uint32(100.into()),
293 Value::Text("Hello".to_string().into()),
294 Value::Uint32(1.into()),
295 Value::Uint32(2.into()),
296 ])
297 .collect::<Vec<(ColumnDef, Value)>>();
298
299 let validator = InsertIntegrityValidator::<Message>::new(&dbms);
300 let result = validator.check_non_nullable_fields(&values);
301 assert!(result.is_ok());
302
303 let values = Message::columns()
306 .iter()
307 .cloned()
308 .zip(vec![
309 Value::Uint32(100.into()),
310 Value::Text("Hello".to_string().into()),
311 Value::Uint32(1.into()),
312 Value::Uint32(2.into()),
313 Value::DateTime(DateTime {
314 year: 2024,
315 month: 6,
316 day: 1,
317 hour: 12,
318 minute: 0,
319 second: 0,
320 microsecond: 0,
321 timezone_offset_minutes: 0,
322 }),
323 ])
324 .collect::<Vec<(ColumnDef, Value)>>();
325
326 let validator = InsertIntegrityValidator::<Message>::new(&dbms);
327 let result = validator.check_non_nullable_fields(&values);
328 assert!(result.is_ok());
329 }
330}