1use std::cmp::Ordering;
7use std::collections::HashSet;
8
9use wasm_dbms_api::prelude::{
10 CandidColumnDef, ColumnDef, DataTypeKind, Database, DbmsError, DbmsResult, DeleteBehavior,
11 Filter, ForeignFetcher, ForeignKeyDef, InsertRecord, OrderDirection, Query, QueryError,
12 TableColumns, TableError, TableRecord, TableSchema, TransactionError, TransactionId,
13 UpdateRecord, Value, ValuesSource,
14};
15use wasm_dbms_memory::prelude::{
16 AccessControl, AccessControlList, MemoryProvider, NextRecord, TableRegistry,
17};
18
19use crate::context::DbmsContext;
20use crate::schema::DatabaseSchema;
21use crate::transaction::journal::{Journal, JournaledWriter};
22use crate::transaction::{DatabaseOverlay, Transaction, TransactionOp};
23
24const DEFAULT_SELECT_CAPACITY: usize = 128;
26
27pub struct WasmDbmsDatabase<'ctx, M, A = AccessControlList>
33where
34 M: MemoryProvider,
35 A: AccessControl,
36{
37 ctx: &'ctx DbmsContext<M, A>,
39 schema: Box<dyn DatabaseSchema<M, A> + 'ctx>,
41 transaction: Option<TransactionId>,
43}
44
45impl<'ctx, M, A> WasmDbmsDatabase<'ctx, M, A>
46where
47 M: MemoryProvider,
48 A: AccessControl,
49{
50 pub fn oneshot(ctx: &'ctx DbmsContext<M, A>, schema: impl DatabaseSchema<M, A> + 'ctx) -> Self {
52 Self {
53 ctx,
54 schema: Box::new(schema),
55 transaction: None,
56 }
57 }
58
59 pub fn from_transaction(
61 ctx: &'ctx DbmsContext<M, A>,
62 schema: impl DatabaseSchema<M, A> + 'ctx,
63 transaction_id: TransactionId,
64 ) -> Self {
65 Self {
66 ctx,
67 schema: Box::new(schema),
68 transaction: Some(transaction_id),
69 }
70 }
71
72 fn with_transaction_mut<F, R>(&self, f: F) -> DbmsResult<R>
74 where
75 F: FnOnce(&mut Transaction) -> DbmsResult<R>,
76 {
77 let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
78 TransactionError::NoActiveTransaction,
79 ))?;
80
81 let mut ts = self.ctx.transaction_session.borrow_mut();
82 let tx = ts.get_transaction_mut(txid)?;
83 f(tx)
84 }
85
86 fn with_transaction<F, R>(&self, f: F) -> DbmsResult<R>
88 where
89 F: FnOnce(&Transaction) -> DbmsResult<R>,
90 {
91 let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
92 TransactionError::NoActiveTransaction,
93 ))?;
94
95 let ts = self.ctx.transaction_session.borrow();
96 let tx = ts.get_transaction(txid)?;
97 f(tx)
98 }
99
100 fn atomic<F, R>(&self, f: F) -> DbmsResult<R>
114 where
115 F: FnOnce(&WasmDbmsDatabase<'ctx, M, A>) -> DbmsResult<R>,
116 {
117 let nested = self.ctx.journal.borrow().is_some();
118 if !nested {
119 *self.ctx.journal.borrow_mut() = Some(Journal::new());
120 }
121 match f(self) {
122 Ok(res) => {
123 if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
124 journal.commit();
125 }
126 Ok(res)
127 }
128 Err(err) => {
129 if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
130 journal
131 .rollback(&mut self.ctx.mm.borrow_mut())
132 .expect("critical: failed to rollback journal");
133 }
134 Err(err)
135 }
136 }
137 }
138
139 fn has_foreign_key_references<T>(
143 &self,
144 record_values: &[(ColumnDef, Value)],
145 ) -> DbmsResult<bool>
146 where
147 T: TableSchema,
148 {
149 let pk = Self::extract_pk::<T>(record_values)?;
150
151 for (table, columns) in self.schema.referenced_tables(T::table_name()) {
152 for column in columns.iter() {
153 let filter = Filter::eq(column, pk.clone());
154 let query = Query::builder().field(column).filter(Some(filter)).build();
155 let rows = self.schema.select(self, table, query)?;
156 if !rows.is_empty() {
157 return Ok(true);
158 }
159 }
160 }
161 Ok(false)
162 }
163
164 fn delete_foreign_keys_cascade<T>(
166 &self,
167 record_values: &[(ColumnDef, Value)],
168 ) -> DbmsResult<u64>
169 where
170 T: TableSchema,
171 {
172 let pk = Self::extract_pk::<T>(record_values)?;
173
174 let mut count = 0;
175 for (table, columns) in self.schema.referenced_tables(T::table_name()) {
176 for column in columns.iter() {
177 let filter = Filter::eq(column, pk.clone());
178 let res = self
179 .schema
180 .delete(self, table, DeleteBehavior::Cascade, Some(filter))?;
181 count += res;
182 }
183 }
184 Ok(count)
185 }
186
187 fn extract_pk<T>(record_values: &[(ColumnDef, Value)]) -> DbmsResult<Value>
189 where
190 T: TableSchema,
191 {
192 record_values
193 .iter()
194 .find(|(col_def, _)| col_def.primary_key)
195 .ok_or(DbmsError::Query(QueryError::UnknownColumn(
196 T::primary_key().to_string(),
197 )))
198 .map(|(_, v)| v.clone())
199 }
200
201 fn overlay(&self) -> DbmsResult<DatabaseOverlay> {
203 self.with_transaction(|tx| Ok(tx.overlay().clone()))
204 }
205
206 fn record_matches_filter(
208 &self,
209 record_values: &[(ColumnDef, Value)],
210 filter: &Filter,
211 ) -> DbmsResult<bool> {
212 filter.matches(record_values).map_err(DbmsError::from)
213 }
214
215 fn apply_column_selection<T>(&self, results: &mut [TableColumns], query: &Query)
217 where
218 T: TableSchema,
219 {
220 if query.all_selected() {
221 return;
222 }
223 let selected_columns = query.columns::<T>();
224 results
225 .iter_mut()
226 .flat_map(|record| record.iter_mut())
227 .filter(|(source, _)| *source == ValuesSource::This)
228 .for_each(|(_, cols)| {
229 cols.retain(|(col_def, _)| selected_columns.contains(&col_def.name.to_string()));
230 });
231 }
232
233 fn batch_load_eager_relations<T>(
235 &self,
236 results: &mut [TableColumns],
237 query: &Query,
238 ) -> DbmsResult<()>
239 where
240 T: TableSchema,
241 {
242 if query.eager_relations.is_empty() {
243 return Ok(());
244 }
245
246 let fetcher = T::foreign_fetcher();
247
248 for relation in &query.eager_relations {
249 let fk_columns = Self::collect_fk_values::<T>(results, relation)?;
250
251 for (local_column, pk_values) in &fk_columns {
252 let batch_map = fetcher.fetch_batch(self, relation, pk_values)?;
253
254 Self::verify_fk_batch(&batch_map, pk_values, relation)?;
255 Self::attach_foreign_data(results, &batch_map, relation, local_column);
256 }
257 }
258
259 Ok(())
260 }
261
262 fn collect_fk_values<T>(
264 results: &[TableColumns],
265 relation: &str,
266 ) -> DbmsResult<Vec<(&'static str, Vec<Value>)>>
267 where
268 T: TableSchema,
269 {
270 let mut fk_columns: Vec<(&'static str, HashSet<Value>)> = vec![];
271
272 for record_columns in results {
273 let Some(cols) = Self::this_columns(record_columns) else {
274 continue;
275 };
276
277 let mut found_fk = false;
278 for (col_def, value) in cols {
279 let Some(fk) = &col_def.foreign_key else {
280 continue;
281 };
282 if *fk.foreign_table != *relation {
283 continue;
284 }
285
286 found_fk = true;
287 match fk_columns.iter_mut().find(|(lc, _)| *lc == fk.local_column) {
288 Some((_, values)) => {
289 values.insert(value.clone());
290 }
291 None => {
292 let mut set = HashSet::new();
293 set.insert(value.clone());
294 fk_columns.push((fk.local_column, set));
295 }
296 }
297 }
298
299 if !found_fk {
300 return Err(DbmsError::Query(QueryError::InvalidQuery(format!(
301 "Cannot load relation '{relation}' for table '{}': no foreign key found",
302 T::table_name()
303 ))));
304 }
305 }
306
307 Ok(fk_columns
308 .into_iter()
309 .map(|(col, set)| (col, set.into_iter().collect()))
310 .collect())
311 }
312
313 fn verify_fk_batch(
315 batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
316 pk_values: &[Value],
317 relation: &str,
318 ) -> DbmsResult<()> {
319 if let Some(missing) = pk_values.iter().find(|v| !batch_map.contains_key(v)) {
320 return Err(DbmsError::Query(QueryError::BrokenForeignKeyReference {
321 table: relation.to_string(),
322 key: missing.clone(),
323 }));
324 }
325 Ok(())
326 }
327
328 fn attach_foreign_data(
330 results: &mut [TableColumns],
331 batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
332 relation: &str,
333 local_column: &str,
334 ) {
335 for record_columns in results.iter_mut() {
336 let fk_value = Self::this_columns(record_columns).and_then(|cols| {
337 cols.iter().find_map(|(col_def, value)| {
338 let fk = col_def.foreign_key.as_ref()?;
339 (fk.foreign_table == relation && fk.local_column == local_column)
340 .then(|| value.clone())
341 })
342 });
343
344 let Some(fk_val) = fk_value else { continue };
345 let Some(foreign_values) = batch_map.get(&fk_val) else {
346 continue;
347 };
348
349 record_columns.push((
350 ValuesSource::Foreign {
351 table: relation.to_string(),
352 column: local_column.to_string(),
353 },
354 foreign_values.clone(),
355 ));
356 }
357 }
358
359 fn this_columns(
361 record: &[(ValuesSource, Vec<(ColumnDef, Value)>)],
362 ) -> Option<&Vec<(ColumnDef, Value)>> {
363 record
364 .iter()
365 .find(|(src, _)| *src == ValuesSource::This)
366 .map(|(_, cols)| cols)
367 }
368
369 fn existing_primary_keys_for_filter<T>(&self, filter: Option<Filter>) -> DbmsResult<Vec<Value>>
371 where
372 T: TableSchema,
373 {
374 let pk = T::primary_key();
375 let query = Query::builder().field(pk).filter(filter).build();
376 let fields = self.select::<T>(query)?;
377 let pks = fields
378 .into_iter()
379 .map(|record| {
380 record
381 .to_values()
382 .into_iter()
383 .find(|(col_def, _value)| col_def.name == pk)
384 .expect("primary key not found")
385 .1
386 })
387 .collect::<Vec<Value>>();
388
389 Ok(pks)
390 }
391
392 fn load_table_registry<T>(&self) -> DbmsResult<TableRegistry>
394 where
395 T: TableSchema,
396 {
397 let sr = self.ctx.schema_registry.borrow();
398 let registry_pages = sr
399 .table_registry_page::<T>()
400 .ok_or(DbmsError::Table(TableError::TableNotFound))?;
401
402 let mm = self.ctx.mm.borrow();
403 TableRegistry::load(registry_pages, &*mm).map_err(DbmsError::from)
404 }
405
406 fn sort_query_results(
408 &self,
409 results: &mut [TableColumns],
410 column: &str,
411 direction: OrderDirection,
412 ) {
413 results.sort_by(|a, b| {
414 fn get_value<'a>(
415 values: &'a [(ValuesSource, Vec<(ColumnDef, Value)>)],
416 column: &str,
417 ) -> Option<&'a Value> {
418 values
419 .iter()
420 .find(|(source, _)| *source == ValuesSource::This)
421 .and_then(|(_, cols)| {
422 cols.iter()
423 .find(|(col_def, _)| col_def.name == column)
424 .map(|(_, value)| value)
425 })
426 }
427
428 let a_value = get_value(a, column);
429 let b_value = get_value(b, column);
430
431 sort_values_with_direction(a_value, b_value, direction)
432 });
433 }
434
435 #[doc(hidden)]
437 pub fn select_columns<T>(&self, query: Query) -> DbmsResult<Vec<TableColumns>>
438 where
439 T: TableSchema,
440 {
441 let table_registry = self.load_table_registry::<T>()?;
442 let mut table_overlay = if self.transaction.is_some() {
443 self.overlay()?
444 } else {
445 DatabaseOverlay::default()
446 };
447
448 let mut results = Vec::with_capacity(query.limit.unwrap_or(DEFAULT_SELECT_CAPACITY));
449 let mut count = 0;
450
451 {
452 let mm = self.ctx.mm.borrow();
453 let table_reader = table_registry.read::<T, _>(&*mm);
454 let mut table_reader = table_overlay.reader(table_reader);
455
456 while let Some(values) = table_reader.try_next()? {
457 if let Some(filter) = &query.filter
458 && !self.record_matches_filter(&values, filter)?
459 {
460 continue;
461 }
462 count += 1;
463 if query.offset.is_some_and(|offset| count <= offset) {
464 continue;
465 }
466 results.push(vec![(ValuesSource::This, values)]);
467 if query.limit.is_some_and(|limit| results.len() >= limit) {
468 break;
469 }
470 }
471 }
472
473 self.batch_load_eager_relations::<T>(&mut results, &query)?;
474 self.apply_column_selection::<T>(&mut results, &query);
475
476 for (column, direction) in query.order_by.into_iter().rev() {
477 self.sort_query_results(&mut results, &column, direction);
478 }
479
480 Ok(results)
481 }
482
483 #[doc(hidden)]
485 pub fn select_join(
486 &self,
487 table: &str,
488 query: Query,
489 ) -> DbmsResult<Vec<Vec<(CandidColumnDef, Value)>>> {
490 self.schema.select_join(self, table, query)
491 }
492
493 fn update_pk_referencing_updated_table<T>(
495 &self,
496 old_pk: Value,
497 new_pk: Value,
498 data_type: DataTypeKind,
499 pk_name: &'static str,
500 ) -> DbmsResult<u64>
501 where
502 T: TableSchema,
503 {
504 let mut count = 0;
505 for (ref_table, ref_col) in self
506 .schema
507 .referenced_tables(T::table_name())
508 .into_iter()
509 .flat_map(|(ref_table, ref_cols)| {
510 ref_cols
511 .into_iter()
512 .map(move |ref_col| (ref_table, ref_col))
513 })
514 {
515 let ref_patch_value = (
516 ColumnDef {
517 name: ref_col,
518 data_type,
519 nullable: false,
520 primary_key: false,
521 foreign_key: Some(ForeignKeyDef {
522 foreign_table: T::table_name(),
523 foreign_column: pk_name,
524 local_column: ref_col,
525 }),
526 },
527 new_pk.clone(),
528 );
529 let filter = Filter::eq(ref_col, old_pk.clone());
530
531 count += self
532 .schema
533 .update(self, ref_table, &[ref_patch_value], Some(filter))?;
534 }
535
536 Ok(count)
537 }
538
539 fn sanitize_values<T>(
541 &self,
542 values: Vec<(ColumnDef, Value)>,
543 ) -> DbmsResult<Vec<(ColumnDef, Value)>>
544 where
545 T: TableSchema,
546 {
547 let mut sanitized_values = Vec::with_capacity(values.len());
548 for (col_def, value) in values.into_iter() {
549 let value = match T::sanitizer(col_def.name) {
550 Some(sanitizer) => sanitizer.sanitize(value)?,
551 None => value,
552 };
553 sanitized_values.push((col_def, value));
554 }
555 Ok(sanitized_values)
556 }
557
558 #[allow(clippy::type_complexity)]
560 fn collect_matching_records<T>(
561 &self,
562 table_registry: &TableRegistry,
563 filter: &Option<Filter>,
564 ) -> DbmsResult<Vec<(NextRecord<T>, Vec<(ColumnDef, Value)>)>>
565 where
566 T: TableSchema,
567 {
568 let mm = self.ctx.mm.borrow();
569 let mut table_reader = table_registry.read::<T, _>(&*mm);
570 let mut records = vec![];
571 while let Some(values) = table_reader.try_next()? {
572 let record_values = values.record.clone().to_values();
573 if let Some(filter) = filter
574 && !self.record_matches_filter(&record_values, filter)?
575 {
576 continue;
577 }
578 records.push((values, record_values));
579 }
580 Ok(records)
581 }
582}
583
584pub fn sort_values_with_direction(
586 a: Option<&Value>,
587 b: Option<&Value>,
588 direction: OrderDirection,
589) -> Ordering {
590 match (a, b) {
591 (Some(a_val), Some(b_val)) => match direction {
592 OrderDirection::Ascending => a_val.cmp(b_val),
593 OrderDirection::Descending => b_val.cmp(a_val),
594 },
595 (Some(_), None) => std::cmp::Ordering::Greater,
596 (None, Some(_)) => std::cmp::Ordering::Less,
597 (None, None) => std::cmp::Ordering::Equal,
598 }
599}
600
601fn values_to_schema_entity<T>(values: Vec<(ColumnDef, Value)>) -> DbmsResult<T>
603where
604 T: TableSchema,
605{
606 let record = T::Insert::from_values(&values)?.into_record();
607 Ok(record)
608}
609
610impl<M, A> Database for WasmDbmsDatabase<'_, M, A>
611where
612 M: MemoryProvider,
613 A: AccessControl,
614{
615 fn select<T>(&self, query: Query) -> DbmsResult<Vec<T::Record>>
616 where
617 T: TableSchema,
618 {
619 if !query.joins.is_empty() {
620 return Err(DbmsError::Query(QueryError::JoinInsideTypedSelect));
621 }
622 let results = self.select_columns::<T>(query)?;
623 Ok(results.into_iter().map(T::Record::from_values).collect())
624 }
625
626 fn select_raw(&self, table: &str, query: Query) -> DbmsResult<Vec<Vec<(ColumnDef, Value)>>> {
627 self.schema.select(self, table, query)
628 }
629
630 fn insert<T>(&self, record: T::Insert) -> DbmsResult<()>
631 where
632 T: TableSchema,
633 T::Insert: InsertRecord<Schema = T>,
634 {
635 let record_values = record.clone().into_values();
636 let sanitized_values = self.sanitize_values::<T>(record_values)?;
637 self.schema
638 .validate_insert(self, T::table_name(), &sanitized_values)?;
639 if self.transaction.is_some() {
640 self.with_transaction_mut(|tx| tx.insert::<T>(sanitized_values))?;
641 } else {
642 self.atomic(|db| {
643 let mut table_registry = db.load_table_registry::<T>()?;
644 let record = T::Insert::from_values(&sanitized_values)?;
645 let mut mm = db.ctx.mm.borrow_mut();
646 let mut journal_ref = db.ctx.journal.borrow_mut();
647 let journal = journal_ref
648 .as_mut()
649 .expect("journal must be active inside atomic");
650 let mut writer = JournaledWriter::new(&mut *mm, journal);
651 table_registry
652 .insert(record.into_record(), &mut writer)
653 .map_err(DbmsError::from)?;
654 Ok(())
655 })?;
656 }
657
658 Ok(())
659 }
660
661 fn update<T>(&self, patch: T::Update) -> DbmsResult<u64>
662 where
663 T: TableSchema,
664 T::Update: UpdateRecord<Schema = T>,
665 {
666 let filter = patch.where_clause().clone();
667 if self.transaction.is_some() {
668 let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
669 let count = pks.len() as u64;
670 self.with_transaction_mut(|tx| tx.update::<T>(patch, filter, pks))?;
671
672 return Ok(count);
673 }
674
675 let patch = patch.update_values();
676
677 let pk_in_patch = patch.iter().find_map(|(col_def, value)| {
678 if col_def.primary_key {
679 Some((col_def, value))
680 } else {
681 None
682 }
683 });
684
685 self.atomic(|db| {
686 let mut count = 0;
687
688 let mut table_registry = db.load_table_registry::<T>()?;
689 let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
690
691 for (record, record_values) in records {
692 let current_pk_value = record_values
693 .iter()
694 .find(|(col_def, _)| col_def.primary_key)
695 .expect("primary key not found")
696 .1
697 .clone();
698
699 let previous_record = values_to_schema_entity::<T>(record_values.clone())?;
700 let mut record_values = record_values;
701
702 for (patch_col_def, patch_value) in &patch {
703 if let Some((_, record_value)) = record_values
704 .iter_mut()
705 .find(|(record_col_def, _)| record_col_def.name == patch_col_def.name)
706 {
707 *record_value = patch_value.clone();
708 }
709 }
710 let record_values = db.sanitize_values::<T>(record_values)?;
711 db.schema.validate_update(
712 db,
713 T::table_name(),
714 &record_values,
715 current_pk_value.clone(),
716 )?;
717 let updated_record = values_to_schema_entity::<T>(record_values)?;
718 {
719 let mut mm = db.ctx.mm.borrow_mut();
720 let mut journal_ref = db.ctx.journal.borrow_mut();
721 let journal = journal_ref
722 .as_mut()
723 .expect("journal must be active inside atomic");
724 let mut writer = JournaledWriter::new(&mut *mm, journal);
725 table_registry
726 .update(
727 updated_record,
728 previous_record,
729 record.page,
730 record.offset,
731 &mut writer,
732 )
733 .map_err(DbmsError::from)?;
734 }
735 count += 1;
736
737 if let Some((pk_column, new_pk_value)) = pk_in_patch {
738 count += db.update_pk_referencing_updated_table::<T>(
739 current_pk_value,
740 new_pk_value.clone(),
741 pk_column.data_type,
742 pk_column.name,
743 )?;
744 }
745 }
746
747 Ok(count)
748 })
749 }
750
751 fn delete<T>(&self, behaviour: DeleteBehavior, filter: Option<Filter>) -> DbmsResult<u64>
752 where
753 T: TableSchema,
754 {
755 if self.transaction.is_some() {
756 let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
757 let count = pks.len() as u64;
758
759 self.with_transaction_mut(|tx| tx.delete::<T>(behaviour, filter, pks))?;
760
761 return Ok(count);
762 }
763
764 self.atomic(|db| {
765 let mut table_registry = db.load_table_registry::<T>()?;
766 let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
767 let mut count = records.len() as u64;
768 for (record, record_values) in records {
769 match behaviour {
770 DeleteBehavior::Cascade => {
771 count += db.delete_foreign_keys_cascade::<T>(&record_values)?;
772 }
773 DeleteBehavior::Restrict => {
774 if db.has_foreign_key_references::<T>(&record_values)? {
775 return Err(DbmsError::Query(
776 QueryError::ForeignKeyConstraintViolation {
777 referencing_table: T::table_name().to_string(),
778 field: T::primary_key().to_string(),
779 },
780 ));
781 }
782 }
783 }
784 let mut mm = db.ctx.mm.borrow_mut();
785 let mut journal_ref = db.ctx.journal.borrow_mut();
786 let journal = journal_ref
787 .as_mut()
788 .expect("journal must be active inside atomic");
789 let mut writer = JournaledWriter::new(&mut *mm, journal);
790 table_registry
791 .delete(record.record, record.page, record.offset, &mut writer)
792 .map_err(DbmsError::from)?;
793 }
794
795 Ok(count)
796 })
797 }
798
799 fn commit(&mut self) -> DbmsResult<()> {
800 let Some(txid) = self.transaction.take() else {
801 return Err(DbmsError::Transaction(
802 TransactionError::NoActiveTransaction,
803 ));
804 };
805 let transaction = {
806 let mut ts = self.ctx.transaction_session.borrow_mut();
807 ts.take_transaction(&txid)?
808 };
809
810 *self.ctx.journal.borrow_mut() = Some(Journal::new());
811
812 for op in transaction.operations {
813 let result = match op {
814 TransactionOp::Insert { table, values } => self
815 .schema
816 .validate_insert(self, table, &values)
817 .and_then(|()| self.schema.insert(self, table, &values)),
818 TransactionOp::Delete {
819 table,
820 behaviour,
821 filter,
822 } => self
823 .schema
824 .delete(self, table, behaviour, filter)
825 .map(|_| ()),
826 TransactionOp::Update {
827 table,
828 patch,
829 filter,
830 } => self.schema.update(self, table, &patch, filter).map(|_| ()),
831 };
832
833 if let Err(err) = result {
834 if let Some(journal) = self.ctx.journal.borrow_mut().take() {
835 journal
836 .rollback(&mut self.ctx.mm.borrow_mut())
837 .expect("critical: failed to rollback journal");
838 }
839 return Err(err);
840 }
841 }
842
843 if let Some(journal) = self.ctx.journal.borrow_mut().take() {
844 journal.commit();
845 }
846 Ok(())
847 }
848
849 fn rollback(&mut self) -> DbmsResult<()> {
850 let Some(txid) = self.transaction.take() else {
851 return Err(DbmsError::Transaction(
852 TransactionError::NoActiveTransaction,
853 ));
854 };
855
856 let mut ts = self.ctx.transaction_session.borrow_mut();
857 ts.close_transaction(&txid);
858 Ok(())
859 }
860}
861
862#[cfg(test)]
863mod tests {
864
865 use std::cmp::Ordering;
866
867 use wasm_dbms_api::prelude::{
868 Database as _, DeleteBehavior, Filter, InsertRecord as _, OrderDirection, Query,
869 TableSchema as _, Text, Uint32, UpdateRecord as _, Value,
870 };
871 use wasm_dbms_macros::{DatabaseSchema, Table};
872 use wasm_dbms_memory::prelude::HeapMemoryProvider;
873
874 use super::sort_values_with_direction;
875 use crate::prelude::{DbmsContext, WasmDbmsDatabase};
876 use crate::schema::DatabaseSchema as _;
877
878 #[derive(Debug, Table, Clone, PartialEq, Eq)]
879 #[table = "users"]
880 pub struct User {
881 #[primary_key]
882 pub id: Uint32,
883 pub name: Text,
884 }
885
886 #[derive(Debug, Table, Clone, PartialEq, Eq)]
887 #[table = "posts"]
888 pub struct Post {
889 #[primary_key]
890 pub id: Uint32,
891 pub title: Text,
892 #[foreign_key(entity = "User", table = "users", column = "id")]
893 pub user_id: Uint32,
894 }
895
896 #[derive(DatabaseSchema)]
897 #[tables(User = "users", Post = "posts")]
898 pub struct TestSchema;
899
900 fn setup() -> DbmsContext<HeapMemoryProvider> {
901 let ctx = DbmsContext::new(HeapMemoryProvider::default());
902 TestSchema::register_tables(&ctx).unwrap();
903 ctx
904 }
905
906 fn insert_user(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
907 let insert = UserInsertRequest::from_values(&[
908 (User::columns()[0], Value::Uint32(Uint32(id))),
909 (User::columns()[1], Value::Text(Text(name.to_string()))),
910 ])
911 .unwrap();
912 db.insert::<User>(insert).unwrap();
913 }
914
915 fn insert_post(
916 db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
917 id: u32,
918 title: &str,
919 user_id: u32,
920 ) {
921 let insert = PostInsertRequest::from_values(&[
922 (Post::columns()[0], Value::Uint32(Uint32(id))),
923 (Post::columns()[1], Value::Text(Text(title.to_string()))),
924 (Post::columns()[2], Value::Uint32(Uint32(user_id))),
925 ])
926 .unwrap();
927 db.insert::<Post>(insert).unwrap();
928 }
929
930 #[test]
933 fn test_sort_values_ascending() {
934 let a = Value::Uint32(Uint32(1));
935 let b = Value::Uint32(Uint32(2));
936 assert_eq!(
937 sort_values_with_direction(Some(&a), Some(&b), OrderDirection::Ascending),
938 Ordering::Less
939 );
940 }
941
942 #[test]
943 fn test_sort_values_descending() {
944 let a = Value::Uint32(Uint32(1));
945 let b = Value::Uint32(Uint32(2));
946 assert_eq!(
947 sort_values_with_direction(Some(&a), Some(&b), OrderDirection::Descending),
948 Ordering::Greater
949 );
950 }
951
952 #[test]
953 fn test_sort_values_some_none() {
954 let a = Value::Uint32(Uint32(1));
955 assert_eq!(
956 sort_values_with_direction(Some(&a), None, OrderDirection::Ascending),
957 Ordering::Greater
958 );
959 assert_eq!(
960 sort_values_with_direction(None, Some(&a), OrderDirection::Ascending),
961 Ordering::Less
962 );
963 }
964
965 #[test]
966 fn test_sort_values_none_none() {
967 assert_eq!(
968 sort_values_with_direction(None, None, OrderDirection::Ascending),
969 Ordering::Equal
970 );
971 }
972
973 #[test]
976 fn test_select_with_order_by_ascending() {
977 let ctx = setup();
978 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
979 insert_user(&db, 3, "charlie");
980 insert_user(&db, 1, "alice");
981 insert_user(&db, 2, "bob");
982
983 let rows = db
984 .select::<User>(Query::builder().all().order_by_asc("name").build())
985 .unwrap();
986 assert_eq!(rows.len(), 3);
987 assert_eq!(rows[0].name, Some(Text("alice".to_string())));
988 assert_eq!(rows[1].name, Some(Text("bob".to_string())));
989 assert_eq!(rows[2].name, Some(Text("charlie".to_string())));
990 }
991
992 #[test]
993 fn test_select_with_order_by_descending() {
994 let ctx = setup();
995 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
996 insert_user(&db, 1, "alice");
997 insert_user(&db, 2, "bob");
998 insert_user(&db, 3, "charlie");
999
1000 let rows = db
1001 .select::<User>(Query::builder().all().order_by_desc("name").build())
1002 .unwrap();
1003 assert_eq!(rows.len(), 3);
1004 assert_eq!(rows[0].name, Some(Text("charlie".to_string())));
1005 assert_eq!(rows[1].name, Some(Text("bob".to_string())));
1006 assert_eq!(rows[2].name, Some(Text("alice".to_string())));
1007 }
1008
1009 #[test]
1012 fn test_select_with_limit() {
1013 let ctx = setup();
1014 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1015 insert_user(&db, 1, "alice");
1016 insert_user(&db, 2, "bob");
1017 insert_user(&db, 3, "charlie");
1018
1019 let rows = db
1020 .select::<User>(Query::builder().all().limit(2).build())
1021 .unwrap();
1022 assert_eq!(rows.len(), 2);
1023 }
1024
1025 #[test]
1026 fn test_select_with_offset() {
1027 let ctx = setup();
1028 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1029 insert_user(&db, 1, "alice");
1030 insert_user(&db, 2, "bob");
1031 insert_user(&db, 3, "charlie");
1032
1033 let rows = db
1034 .select::<User>(Query::builder().all().offset(1).build())
1035 .unwrap();
1036 assert_eq!(rows.len(), 2);
1037 }
1038
1039 #[test]
1040 fn test_select_with_offset_and_limit() {
1041 let ctx = setup();
1042 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1043 insert_user(&db, 1, "alice");
1044 insert_user(&db, 2, "bob");
1045 insert_user(&db, 3, "charlie");
1046
1047 let rows = db
1048 .select::<User>(Query::builder().all().offset(1).limit(1).build())
1049 .unwrap();
1050 assert_eq!(rows.len(), 1);
1051 }
1052
1053 #[test]
1056 fn test_select_with_filter() {
1057 let ctx = setup();
1058 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1059 insert_user(&db, 1, "alice");
1060 insert_user(&db, 2, "bob");
1061
1062 let rows = db
1063 .select::<User>(
1064 Query::builder()
1065 .all()
1066 .and_where(Filter::eq("name", Value::Text(Text("alice".to_string()))))
1067 .build(),
1068 )
1069 .unwrap();
1070 assert_eq!(rows.len(), 1);
1071 assert_eq!(rows[0].name, Some(Text("alice".to_string())));
1072 }
1073
1074 #[test]
1077 fn test_select_with_column_selection() {
1078 let ctx = setup();
1079 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1080 insert_user(&db, 1, "alice");
1081
1082 let rows = TestSchema
1083 .select(&db, "users", Query::builder().field("name").build())
1084 .unwrap();
1085 assert_eq!(rows.len(), 1);
1086 assert_eq!(rows[0].len(), 1);
1088 assert_eq!(rows[0][0].0.name, "name");
1089 }
1090
1091 #[test]
1094 fn test_update_record() {
1095 let ctx = setup();
1096 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1097 insert_user(&db, 1, "alice");
1098
1099 let patch = UserUpdateRequest::from_values(
1100 &[(User::columns()[1], Value::Text(Text("alicia".to_string())))],
1101 Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1102 );
1103 let count = db.update::<User>(patch).unwrap();
1104 assert_eq!(count, 1);
1105
1106 let rows = db.select::<User>(Query::builder().build()).unwrap();
1107 assert_eq!(rows[0].name, Some(Text("alicia".to_string())));
1108 }
1109
1110 #[test]
1111 fn test_update_no_matching_records() {
1112 let ctx = setup();
1113 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1114 insert_user(&db, 1, "alice");
1115
1116 let patch = UserUpdateRequest::from_values(
1117 &[(User::columns()[1], Value::Text(Text("bob".to_string())))],
1118 Some(Filter::eq("id", Value::Uint32(Uint32(999)))),
1119 );
1120 let count = db.update::<User>(patch).unwrap();
1121 assert_eq!(count, 0);
1122 }
1123
1124 #[test]
1127 fn test_delete_with_filter() {
1128 let ctx = setup();
1129 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1130 insert_user(&db, 1, "alice");
1131 insert_user(&db, 2, "bob");
1132
1133 let count = db
1134 .delete::<User>(
1135 DeleteBehavior::Restrict,
1136 Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1137 )
1138 .unwrap();
1139 assert_eq!(count, 1);
1140
1141 let rows = db.select::<User>(Query::builder().build()).unwrap();
1142 assert_eq!(rows.len(), 1);
1143 assert_eq!(rows[0].id, Some(Uint32(2)));
1144 }
1145
1146 #[test]
1147 fn test_delete_restrict_with_fk_reference_fails() {
1148 let ctx = setup();
1149 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1150 insert_user(&db, 1, "alice");
1151 insert_post(&db, 10, "post1", 1);
1152
1153 let result = db.delete::<User>(DeleteBehavior::Restrict, None);
1154 assert!(result.is_err());
1155 }
1156
1157 #[test]
1158 fn test_delete_cascade_removes_referencing_records() {
1159 let ctx = setup();
1160 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1161 insert_user(&db, 1, "alice");
1162 insert_post(&db, 10, "post1", 1);
1163
1164 let count = db.delete::<User>(DeleteBehavior::Cascade, None).unwrap();
1165 assert_eq!(count, 2);
1167
1168 let users = db.select::<User>(Query::builder().build()).unwrap();
1169 assert!(users.is_empty());
1170 let posts = db.select::<Post>(Query::builder().build()).unwrap();
1171 assert!(posts.is_empty());
1172 }
1173
1174 #[test]
1177 fn test_commit_without_transaction_returns_error() {
1178 let ctx = setup();
1179 let mut db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1180 let result = db.commit();
1181 assert!(result.is_err());
1182 }
1183
1184 #[test]
1187 fn test_transaction_update_and_commit() {
1188 let ctx = setup();
1189 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1190 insert_user(&db, 1, "alice");
1191
1192 let owner = vec![1, 2, 3];
1193 let tx_id = ctx.begin_transaction(owner);
1194 let mut db = WasmDbmsDatabase::from_transaction(&ctx, TestSchema, tx_id);
1195
1196 let patch = UserUpdateRequest::from_values(
1197 &[(User::columns()[1], Value::Text(Text("alicia".to_string())))],
1198 Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1199 );
1200 db.update::<User>(patch).unwrap();
1201 db.commit().unwrap();
1202
1203 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1204 let rows = db.select::<User>(Query::builder().build()).unwrap();
1205 assert_eq!(rows[0].name, Some(Text("alicia".to_string())));
1206 }
1207
1208 #[test]
1211 fn test_transaction_delete_and_commit() {
1212 let ctx = setup();
1213 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1214 insert_user(&db, 1, "alice");
1215 insert_user(&db, 2, "bob");
1216
1217 let owner = vec![1, 2, 3];
1218 let tx_id = ctx.begin_transaction(owner);
1219 let mut db = WasmDbmsDatabase::from_transaction(&ctx, TestSchema, tx_id);
1220
1221 db.delete::<User>(
1222 DeleteBehavior::Restrict,
1223 Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1224 )
1225 .unwrap();
1226 db.commit().unwrap();
1227
1228 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1229 let rows = db.select::<User>(Query::builder().build()).unwrap();
1230 assert_eq!(rows.len(), 1);
1231 assert_eq!(rows[0].id, Some(Uint32(2)));
1232 }
1233
1234 #[test]
1237 fn test_select_raw() {
1238 let ctx = setup();
1239 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1240 insert_user(&db, 1, "alice");
1241
1242 let rows = db.select_raw("users", Query::builder().build()).unwrap();
1243 assert_eq!(rows.len(), 1);
1244 assert_eq!(rows[0][0].1, Value::Uint32(Uint32(1)));
1245 }
1246
1247 #[test]
1250 fn test_typed_select_with_join_returns_error() {
1251 let ctx = setup();
1252 let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1253
1254 let query = Query::builder()
1255 .all()
1256 .inner_join("posts", "id", "user_id")
1257 .build();
1258 let result = db.select::<User>(query);
1259 assert!(result.is_err());
1260 }
1261}