ic_dbms_canister/dbms/integrity/
insert.rs1use ic_dbms_api::prelude::{
2 ColumnDef, Database as _, Filter, ForeignFetcher, ForeignKeyDef, IcDbmsError, IcDbmsResult,
3 Query, QueryError, TableSchema, Value,
4};
5
6use crate::dbms::IcDbmsDatabase;
7
8pub struct InsertIntegrityValidator<'a, T>
10where
11 T: TableSchema,
12{
13 database: &'a IcDbmsDatabase,
14 _marker: std::marker::PhantomData<T>,
15}
16
17impl<'a, T> InsertIntegrityValidator<'a, T>
18where
19 T: TableSchema,
20{
21 pub fn new(dbms: &'a IcDbmsDatabase) -> Self {
23 Self {
24 database: dbms,
25 _marker: std::marker::PhantomData,
26 }
27 }
28}
29
30impl<T> InsertIntegrityValidator<'_, T>
31where
32 T: TableSchema,
33{
34 pub fn validate(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
41 self.check_primary_key_conflict(record_values)?;
42 self.check_foreign_keys(record_values)?;
43 self.check_non_nullable_fields(record_values)?;
44
45 Ok(())
46 }
47
48 fn check_primary_key_conflict(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
50 let pk_name = T::primary_key();
51 let pk = record_values
52 .iter()
53 .find(|(col_def, _)| col_def.name == pk_name)
54 .map(|(_, value)| value.clone())
55 .ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
56 pk_name.to_string(),
57 )))?;
58
59 let query: Query<T> = Query::builder()
61 .field(pk_name)
62 .and_where(Filter::Eq(pk_name.to_string(), pk))
63 .build();
64
65 let res = self.database.select(query)?;
66 if res.is_empty() {
67 Ok(())
68 } else {
69 Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
70 }
71 }
72
73 fn check_foreign_keys(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
75 record_values
76 .iter()
77 .filter_map(|(col, value)| col.foreign_key.as_ref().map(|fk| (fk, value)))
78 .try_for_each(|(col, value)| self.check_foreign_key_existence(col, value))
79 }
80
81 fn check_foreign_key_existence(
83 &self,
84 foreign_key: &ForeignKeyDef,
85 value: &Value,
86 ) -> IcDbmsResult<()> {
87 let res = T::foreign_fetcher().fetch(
88 self.database,
89 foreign_key.foreign_table,
90 foreign_key.local_column,
91 value.clone(),
92 )?;
93 if res.is_empty() {
94 Err(IcDbmsError::Query(
95 QueryError::ForeignKeyConstraintViolation {
96 field: foreign_key.local_column.to_string(),
97 referencing_table: foreign_key.foreign_table.to_string(),
98 },
99 ))
100 } else {
101 Ok(())
102 }
103 }
104
105 fn check_non_nullable_fields(&self, record_values: &[(ColumnDef, Value)]) -> IcDbmsResult<()> {
107 for column in T::columns().iter().filter(|col| !col.nullable) {
108 if !record_values
109 .iter()
110 .any(|(col_def, _)| col_def.name == column.name)
111 {
112 return Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
113 column.name.to_string(),
114 )));
115 }
116 }
117
118 Ok(())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124
125 use ic_dbms_api::prelude::DateTime;
126
127 use super::*;
128 use crate::tests::{Message, Post, TestDatabaseSchema, User, load_fixtures};
129
130 #[test]
131 fn test_should_not_pass_check_for_pk_conflict() {
132 load_fixtures();
133 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
134
135 let values = User::columns()
136 .iter()
137 .cloned()
138 .zip(vec![
139 Value::Uint32(1.into()),
140 Value::Text("Alice".to_string().into()),
141 ])
142 .collect::<Vec<(ColumnDef, Value)>>();
143
144 let validator = InsertIntegrityValidator::<User>::new(&dbms);
145 let result = validator.validate(&values);
146 assert!(matches!(
147 result,
148 Err(IcDbmsError::Query(QueryError::PrimaryKeyConflict))
149 ));
150 }
151 #[test]
152 fn test_should_pass_check_for_pk_conflict() {
153 load_fixtures();
154 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
155
156 let values = User::columns()
158 .iter()
159 .cloned()
160 .zip(vec![
161 Value::Uint32(1000.into()),
162 Value::Text("Alice".to_string().into()),
163 ])
164 .collect::<Vec<(ColumnDef, Value)>>();
165
166 let validator = InsertIntegrityValidator::<User>::new(&dbms);
167 let result = validator.validate(&values);
168 assert!(result.is_ok());
169 }
170
171 #[test]
172 fn test_should_not_pass_check_for_fk_conflict() {
173 load_fixtures();
174 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
175
176 let values = Post::columns()
177 .iter()
178 .cloned()
179 .zip(vec![
180 Value::Uint32(1.into()),
181 Value::Text("Title".to_string().into()),
182 Value::Text("Content".to_string().into()),
183 Value::Uint32(9999.into()), ])
185 .collect::<Vec<(ColumnDef, Value)>>();
186
187 let validator = InsertIntegrityValidator::<Post>::new(&dbms);
188 let result = validator.check_foreign_keys(&values);
189 println!("{:?}", result);
190 assert!(matches!(
191 result,
192 Err(IcDbmsError::Query(QueryError::BrokenForeignKeyReference {
193 table,
194 ..
195 })) if table == User::table_name()
196 ));
197 }
198
199 #[test]
200 fn test_should_pass_check_for_fk_conflict() {
201 load_fixtures();
202 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
203
204 let values = Post::columns()
205 .iter()
206 .cloned()
207 .zip(vec![
208 Value::Uint32(1.into()),
209 Value::Text("Title".to_string().into()),
210 Value::Text("Content".to_string().into()),
211 Value::Uint32(1.into()), ])
213 .collect::<Vec<(ColumnDef, Value)>>();
214
215 let validator = InsertIntegrityValidator::<Post>::new(&dbms);
216 let result = validator.check_foreign_keys(&values);
217 assert!(result.is_ok());
218 }
219
220 #[test]
221 fn test_should_not_pass_non_nullable_field_check() {
222 load_fixtures();
223 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
224
225 let values = Post::columns()
226 .iter()
227 .cloned()
228 .filter(|col| col.name != "title") .zip(vec![
230 Value::Uint32(1.into()),
231 Value::Text("Content".to_string().into()),
233 Value::Uint32(1.into()),
234 ])
235 .collect::<Vec<(ColumnDef, Value)>>();
236
237 let validator = InsertIntegrityValidator::<Post>::new(&dbms);
238 let result = validator.check_non_nullable_fields(&values);
239 assert!(matches!(
240 result,
241 Err(IcDbmsError::Query(QueryError::MissingNonNullableField(
242 field_name
243 ))) if field_name == "title"
244 ));
245 }
246
247 #[test]
248 fn test_should_pass_non_nullable_field_check() {
249 load_fixtures();
250 let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
251
252 let values = Message::columns()
253 .iter()
254 .filter(|col| !col.nullable)
255 .cloned()
256 .zip(vec![
257 Value::Uint32(100.into()),
258 Value::Text("Hello".to_string().into()),
259 Value::Uint32(1.into()),
260 Value::Uint32(2.into()),
261 ])
262 .collect::<Vec<(ColumnDef, Value)>>();
263
264 let validator = InsertIntegrityValidator::<Message>::new(&dbms);
265 let result = validator.check_non_nullable_fields(&values);
266 assert!(result.is_ok());
267
268 let values = Message::columns()
271 .iter()
272 .cloned()
273 .zip(vec![
274 Value::Uint32(100.into()),
275 Value::Text("Hello".to_string().into()),
276 Value::Uint32(1.into()),
277 Value::Uint32(2.into()),
278 Value::DateTime(DateTime {
279 year: 2024,
280 month: 6,
281 day: 1,
282 hour: 12,
283 minute: 0,
284 second: 0,
285 microsecond: 0,
286 timezone_offset_minutes: 0,
287 }),
288 ])
289 .collect::<Vec<(ColumnDef, Value)>>();
290
291 let validator = InsertIntegrityValidator::<Message>::new(&dbms);
292 let result = validator.check_non_nullable_fields(&values);
293 assert!(result.is_ok());
294 }
295}