1use wasm_dbms_api::prelude::{
7 ColumnDef, Database as _, DbmsError, DbmsResult, Filter, Query, QueryError, TableSchema, Value,
8};
9use wasm_dbms_memory::prelude::{AccessControl, AccessControlList, MemoryProvider};
10
11use super::common;
12use crate::database::WasmDbmsDatabase;
13
14pub struct InsertIntegrityValidator<'a, T, M, A = AccessControlList>
16where
17 T: TableSchema,
18 M: MemoryProvider,
19 A: AccessControl,
20{
21 database: &'a WasmDbmsDatabase<'a, M, A>,
22 _marker: std::marker::PhantomData<T>,
23}
24
25impl<'a, T, M, A> InsertIntegrityValidator<'a, T, M, A>
26where
27 T: TableSchema,
28 M: MemoryProvider,
29 A: AccessControl,
30{
31 pub fn new(dbms: &'a WasmDbmsDatabase<'a, M, A>) -> Self {
33 Self {
34 database: dbms,
35 _marker: std::marker::PhantomData,
36 }
37 }
38}
39
40impl<T, M, A> InsertIntegrityValidator<'_, T, M, A>
41where
42 T: TableSchema,
43 M: MemoryProvider,
44 A: AccessControl,
45{
46 pub fn validate(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
48 for (col, value) in record_values {
49 common::check_column_validate::<T>(col, value)?;
50 }
51 self.check_primary_key_conflict(record_values)?;
52 self.check_unique_constraints(record_values)?;
53 common::check_foreign_keys::<T>(self.database, record_values)?;
54 common::check_non_nullable_fields::<T>(record_values)?;
55
56 Ok(())
57 }
58
59 fn check_primary_key_conflict(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
61 let pk_name = T::primary_key();
62 let pk = record_values
63 .iter()
64 .find(|(col_def, _)| col_def.name == pk_name)
65 .map(|(_, value)| value.clone())
66 .ok_or(DbmsError::Query(QueryError::MissingNonNullableField(
67 pk_name.to_string(),
68 )))?;
69
70 let query: Query = Query::builder()
71 .field(pk_name)
72 .and_where(Filter::Eq(pk_name.to_string(), pk))
73 .build();
74
75 let res = self.database.select::<T>(query)?;
76 if res.is_empty() {
77 Ok(())
78 } else {
79 Err(DbmsError::Query(QueryError::PrimaryKeyConflict))
80 }
81 }
82
83 fn check_unique_constraints(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
87 for (col_def, value) in record_values.iter().filter(|(col_def, _)| col_def.unique) {
88 let query = Query::builder()
89 .field(T::primary_key())
90 .and_where(Filter::Eq(col_def.name.to_string(), value.clone()))
91 .build();
92
93 if !self.database.select::<T>(query)?.is_empty() {
94 return Err(DbmsError::Query(QueryError::UniqueConstraintViolation {
95 field: col_def.name.to_string(),
96 }));
97 }
98 }
99
100 Ok(())
101 }
102}
103
104#[cfg(test)]
105mod tests {
106
107 use wasm_dbms_api::prelude::{
108 Database as _, InsertRecord as _, TableSchema as _, Text, Uint32, Value,
109 };
110 use wasm_dbms_macros::{DatabaseSchema, Table};
111 use wasm_dbms_memory::prelude::HeapMemoryProvider;
112
113 use crate::prelude::{DbmsContext, WasmDbmsDatabase};
114
115 #[derive(Debug, Table, Clone, PartialEq, Eq)]
116 #[table = "users"]
117 pub struct User {
118 #[primary_key]
119 pub id: Uint32,
120 pub name: Text,
121 }
122
123 #[derive(Debug, Table, Clone, PartialEq, Eq)]
124 #[table = "contracts"]
125 pub struct Contract {
126 #[primary_key]
127 pub id: Uint32,
128 #[unique]
129 pub code: Text,
130 #[foreign_key(entity = "User", table = "users", column = "id")]
131 pub user_id: Uint32,
132 }
133
134 #[derive(DatabaseSchema)]
135 #[tables(User = "users", Contract = "contracts")]
136 pub struct TestSchema;
137
138 fn setup() -> DbmsContext<HeapMemoryProvider> {
139 let ctx = DbmsContext::new(HeapMemoryProvider::default());
140 TestSchema::register_tables(&ctx).unwrap();
141 ctx
142 }
143
144 fn insert_user(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
145 let insert = UserInsertRequest::from_values(&[
146 (User::columns()[0], Value::Uint32(Uint32(id))),
147 (User::columns()[1], Value::Text(Text(name.to_string()))),
148 ])
149 .unwrap();
150 db.insert::<User>(insert).unwrap();
151 }
152
153 fn insert_contract(
154 db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
155 id: u32,
156 code: &str,
157 user_id: u32,
158 ) {
159 let insert = ContractInsertRequest::from_values(&[
160 (Contract::columns()[0], Value::Uint32(Uint32(id))),
161 (Contract::columns()[1], Value::Text(Text(code.to_string()))),
162 (Contract::columns()[2], Value::Uint32(Uint32(user_id))),
163 ])
164 .unwrap();
165 db.insert::<Contract>(insert).unwrap();
166 }
167
168 #[test]
169 fn test_insert_with_unique_field_succeeds() {
170 let ctx = setup();
171 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
172 insert_user(&db, 1, "alice");
173 insert_contract(&db, 1, "CONTRACT-001", 1);
174 insert_contract(&db, 2, "CONTRACT-002", 1);
175 }
176
177 #[test]
178 fn test_insert_with_duplicate_unique_field_fails() {
179 let ctx = setup();
180 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
181 insert_user(&db, 1, "alice");
182 insert_contract(&db, 1, "CONTRACT-001", 1);
183
184 let insert = ContractInsertRequest::from_values(&[
185 (Contract::columns()[0], Value::Uint32(Uint32(2))),
186 (
187 Contract::columns()[1],
188 Value::Text(Text("CONTRACT-001".to_string())),
189 ),
190 (Contract::columns()[2], Value::Uint32(Uint32(1))),
191 ])
192 .unwrap();
193 let result = db.insert::<Contract>(insert);
194 assert!(result.is_err());
195 assert!(matches!(
196 result.unwrap_err(),
197 wasm_dbms_api::prelude::DbmsError::Query(
198 wasm_dbms_api::prelude::QueryError::UniqueConstraintViolation { ref field }
199 ) if field == "code"
200 ),);
201 }
202
203 #[test]
204 fn test_insert_detects_conflict_on_each_unique_field_independently() {
205 let ctx = setup();
206 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
207 insert_user(&db, 1, "alice");
208 insert_contract(&db, 1, "CONTRACT-001", 1);
209 insert_contract(&db, 2, "CONTRACT-002", 1);
210
211 let insert = ContractInsertRequest::from_values(&[
213 (Contract::columns()[0], Value::Uint32(Uint32(3))),
214 (
215 Contract::columns()[1],
216 Value::Text(Text("CONTRACT-002".to_string())),
217 ),
218 (Contract::columns()[2], Value::Uint32(Uint32(1))),
219 ])
220 .unwrap();
221 let result = db.insert::<Contract>(insert);
222 assert!(result.is_err());
223 assert!(matches!(
224 result.unwrap_err(),
225 wasm_dbms_api::prelude::DbmsError::Query(
226 wasm_dbms_api::prelude::QueryError::UniqueConstraintViolation { ref field }
227 ) if field == "code"
228 ),);
229 }
230}