Skip to main content

wasm_dbms/
database.rs

1//! Core DBMS database struct providing CRUD and transaction operations.
2
3mod aggregate;
4mod filter_analyzer;
5mod index_reader;
6mod migration;
7
8use std::cmp::Ordering;
9use std::collections::HashSet;
10
11use wasm_dbms_api::prelude::{
12    AggregateFunction, AggregatedRow, ColumnDef, DataTypeKind, Database, DbmsError, DbmsResult,
13    DeleteBehavior, Filter, ForeignFetcher, ForeignKeyDef, InsertRecord, JoinColumnDef,
14    MigrationError, MigrationOp, MigrationPolicy, OrderDirection, Query, QueryError, TableColumns,
15    TableError, TableRecord, TableSchema, TransactionError, TransactionId, UpdateRecord, Value,
16    ValuesSource,
17};
18use wasm_dbms_memory::RecordAddress;
19use wasm_dbms_memory::prelude::{
20    AccessControl, AccessControlList, MemoryAccess, MemoryProvider, NextRecord, TableRegistry,
21};
22
23use self::filter_analyzer::{IndexPlan, analyze_filter};
24use self::index_reader::{IndexReader, IndexSearchResult};
25use crate::context::DbmsContext;
26use crate::database::migration::snapshots;
27use crate::schema::DatabaseSchema;
28use crate::transaction::journal::{Journal, JournaledWriter};
29use crate::transaction::{DatabaseOverlay, Transaction, TransactionOp};
30
31/// Default capacity for SELECT queries.
32const DEFAULT_SELECT_CAPACITY: usize = 128;
33
34fn prime_drift_cache<M, A>(ctx: &DbmsContext<M, A>, schema: &dyn DatabaseSchema<M, A>)
35where
36    M: MemoryProvider,
37    A: AccessControl,
38{
39    let compiled_hash = snapshots::compute_hash(schema.compiled_snapshots_dyn());
40    let drifted = ctx.schema_registry.borrow().schema_hash() != compiled_hash;
41    ctx.set_drift(compiled_hash, drifted);
42}
43
44/// The main DBMS database struct, generic over `MemoryProvider` and
45/// `AccessControl`.
46///
47/// This struct borrows from a [`DbmsContext`] and provides all CRUD
48/// operations, transaction management, and query execution.
49pub struct WasmDbmsDatabase<'ctx, M, A = AccessControlList>
50where
51    M: MemoryProvider,
52    A: AccessControl,
53{
54    /// Reference to the DBMS context owning all state.
55    ctx: &'ctx DbmsContext<M, A>,
56    /// Schema for dynamic dispatch of table operations.
57    schema: Box<dyn DatabaseSchema<M, A> + 'ctx>,
58    /// Active transaction ID, if any.
59    transaction: Option<TransactionId>,
60}
61
62impl<'ctx, M, A> WasmDbmsDatabase<'ctx, M, A>
63where
64    M: MemoryProvider,
65    A: AccessControl,
66{
67    /// Creates a one-shot (non-transactional) database instance.
68    pub fn oneshot(ctx: &'ctx DbmsContext<M, A>, schema: impl DatabaseSchema<M, A> + 'ctx) -> Self {
69        let schema = Box::new(schema);
70        prime_drift_cache(ctx, schema.as_ref());
71        Self {
72            ctx,
73            schema,
74            transaction: None,
75        }
76    }
77
78    /// Creates a transactional database instance.
79    pub fn from_transaction(
80        ctx: &'ctx DbmsContext<M, A>,
81        schema: impl DatabaseSchema<M, A> + 'ctx,
82        transaction_id: TransactionId,
83    ) -> Self {
84        let schema = Box::new(schema);
85        prime_drift_cache(ctx, schema.as_ref());
86        Self {
87            ctx,
88            schema,
89            transaction: Some(transaction_id),
90        }
91    }
92
93    /// Executes a closure with a mutable reference to the current transaction.
94    fn with_transaction_mut<F, R>(&self, f: F) -> DbmsResult<R>
95    where
96        F: FnOnce(&mut Transaction) -> DbmsResult<R>,
97    {
98        let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
99            TransactionError::NoActiveTransaction,
100        ))?;
101
102        let mut ts = self.ctx.transaction_session.borrow_mut();
103        let tx = ts.get_transaction_mut(txid)?;
104        f(tx)
105    }
106
107    /// Executes a closure with a reference to the current transaction.
108    fn with_transaction<F, R>(&self, f: F) -> DbmsResult<R>
109    where
110        F: FnOnce(&Transaction) -> DbmsResult<R>,
111    {
112        let txid = self.transaction.as_ref().ok_or(DbmsError::Transaction(
113            TransactionError::NoActiveTransaction,
114        ))?;
115
116        let ts = self.ctx.transaction_session.borrow();
117        let tx = ts.get_transaction(txid)?;
118        f(tx)
119    }
120
121    // --- ACL surface ------------------------------------------------------
122
123    /// Returns whether `id` is granted `required` on `table`.
124    pub fn granted(
125        &self,
126        id: &A::Id,
127        table: wasm_dbms_api::prelude::TableFingerprint,
128        required: wasm_dbms_api::prelude::TablePerms,
129    ) -> bool {
130        self.ctx.granted(id, table, required)
131    }
132
133    /// Returns whether `id` carries the `admin` bypass flag.
134    pub fn granted_admin(&self, id: &A::Id) -> bool {
135        self.ctx.granted_admin(id)
136    }
137
138    /// Returns whether `id` carries the `manage_acl` flag.
139    pub fn granted_manage_acl(&self, id: &A::Id) -> bool {
140        self.ctx.granted_manage_acl(id)
141    }
142
143    /// Returns whether `id` carries the `migrate` flag.
144    pub fn granted_migrate(&self, id: &A::Id) -> bool {
145        self.ctx.granted_migrate(id)
146    }
147
148    /// Applies a grant on behalf of `caller`. Self-enforces `manage_acl`.
149    pub fn grant(
150        &self,
151        caller: &A::Id,
152        target: A::Id,
153        grant: wasm_dbms_api::prelude::PermGrant,
154    ) -> DbmsResult<()> {
155        if !self.ctx.granted_manage_acl(caller) {
156            return Err(DbmsError::AccessDenied {
157                table: None,
158                required: wasm_dbms_api::prelude::RequiredPerm::ManageAcl,
159            });
160        }
161        self.ctx.acl_grant(target, grant)
162    }
163
164    /// Applies a revoke on behalf of `caller`. Self-enforces `manage_acl`.
165    pub fn revoke(
166        &self,
167        caller: &A::Id,
168        target: &A::Id,
169        revoke: wasm_dbms_api::prelude::PermRevoke,
170    ) -> DbmsResult<()> {
171        if !self.ctx.granted_manage_acl(caller) {
172            return Err(DbmsError::AccessDenied {
173                table: None,
174                required: wasm_dbms_api::prelude::RequiredPerm::ManageAcl,
175            });
176        }
177        self.ctx.acl_revoke(target, revoke)
178    }
179
180    /// Removes `target` from the ACL on behalf of `caller`. Self-enforces
181    /// `manage_acl`.
182    pub fn remove_identity(&self, caller: &A::Id, target: &A::Id) -> DbmsResult<()> {
183        if !self.ctx.granted_manage_acl(caller) {
184            return Err(DbmsError::AccessDenied {
185                table: None,
186                required: wasm_dbms_api::prelude::RequiredPerm::ManageAcl,
187            });
188        }
189        self.ctx.acl_remove_identity(target)
190    }
191
192    /// Returns every identity in the ACL together with its perms, on
193    /// behalf of `caller`. Self-enforces `manage_acl`.
194    pub fn identities(
195        &self,
196        caller: &A::Id,
197    ) -> DbmsResult<Vec<(A::Id, wasm_dbms_api::prelude::IdentityPerms)>> {
198        if !self.ctx.granted_manage_acl(caller) {
199            return Err(DbmsError::AccessDenied {
200                table: None,
201                required: wasm_dbms_api::prelude::RequiredPerm::ManageAcl,
202            });
203        }
204        Ok(self.ctx.acl_identities())
205    }
206
207    /// Returns the cached drift flag, computing and caching it on first call.
208    ///
209    /// `O(tables × snapshot bytes)` on the first invocation; `O(1)` thereafter.
210    /// Cleared by [`Self::migrate`] on success so subsequent CRUD calls
211    /// recompute against the new state.
212    fn drift_check(&self) -> DbmsResult<bool> {
213        if self.ctx.is_migrating() {
214            return Ok(false);
215        }
216        let compiled_hash = snapshots::compute_hash(self.schema.compiled_snapshots_dyn());
217        if let Some(cached) = self.ctx.cached_drift_for(compiled_hash) {
218            return Ok(cached);
219        }
220        let drifted = self.ctx.schema_registry.borrow().schema_hash() != compiled_hash;
221        self.ctx.set_drift(compiled_hash, drifted);
222        Ok(drifted)
223    }
224
225    /// Returns `Err(SchemaDrift)` when the cached drift flag is set.
226    ///
227    /// Called as the first line of every CRUD entry point on the
228    /// [`Database`] trait so reads and writes refuse to proceed against a
229    /// schema the storage layer cannot interpret.
230    fn ensure_no_drift(&self) -> DbmsResult<()> {
231        if self.drift_check()? {
232            return Err(DbmsError::Migration(MigrationError::SchemaDrift));
233        }
234        Ok(())
235    }
236
237    /// Executes a closure atomically using a write-ahead journal.
238    ///
239    /// All writes performed inside `f` are recorded. On success the journal
240    /// is committed (entries discarded). On error the journal is rolled back,
241    /// restoring every modified byte to its pre-call state.
242    ///
243    /// When a journal is already active (e.g., inside [`Database::commit`]),
244    /// this method delegates to the outer journal and does not manage its own.
245    ///
246    /// # Panics
247    ///
248    /// Panics if the rollback itself fails, because a failed rollback leaves
249    /// memory in an irrecoverably corrupt state (M-PANIC-ON-BUG).
250    fn atomic<F, R>(&self, f: F) -> DbmsResult<R>
251    where
252        F: FnOnce(&WasmDbmsDatabase<'ctx, M, A>) -> DbmsResult<R>,
253    {
254        let nested = self.ctx.journal.borrow().is_some();
255        if !nested {
256            *self.ctx.journal.borrow_mut() = Some(Journal::new());
257        }
258        match f(self) {
259            Ok(res) => {
260                if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
261                    journal.commit();
262                }
263                Ok(res)
264            }
265            Err(err) => {
266                if !nested && let Some(journal) = self.ctx.journal.borrow_mut().take() {
267                    journal
268                        .rollback(&mut self.ctx.mm.borrow_mut())
269                        .expect("critical: failed to rollback journal");
270                }
271                Err(err)
272            }
273        }
274    }
275
276    /// Checks whether any foreign key references exist for the given record.
277    ///
278    /// Returns `true` if at least one referencing row exists in any table.
279    fn has_foreign_key_references<T>(
280        &self,
281        record_values: &[(ColumnDef, Value)],
282    ) -> DbmsResult<bool>
283    where
284        T: TableSchema,
285    {
286        let pk = Self::extract_pk::<T>(record_values)?;
287
288        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
289            for column in columns.iter() {
290                let filter = Filter::eq(column, pk.clone());
291                let query = Query::builder().field(column).filter(Some(filter)).build();
292                let rows = self.schema.select(self, table, query)?;
293                if !rows.is_empty() {
294                    return Ok(true);
295                }
296            }
297        }
298        Ok(false)
299    }
300
301    /// Deletes foreign key related records recursively for cascade deletes.
302    fn delete_foreign_keys_cascade<T>(
303        &self,
304        record_values: &[(ColumnDef, Value)],
305    ) -> DbmsResult<u64>
306    where
307        T: TableSchema,
308    {
309        let pk = Self::extract_pk::<T>(record_values)?;
310
311        let mut count = 0;
312        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
313            for column in columns.iter() {
314                let filter = Filter::eq(column, pk.clone());
315                let res = self
316                    .schema
317                    .delete(self, table, DeleteBehavior::Cascade, Some(filter))?;
318                count += res;
319            }
320        }
321        Ok(count)
322    }
323
324    /// Extracts the primary key value from a record's column-value pairs.
325    fn extract_pk<T>(record_values: &[(ColumnDef, Value)]) -> DbmsResult<Value>
326    where
327        T: TableSchema,
328    {
329        record_values
330            .iter()
331            .find(|(col_def, _)| col_def.primary_key)
332            .ok_or(DbmsError::Query(QueryError::UnknownColumn(
333                T::primary_key().to_string(),
334            )))
335            .map(|(_, v)| v.clone())
336    }
337
338    /// Retrieves the current overlay from the active transaction.
339    fn overlay(&self) -> DbmsResult<DatabaseOverlay> {
340        self.with_transaction(|tx| Ok(tx.overlay().clone()))
341    }
342
343    /// Returns whether the record matches the provided filter.
344    fn record_matches_filter(
345        &self,
346        record_values: &[(ColumnDef, Value)],
347        filter: &Filter,
348    ) -> DbmsResult<bool> {
349        filter.matches(record_values).map_err(DbmsError::from)
350    }
351
352    /// Removes duplicate records based on the values of the given columns.
353    ///
354    /// Keeps the first record encountered for each distinct combination of the
355    /// values of the specified columns. Columns missing from a record are
356    /// treated as [`Value::Null`]. When `distinct_by` is empty, no records are
357    /// removed.
358    fn apply_distinct(&self, results: &mut Vec<TableColumns>, distinct_by: &[String]) {
359        if distinct_by.is_empty() {
360            return;
361        }
362        let mut seen: HashSet<Vec<Value>> = HashSet::new();
363        results.retain(|record| {
364            let key: Vec<Value> = distinct_by
365                .iter()
366                .map(|col| {
367                    Self::this_columns(record)
368                        .and_then(|cols| {
369                            cols.iter()
370                                .find(|(cd, _)| cd.name == col.as_str())
371                                .map(|(_, v)| v.clone())
372                        })
373                        .unwrap_or(Value::Null)
374                })
375                .collect();
376            seen.insert(key)
377        });
378    }
379
380    /// Filters record columns down to only the selected fields.
381    fn apply_column_selection<T>(&self, results: &mut [TableColumns], query: &Query)
382    where
383        T: TableSchema,
384    {
385        if query.all_selected() {
386            return;
387        }
388        let selected_columns = query.columns::<T>();
389        results
390            .iter_mut()
391            .flat_map(|record| record.iter_mut())
392            .filter(|(source, _)| *source == ValuesSource::This)
393            .for_each(|(_, cols)| {
394                cols.retain(|(col_def, _)| selected_columns.contains(&col_def.name.to_string()));
395            });
396    }
397
398    /// Batch-fetches eager relations for collected results.
399    fn batch_load_eager_relations<T>(
400        &self,
401        results: &mut [TableColumns],
402        query: &Query,
403    ) -> DbmsResult<()>
404    where
405        T: TableSchema,
406    {
407        if query.eager_relations.is_empty() {
408            return Ok(());
409        }
410
411        let fetcher = T::foreign_fetcher();
412
413        for relation in &query.eager_relations {
414            let fk_columns = Self::collect_fk_values::<T>(results, relation)?;
415
416            for (local_column, pk_values) in &fk_columns {
417                let batch_map = fetcher.fetch_batch(self, relation, pk_values)?;
418
419                Self::verify_fk_batch(&batch_map, pk_values, relation)?;
420                Self::attach_foreign_data(results, &batch_map, relation, local_column);
421            }
422        }
423
424        Ok(())
425    }
426
427    /// Collects distinct FK values across all records for a given relation.
428    fn collect_fk_values<T>(
429        results: &[TableColumns],
430        relation: &str,
431    ) -> DbmsResult<Vec<(&'static str, Vec<Value>)>>
432    where
433        T: TableSchema,
434    {
435        let mut fk_columns: Vec<(&'static str, HashSet<Value>)> = vec![];
436
437        for record_columns in results {
438            let Some(cols) = Self::this_columns(record_columns) else {
439                continue;
440            };
441
442            let mut found_fk = false;
443            for (col_def, value) in cols {
444                let Some(fk) = &col_def.foreign_key else {
445                    continue;
446                };
447                if *fk.foreign_table != *relation {
448                    continue;
449                }
450
451                found_fk = true;
452                match fk_columns.iter_mut().find(|(lc, _)| *lc == fk.local_column) {
453                    Some((_, values)) => {
454                        values.insert(value.clone());
455                    }
456                    None => {
457                        let mut set = HashSet::new();
458                        set.insert(value.clone());
459                        fk_columns.push((fk.local_column, set));
460                    }
461                }
462            }
463
464            if !found_fk {
465                return Err(DbmsError::Query(QueryError::InvalidQuery(format!(
466                    "Cannot load relation '{relation}' for table '{}': no foreign key found",
467                    T::table_name()
468                ))));
469            }
470        }
471
472        Ok(fk_columns
473            .into_iter()
474            .map(|(col, set)| (col, set.into_iter().collect()))
475            .collect())
476    }
477
478    /// Verifies all FK values were found in the batch result.
479    fn verify_fk_batch(
480        batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
481        pk_values: &[Value],
482        relation: &str,
483    ) -> DbmsResult<()> {
484        if let Some(missing) = pk_values.iter().find(|v| !batch_map.contains_key(v)) {
485            return Err(DbmsError::Query(QueryError::BrokenForeignKeyReference {
486                table: relation.to_string(),
487                key: missing.clone(),
488            }));
489        }
490        Ok(())
491    }
492
493    /// Attaches batch-fetched foreign data to each record.
494    fn attach_foreign_data(
495        results: &mut [TableColumns],
496        batch_map: &std::collections::HashMap<Value, Vec<(ColumnDef, Value)>>,
497        relation: &str,
498        local_column: &str,
499    ) {
500        for record_columns in results.iter_mut() {
501            let fk_value = Self::this_columns(record_columns).and_then(|cols| {
502                cols.iter().find_map(|(col_def, value)| {
503                    let fk = col_def.foreign_key.as_ref()?;
504                    (fk.foreign_table == relation && fk.local_column == local_column)
505                        .then(|| value.clone())
506                })
507            });
508
509            let Some(fk_val) = fk_value else { continue };
510            let Some(foreign_values) = batch_map.get(&fk_val) else {
511                continue;
512            };
513
514            record_columns.push((
515                ValuesSource::Foreign {
516                    table: relation.to_string(),
517                    column: local_column.to_string(),
518                },
519                foreign_values.clone(),
520            ));
521        }
522    }
523
524    /// Extracts the `ValuesSource::This` columns from a record.
525    fn this_columns(
526        record: &[(ValuesSource, Vec<(ColumnDef, Value)>)],
527    ) -> Option<&Vec<(ColumnDef, Value)>> {
528        record
529            .iter()
530            .find(|(src, _)| *src == ValuesSource::This)
531            .map(|(_, cols)| cols)
532    }
533
534    /// Retrieves existing rows matching a filter, returning `(primary_key, full_row)` pairs.
535    #[expect(
536        clippy::type_complexity,
537        reason = "complex return type is necessary for returning both PK and full row data"
538    )]
539    fn existing_rows_for_filter<T>(
540        &self,
541        filter: Option<Filter>,
542    ) -> DbmsResult<Vec<(Value, Vec<(ColumnDef, Value)>)>>
543    where
544        T: TableSchema,
545    {
546        let pk = T::primary_key();
547        let query = Query::builder().filter(filter).build();
548        let records = self.select::<T>(query)?;
549        let rows = records
550            .into_iter()
551            .map(|record| {
552                let values = record.to_values();
553                let pk_value = values
554                    .iter()
555                    .find(|(col_def, _)| col_def.name == pk)
556                    .expect("primary key not found")
557                    .1
558                    .clone();
559                (pk_value, values)
560            })
561            .collect();
562
563        Ok(rows)
564    }
565
566    /// Loads the table registry for a given table schema.
567    fn load_table_registry<T>(&self) -> DbmsResult<TableRegistry>
568    where
569        T: TableSchema,
570    {
571        let sr = self.ctx.schema_registry.borrow();
572        let registry_pages = sr
573            .table_registry_page::<T>()
574            .ok_or(DbmsError::Table(TableError::TableNotFound))?;
575
576        let mut mm = self.ctx.mm.borrow_mut();
577        TableRegistry::load(registry_pages, &mut *mm).map_err(DbmsError::from)
578    }
579
580    /// Sorts query results by a column.
581    fn sort_query_results(
582        &self,
583        results: &mut [TableColumns],
584        column: &str,
585        direction: OrderDirection,
586    ) {
587        results.sort_by(|a, b| {
588            fn get_value<'a>(
589                values: &'a [(ValuesSource, Vec<(ColumnDef, Value)>)],
590                column: &str,
591            ) -> Option<&'a Value> {
592                values
593                    .iter()
594                    .find(|(source, _)| *source == ValuesSource::This)
595                    .and_then(|(_, cols)| {
596                        cols.iter()
597                            .find(|(col_def, _)| col_def.name == column)
598                            .map(|(_, value)| value)
599                    })
600            }
601
602            let a_value = get_value(a, column);
603            let b_value = get_value(b, column);
604
605            sort_values_with_direction(a_value, b_value, direction)
606        });
607    }
608
609    fn execute_index_plan<MA>(
610        &self,
611        reader: &IndexReader<'_>,
612        plan: &IndexPlan,
613        mm: &mut MA,
614    ) -> DbmsResult<IndexSearchResult>
615    where
616        MA: MemoryAccess,
617    {
618        let columns = [plan.column()];
619        match plan {
620            IndexPlan::Eq { value, .. } => {
621                let key = [value.clone()];
622                reader
623                    .search_eq(&columns, &key, mm)
624                    .map_err(DbmsError::from)
625            }
626            IndexPlan::Range { start, end, .. } => {
627                let start_key = start.as_ref().map(|value| vec![value.clone()]);
628                let end_key = end.as_ref().map(|value| vec![value.clone()]);
629                reader
630                    .search_range(&columns, start_key.as_deref(), end_key.as_deref(), mm)
631                    .map_err(DbmsError::from)
632            }
633            IndexPlan::In { values, .. } => {
634                let keys: Vec<Vec<Value>> =
635                    values.iter().cloned().map(|value| vec![value]).collect();
636                reader
637                    .search_in(&columns, &keys, mm)
638                    .map_err(DbmsError::from)
639            }
640        }
641    }
642
643    #[expect(
644        clippy::type_complexity,
645        reason = "complex return type is necessary for returning addresses and overlay PKs"
646    )]
647    fn try_index_select<T>(
648        &self,
649        query: &Query,
650        table_registry: &TableRegistry,
651        table_overlay: &DatabaseOverlay,
652    ) -> DbmsResult<Option<Vec<Vec<(ColumnDef, Value)>>>>
653    where
654        T: TableSchema,
655    {
656        let Some(filter) = &query.filter else {
657            return Ok(None);
658        };
659
660        let Some(analyzed) = analyze_filter(filter, T::indexes()) else {
661            return Ok(None);
662        };
663
664        let mut mm = self.ctx.mm.borrow_mut();
665        let reader = IndexReader::new(
666            table_registry.index_ledger(),
667            table_overlay.index_overlay(T::table_name()),
668        );
669        let search_result = self.execute_index_plan(&reader, &analyzed.plan, &mut *mm)?;
670
671        let mut indexed_rows = Vec::new();
672        let pk_name = T::primary_key();
673
674        for address in &search_result.addresses {
675            let record: T = table_registry
676                .read_at(*address, &mut *mm)
677                .map_err(DbmsError::from)?;
678            let values = record.to_values();
679            let Some(pk) = values
680                .iter()
681                .find(|(column, _)| column.name == pk_name)
682                .map(|(_, value)| value)
683            else {
684                continue;
685            };
686
687            if search_result.removed_pks.contains(pk) || search_result.overlay_pks.contains(pk) {
688                continue;
689            }
690
691            if let Some(remaining_filter) = &analyzed.remaining_filter
692                && !self.record_matches_filter(&values, remaining_filter)?
693            {
694                continue;
695            }
696
697            indexed_rows.push(values);
698        }
699
700        if let Some(overlay) = table_overlay.table_overlay(T::table_name()) {
701            let mut pending_overlay_pks = search_result.overlay_pks.clone();
702
703            for row in overlay.iter_inserted() {
704                let Some(pk) = row
705                    .iter()
706                    .find(|(column, _)| column.name == pk_name)
707                    .map(|(_, value)| value)
708                else {
709                    continue;
710                };
711
712                if !pending_overlay_pks.remove(pk) {
713                    continue;
714                }
715                if let Some(remaining_filter) = &analyzed.remaining_filter
716                    && !self.record_matches_filter(&row, remaining_filter)?
717                {
718                    continue;
719                }
720
721                indexed_rows.push(row);
722            }
723
724            if !pending_overlay_pks.is_empty() {
725                let pk_reader = IndexReader::new(table_registry.index_ledger(), None);
726                let pk_columns = [T::primary_key()];
727
728                for pk in pending_overlay_pks {
729                    let pk_key = [pk];
730                    let pk_lookup = pk_reader.search_eq(&pk_columns, &pk_key, &mut *mm)?;
731                    for address in pk_lookup.addresses {
732                        let record: T = table_registry
733                            .read_at(address, &mut *mm)
734                            .map_err(DbmsError::from)?;
735                        let values = record.to_values();
736                        let Some(patched_values) = overlay.patch_row(values) else {
737                            continue;
738                        };
739
740                        if let Some(remaining_filter) = &analyzed.remaining_filter
741                            && !self.record_matches_filter(&patched_values, remaining_filter)?
742                        {
743                            continue;
744                        }
745
746                        indexed_rows.push(patched_values);
747                    }
748                }
749            }
750        }
751
752        Ok(Some(indexed_rows))
753    }
754
755    /// Core select logic returning intermediate `TableColumns`.
756    #[doc(hidden)]
757    pub fn select_columns<T>(&self, query: Query) -> DbmsResult<Vec<TableColumns>>
758    where
759        T: TableSchema,
760    {
761        reject_aggregate_clauses(&query)?;
762        let table_registry = self.load_table_registry::<T>()?;
763        let mut table_overlay = if self.transaction.is_some() {
764            self.overlay()?
765        } else {
766            DatabaseOverlay::default()
767        };
768
769        let mut results = Vec::with_capacity(query.limit.unwrap_or(DEFAULT_SELECT_CAPACITY));
770        // When ORDER BY or DISTINCT is present, LIMIT and OFFSET must be applied after
771        // sorting and deduplication to comply with standard SQL semantics
772        // (WHERE -> DISTINCT -> ORDER BY -> OFFSET -> LIMIT).
773        let has_order_by = !query.order_by.is_empty();
774        let has_distinct = !query.distinct_by.is_empty();
775        let defer_pagination = has_order_by || has_distinct;
776        let mut count = 0;
777
778        if let Some(indexed_rows) =
779            self.try_index_select::<T>(&query, &table_registry, &table_overlay)?
780        {
781            for values in indexed_rows {
782                if !defer_pagination {
783                    count += 1;
784                    if query.offset.is_some_and(|offset| count <= offset) {
785                        continue;
786                    }
787                }
788                results.push(vec![(ValuesSource::This, values)]);
789                if !defer_pagination && query.limit.is_some_and(|limit| results.len() >= limit) {
790                    break;
791                }
792            }
793        } else {
794            let mut mm = self.ctx.mm.borrow_mut();
795            let table_reader = table_registry.read::<T, _>(&mut *mm);
796            let mut table_reader = table_overlay.reader(table_reader);
797
798            while let Some(values) = table_reader.try_next()? {
799                if let Some(filter) = &query.filter
800                    && !self.record_matches_filter(&values, filter)?
801                {
802                    continue;
803                }
804                if !defer_pagination {
805                    count += 1;
806                    if query.offset.is_some_and(|offset| count <= offset) {
807                        continue;
808                    }
809                }
810                results.push(vec![(ValuesSource::This, values)]);
811                if !defer_pagination && query.limit.is_some_and(|limit| results.len() >= limit) {
812                    break;
813                }
814            }
815        }
816
817        self.apply_distinct(&mut results, &query.distinct_by);
818        self.batch_load_eager_relations::<T>(&mut results, &query)?;
819        self.apply_column_selection::<T>(&mut results, &query);
820
821        for (column, direction) in query.order_by.into_iter().rev() {
822            self.sort_query_results(&mut results, &column, direction);
823        }
824
825        // Apply OFFSET and LIMIT after sorting/deduplication when deferred
826        if defer_pagination {
827            let offset = query.offset.unwrap_or_default();
828            if offset > 0 {
829                if offset >= results.len() {
830                    results.clear();
831                } else {
832                    results.drain(..offset);
833                }
834            }
835            if let Some(limit) = query.limit {
836                results.truncate(limit);
837            }
838        }
839
840        Ok(results)
841    }
842
843    /// Executes a join query.
844    fn select_join_inner(
845        &self,
846        table: &str,
847        query: Query,
848    ) -> DbmsResult<Vec<Vec<(JoinColumnDef, Value)>>> {
849        reject_aggregate_clauses(&query)?;
850        self.schema.select_join(self, table, query)
851    }
852
853    /// Updates primary key references in tables referencing the updated table.
854    fn update_pk_referencing_updated_table<T>(
855        &self,
856        old_pk: Value,
857        new_pk: Value,
858        data_type: DataTypeKind,
859        pk_name: &'static str,
860    ) -> DbmsResult<u64>
861    where
862        T: TableSchema,
863    {
864        let mut count = 0;
865        for (ref_table, ref_col) in self
866            .schema
867            .referenced_tables(T::table_name())
868            .into_iter()
869            .flat_map(|(ref_table, ref_cols)| {
870                ref_cols
871                    .into_iter()
872                    .map(move |ref_col| (ref_table, ref_col))
873            })
874        {
875            let ref_patch_value = (
876                ColumnDef {
877                    name: ref_col,
878                    data_type,
879                    auto_increment: false,
880                    nullable: false,
881                    primary_key: false,
882                    unique: false,
883                    foreign_key: Some(ForeignKeyDef {
884                        foreign_table: T::table_name(),
885                        foreign_column: pk_name,
886                        local_column: ref_col,
887                    }),
888                    default: None,
889                    renamed_from: &[],
890                },
891                new_pk.clone(),
892            );
893            let filter = Filter::eq(ref_col, old_pk.clone());
894
895            count += self
896                .schema
897                .update(self, ref_table, &[ref_patch_value], Some(filter))?;
898        }
899
900        Ok(count)
901    }
902
903    /// Sanitizes values using the table schema's sanitizers.
904    fn sanitize_values<T>(
905        &self,
906        values: Vec<(ColumnDef, Value)>,
907    ) -> DbmsResult<Vec<(ColumnDef, Value)>>
908    where
909        T: TableSchema,
910    {
911        let mut sanitized_values = Vec::with_capacity(values.len());
912        for (col_def, value) in values.into_iter() {
913            let value = match T::sanitizer(col_def.name) {
914                Some(sanitizer) => sanitizer.sanitize(value)?,
915                None => value,
916            };
917            sanitized_values.push((col_def, value));
918        }
919        Ok(sanitized_values)
920    }
921
922    /// Collects all records matching a filter from the table registry.
923    #[allow(clippy::type_complexity)]
924    fn collect_matching_records<T>(
925        &self,
926        table_registry: &TableRegistry,
927        filter: &Option<Filter>,
928    ) -> DbmsResult<Vec<(NextRecord<T>, Vec<(ColumnDef, Value)>)>>
929    where
930        T: TableSchema,
931    {
932        let mut mm = self.ctx.mm.borrow_mut();
933
934        // `collect_matching_records` is only used by the non-transactional update/delete paths.
935        // Transactional mutations first resolve rows via `existing_rows_for_filter`, which reads
936        // through `select()` and therefore includes the overlay. Using `overlay = None` here is
937        // intentional because the atomic write path is operating on committed storage only.
938        if let Some(filter) = filter
939            && let Some(analyzed) = analyze_filter(filter, T::indexes())
940        {
941            let reader = IndexReader::new(table_registry.index_ledger(), None);
942            let search_result = self.execute_index_plan(&reader, &analyzed.plan, &mut *mm)?;
943
944            let mut records = Vec::new();
945            for address in search_result.addresses {
946                let record: T = table_registry
947                    .read_at(address, &mut *mm)
948                    .map_err(DbmsError::from)?;
949                let record_values = record.clone().to_values();
950                if let Some(remaining_filter) = &analyzed.remaining_filter
951                    && !self.record_matches_filter(&record_values, remaining_filter)?
952                {
953                    continue;
954                }
955                records.push((
956                    NextRecord {
957                        record,
958                        page: address.page,
959                        offset: address.offset,
960                    },
961                    record_values,
962                ));
963            }
964
965            return Ok(records);
966        }
967
968        let mut table_reader = table_registry.read::<T, _>(&mut *mm);
969        let mut records = vec![];
970        while let Some(values) = table_reader.try_next()? {
971            let record_values = values.record.clone().to_values();
972            if let Some(filter) = filter
973                && !self.record_matches_filter(&record_values, filter)?
974            {
975                continue;
976            }
977            records.push((values, record_values));
978        }
979        Ok(records)
980    }
981
982    /// For each indexed column for the table, inserts the index for the given record address.
983    fn insert_index<T>(
984        &self,
985        table_registry: &mut TableRegistry,
986        record_address: RecordAddress,
987        values: &[(ColumnDef, Value)],
988        mm: &mut impl wasm_dbms_memory::MemoryAccess,
989    ) -> DbmsResult<()>
990    where
991        T: TableSchema,
992    {
993        let index_ledger = table_registry.index_ledger_mut();
994        for columns in T::indexes().iter().map(|index| index.columns()) {
995            let key = index_key(columns, values);
996            index_ledger.insert(columns, key, record_address, mm)?;
997        }
998
999        Ok(())
1000    }
1001
1002    /// For each indexed column for the table, deletes the index for the given record address.
1003    fn delete_index<T>(
1004        &self,
1005        table_registry: &mut TableRegistry,
1006        record_address: RecordAddress,
1007        values: &[(ColumnDef, Value)],
1008        mm: &mut impl wasm_dbms_memory::MemoryAccess,
1009    ) -> DbmsResult<()>
1010    where
1011        T: TableSchema,
1012    {
1013        let index_ledger = table_registry.index_ledger_mut();
1014        for columns in T::indexes().iter().map(|index| index.columns()) {
1015            let key = index_key(columns, values);
1016            index_ledger.delete(columns, &key, record_address, mm)?;
1017        }
1018        Ok(())
1019    }
1020
1021    /// For each indexed column for the table, updates the index for the given record address.
1022    ///
1023    /// When an indexed column's value changed, the old key is deleted and the new key is inserted.
1024    /// When only the record address moved (same key), the pointer is updated in place.
1025    fn update_index<T>(
1026        &self,
1027        table_registry: &mut TableRegistry,
1028        old_record_address: RecordAddress,
1029        new_record_address: RecordAddress,
1030        old_values: &[(ColumnDef, Value)],
1031        new_values: &[(ColumnDef, Value)],
1032        mm: &mut impl wasm_dbms_memory::MemoryAccess,
1033    ) -> DbmsResult<()>
1034    where
1035        T: TableSchema,
1036    {
1037        let index_ledger = table_registry.index_ledger_mut();
1038        for columns in T::indexes().iter().map(|index| index.columns()) {
1039            let old_key = index_key(columns, old_values);
1040            let new_key = index_key(columns, new_values);
1041            if old_key == new_key {
1042                index_ledger.update(
1043                    columns,
1044                    &new_key,
1045                    old_record_address,
1046                    new_record_address,
1047                    mm,
1048                )?;
1049            } else {
1050                index_ledger.delete(columns, &old_key, old_record_address, mm)?;
1051                index_ledger.insert(columns, new_key, new_record_address, mm)?;
1052            }
1053        }
1054        Ok(())
1055    }
1056
1057    /// Fills in auto-increment values for columns that are missing from the input.
1058    fn fill_auto_increment_values<T>(
1059        &self,
1060        table_registry: &mut TableRegistry,
1061        mut values: Vec<(ColumnDef, Value)>,
1062    ) -> DbmsResult<Vec<(ColumnDef, Value)>>
1063    where
1064        T: TableSchema,
1065    {
1066        let mut mm = self.ctx.mm.borrow_mut();
1067        // iter over auto-increment columns, for each of them check if the value is provided, if not get the next auto-increment value.
1068        for auto_increment_column in T::columns().iter().filter(|col| col.auto_increment) {
1069            if values
1070                .iter()
1071                .any(|(col_def, _)| col_def.name == auto_increment_column.name)
1072            {
1073                continue;
1074            }
1075            let next_value = table_registry
1076                .next_autoincrement(auto_increment_column.name, &mut *mm)?
1077                .ok_or(DbmsError::Table(TableError::SchemaMismatch))?;
1078            values.push((*auto_increment_column, next_value));
1079        }
1080
1081        Ok(values)
1082    }
1083}
1084
1085/// Rejects queries that carry `GROUP BY` or `HAVING` on a non-aggregate path.
1086///
1087/// `GROUP BY` and `HAVING` only have meaning under
1088/// [`Database::aggregate`](wasm_dbms_api::prelude::Database::aggregate). When
1089/// they appear on a regular `select`, `select_raw`, or `select_join`, returning
1090/// rows would silently drop the user's grouping intent — better to fail loudly
1091/// and steer the caller to the right method.
1092fn reject_aggregate_clauses(query: &Query) -> DbmsResult<()> {
1093    if !query.group_by.is_empty() || query.having.is_some() {
1094        return Err(DbmsError::Query(QueryError::AggregateClauseInSelect));
1095    }
1096    Ok(())
1097}
1098
1099/// Provides ordering for two optional values by direction.
1100pub fn sort_values_with_direction(
1101    a: Option<&Value>,
1102    b: Option<&Value>,
1103    direction: OrderDirection,
1104) -> Ordering {
1105    match (a, b) {
1106        (Some(a_val), Some(b_val)) => match direction {
1107            OrderDirection::Ascending => a_val.cmp(b_val),
1108            OrderDirection::Descending => b_val.cmp(a_val),
1109        },
1110        (Some(_), None) => std::cmp::Ordering::Greater,
1111        (None, Some(_)) => std::cmp::Ordering::Less,
1112        (None, None) => std::cmp::Ordering::Equal,
1113    }
1114}
1115
1116/// Converts column-value pairs to a schema entity.
1117fn values_to_schema_entity<T>(values: Vec<(ColumnDef, Value)>) -> DbmsResult<T>
1118where
1119    T: TableSchema,
1120{
1121    let record = T::Insert::from_values(&values)?.into_record();
1122    Ok(record)
1123}
1124
1125/// Builds the index key for the given columns by extracting values from the record.
1126///
1127/// Columns not found in `values` default to [`Value::Null`].
1128fn index_key(columns: &[&str], values: &[(ColumnDef, Value)]) -> Vec<Value> {
1129    columns
1130        .iter()
1131        .map(|col| {
1132            values
1133                .iter()
1134                .find(|(cd, _)| cd.name == *col)
1135                .map(|(_, v)| v.clone())
1136                .unwrap_or(Value::Null)
1137        })
1138        .collect()
1139}
1140
1141impl<M, A> Database for WasmDbmsDatabase<'_, M, A>
1142where
1143    M: MemoryProvider,
1144    A: AccessControl,
1145{
1146    fn select<T>(&self, query: Query) -> DbmsResult<Vec<T::Record>>
1147    where
1148        T: TableSchema,
1149    {
1150        self.ensure_no_drift()?;
1151        if !query.joins.is_empty() {
1152            return Err(DbmsError::Query(QueryError::JoinInsideTypedSelect));
1153        }
1154        let results = self.select_columns::<T>(query)?;
1155        Ok(results.into_iter().map(T::Record::from_values).collect())
1156    }
1157
1158    fn select_raw(&self, table: &str, query: Query) -> DbmsResult<Vec<Vec<(ColumnDef, Value)>>> {
1159        self.ensure_no_drift()?;
1160        self.schema.select(self, table, query)
1161    }
1162
1163    fn select_join(
1164        &self,
1165        table: &str,
1166        query: Query,
1167    ) -> DbmsResult<Vec<Vec<(JoinColumnDef, Value)>>> {
1168        self.ensure_no_drift()?;
1169        self.select_join_inner(table, query)
1170    }
1171
1172    fn aggregate<T>(
1173        &self,
1174        query: Query,
1175        aggregates: &[AggregateFunction],
1176    ) -> DbmsResult<Vec<AggregatedRow>>
1177    where
1178        T: TableSchema,
1179    {
1180        self.ensure_no_drift()?;
1181        aggregate::run_aggregate::<T, _, _>(self, query, aggregates)
1182    }
1183
1184    fn insert<T>(&self, record: T::Insert) -> DbmsResult<()>
1185    where
1186        T: TableSchema,
1187        T::Insert: InsertRecord<Schema = T>,
1188    {
1189        self.ensure_no_drift()?;
1190        let mut table_registry = self.load_table_registry::<T>()?;
1191        let record_values = record.clone().into_values();
1192        let record_values =
1193            self.fill_auto_increment_values::<T>(&mut table_registry, record_values)?;
1194        let sanitized_values = self.sanitize_values::<T>(record_values)?;
1195        self.schema
1196            .validate_insert(self, T::table_name(), &sanitized_values)?;
1197        if self.transaction.is_some() {
1198            self.with_transaction_mut(|tx| tx.insert::<T>(sanitized_values))?;
1199        } else {
1200            self.atomic(|db| {
1201                let record = T::Insert::from_values(&sanitized_values)?;
1202                let mut mm = db.ctx.mm.borrow_mut();
1203                // update journal with the insert operation before mutating memory
1204                let mut journal_ref = db.ctx.journal.borrow_mut();
1205                let journal = journal_ref
1206                    .as_mut()
1207                    .expect("journal must be active inside atomic");
1208                let mut writer = JournaledWriter::new(&mut *mm, journal);
1209                // insert the record in the table registry, and eventually update the indexes
1210                let record_address = table_registry
1211                    .insert(record.into_record(), &mut writer)
1212                    .map_err(DbmsError::from)?;
1213                self.insert_index::<T>(
1214                    &mut table_registry,
1215                    record_address,
1216                    &sanitized_values,
1217                    &mut writer,
1218                )?;
1219                Ok(())
1220            })?;
1221        }
1222
1223        Ok(())
1224    }
1225
1226    fn update<T>(&self, patch: T::Update) -> DbmsResult<u64>
1227    where
1228        T: TableSchema,
1229        T::Update: UpdateRecord<Schema = T>,
1230    {
1231        self.ensure_no_drift()?;
1232        let filter = patch.where_clause().clone();
1233        if self.transaction.is_some() {
1234            let rows = self.existing_rows_for_filter::<T>(filter.clone())?;
1235            let count = rows.len() as u64;
1236            self.with_transaction_mut(|tx| tx.update::<T>(patch, filter, rows))?;
1237
1238            return Ok(count);
1239        }
1240
1241        let patch = patch.update_values();
1242
1243        let pk_in_patch = patch.iter().find_map(|(col_def, value)| {
1244            if col_def.primary_key {
1245                Some((col_def, value))
1246            } else {
1247                None
1248            }
1249        });
1250
1251        self.atomic(|db| {
1252            let mut count = 0;
1253
1254            let mut table_registry = db.load_table_registry::<T>()?;
1255            let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
1256
1257            for (record, record_values) in records {
1258                let current_pk_value = record_values
1259                    .iter()
1260                    .find(|(col_def, _)| col_def.primary_key)
1261                    .expect("primary key not found")
1262                    .1
1263                    .clone();
1264
1265                let previous_record = values_to_schema_entity::<T>(record_values.clone())?;
1266                let old_values_for_index = record_values.clone();
1267                let mut record_values = record_values;
1268
1269                for (patch_col_def, patch_value) in &patch {
1270                    if let Some((_, record_value)) = record_values
1271                        .iter_mut()
1272                        .find(|(record_col_def, _)| record_col_def.name == patch_col_def.name)
1273                    {
1274                        *record_value = patch_value.clone();
1275                    }
1276                }
1277                let record_values = db.sanitize_values::<T>(record_values)?;
1278                db.schema.validate_update(
1279                    db,
1280                    T::table_name(),
1281                    &record_values,
1282                    current_pk_value.clone(),
1283                )?;
1284                let updated_record = values_to_schema_entity::<T>(record_values.clone())?;
1285                {
1286                    let mut mm = db.ctx.mm.borrow_mut();
1287                    // update journal with the update operation before mutating memory
1288                    let mut journal_ref = db.ctx.journal.borrow_mut();
1289                    let journal = journal_ref
1290                        .as_mut()
1291                        .expect("journal must be active inside atomic");
1292                    let mut writer = JournaledWriter::new(&mut *mm, journal);
1293                    // update table registry
1294                    let old_address = RecordAddress::new(record.page, record.offset);
1295                    let new_address = table_registry
1296                        .update(updated_record, previous_record, old_address, &mut writer)
1297                        .map_err(DbmsError::from)?;
1298                    // update indexes if needed
1299                    self.update_index::<T>(
1300                        &mut table_registry,
1301                        old_address,
1302                        new_address,
1303                        &old_values_for_index,
1304                        &record_values,
1305                        &mut writer,
1306                    )?;
1307                }
1308                count += 1;
1309
1310                if let Some((pk_column, new_pk_value)) = pk_in_patch {
1311                    count += db.update_pk_referencing_updated_table::<T>(
1312                        current_pk_value,
1313                        new_pk_value.clone(),
1314                        pk_column.data_type,
1315                        pk_column.name,
1316                    )?;
1317                }
1318            }
1319
1320            Ok(count)
1321        })
1322    }
1323
1324    fn delete<T>(&self, behaviour: DeleteBehavior, filter: Option<Filter>) -> DbmsResult<u64>
1325    where
1326        T: TableSchema,
1327    {
1328        self.ensure_no_drift()?;
1329        if self.transaction.is_some() {
1330            let rows = self.existing_rows_for_filter::<T>(filter.clone())?;
1331            let count = rows.len() as u64;
1332
1333            self.with_transaction_mut(|tx| tx.delete::<T>(behaviour, filter, rows))?;
1334
1335            return Ok(count);
1336        }
1337
1338        self.atomic(|db| {
1339            let mut table_registry = db.load_table_registry::<T>()?;
1340            let records = db.collect_matching_records::<T>(&table_registry, &filter)?;
1341            let mut count = records.len() as u64;
1342            for (record, record_values) in records {
1343                match behaviour {
1344                    DeleteBehavior::Cascade => {
1345                        count += db.delete_foreign_keys_cascade::<T>(&record_values)?;
1346                    }
1347                    DeleteBehavior::Restrict => {
1348                        if db.has_foreign_key_references::<T>(&record_values)? {
1349                            return Err(DbmsError::Query(
1350                                QueryError::ForeignKeyConstraintViolation {
1351                                    referencing_table: T::table_name().to_string(),
1352                                    field: T::primary_key().to_string(),
1353                                },
1354                            ));
1355                        }
1356                    }
1357                }
1358                let mut mm = db.ctx.mm.borrow_mut();
1359                let mut journal_ref = db.ctx.journal.borrow_mut();
1360                let journal = journal_ref
1361                    .as_mut()
1362                    .expect("journal must be active inside atomic");
1363                // write table and index deletions to the journal before mutating memory
1364                let mut writer = JournaledWriter::new(&mut *mm, journal);
1365                let address = RecordAddress::new(record.page, record.offset);
1366                table_registry
1367                    .delete(record.record, address, &mut writer)
1368                    .map_err(DbmsError::from)?;
1369                self.delete_index::<T>(&mut table_registry, address, &record_values, &mut writer)?;
1370            }
1371
1372            Ok(count)
1373        })
1374    }
1375
1376    fn commit(&mut self) -> DbmsResult<()> {
1377        self.ensure_no_drift()?;
1378        let Some(txid) = self.transaction.take() else {
1379            return Err(DbmsError::Transaction(
1380                TransactionError::NoActiveTransaction,
1381            ));
1382        };
1383        let transaction = {
1384            let mut ts = self.ctx.transaction_session.borrow_mut();
1385            ts.take_transaction(&txid)?
1386        };
1387
1388        *self.ctx.journal.borrow_mut() = Some(Journal::new());
1389
1390        for op in transaction.operations {
1391            let result = match op {
1392                TransactionOp::Insert { table, values } => self
1393                    .schema
1394                    .validate_insert(self, table, &values)
1395                    .and_then(|()| self.schema.insert(self, table, &values)),
1396                TransactionOp::Delete {
1397                    table,
1398                    behaviour,
1399                    filter,
1400                } => self
1401                    .schema
1402                    .delete(self, table, behaviour, filter)
1403                    .map(|_| ()),
1404                TransactionOp::Update {
1405                    table,
1406                    patch,
1407                    filter,
1408                } => self.schema.update(self, table, &patch, filter).map(|_| ()),
1409            };
1410
1411            if let Err(err) = result {
1412                if let Some(journal) = self.ctx.journal.borrow_mut().take() {
1413                    journal
1414                        .rollback(&mut self.ctx.mm.borrow_mut())
1415                        .expect("critical: failed to rollback journal");
1416                }
1417                return Err(err);
1418            }
1419        }
1420
1421        if let Some(journal) = self.ctx.journal.borrow_mut().take() {
1422            journal.commit();
1423        }
1424        Ok(())
1425    }
1426
1427    fn rollback(&mut self) -> DbmsResult<()> {
1428        let Some(txid) = self.transaction.take() else {
1429            return Err(DbmsError::Transaction(
1430                TransactionError::NoActiveTransaction,
1431            ));
1432        };
1433
1434        let mut ts = self.ctx.transaction_session.borrow_mut();
1435        ts.close_transaction(&txid);
1436        Ok(())
1437    }
1438
1439    fn has_drift(&self) -> DbmsResult<bool> {
1440        self.drift_check()
1441    }
1442
1443    fn pending_migrations(&self) -> DbmsResult<Vec<MigrationOp>> {
1444        let stored = {
1445            let sr = self.ctx.schema_registry.borrow();
1446            let mut mm = self.ctx.mm.borrow_mut();
1447            sr.stored_snapshots(&mut *mm)?
1448        };
1449        let compiled = self.schema.compiled_snapshots_dyn();
1450        migration::diff::diff(&stored, &compiled, self.schema.as_ref())
1451    }
1452
1453    fn migrate(&mut self, policy: MigrationPolicy) -> DbmsResult<()> {
1454        let mut ops = self.pending_migrations()?;
1455        migration::plan::validate(&ops, policy)?;
1456        migration::plan::order_ops(&mut ops);
1457        migration::apply::apply(self, ops)
1458    }
1459}
1460
1461#[cfg(test)]
1462mod tests;