1use wasm_dbms_api::prelude::{
7 ColumnDef, Database as _, DbmsError, DbmsResult, Filter, Query, QueryError, TableRecord,
8 TableSchema, Value,
9};
10use wasm_dbms_memory::prelude::{AccessControl, AccessControlList, MemoryProvider};
11
12use super::common;
13use crate::database::WasmDbmsDatabase;
14
15pub struct UpdateIntegrityValidator<'a, T, M, A = AccessControlList>
20where
21 T: TableSchema,
22 M: MemoryProvider,
23 A: AccessControl,
24{
25 database: &'a WasmDbmsDatabase<'a, M, A>,
26 old_pk: Value,
28 _marker: std::marker::PhantomData<T>,
29}
30
31impl<'a, T, M, A> UpdateIntegrityValidator<'a, T, M, A>
32where
33 T: TableSchema,
34 M: MemoryProvider,
35 A: AccessControl,
36{
37 pub fn new(dbms: &'a WasmDbmsDatabase<'a, M, A>, old_pk: Value) -> Self {
39 Self {
40 database: dbms,
41 old_pk,
42 _marker: std::marker::PhantomData,
43 }
44 }
45}
46
47impl<T, M, A> UpdateIntegrityValidator<'_, T, M, A>
48where
49 T: TableSchema,
50 M: MemoryProvider,
51 A: AccessControl,
52{
53 pub fn validate(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
55 for (col, value) in record_values {
56 common::check_column_validate::<T>(col, value)?;
57 }
58 self.check_primary_key_conflict(record_values)?;
59 self.check_unique_constraints(record_values)?;
60 common::check_foreign_keys::<T>(self.database, record_values)?;
61 common::check_non_nullable_fields::<T>(record_values)?;
62
63 Ok(())
64 }
65
66 fn check_primary_key_conflict(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
68 let pk_name = T::primary_key();
69 let new_pk = record_values
70 .iter()
71 .find(|(col_def, _)| col_def.name == pk_name)
72 .map(|(_, value)| value.clone())
73 .ok_or(DbmsError::Query(QueryError::MissingNonNullableField(
74 pk_name.to_string(),
75 )))?;
76
77 let query = Query::builder()
78 .field(pk_name)
79 .and_where(Filter::Eq(pk_name.to_string(), new_pk.clone()))
80 .build();
81
82 let res = self.database.select::<T>(query)?;
83 match res.len() {
84 0 => Ok(()),
85 1 => {
86 if new_pk == self.old_pk {
87 Ok(())
88 } else {
89 Err(DbmsError::Query(QueryError::PrimaryKeyConflict))
90 }
91 }
92 _ => Err(DbmsError::Query(QueryError::PrimaryKeyConflict)),
93 }
94 }
95
96 fn check_unique_constraints(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
101 let pk_name = T::primary_key();
102
103 for (col_def, value) in record_values.iter().filter(|(col_def, _)| col_def.unique) {
104 let query = Query::builder()
105 .field(pk_name)
106 .and_where(Filter::Eq(col_def.name.to_string(), value.clone()))
107 .build();
108
109 let res = self.database.select::<T>(query)?;
110 for record in &res {
111 let record_pk = record
112 .to_values()
113 .into_iter()
114 .find(|(c, _)| c.name == pk_name)
115 .map(|(_, v)| v);
116
117 if record_pk.as_ref() != Some(&self.old_pk) {
118 return Err(DbmsError::Query(QueryError::UniqueConstraintViolation {
119 field: col_def.name.to_string(),
120 }));
121 }
122 }
123 }
124
125 Ok(())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131
132 use wasm_dbms_api::prelude::{
133 Database as _, Filter, InsertRecord as _, TableSchema as _, Text, Uint32,
134 UpdateRecord as _, Value,
135 };
136 use wasm_dbms_macros::{DatabaseSchema, Table};
137 use wasm_dbms_memory::prelude::HeapMemoryProvider;
138
139 use crate::prelude::{DbmsContext, WasmDbmsDatabase};
140
141 #[derive(Debug, Table, Clone, PartialEq, Eq)]
142 #[table = "users"]
143 pub struct User {
144 #[primary_key]
145 pub id: Uint32,
146 pub name: Text,
147 }
148
149 #[derive(Debug, Table, Clone, PartialEq, Eq)]
150 #[table = "contracts"]
151 pub struct Contract {
152 #[primary_key]
153 pub id: Uint32,
154 #[unique]
155 pub code: Text,
156 #[foreign_key(entity = "User", table = "users", column = "id")]
157 pub user_id: Uint32,
158 }
159
160 #[derive(DatabaseSchema)]
161 #[tables(User = "users", Contract = "contracts")]
162 pub struct TestSchema;
163
164 fn setup() -> DbmsContext<HeapMemoryProvider> {
165 let ctx = DbmsContext::new(HeapMemoryProvider::default());
166 TestSchema::register_tables(&ctx).unwrap();
167 ctx
168 }
169
170 fn insert_user(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
171 let insert = UserInsertRequest::from_values(&[
172 (User::columns()[0], Value::Uint32(Uint32(id))),
173 (User::columns()[1], Value::Text(Text(name.to_string()))),
174 ])
175 .unwrap();
176 db.insert::<User>(insert).unwrap();
177 }
178
179 fn insert_contract(
180 db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
181 id: u32,
182 code: &str,
183 user_id: u32,
184 ) {
185 let insert = ContractInsertRequest::from_values(&[
186 (Contract::columns()[0], Value::Uint32(Uint32(id))),
187 (Contract::columns()[1], Value::Text(Text(code.to_string()))),
188 (Contract::columns()[2], Value::Uint32(Uint32(user_id))),
189 ])
190 .unwrap();
191 db.insert::<Contract>(insert).unwrap();
192 }
193
194 #[test]
195 fn test_update_unique_field_to_new_value_succeeds() {
196 let ctx = setup();
197 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
198 insert_user(&db, 1, "alice");
199 insert_contract(&db, 1, "CONTRACT-001", 1);
200
201 let patch = ContractUpdateRequest::from_values(
202 &[(
203 Contract::columns()[1],
204 Value::Text(Text("CONTRACT-999".to_string())),
205 )],
206 Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
207 );
208 assert!(db.update::<Contract>(patch).is_ok());
209 }
210
211 #[test]
212 fn test_update_keeping_same_unique_value_succeeds() {
213 let ctx = setup();
214 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
215 insert_user(&db, 1, "alice");
216 insert_contract(&db, 1, "CONTRACT-001", 1);
217
218 let patch = ContractUpdateRequest::from_values(
220 &[(
221 Contract::columns()[1],
222 Value::Text(Text("CONTRACT-001".to_string())),
223 )],
224 Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
225 );
226 assert!(db.update::<Contract>(patch).is_ok());
227 }
228
229 #[test]
230 fn test_update_unique_field_to_existing_value_fails() {
231 let ctx = setup();
232 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
233 insert_user(&db, 1, "alice");
234 insert_contract(&db, 1, "CONTRACT-001", 1);
235 insert_contract(&db, 2, "CONTRACT-002", 1);
236
237 let patch = ContractUpdateRequest::from_values(
239 &[(
240 Contract::columns()[1],
241 Value::Text(Text("CONTRACT-001".to_string())),
242 )],
243 Some(Filter::eq("id", Value::Uint32(Uint32(2)))),
244 );
245 let result = db.update::<Contract>(patch);
246 assert!(result.is_err());
247 assert!(matches!(
248 result.unwrap_err(),
249 wasm_dbms_api::prelude::DbmsError::Query(
250 wasm_dbms_api::prelude::QueryError::UniqueConstraintViolation { ref field }
251 ) if field == "code"
252 ),);
253 }
254}