Skip to main content

wasm_dbms/
database.rs

1// Rust guideline compliant 2026-03-01
2// X-WHERE-CLAUSE, M-CANONICAL-DOCS, M-PANIC-ON-BUG
3
4//! Core DBMS database struct providing CRUD and transaction operations.
5
6use std::cmp::Ordering;
7use std::collections::HashSet;
8
9use wasm_dbms_api::prelude::{
10    CandidColumnDef, ColumnDef, DataTypeKind, Database, DbmsError, DbmsResult, DeleteBehavior,
11    Filter, ForeignFetcher, ForeignKeyDef, InsertRecord, OrderDirection, Query, QueryError,
12    TableColumns, TableError, TableRecord, TableSchema, TransactionError, TransactionId,
13    UpdateRecord, Value, ValuesSource,
14};
15use wasm_dbms_memory::prelude::{
16    AccessControl, AccessControlList, MemoryProvider, NextRecord, TableRegistry,
17};
18
19use crate::context::DbmsContext;
20use crate::schema::DatabaseSchema;
21use crate::transaction::journal::{Journal, JournaledWriter};
22use crate::transaction::{DatabaseOverlay, Transaction, TransactionOp};
23
24/// Default capacity for SELECT queries.
25const DEFAULT_SELECT_CAPACITY: usize = 128;
26
27/// The main DBMS database struct, generic over `MemoryProvider` and
28/// `AccessControl`.
29///
30/// This struct borrows from a [`DbmsContext`] and provides all CRUD
31/// operations, transaction management, and query execution.
32pub struct WasmDbmsDatabase<'ctx, M, A = AccessControlList>
33where
34    M: MemoryProvider,
35    A: AccessControl,
36{
37    /// Reference to the DBMS context owning all state.
38    ctx: &'ctx DbmsContext<M, A>,
39    /// Schema for dynamic dispatch of table operations.
40    schema: Box<dyn DatabaseSchema<M, A> + 'ctx>,
41    /// Active transaction ID, if any.
42    transaction: Option<TransactionId>,
43}
44
45impl<'ctx, M, A> WasmDbmsDatabase<'ctx, M, A>
46where
47    M: MemoryProvider,
48    A: AccessControl,
49{
50    /// Creates a one-shot (non-transactional) database instance.
51    pub fn oneshot(ctx: &'ctx DbmsContext<M, A>, schema: impl DatabaseSchema<M, A> + 'ctx) -> Self {
52        Self {
53            ctx,
54            schema: Box::new(schema),
55            transaction: None,
56        }
57    }
58
59    /// Creates a transactional database instance.
60    pub fn from_transaction(
61        ctx: &'ctx DbmsContext<M, A>,
62        schema: impl DatabaseSchema<M, A> + 'ctx,
63        transaction_id: TransactionId,
64    ) -> Self {
65        Self {
66            ctx,
67            schema: Box::new(schema),
68            transaction: Some(transaction_id),
69        }
70    }
71
72    /// Executes a closure with a mutable reference to the current transaction.
73    fn with_transaction_mut<F, R>(&self, f: F) -> DbmsResult<R>
74    where
75        F: FnOnce(&mut Transaction) -> DbmsResult<R>,
76    {
77        let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
78            TransactionError::NoActiveTransaction,
79        ))?;
80
81        let mut ts = self.ctx.transaction_session.borrow_mut();
82        let tx = ts.get_transaction_mut(txid)?;
83        f(tx)
84    }
85
86    /// Executes a closure with a reference to the current transaction.
87    fn with_transaction<F, R>(&self, f: F) -> DbmsResult<R>
88    where
89        F: FnOnce(&Transaction) -> DbmsResult<R>,
90    {
91        let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
92            TransactionError::NoActiveTransaction,
93        ))?;
94
95        let ts = self.ctx.transaction_session.borrow();
96        let tx = ts.get_transaction(txid)?;
97        f(tx)
98    }
99
100    /// Executes a closure atomically using a write-ahead journal.
101    ///
102    /// All writes performed inside `f` are recorded. On success the journal
103    /// is committed (entries discarded). On error the journal is rolled back,
104    /// restoring every modified byte to its pre-call state.
105    ///
106    /// When a journal is already active (e.g., inside [`Database::commit`]),
107    /// this method delegates to the outer journal and does not manage its own.
108    ///
109    /// # Panics
110    ///
111    /// Panics if the rollback itself fails, because a failed rollback leaves
112    /// memory in an irrecoverably corrupt state (M-PANIC-ON-BUG).
113    fn atomic<F, R>(&self, f: F) -> DbmsResult<R>
114    where
115        F: FnOnce(&WasmDbmsDatabase<'ctx, M, A>) -> DbmsResult<R>,
116    {
117        let nested = self.ctx.journal.borrow().is_some();
118        if !nested {
119            *self.ctx.journal.borrow_mut() = Some(Journal::new());
120        }
121        match f(self) {
122            Ok(res) => {
123                if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
124                    journal.commit();
125                }
126                Ok(res)
127            }
128            Err(err) => {
129                if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
130                    journal
131                        .rollback(&mut self.ctx.mm.borrow_mut())
132                        .expect("critical: failed to rollback journal");
133                }
134                Err(err)
135            }
136        }
137    }
138
139    /// Checks whether any foreign key references exist for the given record.
140    ///
141    /// Returns `true` if at least one referencing row exists in any table.
142    fn has_foreign_key_references<T>(
143        &self,
144        record_values: &[(ColumnDef, Value)],
145    ) -> DbmsResult<bool>
146    where
147        T: TableSchema,
148    {
149        let pk = Self::extract_pk::<T>(record_values)?;
150
151        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
152            for column in columns.iter() {
153                let filter = Filter::eq(column, pk.clone());
154                let query = Query::builder().field(column).filter(Some(filter)).build();
155                let rows = self.schema.select(self, table, query)?;
156                if !rows.is_empty() {
157                    return Ok(true);
158                }
159            }
160        }
161        Ok(false)
162    }
163
164    /// Deletes foreign key related records recursively for cascade deletes.
165    fn delete_foreign_keys_cascade<T>(
166        &self,
167        record_values: &[(ColumnDef, Value)],
168    ) -> DbmsResult<u64>
169    where
170        T: TableSchema,
171    {
172        let pk = Self::extract_pk::<T>(record_values)?;
173
174        let mut count = 0;
175        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
176            for column in columns.iter() {
177                let filter = Filter::eq(column, pk.clone());
178                let res = self
179                    .schema
180                    .delete(self, table, DeleteBehavior::Cascade, Some(filter))?;
181                count += res;
182            }
183        }
184        Ok(count)
185    }
186
187    /// Extracts the primary key value from a record's column-value pairs.
188    fn extract_pk<T>(record_values: &[(ColumnDef, Value)]) -> DbmsResult<Value>
189    where
190        T: TableSchema,
191    {
192        record_values
193            .iter()
194            .find(|(col_def, _)| col_def.primary_key)
195            .ok_or(DbmsError::Query(QueryError::UnknownColumn(
196                T::primary_key().to_string(),
197            )))
198            .map(|(_, v)| v.clone())
199    }
200
201    /// Retrieves the current overlay from the active transaction.
202    fn overlay(&self) -> DbmsResult<DatabaseOverlay> {
203        self.with_transaction(|tx| Ok(tx.overlay().clone()))
204    }
205
206    /// Returns whether the record matches the provided filter.
207    fn record_matches_filter(
208        &self,
209        record_values: &[(ColumnDef, Value)],
210        filter: &Filter,
211    ) -> DbmsResult<bool> {
212        filter.matches(record_values).map_err(DbmsError::from)
213    }
214
215    /// Filters record columns down to only the selected fields.
216    fn apply_column_selection<T>(&self, results: &mut [TableColumns], query: &Query)
217    where
218        T: TableSchema,
219    {
220        if query.all_selected() {
221            return;
222        }
223        let selected_columns = query.columns::<T>();
224        results
225            .iter_mut()
226            .flat_map(|record| record.iter_mut())
227            .filter(|(source, _)| *source == ValuesSource::This)
228            .for_each(|(_, cols)| {
229                cols.retain(|(col_def, _)| selected_columns.contains(&col_def.name.to_string()));
230            });
231    }
232
233    /// Batch-fetches eager relations for collected results.
234    fn batch_load_eager_relations<T>(
235        &self,
236        results: &mut [TableColumns],
237        query: &Query,
238    ) -> DbmsResult<()>
239    where
240        T: TableSchema,
241    {
242        if query.eager_relations.is_empty() {
243            return Ok(());
244        }
245
246        let fetcher = T::foreign_fetcher();
247
248        for relation in &query.eager_relations {
249            let fk_columns = Self::collect_fk_values::<T>(results, relation)?;
250
251            for (local_column, pk_values) in &fk_columns {
252                let batch_map = fetcher.fetch_batch(self, relation, pk_values)?;
253
254                Self::verify_fk_batch(&batch_map, pk_values, relation)?;
255                Self::attach_foreign_data(results, &batch_map, relation, local_column);
256            }
257        }
258
259        Ok(())
260    }
261
262    /// Collects distinct FK values across all records for a given relation.
263    fn collect_fk_values<T>(
264        results: &[TableColumns],
265        relation: &str,
266    ) -> DbmsResult<Vec<(&'static str, Vec<Value>)>>
267    where
268        T: TableSchema,
269    {
270        let mut fk_columns: Vec<(&'static str, HashSet<Value>)> = vec![];
271
272        for record_columns in results {
273            let Some(cols) = Self::this_columns(record_columns) else {
274                continue;
275            };
276
277            let mut found_fk = false;
278            for (col_def, value) in cols {
279                let Some(fk) = &col_def.foreign_key else {
280                    continue;
281                };
282                if *fk.foreign_table != *relation {
283                    continue;
284                }
285
286                found_fk = true;
287                match fk_columns.iter_mut().find(|(lc, _)| *lc == fk.local_column) {
288                    Some((_, values)) => {
289                        values.insert(value.clone());
290                    }
291                    None => {
292                        let mut set = HashSet::new();
293                        set.insert(value.clone());
294                        fk_columns.push((fk.local_column, set));
295                    }
296                }
297            }
298
299            if !found_fk {
300                return Err(DbmsError::Query(QueryError::InvalidQuery(format!(
301                    "Cannot load relation '{relation}' for table '{}': no foreign key found",
302                    T::table_name()
303                ))));
304            }
305        }
306
307        Ok(fk_columns
308            .into_iter()
309            .map(|(col, set)| (col, set.into_iter().collect()))
310            .collect())
311    }
312
313    /// Verifies all FK values were found in the batch result.
314    fn verify_fk_batch(
315        batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
316        pk_values: &[Value],
317        relation: &str,
318    ) -> DbmsResult<()> {
319        if let Some(missing) = pk_values.iter().find(|v| !batch_map.contains_key(v)) {
320            return Err(DbmsError::Query(QueryError::BrokenForeignKeyReference {
321                table: relation.to_string(),
322                key: missing.clone(),
323            }));
324        }
325        Ok(())
326    }
327
328    /// Attaches batch-fetched foreign data to each record.
329    fn attach_foreign_data(
330        results: &mut [TableColumns],
331        batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
332        relation: &str,
333        local_column: &str,
334    ) {
335        for record_columns in results.iter_mut() {
336            let fk_value = Self::this_columns(record_columns).and_then(|cols| {
337                cols.iter().find_map(|(col_def, value)| {
338                    let fk = col_def.foreign_key.as_ref()?;
339                    (fk.foreign_table == relation && fk.local_column == local_column)
340                        .then(|| value.clone())
341                })
342            });
343
344            let Some(fk_val) = fk_value else { continue };
345            let Some(foreign_values) = batch_map.get(&fk_val) else {
346                continue;
347            };
348
349            record_columns.push((
350                ValuesSource::Foreign {
351                    table: relation.to_string(),
352                    column: local_column.to_string(),
353                },
354                foreign_values.clone(),
355            ));
356        }
357    }
358
359    /// Extracts the `ValuesSource::This` columns from a record.
360    fn this_columns(
361        record: &[(ValuesSource, Vec<(ColumnDef, Value)>)],
362    ) -> Option<&Vec<(ColumnDef, Value)>> {
363        record
364            .iter()
365            .find(|(src, _)| *src == ValuesSource::This)
366            .map(|(_, cols)| cols)
367    }
368
369    /// Retrieves existing primary keys matching a filter.
370    fn existing_primary_keys_for_filter<T>(&self, filter: Option<Filter>) -> DbmsResult<Vec<Value>>
371    where
372        T: TableSchema,
373    {
374        let pk = T::primary_key();
375        let query = Query::builder().field(pk).filter(filter).build();
376        let fields = self.select::<T>(query)?;
377        let pks = fields
378            .into_iter()
379            .map(|record| {
380                record
381                    .to_values()
382                    .into_iter()
383                    .find(|(col_def, _value)| col_def.name == pk)
384                    .expect("primary key not found")
385                    .1
386            })
387            .collect::<Vec<Value>>();
388
389        Ok(pks)
390    }
391
392    /// Loads the table registry for a given table schema.
393    fn load_table_registry<T>(&self) -> DbmsResult<TableRegistry>
394    where
395        T: TableSchema,
396    {
397        let sr = self.ctx.schema_registry.borrow();
398        let registry_pages = sr
399            .table_registry_page::<T>()
400            .ok_or(DbmsError::Table(TableError::TableNotFound))?;
401
402        let mm = self.ctx.mm.borrow();
403        TableRegistry::load(registry_pages, &*mm).map_err(DbmsError::from)
404    }
405
406    /// Sorts query results by a column.
407    fn sort_query_results(
408        &self,
409        results: &mut [TableColumns],
410        column: &str,
411        direction: OrderDirection,
412    ) {
413        results.sort_by(|a, b| {
414            fn get_value<'a>(
415                values: &'a [(ValuesSource, Vec<(ColumnDef, Value)>)],
416                column: &str,
417            ) -> Option<&'a Value> {
418                values
419                    .iter()
420                    .find(|(source, _)| *source == ValuesSource::This)
421                    .and_then(|(_, cols)| {
422                        cols.iter()
423                            .find(|(col_def, _)| col_def.name == column)
424                            .map(|(_, value)| value)
425                    })
426            }
427
428            let a_value = get_value(a, column);
429            let b_value = get_value(b, column);
430
431            sort_values_with_direction(a_value, b_value, direction)
432        });
433    }
434
435    /// Core select logic returning intermediate `TableColumns`.
436    #[doc(hidden)]
437    pub fn select_columns<T>(&self, query: Query) -> DbmsResult<Vec<TableColumns>>
438    where
439        T: TableSchema,
440    {
441        let table_registry = self.load_table_registry::<T>()?;
442        let mut table_overlay = if self.transaction.is_some() {
443            self.overlay()?
444        } else {
445            DatabaseOverlay::default()
446        };
447
448        let mut results = Vec::with_capacity(query.limit.unwrap_or(DEFAULT_SELECT_CAPACITY));
449        let mut count = 0;
450
451        {
452            let mm = self.ctx.mm.borrow();
453            let table_reader = table_registry.read::<T, _>(&*mm);
454            let mut table_reader = table_overlay.reader(table_reader);
455
456            while let Some(values) = table_reader.try_next()? {
457                if let Some(filter) = &query.filter
458                    && !self.record_matches_filter(&values, filter)?
459                {
460                    continue;
461                }
462                count += 1;
463                if query.offset.is_some_and(|offset| count <= offset) {
464                    continue;
465                }
466                results.push(vec![(ValuesSource::This, values)]);
467                if query.limit.is_some_and(|limit| results.len() >= limit) {
468                    break;
469                }
470            }
471        }
472
473        self.batch_load_eager_relations::<T>(&mut results, &query)?;
474        self.apply_column_selection::<T>(&mut results, &query);
475
476        for (column, direction) in query.order_by.into_iter().rev() {
477            self.sort_query_results(&mut results, &column, direction);
478        }
479
480        Ok(results)
481    }
482
483    /// Executes a join query.
484    #[doc(hidden)]
485    pub fn select_join(
486        &self,
487        table: &str,
488        query: Query,
489    ) -> DbmsResult<Vec<Vec<(CandidColumnDef, Value)>>> {
490        self.schema.select_join(self, table, query)
491    }
492
493    /// Updates primary key references in tables referencing the updated table.
494    fn update_pk_referencing_updated_table<T>(
495        &self,
496        old_pk: Value,
497        new_pk: Value,
498        data_type: DataTypeKind,
499        pk_name: &'static str,
500    ) -> DbmsResult<u64>
501    where
502        T: TableSchema,
503    {
504        let mut count = 0;
505        for (ref_table, ref_col) in self
506            .schema
507            .referenced_tables(T::table_name())
508            .into_iter()
509            .flat_map(|(ref_table, ref_cols)| {
510                ref_cols
511                    .into_iter()
512                    .map(move |ref_col| (ref_table, ref_col))
513            })
514        {
515            let ref_patch_value = (
516                ColumnDef {
517                    name: ref_col,
518                    data_type,
519                    nullable: false,
520                    primary_key: false,
521                    foreign_key: Some(ForeignKeyDef {
522                        foreign_table: T::table_name(),
523                        foreign_column: pk_name,
524                        local_column: ref_col,
525                    }),
526                },
527                new_pk.clone(),
528            );
529            let filter = Filter::eq(ref_col, old_pk.clone());
530
531            count += self
532                .schema
533                .update(self, ref_table, &[ref_patch_value], Some(filter))?;
534        }
535
536        Ok(count)
537    }
538
539    /// Sanitizes values using the table schema's sanitizers.
540    fn sanitize_values<T>(
541        &self,
542        values: Vec<(ColumnDef, Value)>,
543    ) -> DbmsResult<Vec<(ColumnDef, Value)>>
544    where
545        T: TableSchema,
546    {
547        let mut sanitized_values = Vec::with_capacity(values.len());
548        for (col_def, value) in values.into_iter() {
549            let value = match T::sanitizer(col_def.name) {
550                Some(sanitizer) => sanitizer.sanitize(value)?,
551                None => value,
552            };
553            sanitized_values.push((col_def, value));
554        }
555        Ok(sanitized_values)
556    }
557
558    /// Collects all records matching a filter from the table registry.
559    #[allow(clippy::type_complexity)]
560    fn collect_matching_records<T>(
561        &self,
562        table_registry: &TableRegistry,
563        filter: &Option<Filter>,
564    ) -> DbmsResult<Vec<(NextRecord<T>, Vec<(ColumnDef, Value)>)>>
565    where
566        T: TableSchema,
567    {
568        let mm = self.ctx.mm.borrow();
569        let mut table_reader = table_registry.read::<T, _>(&*mm);
570        let mut records = vec![];
571        while let Some(values) = table_reader.try_next()? {
572            let record_values = values.record.clone().to_values();
573            if let Some(filter) = filter
574                && !self.record_matches_filter(&record_values, filter)?
575            {
576                continue;
577            }
578            records.push((values, record_values));
579        }
580        Ok(records)
581    }
582}
583
584/// Provides ordering for two optional values by direction.
585pub fn sort_values_with_direction(
586    a: Option<&Value>,
587    b: Option<&Value>,
588    direction: OrderDirection,
589) -> Ordering {
590    match (a, b) {
591        (Some(a_val), Some(b_val)) => match direction {
592            OrderDirection::Ascending => a_val.cmp(b_val),
593            OrderDirection::Descending => b_val.cmp(a_val),
594        },
595        (Some(_), None) => std::cmp::Ordering::Greater,
596        (None, Some(_)) => std::cmp::Ordering::Less,
597        (None, None) => std::cmp::Ordering::Equal,
598    }
599}
600
601/// Converts column-value pairs to a schema entity.
602fn values_to_schema_entity<T>(values: Vec<(ColumnDef, Value)>) -> DbmsResult<T>
603where
604    T: TableSchema,
605{
606    let record = T::Insert::from_values(&values)?.into_record();
607    Ok(record)
608}
609
610impl<M, A> Database for WasmDbmsDatabase<'_, M, A>
611where
612    M: MemoryProvider,
613    A: AccessControl,
614{
615    fn select<T>(&self, query: Query) -> DbmsResult<Vec<T::Record>>
616    where
617        T: TableSchema,
618    {
619        if !query.joins.is_empty() {
620            return Err(DbmsError::Query(QueryError::JoinInsideTypedSelect));
621        }
622        let results = self.select_columns::<T>(query)?;
623        Ok(results.into_iter().map(T::Record::from_values).collect())
624    }
625
626    fn select_raw(&self, table: &str, query: Query) -> DbmsResult<Vec<Vec<(ColumnDef, Value)>>> {
627        self.schema.select(self, table, query)
628    }
629
630    fn insert<T>(&self, record: T::Insert) -> DbmsResult<()>
631    where
632        T: TableSchema,
633        T::Insert: InsertRecord<Schema = T>,
634    {
635        let record_values = record.clone().into_values();
636        let sanitized_values = self.sanitize_values::<T>(record_values)?;
637        self.schema
638            .validate_insert(self, T::table_name(), &sanitized_values)?;
639        if self.transaction.is_some() {
640            self.with_transaction_mut(|tx| tx.insert::<T>(sanitized_values))?;
641        } else {
642            self.atomic(|db| {
643                let mut table_registry = db.load_table_registry::<T>()?;
644                let record = T::Insert::from_values(&sanitized_values)?;
645                let mut mm = db.ctx.mm.borrow_mut();
646                let mut journal_ref = db.ctx.journal.borrow_mut();
647                let journal = journal_ref
648                    .as_mut()
649                    .expect("journal must be active inside atomic");
650                let mut writer = JournaledWriter::new(&mut *mm, journal);
651                table_registry
652                    .insert(record.into_record(), &mut writer)
653                    .map_err(DbmsError::from)?;
654                Ok(())
655            })?;
656        }
657
658        Ok(())
659    }
660
661    fn update<T>(&self, patch: T::Update) -> DbmsResult<u64>
662    where
663        T: TableSchema,
664        T::Update: UpdateRecord<Schema = T>,
665    {
666        let filter = patch.where_clause().clone();
667        if self.transaction.is_some() {
668            let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
669            let count = pks.len() as u64;
670            self.with_transaction_mut(|tx| tx.update::<T>(patch, filter, pks))?;
671
672            return Ok(count);
673        }
674
675        let patch = patch.update_values();
676
677        let pk_in_patch = patch.iter().find_map(|(col_def, value)| {
678            if col_def.primary_key {
679                Some((col_def, value))
680            } else {
681                None
682            }
683        });
684
685        self.atomic(|db| {
686            let mut count = 0;
687
688            let mut table_registry = db.load_table_registry::<T>()?;
689            let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
690
691            for (record, record_values) in records {
692                let current_pk_value = record_values
693                    .iter()
694                    .find(|(col_def, _)| col_def.primary_key)
695                    .expect("primary key not found")
696                    .1
697                    .clone();
698
699                let previous_record = values_to_schema_entity::<T>(record_values.clone())?;
700                let mut record_values = record_values;
701
702                for (patch_col_def, patch_value) in &patch {
703                    if let Some((_, record_value)) = record_values
704                        .iter_mut()
705                        .find(|(record_col_def, _)| record_col_def.name == patch_col_def.name)
706                    {
707                        *record_value = patch_value.clone();
708                    }
709                }
710                let record_values = db.sanitize_values::<T>(record_values)?;
711                db.schema.validate_update(
712                    db,
713                    T::table_name(),
714                    &record_values,
715                    current_pk_value.clone(),
716                )?;
717                let updated_record = values_to_schema_entity::<T>(record_values)?;
718                {
719                    let mut mm = db.ctx.mm.borrow_mut();
720                    let mut journal_ref = db.ctx.journal.borrow_mut();
721                    let journal = journal_ref
722                        .as_mut()
723                        .expect("journal must be active inside atomic");
724                    let mut writer = JournaledWriter::new(&mut *mm, journal);
725                    table_registry
726                        .update(
727                            updated_record,
728                            previous_record,
729                            record.page,
730                            record.offset,
731                            &mut writer,
732                        )
733                        .map_err(DbmsError::from)?;
734                }
735                count += 1;
736
737                if let Some((pk_column, new_pk_value)) = pk_in_patch {
738                    count += db.update_pk_referencing_updated_table::<T>(
739                        current_pk_value,
740                        new_pk_value.clone(),
741                        pk_column.data_type,
742                        pk_column.name,
743                    )?;
744                }
745            }
746
747            Ok(count)
748        })
749    }
750
751    fn delete<T>(&self, behaviour: DeleteBehavior, filter: Option<Filter>) -> DbmsResult<u64>
752    where
753        T: TableSchema,
754    {
755        if self.transaction.is_some() {
756            let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
757            let count = pks.len() as u64;
758
759            self.with_transaction_mut(|tx| tx.delete::<T>(behaviour, filter, pks))?;
760
761            return Ok(count);
762        }
763
764        self.atomic(|db| {
765            let mut table_registry = db.load_table_registry::<T>()?;
766            let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
767            let mut count = records.len() as u64;
768            for (record, record_values) in records {
769                match behaviour {
770                    DeleteBehavior::Cascade => {
771                        count += db.delete_foreign_keys_cascade::<T>(&record_values)?;
772                    }
773                    DeleteBehavior::Restrict => {
774                        if db.has_foreign_key_references::<T>(&record_values)? {
775                            return Err(DbmsError::Query(
776                                QueryError::ForeignKeyConstraintViolation {
777                                    referencing_table: T::table_name().to_string(),
778                                    field: T::primary_key().to_string(),
779                                },
780                            ));
781                        }
782                    }
783                }
784                let mut mm = db.ctx.mm.borrow_mut();
785                let mut journal_ref = db.ctx.journal.borrow_mut();
786                let journal = journal_ref
787                    .as_mut()
788                    .expect("journal must be active inside atomic");
789                let mut writer = JournaledWriter::new(&mut *mm, journal);
790                table_registry
791                    .delete(record.record, record.page, record.offset, &mut writer)
792                    .map_err(DbmsError::from)?;
793            }
794
795            Ok(count)
796        })
797    }
798
799    fn commit(&mut self) -> DbmsResult<()> {
800        let Some(txid) = self.transaction.take() else {
801            return Err(DbmsError::Transaction(
802                TransactionError::NoActiveTransaction,
803            ));
804        };
805        let transaction = {
806            let mut ts = self.ctx.transaction_session.borrow_mut();
807            ts.take_transaction(&txid)?
808        };
809
810        *self.ctx.journal.borrow_mut() = Some(Journal::new());
811
812        for op in transaction.operations {
813            let result = match op {
814                TransactionOp::Insert { table, values } => self
815                    .schema
816                    .validate_insert(self, table, &values)
817                    .and_then(|()| self.schema.insert(self, table, &values)),
818                TransactionOp::Delete {
819                    table,
820                    behaviour,
821                    filter,
822                } => self
823                    .schema
824                    .delete(self, table, behaviour, filter)
825                    .map(|_| ()),
826                TransactionOp::Update {
827                    table,
828                    patch,
829                    filter,
830                } => self.schema.update(self, table, &patch, filter).map(|_| ()),
831            };
832
833            if let Err(err) = result {
834                if let Some(journal) = self.ctx.journal.borrow_mut().take() {
835                    journal
836                        .rollback(&mut self.ctx.mm.borrow_mut())
837                        .expect("critical: failed to rollback journal");
838                }
839                return Err(err);
840            }
841        }
842
843        if let Some(journal) = self.ctx.journal.borrow_mut().take() {
844            journal.commit();
845        }
846        Ok(())
847    }
848
849    fn rollback(&mut self) -> DbmsResult<()> {
850        let Some(txid) = self.transaction.take() else {
851            return Err(DbmsError::Transaction(
852                TransactionError::NoActiveTransaction,
853            ));
854        };
855
856        let mut ts = self.ctx.transaction_session.borrow_mut();
857        ts.close_transaction(&txid);
858        Ok(())
859    }
860}
861
862#[cfg(test)]
863mod tests {
864
865    use std::cmp::Ordering;
866
867    use wasm_dbms_api::prelude::{
868        Database as _, DeleteBehavior, Filter, InsertRecord as _, OrderDirection, Query,
869        TableSchema as _, Text, Uint32, UpdateRecord as _, Value,
870    };
871    use wasm_dbms_macros::{DatabaseSchema, Table};
872    use wasm_dbms_memory::prelude::HeapMemoryProvider;
873
874    use super::sort_values_with_direction;
875    use crate::prelude::{DbmsContext, WasmDbmsDatabase};
876    use crate::schema::DatabaseSchema as _;
877
878    #[derive(Debug, Table, Clone, PartialEq, Eq)]
879    #[table = "users"]
880    pub struct User {
881        #[primary_key]
882        pub id: Uint32,
883        pub name: Text,
884    }
885
886    #[derive(Debug, Table, Clone, PartialEq, Eq)]
887    #[table = "posts"]
888    pub struct Post {
889        #[primary_key]
890        pub id: Uint32,
891        pub title: Text,
892        #[foreign_key(entity = "User", table = "users", column = "id")]
893        pub user_id: Uint32,
894    }
895
896    #[derive(DatabaseSchema)]
897    #[tables(User = "users", Post = "posts")]
898    pub struct TestSchema;
899
900    fn setup() -> DbmsContext<HeapMemoryProvider> {
901        let ctx = DbmsContext::new(HeapMemoryProvider::default());
902        TestSchema::register_tables(&ctx).unwrap();
903        ctx
904    }
905
906    fn insert_user(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
907        let insert = UserInsertRequest::from_values(&[
908            (User::columns()[0], Value::Uint32(Uint32(id))),
909            (User::columns()[1], Value::Text(Text(name.to_string()))),
910        ])
911        .unwrap();
912        db.insert::<User>(insert).unwrap();
913    }
914
915    fn insert_post(
916        db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
917        id: u32,
918        title: &str,
919        user_id: u32,
920    ) {
921        let insert = PostInsertRequest::from_values(&[
922            (Post::columns()[0], Value::Uint32(Uint32(id))),
923            (Post::columns()[1], Value::Text(Text(title.to_string()))),
924            (Post::columns()[2], Value::Uint32(Uint32(user_id))),
925        ])
926        .unwrap();
927        db.insert::<Post>(insert).unwrap();
928    }
929
930    // -- sort_values_with_direction tests --
931
932    #[test]
933    fn test_sort_values_ascending() {
934        let a = Value::Uint32(Uint32(1));
935        let b = Value::Uint32(Uint32(2));
936        assert_eq!(
937            sort_values_with_direction(Some(&a), Some(&b), OrderDirection::Ascending),
938            Ordering::Less
939        );
940    }
941
942    #[test]
943    fn test_sort_values_descending() {
944        let a = Value::Uint32(Uint32(1));
945        let b = Value::Uint32(Uint32(2));
946        assert_eq!(
947            sort_values_with_direction(Some(&a), Some(&b), OrderDirection::Descending),
948            Ordering::Greater
949        );
950    }
951
952    #[test]
953    fn test_sort_values_some_none() {
954        let a = Value::Uint32(Uint32(1));
955        assert_eq!(
956            sort_values_with_direction(Some(&a), None, OrderDirection::Ascending),
957            Ordering::Greater
958        );
959        assert_eq!(
960            sort_values_with_direction(None, Some(&a), OrderDirection::Ascending),
961            Ordering::Less
962        );
963    }
964
965    #[test]
966    fn test_sort_values_none_none() {
967        assert_eq!(
968            sort_values_with_direction(None, None, OrderDirection::Ascending),
969            Ordering::Equal
970        );
971    }
972
973    // -- select with ordering --
974
975    #[test]
976    fn test_select_with_order_by_ascending() {
977        let ctx = setup();
978        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
979        insert_user(&db, 3, "charlie");
980        insert_user(&db, 1, "alice");
981        insert_user(&db, 2, "bob");
982
983        let rows = db
984            .select::<User>(Query::builder().all().order_by_asc("name").build())
985            .unwrap();
986        assert_eq!(rows.len(), 3);
987        assert_eq!(rows[0].name, Some(Text("alice".to_string())));
988        assert_eq!(rows[1].name, Some(Text("bob".to_string())));
989        assert_eq!(rows[2].name, Some(Text("charlie".to_string())));
990    }
991
992    #[test]
993    fn test_select_with_order_by_descending() {
994        let ctx = setup();
995        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
996        insert_user(&db, 1, "alice");
997        insert_user(&db, 2, "bob");
998        insert_user(&db, 3, "charlie");
999
1000        let rows = db
1001            .select::<User>(Query::builder().all().order_by_desc("name").build())
1002            .unwrap();
1003        assert_eq!(rows.len(), 3);
1004        assert_eq!(rows[0].name, Some(Text("charlie".to_string())));
1005        assert_eq!(rows[1].name, Some(Text("bob".to_string())));
1006        assert_eq!(rows[2].name, Some(Text("alice".to_string())));
1007    }
1008
1009    // -- select with offset and limit --
1010
1011    #[test]
1012    fn test_select_with_limit() {
1013        let ctx = setup();
1014        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1015        insert_user(&db, 1, "alice");
1016        insert_user(&db, 2, "bob");
1017        insert_user(&db, 3, "charlie");
1018
1019        let rows = db
1020            .select::<User>(Query::builder().all().limit(2).build())
1021            .unwrap();
1022        assert_eq!(rows.len(), 2);
1023    }
1024
1025    #[test]
1026    fn test_select_with_offset() {
1027        let ctx = setup();
1028        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1029        insert_user(&db, 1, "alice");
1030        insert_user(&db, 2, "bob");
1031        insert_user(&db, 3, "charlie");
1032
1033        let rows = db
1034            .select::<User>(Query::builder().all().offset(1).build())
1035            .unwrap();
1036        assert_eq!(rows.len(), 2);
1037    }
1038
1039    #[test]
1040    fn test_select_with_offset_and_limit() {
1041        let ctx = setup();
1042        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1043        insert_user(&db, 1, "alice");
1044        insert_user(&db, 2, "bob");
1045        insert_user(&db, 3, "charlie");
1046
1047        let rows = db
1048            .select::<User>(Query::builder().all().offset(1).limit(1).build())
1049            .unwrap();
1050        assert_eq!(rows.len(), 1);
1051    }
1052
1053    // -- select with filter --
1054
1055    #[test]
1056    fn test_select_with_filter() {
1057        let ctx = setup();
1058        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1059        insert_user(&db, 1, "alice");
1060        insert_user(&db, 2, "bob");
1061
1062        let rows = db
1063            .select::<User>(
1064                Query::builder()
1065                    .all()
1066                    .and_where(Filter::eq("name", Value::Text(Text("alice".to_string()))))
1067                    .build(),
1068            )
1069            .unwrap();
1070        assert_eq!(rows.len(), 1);
1071        assert_eq!(rows[0].name, Some(Text("alice".to_string())));
1072    }
1073
1074    // -- select with column selection --
1075
1076    #[test]
1077    fn test_select_with_column_selection() {
1078        let ctx = setup();
1079        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1080        insert_user(&db, 1, "alice");
1081
1082        let rows = TestSchema
1083            .select(&db, "users", Query::builder().field("name").build())
1084            .unwrap();
1085        assert_eq!(rows.len(), 1);
1086        // Should only have the "name" column
1087        assert_eq!(rows[0].len(), 1);
1088        assert_eq!(rows[0][0].0.name, "name");
1089    }
1090
1091    // -- update operations --
1092
1093    #[test]
1094    fn test_update_record() {
1095        let ctx = setup();
1096        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1097        insert_user(&db, 1, "alice");
1098
1099        let patch = UserUpdateRequest::from_values(
1100            &[(User::columns()[1], Value::Text(Text("alicia".to_string())))],
1101            Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1102        );
1103        let count = db.update::<User>(patch).unwrap();
1104        assert_eq!(count, 1);
1105
1106        let rows = db.select::<User>(Query::builder().build()).unwrap();
1107        assert_eq!(rows[0].name, Some(Text("alicia".to_string())));
1108    }
1109
1110    #[test]
1111    fn test_update_no_matching_records() {
1112        let ctx = setup();
1113        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1114        insert_user(&db, 1, "alice");
1115
1116        let patch = UserUpdateRequest::from_values(
1117            &[(User::columns()[1], Value::Text(Text("bob".to_string())))],
1118            Some(Filter::eq("id", Value::Uint32(Uint32(999)))),
1119        );
1120        let count = db.update::<User>(patch).unwrap();
1121        assert_eq!(count, 0);
1122    }
1123
1124    // -- delete operations --
1125
1126    #[test]
1127    fn test_delete_with_filter() {
1128        let ctx = setup();
1129        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1130        insert_user(&db, 1, "alice");
1131        insert_user(&db, 2, "bob");
1132
1133        let count = db
1134            .delete::<User>(
1135                DeleteBehavior::Restrict,
1136                Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1137            )
1138            .unwrap();
1139        assert_eq!(count, 1);
1140
1141        let rows = db.select::<User>(Query::builder().build()).unwrap();
1142        assert_eq!(rows.len(), 1);
1143        assert_eq!(rows[0].id, Some(Uint32(2)));
1144    }
1145
1146    #[test]
1147    fn test_delete_restrict_with_fk_reference_fails() {
1148        let ctx = setup();
1149        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1150        insert_user(&db, 1, "alice");
1151        insert_post(&db, 10, "post1", 1);
1152
1153        let result = db.delete::<User>(DeleteBehavior::Restrict, None);
1154        assert!(result.is_err());
1155    }
1156
1157    #[test]
1158    fn test_delete_cascade_removes_referencing_records() {
1159        let ctx = setup();
1160        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1161        insert_user(&db, 1, "alice");
1162        insert_post(&db, 10, "post1", 1);
1163
1164        let count = db.delete::<User>(DeleteBehavior::Cascade, None).unwrap();
1165        // 1 user + 1 cascaded post
1166        assert_eq!(count, 2);
1167
1168        let users = db.select::<User>(Query::builder().build()).unwrap();
1169        assert!(users.is_empty());
1170        let posts = db.select::<Post>(Query::builder().build()).unwrap();
1171        assert!(posts.is_empty());
1172    }
1173
1174    // -- commit without transaction --
1175
1176    #[test]
1177    fn test_commit_without_transaction_returns_error() {
1178        let ctx = setup();
1179        let mut db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1180        let result = db.commit();
1181        assert!(result.is_err());
1182    }
1183
1184    // -- transaction commit with update --
1185
1186    #[test]
1187    fn test_transaction_update_and_commit() {
1188        let ctx = setup();
1189        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1190        insert_user(&db, 1, "alice");
1191
1192        let owner = vec![1, 2, 3];
1193        let tx_id = ctx.begin_transaction(owner);
1194        let mut db = WasmDbmsDatabase::from_transaction(&ctx, TestSchema, tx_id);
1195
1196        let patch = UserUpdateRequest::from_values(
1197            &[(User::columns()[1], Value::Text(Text("alicia".to_string())))],
1198            Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1199        );
1200        db.update::<User>(patch).unwrap();
1201        db.commit().unwrap();
1202
1203        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1204        let rows = db.select::<User>(Query::builder().build()).unwrap();
1205        assert_eq!(rows[0].name, Some(Text("alicia".to_string())));
1206    }
1207
1208    // -- transaction delete and commit --
1209
1210    #[test]
1211    fn test_transaction_delete_and_commit() {
1212        let ctx = setup();
1213        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1214        insert_user(&db, 1, "alice");
1215        insert_user(&db, 2, "bob");
1216
1217        let owner = vec![1, 2, 3];
1218        let tx_id = ctx.begin_transaction(owner);
1219        let mut db = WasmDbmsDatabase::from_transaction(&ctx, TestSchema, tx_id);
1220
1221        db.delete::<User>(
1222            DeleteBehavior::Restrict,
1223            Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
1224        )
1225        .unwrap();
1226        db.commit().unwrap();
1227
1228        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1229        let rows = db.select::<User>(Query::builder().build()).unwrap();
1230        assert_eq!(rows.len(), 1);
1231        assert_eq!(rows[0].id, Some(Uint32(2)));
1232    }
1233
1234    // -- select_raw --
1235
1236    #[test]
1237    fn test_select_raw() {
1238        let ctx = setup();
1239        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1240        insert_user(&db, 1, "alice");
1241
1242        let rows = db.select_raw("users", Query::builder().build()).unwrap();
1243        assert_eq!(rows.len(), 1);
1244        assert_eq!(rows[0][0].1, Value::Uint32(Uint32(1)));
1245    }
1246
1247    // -- select with join returns error on typed select --
1248
1249    #[test]
1250    fn test_typed_select_with_join_returns_error() {
1251        let ctx = setup();
1252        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
1253
1254        let query = Query::builder()
1255            .all()
1256            .inner_join("posts", "id", "user_id")
1257            .build();
1258        let result = db.select::<User>(query);
1259        assert!(result.is_err());
1260    }
1261}