Skip to main content

wasm_dbms/
database.rs

1// Rust guideline compliant 2026-03-01
2// X-WHERE-CLAUSE, M-CANONICAL-DOCS, M-PANIC-ON-BUG
3
4//! Core DBMS database struct providing CRUD and transaction operations.
5
6mod filter_analyzer;
7mod index_reader;
8
9use std::cmp::Ordering;
10use std::collections::HashSet;
11
12use wasm_dbms_api::prelude::{
13    ColumnDef, DataTypeKind, Database, DbmsError, DbmsResult, DeleteBehavior, Filter,
14    ForeignFetcher, ForeignKeyDef, InsertRecord, JoinColumnDef, OrderDirection, Query, QueryError,
15    TableColumns, TableError, TableRecord, TableSchema, TransactionError, TransactionId,
16    UpdateRecord, Value, ValuesSource,
17};
18use wasm_dbms_memory::RecordAddress;
19use wasm_dbms_memory::prelude::{
20    AccessControl, AccessControlList, MemoryAccess, MemoryProvider, NextRecord, TableRegistry,
21};
22
23use self::filter_analyzer::{IndexPlan, analyze_filter};
24use self::index_reader::{IndexReader, IndexSearchResult};
25use crate::context::DbmsContext;
26use crate::schema::DatabaseSchema;
27use crate::transaction::journal::{Journal, JournaledWriter};
28use crate::transaction::{DatabaseOverlay, Transaction, TransactionOp};
29
30/// Default capacity for SELECT queries.
31const DEFAULT_SELECT_CAPACITY: usize = 128;
32
33/// The main DBMS database struct, generic over `MemoryProvider` and
34/// `AccessControl`.
35///
36/// This struct borrows from a [`DbmsContext`] and provides all CRUD
37/// operations, transaction management, and query execution.
38pub struct WasmDbmsDatabase<'ctx, M, A = AccessControlList>
39where
40    M: MemoryProvider,
41    A: AccessControl,
42{
43    /// Reference to the DBMS context owning all state.
44    ctx: &'ctx DbmsContext<M, A>,
45    /// Schema for dynamic dispatch of table operations.
46    schema: Box<dyn DatabaseSchema<M, A> + 'ctx>,
47    /// Active transaction ID, if any.
48    transaction: Option<TransactionId>,
49}
50
51impl<'ctx, M, A> WasmDbmsDatabase<'ctx, M, A>
52where
53    M: MemoryProvider,
54    A: AccessControl,
55{
56    /// Creates a one-shot (non-transactional) database instance.
57    pub fn oneshot(ctx: &'ctx DbmsContext<M, A>, schema: impl DatabaseSchema<M, A> + 'ctx) -> Self {
58        Self {
59            ctx,
60            schema: Box::new(schema),
61            transaction: None,
62        }
63    }
64
65    /// Creates a transactional database instance.
66    pub fn from_transaction(
67        ctx: &'ctx DbmsContext<M, A>,
68        schema: impl DatabaseSchema<M, A> + 'ctx,
69        transaction_id: TransactionId,
70    ) -> Self {
71        Self {
72            ctx,
73            schema: Box::new(schema),
74            transaction: Some(transaction_id),
75        }
76    }
77
78    /// Executes a closure with a mutable reference to the current transaction.
79    fn with_transaction_mut<F, R>(&self, f: F) -> DbmsResult<R>
80    where
81        F: FnOnce(&mut Transaction) -> DbmsResult<R>,
82    {
83        let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
84            TransactionError::NoActiveTransaction,
85        ))?;
86
87        let mut ts = self.ctx.transaction_session.borrow_mut();
88        let tx = ts.get_transaction_mut(txid)?;
89        f(tx)
90    }
91
92    /// Executes a closure with a reference to the current transaction.
93    fn with_transaction<F, R>(&self, f: F) -> DbmsResult<R>
94    where
95        F: FnOnce(&Transaction) -> DbmsResult<R>,
96    {
97        let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
98            TransactionError::NoActiveTransaction,
99        ))?;
100
101        let ts = self.ctx.transaction_session.borrow();
102        let tx = ts.get_transaction(txid)?;
103        f(tx)
104    }
105
106    /// Executes a closure atomically using a write-ahead journal.
107    ///
108    /// All writes performed inside `f` are recorded. On success the journal
109    /// is committed (entries discarded). On error the journal is rolled back,
110    /// restoring every modified byte to its pre-call state.
111    ///
112    /// When a journal is already active (e.g., inside [`Database::commit`]),
113    /// this method delegates to the outer journal and does not manage its own.
114    ///
115    /// # Panics
116    ///
117    /// Panics if the rollback itself fails, because a failed rollback leaves
118    /// memory in an irrecoverably corrupt state (M-PANIC-ON-BUG).
119    fn atomic<F, R>(&self, f: F) -> DbmsResult<R>
120    where
121        F: FnOnce(&WasmDbmsDatabase<'ctx, M, A>) -> DbmsResult<R>,
122    {
123        let nested = self.ctx.journal.borrow().is_some();
124        if !nested {
125            *self.ctx.journal.borrow_mut() = Some(Journal::new());
126        }
127        match f(self) {
128            Ok(res) => {
129                if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
130                    journal.commit();
131                }
132                Ok(res)
133            }
134            Err(err) => {
135                if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
136                    journal
137                        .rollback(&mut self.ctx.mm.borrow_mut())
138                        .expect("critical: failed to rollback journal");
139                }
140                Err(err)
141            }
142        }
143    }
144
145    /// Checks whether any foreign key references exist for the given record.
146    ///
147    /// Returns `true` if at least one referencing row exists in any table.
148    fn has_foreign_key_references<T>(
149        &self,
150        record_values: &[(ColumnDef, Value)],
151    ) -> DbmsResult<bool>
152    where
153        T: TableSchema,
154    {
155        let pk = Self::extract_pk::<T>(record_values)?;
156
157        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
158            for column in columns.iter() {
159                let filter = Filter::eq(column, pk.clone());
160                let query = Query::builder().field(column).filter(Some(filter)).build();
161                let rows = self.schema.select(self, table, query)?;
162                if !rows.is_empty() {
163                    return Ok(true);
164                }
165            }
166        }
167        Ok(false)
168    }
169
170    /// Deletes foreign key related records recursively for cascade deletes.
171    fn delete_foreign_keys_cascade<T>(
172        &self,
173        record_values: &[(ColumnDef, Value)],
174    ) -> DbmsResult<u64>
175    where
176        T: TableSchema,
177    {
178        let pk = Self::extract_pk::<T>(record_values)?;
179
180        let mut count = 0;
181        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
182            for column in columns.iter() {
183                let filter = Filter::eq(column, pk.clone());
184                let res = self
185                    .schema
186                    .delete(self, table, DeleteBehavior::Cascade, Some(filter))?;
187                count += res;
188            }
189        }
190        Ok(count)
191    }
192
193    /// Extracts the primary key value from a record's column-value pairs.
194    fn extract_pk<T>(record_values: &[(ColumnDef, Value)]) -> DbmsResult<Value>
195    where
196        T: TableSchema,
197    {
198        record_values
199            .iter()
200            .find(|(col_def, _)| col_def.primary_key)
201            .ok_or(DbmsError::Query(QueryError::UnknownColumn(
202                T::primary_key().to_string(),
203            )))
204            .map(|(_, v)| v.clone())
205    }
206
207    /// Retrieves the current overlay from the active transaction.
208    fn overlay(&self) -> DbmsResult<DatabaseOverlay> {
209        self.with_transaction(|tx| Ok(tx.overlay().clone()))
210    }
211
212    /// Returns whether the record matches the provided filter.
213    fn record_matches_filter(
214        &self,
215        record_values: &[(ColumnDef, Value)],
216        filter: &Filter,
217    ) -> DbmsResult<bool> {
218        filter.matches(record_values).map_err(DbmsError::from)
219    }
220
221    /// Filters record columns down to only the selected fields.
222    fn apply_column_selection<T>(&self, results: &mut [TableColumns], query: &Query)
223    where
224        T: TableSchema,
225    {
226        if query.all_selected() {
227            return;
228        }
229        let selected_columns = query.columns::<T>();
230        results
231            .iter_mut()
232            .flat_map(|record| record.iter_mut())
233            .filter(|(source, _)| *source == ValuesSource::This)
234            .for_each(|(_, cols)| {
235                cols.retain(|(col_def, _)| selected_columns.contains(&col_def.name.to_string()));
236            });
237    }
238
239    /// Batch-fetches eager relations for collected results.
240    fn batch_load_eager_relations<T>(
241        &self,
242        results: &mut [TableColumns],
243        query: &Query,
244    ) -> DbmsResult<()>
245    where
246        T: TableSchema,
247    {
248        if query.eager_relations.is_empty() {
249            return Ok(());
250        }
251
252        let fetcher = T::foreign_fetcher();
253
254        for relation in &query.eager_relations {
255            let fk_columns = Self::collect_fk_values::<T>(results, relation)?;
256
257            for (local_column, pk_values) in &fk_columns {
258                let batch_map = fetcher.fetch_batch(self, relation, pk_values)?;
259
260                Self::verify_fk_batch(&batch_map, pk_values, relation)?;
261                Self::attach_foreign_data(results, &batch_map, relation, local_column);
262            }
263        }
264
265        Ok(())
266    }
267
268    /// Collects distinct FK values across all records for a given relation.
269    fn collect_fk_values<T>(
270        results: &[TableColumns],
271        relation: &str,
272    ) -> DbmsResult<Vec<(&'static str, Vec<Value>)>>
273    where
274        T: TableSchema,
275    {
276        let mut fk_columns: Vec<(&'static str, HashSet<Value>)> = vec![];
277
278        for record_columns in results {
279            let Some(cols) = Self::this_columns(record_columns) else {
280                continue;
281            };
282
283            let mut found_fk = false;
284            for (col_def, value) in cols {
285                let Some(fk) = &col_def.foreign_key else {
286                    continue;
287                };
288                if *fk.foreign_table != *relation {
289                    continue;
290                }
291
292                found_fk = true;
293                match fk_columns.iter_mut().find(|(lc, _)| *lc == fk.local_column) {
294                    Some((_, values)) => {
295                        values.insert(value.clone());
296                    }
297                    None => {
298                        let mut set = HashSet::new();
299                        set.insert(value.clone());
300                        fk_columns.push((fk.local_column, set));
301                    }
302                }
303            }
304
305            if !found_fk {
306                return Err(DbmsError::Query(QueryError::InvalidQuery(format!(
307                    "Cannot load relation '{relation}' for table '{}': no foreign key found",
308                    T::table_name()
309                ))));
310            }
311        }
312
313        Ok(fk_columns
314            .into_iter()
315            .map(|(col, set)| (col, set.into_iter().collect()))
316            .collect())
317    }
318
319    /// Verifies all FK values were found in the batch result.
320    fn verify_fk_batch(
321        batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
322        pk_values: &[Value],
323        relation: &str,
324    ) -> DbmsResult<()> {
325        if let Some(missing) = pk_values.iter().find(|v| !batch_map.contains_key(v)) {
326            return Err(DbmsError::Query(QueryError::BrokenForeignKeyReference {
327                table: relation.to_string(),
328                key: missing.clone(),
329            }));
330        }
331        Ok(())
332    }
333
334    /// Attaches batch-fetched foreign data to each record.
335    fn attach_foreign_data(
336        results: &mut [TableColumns],
337        batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
338        relation: &str,
339        local_column: &str,
340    ) {
341        for record_columns in results.iter_mut() {
342            let fk_value = Self::this_columns(record_columns).and_then(|cols| {
343                cols.iter().find_map(|(col_def, value)| {
344                    let fk = col_def.foreign_key.as_ref()?;
345                    (fk.foreign_table == relation && fk.local_column == local_column)
346                        .then(|| value.clone())
347                })
348            });
349
350            let Some(fk_val) = fk_value else { continue };
351            let Some(foreign_values) = batch_map.get(&fk_val) else {
352                continue;
353            };
354
355            record_columns.push((
356                ValuesSource::Foreign {
357                    table: relation.to_string(),
358                    column: local_column.to_string(),
359                },
360                foreign_values.clone(),
361            ));
362        }
363    }
364
365    /// Extracts the `ValuesSource::This` columns from a record.
366    fn this_columns(
367        record: &[(ValuesSource, Vec<(ColumnDef, Value)>)],
368    ) -> Option<&Vec<(ColumnDef, Value)>> {
369        record
370            .iter()
371            .find(|(src, _)| *src == ValuesSource::This)
372            .map(|(_, cols)| cols)
373    }
374
375    /// Retrieves existing rows matching a filter, returning `(primary_key, full_row)` pairs.
376    #[expect(
377        clippy::type_complexity,
378        reason = "complex return type is necessary for returning both PK and full row data"
379    )]
380    fn existing_rows_for_filter<T>(
381        &self,
382        filter: Option<Filter>,
383    ) -> DbmsResult<Vec<(Value, Vec<(ColumnDef, Value)>)>>
384    where
385        T: TableSchema,
386    {
387        let pk = T::primary_key();
388        let query = Query::builder().filter(filter).build();
389        let records = self.select::<T>(query)?;
390        let rows = records
391            .into_iter()
392            .map(|record| {
393                let values = record.to_values();
394                let pk_value = values
395                    .iter()
396                    .find(|(col_def, _)| col_def.name == pk)
397                    .expect("primary key not found")
398                    .1
399                    .clone();
400                (pk_value, values)
401            })
402            .collect();
403
404        Ok(rows)
405    }
406
407    /// Loads the table registry for a given table schema.
408    fn load_table_registry<T>(&self) -> DbmsResult<TableRegistry>
409    where
410        T: TableSchema,
411    {
412        let sr = self.ctx.schema_registry.borrow();
413        let registry_pages = sr
414            .table_registry_page::<T>()
415            .ok_or(DbmsError::Table(TableError::TableNotFound))?;
416
417        let mut mm = self.ctx.mm.borrow_mut();
418        TableRegistry::load(registry_pages, &mut *mm).map_err(DbmsError::from)
419    }
420
421    /// Sorts query results by a column.
422    fn sort_query_results(
423        &self,
424        results: &mut [TableColumns],
425        column: &str,
426        direction: OrderDirection,
427    ) {
428        results.sort_by(|a, b| {
429            fn get_value<'a>(
430                values: &'a [(ValuesSource, Vec<(ColumnDef, Value)>)],
431                column: &str,
432            ) -> Option<&'a Value> {
433                values
434                    .iter()
435                    .find(|(source, _)| *source == ValuesSource::This)
436                    .and_then(|(_, cols)| {
437                        cols.iter()
438                            .find(|(col_def, _)| col_def.name == column)
439                            .map(|(_, value)| value)
440                    })
441            }
442
443            let a_value = get_value(a, column);
444            let b_value = get_value(b, column);
445
446            sort_values_with_direction(a_value, b_value, direction)
447        });
448    }
449
450    fn execute_index_plan<MA>(
451        &self,
452        reader: &IndexReader<'_>,
453        plan: &IndexPlan,
454        mm: &mut MA,
455    ) -> DbmsResult<IndexSearchResult>
456    where
457        MA: MemoryAccess,
458    {
459        let columns = [plan.column()];
460        match plan {
461            IndexPlan::Eq { value, .. } => {
462                let key = [value.clone()];
463                reader
464                    .search_eq(&columns, &key, mm)
465                    .map_err(DbmsError::from)
466            }
467            IndexPlan::Range { start, end, .. } => {
468                let start_key = start.as_ref().map(|value| vec![value.clone()]);
469                let end_key = end.as_ref().map(|value| vec![value.clone()]);
470                reader
471                    .search_range(&columns, start_key.as_deref(), end_key.as_deref(), mm)
472                    .map_err(DbmsError::from)
473            }
474            IndexPlan::In { values, .. } => {
475                let keys: Vec<Vec<Value>> =
476                    values.iter().cloned().map(|value| vec![value]).collect();
477                reader
478                    .search_in(&columns, &keys, mm)
479                    .map_err(DbmsError::from)
480            }
481        }
482    }
483
484    #[expect(
485        clippy::type_complexity,
486        reason = "complex return type is necessary for returning addresses and overlay PKs"
487    )]
488    fn try_index_select<T>(
489        &self,
490        query: &Query,
491        table_registry: &TableRegistry,
492        table_overlay: &DatabaseOverlay,
493    ) -> DbmsResult<Option<Vec<Vec<(ColumnDef, Value)>>>>
494    where
495        T: TableSchema,
496    {
497        let Some(filter) = &query.filter else {
498            return Ok(None);
499        };
500
501        let Some(analyzed) = analyze_filter(filter, T::indexes()) else {
502            return Ok(None);
503        };
504
505        let mut mm = self.ctx.mm.borrow_mut();
506        let reader = IndexReader::new(
507            table_registry.index_ledger(),
508            table_overlay.index_overlay(T::table_name()),
509        );
510        let search_result = self.execute_index_plan(&reader, &analyzed.plan, &mut *mm)?;
511
512        let mut indexed_rows = Vec::new();
513        let pk_name = T::primary_key();
514
515        for address in &search_result.addresses {
516            let record: T = table_registry
517                .read_at(*address, &mut *mm)
518                .map_err(DbmsError::from)?;
519            let values = record.to_values();
520            let Some(pk) = values
521                .iter()
522                .find(|(column, _)| column.name == pk_name)
523                .map(|(_, value)| value)
524            else {
525                continue;
526            };
527
528            if search_result.removed_pks.contains(pk) || search_result.overlay_pks.contains(pk) {
529                continue;
530            }
531
532            if let Some(remaining_filter) = &analyzed.remaining_filter
533                && !self.record_matches_filter(&values, remaining_filter)?
534            {
535                continue;
536            }
537
538            indexed_rows.push(values);
539        }
540
541        if let Some(overlay) = table_overlay.table_overlay(T::table_name()) {
542            let mut pending_overlay_pks = search_result.overlay_pks.clone();
543
544            for row in overlay.iter_inserted() {
545                let Some(pk) = row
546                    .iter()
547                    .find(|(column, _)| column.name == pk_name)
548                    .map(|(_, value)| value)
549                else {
550                    continue;
551                };
552
553                if !pending_overlay_pks.remove(pk) {
554                    continue;
555                }
556                if let Some(remaining_filter) = &analyzed.remaining_filter
557                    && !self.record_matches_filter(&row, remaining_filter)?
558                {
559                    continue;
560                }
561
562                indexed_rows.push(row);
563            }
564
565            if !pending_overlay_pks.is_empty() {
566                let pk_reader = IndexReader::new(table_registry.index_ledger(), None);
567                let pk_columns = [T::primary_key()];
568
569                for pk in pending_overlay_pks {
570                    let pk_key = [pk];
571                    let pk_lookup = pk_reader.search_eq(&pk_columns, &pk_key, &mut *mm)?;
572                    for address in pk_lookup.addresses {
573                        let record: T = table_registry
574                            .read_at(address, &mut *mm)
575                            .map_err(DbmsError::from)?;
576                        let values = record.to_values();
577                        let Some(patched_values) = overlay.patch_row(values) else {
578                            continue;
579                        };
580
581                        if let Some(remaining_filter) = &analyzed.remaining_filter
582                            && !self.record_matches_filter(&patched_values, remaining_filter)?
583                        {
584                            continue;
585                        }
586
587                        indexed_rows.push(patched_values);
588                    }
589                }
590            }
591        }
592
593        Ok(Some(indexed_rows))
594    }
595
596    /// Core select logic returning intermediate `TableColumns`.
597    #[doc(hidden)]
598    pub fn select_columns<T>(&self, query: Query) -> DbmsResult<Vec<TableColumns>>
599    where
600        T: TableSchema,
601    {
602        let table_registry = self.load_table_registry::<T>()?;
603        let mut table_overlay = if self.transaction.is_some() {
604            self.overlay()?
605        } else {
606            DatabaseOverlay::default()
607        };
608
609        let mut results = Vec::with_capacity(query.limit.unwrap_or(DEFAULT_SELECT_CAPACITY));
610        // When ORDER BY is present, LIMIT and OFFSET must be applied after sorting
611        // to comply with standard SQL semantics (ORDER BY -> OFFSET -> LIMIT).
612        let has_order_by = !query.order_by.is_empty();
613        let mut count = 0;
614
615        if let Some(indexed_rows) =
616            self.try_index_select::<T>(&query, &table_registry, &table_overlay)?
617        {
618            for values in indexed_rows {
619                if !has_order_by {
620                    count += 1;
621                    if query.offset.is_some_and(|offset| count <= offset) {
622                        continue;
623                    }
624                }
625                results.push(vec![(ValuesSource::This, values)]);
626                if !has_order_by && query.limit.is_some_and(|limit| results.len() >= limit) {
627                    break;
628                }
629            }
630        } else {
631            let mut mm = self.ctx.mm.borrow_mut();
632            let table_reader = table_registry.read::<T, _>(&mut *mm);
633            let mut table_reader = table_overlay.reader(table_reader);
634
635            while let Some(values) = table_reader.try_next()? {
636                if let Some(filter) = &query.filter
637                    && !self.record_matches_filter(&values, filter)?
638                {
639                    continue;
640                }
641                if !has_order_by {
642                    count += 1;
643                    if query.offset.is_some_and(|offset| count <= offset) {
644                        continue;
645                    }
646                }
647                results.push(vec![(ValuesSource::This, values)]);
648                if !has_order_by && query.limit.is_some_and(|limit| results.len() >= limit) {
649                    break;
650                }
651            }
652        }
653
654        self.batch_load_eager_relations::<T>(&mut results, &query)?;
655        self.apply_column_selection::<T>(&mut results, &query);
656
657        for (column, direction) in query.order_by.into_iter().rev() {
658            self.sort_query_results(&mut results, &column, direction);
659        }
660
661        // Apply OFFSET and LIMIT after sorting when ORDER BY was present
662        if has_order_by {
663            let offset = query.offset.unwrap_or_default();
664            if offset > 0 {
665                if offset >= results.len() {
666                    results.clear();
667                } else {
668                    results.drain(..offset);
669                }
670            }
671            if let Some(limit) = query.limit {
672                results.truncate(limit);
673            }
674        }
675
676        Ok(results)
677    }
678
679    /// Executes a join query.
680    fn select_join_inner(
681        &self,
682        table: &str,
683        query: Query,
684    ) -> DbmsResult<Vec<Vec<(JoinColumnDef, Value)>>> {
685        self.schema.select_join(self, table, query)
686    }
687
688    /// Updates primary key references in tables referencing the updated table.
689    fn update_pk_referencing_updated_table<T>(
690        &self,
691        old_pk: Value,
692        new_pk: Value,
693        data_type: DataTypeKind,
694        pk_name: &'static str,
695    ) -> DbmsResult<u64>
696    where
697        T: TableSchema,
698    {
699        let mut count = 0;
700        for (ref_table, ref_col) in self
701            .schema
702            .referenced_tables(T::table_name())
703            .into_iter()
704            .flat_map(|(ref_table, ref_cols)| {
705                ref_cols
706                    .into_iter()
707                    .map(move |ref_col| (ref_table, ref_col))
708            })
709        {
710            let ref_patch_value = (
711                ColumnDef {
712                    name: ref_col,
713                    data_type,
714                    auto_increment: false,
715                    nullable: false,
716                    primary_key: false,
717                    unique: false,
718                    foreign_key: Some(ForeignKeyDef {
719                        foreign_table: T::table_name(),
720                        foreign_column: pk_name,
721                        local_column: ref_col,
722                    }),
723                },
724                new_pk.clone(),
725            );
726            let filter = Filter::eq(ref_col, old_pk.clone());
727
728            count += self
729                .schema
730                .update(self, ref_table, &[ref_patch_value], Some(filter))?;
731        }
732
733        Ok(count)
734    }
735
736    /// Sanitizes values using the table schema's sanitizers.
737    fn sanitize_values<T>(
738        &self,
739        values: Vec<(ColumnDef, Value)>,
740    ) -> DbmsResult<Vec<(ColumnDef, Value)>>
741    where
742        T: TableSchema,
743    {
744        let mut sanitized_values = Vec::with_capacity(values.len());
745        for (col_def, value) in values.into_iter() {
746            let value = match T::sanitizer(col_def.name) {
747                Some(sanitizer) => sanitizer.sanitize(value)?,
748                None => value,
749            };
750            sanitized_values.push((col_def, value));
751        }
752        Ok(sanitized_values)
753    }
754
755    /// Collects all records matching a filter from the table registry.
756    #[allow(clippy::type_complexity)]
757    fn collect_matching_records<T>(
758        &self,
759        table_registry: &TableRegistry,
760        filter: &Option<Filter>,
761    ) -> DbmsResult<Vec<(NextRecord<T>, Vec<(ColumnDef, Value)>)>>
762    where
763        T: TableSchema,
764    {
765        let mut mm = self.ctx.mm.borrow_mut();
766
767        // `collect_matching_records` is only used by the non-transactional update/delete paths.
768        // Transactional mutations first resolve rows via `existing_rows_for_filter`, which reads
769        // through `select()` and therefore includes the overlay. Using `overlay = None` here is
770        // intentional because the atomic write path is operating on committed storage only.
771        if let Some(filter) = filter
772            && let Some(analyzed) = analyze_filter(filter, T::indexes())
773        {
774            let reader = IndexReader::new(table_registry.index_ledger(), None);
775            let search_result = self.execute_index_plan(&reader, &analyzed.plan, &mut *mm)?;
776
777            let mut records = Vec::new();
778            for address in search_result.addresses {
779                let record: T = table_registry
780                    .read_at(address, &mut *mm)
781                    .map_err(DbmsError::from)?;
782                let record_values = record.clone().to_values();
783                if let Some(remaining_filter) = &analyzed.remaining_filter
784                    && !self.record_matches_filter(&record_values, remaining_filter)?
785                {
786                    continue;
787                }
788                records.push((
789                    NextRecord {
790                        record,
791                        page: address.page,
792                        offset: address.offset,
793                    },
794                    record_values,
795                ));
796            }
797
798            return Ok(records);
799        }
800
801        let mut table_reader = table_registry.read::<T, _>(&mut *mm);
802        let mut records = vec![];
803        while let Some(values) = table_reader.try_next()? {
804            let record_values = values.record.clone().to_values();
805            if let Some(filter) = filter
806                && !self.record_matches_filter(&record_values, filter)?
807            {
808                continue;
809            }
810            records.push((values, record_values));
811        }
812        Ok(records)
813    }
814
815    /// For each indexed column for the table, inserts the index for the given record address.
816    fn insert_index<T>(
817        &self,
818        table_registry: &mut TableRegistry,
819        record_address: RecordAddress,
820        values: &[(ColumnDef, Value)],
821        mm: &mut impl wasm_dbms_memory::MemoryAccess,
822    ) -> DbmsResult<()>
823    where
824        T: TableSchema,
825    {
826        let index_ledger = table_registry.index_ledger_mut();
827        for columns in T::indexes().iter().map(|index| index.columns()) {
828            let key = index_key(columns, values);
829            index_ledger.insert(columns, key, record_address, mm)?;
830        }
831
832        Ok(())
833    }
834
835    /// For each indexed column for the table, deletes the index for the given record address.
836    fn delete_index<T>(
837        &self,
838        table_registry: &mut TableRegistry,
839        record_address: RecordAddress,
840        values: &[(ColumnDef, Value)],
841        mm: &mut impl wasm_dbms_memory::MemoryAccess,
842    ) -> DbmsResult<()>
843    where
844        T: TableSchema,
845    {
846        let index_ledger = table_registry.index_ledger_mut();
847        for columns in T::indexes().iter().map(|index| index.columns()) {
848            let key = index_key(columns, values);
849            index_ledger.delete(columns, &key, record_address, mm)?;
850        }
851        Ok(())
852    }
853
854    /// For each indexed column for the table, updates the index for the given record address.
855    ///
856    /// When an indexed column's value changed, the old key is deleted and the new key is inserted.
857    /// When only the record address moved (same key), the pointer is updated in place.
858    fn update_index<T>(
859        &self,
860        table_registry: &mut TableRegistry,
861        old_record_address: RecordAddress,
862        new_record_address: RecordAddress,
863        old_values: &[(ColumnDef, Value)],
864        new_values: &[(ColumnDef, Value)],
865        mm: &mut impl wasm_dbms_memory::MemoryAccess,
866    ) -> DbmsResult<()>
867    where
868        T: TableSchema,
869    {
870        let index_ledger = table_registry.index_ledger_mut();
871        for columns in T::indexes().iter().map(|index| index.columns()) {
872            let old_key = index_key(columns, old_values);
873            let new_key = index_key(columns, new_values);
874            if old_key == new_key {
875                index_ledger.update(
876                    columns,
877                    &new_key,
878                    old_record_address,
879                    new_record_address,
880                    mm,
881                )?;
882            } else {
883                index_ledger.delete(columns, &old_key, old_record_address, mm)?;
884                index_ledger.insert(columns, new_key, new_record_address, mm)?;
885            }
886        }
887        Ok(())
888    }
889
890    /// Fills in auto-increment values for columns that are missing from the input.
891    fn fill_auto_increment_values<T>(
892        &self,
893        table_registry: &mut TableRegistry,
894        mut values: Vec<(ColumnDef, Value)>,
895    ) -> DbmsResult<Vec<(ColumnDef, Value)>>
896    where
897        T: TableSchema,
898    {
899        let mut mm = self.ctx.mm.borrow_mut();
900        // iter over auto-increment columns, for each of them check if the value is provided, if not get the next auto-increment value.
901        for auto_increment_column in T::columns().iter().filter(|col| col.auto_increment) {
902            if values
903                .iter()
904                .any(|(col_def, _)| col_def.name == auto_increment_column.name)
905            {
906                continue;
907            }
908            let next_value = table_registry
909                .next_autoincrement(auto_increment_column.name, &mut *mm)?
910                .ok_or(DbmsError::Table(TableError::SchemaMismatch))?;
911            values.push((*auto_increment_column, next_value));
912        }
913
914        Ok(values)
915    }
916}
917
918/// Provides ordering for two optional values by direction.
919pub fn sort_values_with_direction(
920    a: Option<&Value>,
921    b: Option<&Value>,
922    direction: OrderDirection,
923) -> Ordering {
924    match (a, b) {
925        (Some(a_val), Some(b_val)) => match direction {
926            OrderDirection::Ascending => a_val.cmp(b_val),
927            OrderDirection::Descending => b_val.cmp(a_val),
928        },
929        (Some(_), None) => std::cmp::Ordering::Greater,
930        (None, Some(_)) => std::cmp::Ordering::Less,
931        (None, None) => std::cmp::Ordering::Equal,
932    }
933}
934
935/// Converts column-value pairs to a schema entity.
936fn values_to_schema_entity<T>(values: Vec<(ColumnDef, Value)>) -> DbmsResult<T>
937where
938    T: TableSchema,
939{
940    let record = T::Insert::from_values(&values)?.into_record();
941    Ok(record)
942}
943
944/// Builds the index key for the given columns by extracting values from the record.
945///
946/// Columns not found in `values` default to [`Value::Null`].
947fn index_key(columns: &[&str], values: &[(ColumnDef, Value)]) -> Vec<Value> {
948    columns
949        .iter()
950        .map(|col| {
951            values
952                .iter()
953                .find(|(cd, _)| cd.name == *col)
954                .map(|(_, v)| v.clone())
955                .unwrap_or(Value::Null)
956        })
957        .collect()
958}
959
960impl<M, A> Database for WasmDbmsDatabase<'_, M, A>
961where
962    M: MemoryProvider,
963    A: AccessControl,
964{
965    fn select<T>(&self, query: Query) -> DbmsResult<Vec<T::Record>>
966    where
967        T: TableSchema,
968    {
969        if !query.joins.is_empty() {
970            return Err(DbmsError::Query(QueryError::JoinInsideTypedSelect));
971        }
972        let results = self.select_columns::<T>(query)?;
973        Ok(results.into_iter().map(T::Record::from_values).collect())
974    }
975
976    fn select_raw(&self, table: &str, query: Query) -> DbmsResult<Vec<Vec<(ColumnDef, Value)>>> {
977        self.schema.select(self, table, query)
978    }
979
980    fn select_join(
981        &self,
982        table: &str,
983        query: Query,
984    ) -> DbmsResult<Vec<Vec<(JoinColumnDef, Value)>>> {
985        self.select_join_inner(table, query)
986    }
987
988    fn insert<T>(&self, record: T::Insert) -> DbmsResult<()>
989    where
990        T: TableSchema,
991        T::Insert: InsertRecord<Schema = T>,
992    {
993        let mut table_registry = self.load_table_registry::<T>()?;
994        let record_values = record.clone().into_values();
995        let record_values =
996            self.fill_auto_increment_values::<T>(&mut table_registry, record_values)?;
997        let sanitized_values = self.sanitize_values::<T>(record_values)?;
998        self.schema
999            .validate_insert(self, T::table_name(), &sanitized_values)?;
1000        if self.transaction.is_some() {
1001            self.with_transaction_mut(|tx| tx.insert::<T>(sanitized_values))?;
1002        } else {
1003            self.atomic(|db| {
1004                let record = T::Insert::from_values(&sanitized_values)?;
1005                let mut mm = db.ctx.mm.borrow_mut();
1006                // update journal with the insert operation before mutating memory
1007                let mut journal_ref = db.ctx.journal.borrow_mut();
1008                let journal = journal_ref
1009                    .as_mut()
1010                    .expect("journal must be active inside atomic");
1011                let mut writer = JournaledWriter::new(&mut *mm, journal);
1012                // insert the record in the table registry, and eventually update the indexes
1013                let record_address = table_registry
1014                    .insert(record.into_record(), &mut writer)
1015                    .map_err(DbmsError::from)?;
1016                self.insert_index::<T>(
1017                    &mut table_registry,
1018                    record_address,
1019                    &sanitized_values,
1020                    &mut writer,
1021                )?;
1022                Ok(())
1023            })?;
1024        }
1025
1026        Ok(())
1027    }
1028
1029    fn update<T>(&self, patch: T::Update) -> DbmsResult<u64>
1030    where
1031        T: TableSchema,
1032        T::Update: UpdateRecord<Schema = T>,
1033    {
1034        let filter = patch.where_clause().clone();
1035        if self.transaction.is_some() {
1036            let rows = self.existing_rows_for_filter::<T>(filter.clone())?;
1037            let count = rows.len() as u64;
1038            self.with_transaction_mut(|tx| tx.update::<T>(patch, filter, rows))?;
1039
1040            return Ok(count);
1041        }
1042
1043        let patch = patch.update_values();
1044
1045        let pk_in_patch = patch.iter().find_map(|(col_def, value)| {
1046            if col_def.primary_key {
1047                Some((col_def, value))
1048            } else {
1049                None
1050            }
1051        });
1052
1053        self.atomic(|db| {
1054            let mut count = 0;
1055
1056            let mut table_registry = db.load_table_registry::<T>()?;
1057            let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
1058
1059            for (record, record_values) in records {
1060                let current_pk_value = record_values
1061                    .iter()
1062                    .find(|(col_def, _)| col_def.primary_key)
1063                    .expect("primary key not found")
1064                    .1
1065                    .clone();
1066
1067                let previous_record = values_to_schema_entity::<T>(record_values.clone())?;
1068                let old_values_for_index = record_values.clone();
1069                let mut record_values = record_values;
1070
1071                for (patch_col_def, patch_value) in &patch {
1072                    if let Some((_, record_value)) = record_values
1073                        .iter_mut()
1074                        .find(|(record_col_def, _)| record_col_def.name == patch_col_def.name)
1075                    {
1076                        *record_value = patch_value.clone();
1077                    }
1078                }
1079                let record_values = db.sanitize_values::<T>(record_values)?;
1080                db.schema.validate_update(
1081                    db,
1082                    T::table_name(),
1083                    &record_values,
1084                    current_pk_value.clone(),
1085                )?;
1086                let updated_record = values_to_schema_entity::<T>(record_values.clone())?;
1087                {
1088                    let mut mm = db.ctx.mm.borrow_mut();
1089                    // update journal with the update operation before mutating memory
1090                    let mut journal_ref = db.ctx.journal.borrow_mut();
1091                    let journal = journal_ref
1092                        .as_mut()
1093                        .expect("journal must be active inside atomic");
1094                    let mut writer = JournaledWriter::new(&mut *mm, journal);
1095                    // update table registry
1096                    let old_address = RecordAddress::new(record.page, record.offset);
1097                    let new_address = table_registry
1098                        .update(updated_record, previous_record, old_address, &mut writer)
1099                        .map_err(DbmsError::from)?;
1100                    // update indexes if needed
1101                    self.update_index::<T>(
1102                        &mut table_registry,
1103                        old_address,
1104                        new_address,
1105                        &old_values_for_index,
1106                        &record_values,
1107                        &mut writer,
1108                    )?;
1109                }
1110                count += 1;
1111
1112                if let Some((pk_column, new_pk_value)) = pk_in_patch {
1113                    count += db.update_pk_referencing_updated_table::<T>(
1114                        current_pk_value,
1115                        new_pk_value.clone(),
1116                        pk_column.data_type,
1117                        pk_column.name,
1118                    )?;
1119                }
1120            }
1121
1122            Ok(count)
1123        })
1124    }
1125
1126    fn delete<T>(&self, behaviour: DeleteBehavior, filter: Option<Filter>) -> DbmsResult<u64>
1127    where
1128        T: TableSchema,
1129    {
1130        if self.transaction.is_some() {
1131            let rows = self.existing_rows_for_filter::<T>(filter.clone())?;
1132            let count = rows.len() as u64;
1133
1134            self.with_transaction_mut(|tx| tx.delete::<T>(behaviour, filter, rows))?;
1135
1136            return Ok(count);
1137        }
1138
1139        self.atomic(|db| {
1140            let mut table_registry = db.load_table_registry::<T>()?;
1141            let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
1142            let mut count = records.len() as u64;
1143            for (record, record_values) in records {
1144                match behaviour {
1145                    DeleteBehavior::Cascade => {
1146                        count += db.delete_foreign_keys_cascade::<T>(&record_values)?;
1147                    }
1148                    DeleteBehavior::Restrict => {
1149                        if db.has_foreign_key_references::<T>(&record_values)? {
1150                            return Err(DbmsError::Query(
1151                                QueryError::ForeignKeyConstraintViolation {
1152                                    referencing_table: T::table_name().to_string(),
1153                                    field: T::primary_key().to_string(),
1154                                },
1155                            ));
1156                        }
1157                    }
1158                }
1159                let mut mm = db.ctx.mm.borrow_mut();
1160                let mut journal_ref = db.ctx.journal.borrow_mut();
1161                let journal = journal_ref
1162                    .as_mut()
1163                    .expect("journal must be active inside atomic");
1164                // write table and index deletions to the journal before mutating memory
1165                let mut writer = JournaledWriter::new(&mut *mm, journal);
1166                let address = RecordAddress::new(record.page, record.offset);
1167                table_registry
1168                    .delete(record.record, address, &mut writer)
1169                    .map_err(DbmsError::from)?;
1170                self.delete_index::<T>(&mut table_registry, address, &record_values, &mut writer)?;
1171            }
1172
1173            Ok(count)
1174        })
1175    }
1176
1177    fn commit(&mut self) -> DbmsResult<()> {
1178        let Some(txid) = self.transaction.take() else {
1179            return Err(DbmsError::Transaction(
1180                TransactionError::NoActiveTransaction,
1181            ));
1182        };
1183        let transaction = {
1184            let mut ts = self.ctx.transaction_session.borrow_mut();
1185            ts.take_transaction(&txid)?
1186        };
1187
1188        *self.ctx.journal.borrow_mut() = Some(Journal::new());
1189
1190        for op in transaction.operations {
1191            let result = match op {
1192                TransactionOp::Insert { table, values } => self
1193                    .schema
1194                    .validate_insert(self, table, &values)
1195                    .and_then(|()| self.schema.insert(self, table, &values)),
1196                TransactionOp::Delete {
1197                    table,
1198                    behaviour,
1199                    filter,
1200                } => self
1201                    .schema
1202                    .delete(self, table, behaviour, filter)
1203                    .map(|_| ()),
1204                TransactionOp::Update {
1205                    table,
1206                    patch,
1207                    filter,
1208                } => self.schema.update(self, table, &patch, filter).map(|_| ()),
1209            };
1210
1211            if let Err(err) = result {
1212                if let Some(journal) = self.ctx.journal.borrow_mut().take() {
1213                    journal
1214                        .rollback(&mut self.ctx.mm.borrow_mut())
1215                        .expect("critical: failed to rollback journal");
1216                }
1217                return Err(err);
1218            }
1219        }
1220
1221        if let Some(journal) = self.ctx.journal.borrow_mut().take() {
1222            journal.commit();
1223        }
1224        Ok(())
1225    }
1226
1227    fn rollback(&mut self) -> DbmsResult<()> {
1228        let Some(txid) = self.transaction.take() else {
1229            return Err(DbmsError::Transaction(
1230                TransactionError::NoActiveTransaction,
1231            ));
1232        };
1233
1234        let mut ts = self.ctx.transaction_session.borrow_mut();
1235        ts.close_transaction(&txid);
1236        Ok(())
1237    }
1238}
1239
1240#[cfg(test)]
1241mod tests;