ic_dbms_canister/
dbms.rs

1//! This module exposes all the types related to the DBMS engine.
2
3pub mod integrity;
4pub mod referenced_tables;
5pub mod schema;
6pub mod transaction;
7
8use ic_dbms_api::prelude::{
9    ColumnDef, Database, DeleteBehavior, Filter, ForeignFetcher, IcDbmsError, IcDbmsResult,
10    InsertRecord, OrderDirection, Query, QueryError, TableColumns, TableError, TableRecord,
11    TableSchema, TransactionError, TransactionId, UpdateRecord, Value, ValuesSource,
12};
13
14use crate::dbms::transaction::{DatabaseOverlay, Transaction, TransactionOp};
15use crate::memory::{SCHEMA_REGISTRY, TableRegistry};
16use crate::prelude::{DatabaseSchema, TRANSACTION_SESSION};
17use crate::utils::trap;
18
19/// Default capacity limit for SELECT queries.
20const DEFAULT_SELECT_LIMIT: usize = 128;
21
22/// The main DBMS struct.
23///
24/// This struct serves as the entry point for interacting with the DBMS engine.
25///
26/// It provides methods for executing queries.
27///
28/// - [`Database::select`] - Execute a SELECT query.
29/// - [`Database::insert`] - Execute an INSERT query.
30/// - [`Database::update`] - Execute an UPDATE query.
31/// - [`Database::delete`] - Execute a DELETE query.
32/// - [`Database::commit`] - Commit the current transaction.
33/// - [`Database::rollback`] - Rollback the current transaction.
34///
35/// The `transaction` field indicates whether the instance is operating within a transaction context.
36/// The [`Database`] can be instantiated for one-shot, with [`Database::oneshot`] operations (no transaction),
37/// or within a transaction context with [`Database::from_transaction`].
38///
39/// If a transaction is active, all operations will be part of that transaction until it is committed or rolled back.
40pub struct IcDbmsDatabase {
41    /// Database schema to perform generic operations, without knowing the concrete table schema at compile time.
42    schema: Box<dyn DatabaseSchema>,
43    /// Id of the loaded transaction, if any.
44    transaction: Option<TransactionId>,
45}
46
47impl IcDbmsDatabase {
48    /// Load an instance of the [`Database`] for one-shot operations (no transaction).
49    pub fn oneshot(schema: impl DatabaseSchema + 'static) -> Self {
50        Self {
51            schema: Box::new(schema),
52            transaction: None,
53        }
54    }
55
56    /// Load an instance of the [`Database`] within a transaction context.
57    pub fn from_transaction(
58        schema: impl DatabaseSchema + 'static,
59        transaction_id: TransactionId,
60    ) -> Self {
61        Self {
62            schema: Box::new(schema),
63            transaction: Some(transaction_id),
64        }
65    }
66
67    /// Executes a closure with a mutable reference to the current [`Transaction`].
68    fn with_transaction_mut<F, R>(&self, f: F) -> IcDbmsResult<R>
69    where
70        F: FnOnce(&mut Transaction) -> IcDbmsResult<R>,
71    {
72        let txid = self.transaction.as_ref().ok_or(IcDbmsError::Transaction(
73            TransactionError::NoActiveTransaction,
74        ))?;
75
76        TRANSACTION_SESSION.with_borrow_mut(|ts| {
77            let tx = ts.get_transaction_mut(txid)?;
78            f(tx)
79        })
80    }
81
82    /// Executes a closure with a reference to the current [`Transaction`].
83    fn with_transaction<F, R>(&self, f: F) -> IcDbmsResult<R>
84    where
85        F: FnOnce(&Transaction) -> IcDbmsResult<R>,
86    {
87        let txid = self.transaction.as_ref().ok_or(IcDbmsError::Transaction(
88            TransactionError::NoActiveTransaction,
89        ))?;
90
91        TRANSACTION_SESSION.with_borrow_mut(|ts| {
92            let tx = ts.get_transaction_mut(txid)?;
93            f(tx)
94        })
95    }
96
97    /// Executes a closure atomically within the database context.
98    ///
99    /// If the closure returns an error, the changes are rolled back by trapping the canister.
100    fn atomic<F, R>(&self, f: F) -> R
101    where
102        F: FnOnce(&IcDbmsDatabase) -> IcDbmsResult<R>,
103    {
104        match f(self) {
105            Ok(res) => res,
106            Err(err) => trap(err.to_string()),
107        }
108    }
109
110    /// Deletes foreign key related records recursively if the delete behavior is [`DeleteBehavior::Cascade`].
111    fn delete_foreign_keys_cascade<T>(
112        &self,
113        record_values: &[(ColumnDef, Value)],
114    ) -> IcDbmsResult<u64>
115    where
116        T: TableSchema,
117    {
118        let mut count = 0;
119        // verify referenced tables for foreign key constraints
120        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
121            for column in columns.iter() {
122                // prepare filter
123                let pk = record_values
124                    .iter()
125                    .find(|(col_def, _)| col_def.primary_key)
126                    .ok_or(IcDbmsError::Query(QueryError::UnknownColumn(
127                        column.to_string(),
128                    )))?
129                    .1
130                    .clone();
131                // make filter to find records in the referenced table
132                let filter = Filter::eq(column, pk);
133                let res = self
134                    .schema
135                    .delete(self, table, DeleteBehavior::Cascade, Some(filter))?;
136                count += res;
137            }
138        }
139        Ok(count)
140    }
141
142    /// Retrieves the current [`DatabaseOverlay`].
143    fn overlay(&self) -> IcDbmsResult<DatabaseOverlay> {
144        self.with_transaction(|tx| Ok(tx.overlay().clone()))
145    }
146
147    /// Returns whether the read given record matches the provided filter.
148    fn record_matches_filter(
149        &self,
150        record_values: &[(ColumnDef, Value)],
151        filter: &Filter,
152    ) -> IcDbmsResult<bool> {
153        filter.matches(record_values).map_err(IcDbmsError::from)
154    }
155
156    /// Select only the queried fields from the given record values.
157    ///
158    /// It also loads eager relations if any.
159    fn select_queried_fields<T>(
160        &self,
161        mut record_values: Vec<(ColumnDef, Value)>,
162        query: &Query<T>,
163    ) -> IcDbmsResult<TableColumns>
164    where
165        T: TableSchema,
166    {
167        let mut queried_fields = vec![];
168
169        // handle eager relations
170        // FIXME: currently we fetch the FK for each record, which is shit.
171        // In the future, we should batch fetch foreign keys for all records in the result set.
172        for relation in &query.eager_relations {
173            let mut fetched = false;
174            // iter all foreign key with that table
175            for (fk, fk_value) in record_values
176                .iter()
177                .filter(|(col_def, _)| {
178                    col_def
179                        .foreign_key
180                        .is_some_and(|fk| fk.foreign_table == *relation)
181                })
182                .map(|(col, value)| {
183                    (
184                        col.foreign_key.as_ref().expect("cannot be empty"),
185                        value.clone(),
186                    )
187                })
188            {
189                // get foreign values
190                queried_fields.extend(T::foreign_fetcher().fetch(
191                    self,
192                    relation,
193                    fk.local_column,
194                    fk_value,
195                )?);
196                fetched = true;
197            }
198
199            if !fetched {
200                return Err(IcDbmsError::Query(QueryError::InvalidQuery(format!(
201                    "Cannot load relation '{}' for table '{}': no foreign key found",
202                    relation,
203                    T::table_name()
204                ))));
205            }
206        }
207
208        // short-circuit if all selected
209        if query.all_selected() {
210            queried_fields.extend(vec![(ValuesSource::This, record_values)]);
211            return Ok(queried_fields);
212        }
213        record_values.retain(|(col_def, _)| query.columns().contains(&col_def.name.to_string()));
214        queried_fields.extend(vec![(ValuesSource::This, record_values)]);
215        Ok(queried_fields)
216    }
217
218    /// Retrieves existing primary keys for records matching the given filter.
219    fn existing_primary_keys_for_filter<T>(
220        &self,
221        filter: Option<Filter>,
222    ) -> IcDbmsResult<Vec<Value>>
223    where
224        T: TableSchema,
225    {
226        let pk = T::primary_key();
227        let fields = self.select(Query::<T>::builder().filter(filter).build())?;
228        let pks = fields
229            .into_iter()
230            .map(|record| {
231                record
232                    .to_values()
233                    .into_iter()
234                    .find(|(col_def, _value)| col_def.name == pk)
235                    .expect("primary key not found") // this can't fail.
236                    .1
237            })
238            .collect::<Vec<Value>>();
239
240        Ok(pks)
241    }
242
243    /// Load the table registry for the given table schema.
244    fn load_table_registry<T>(&self) -> IcDbmsResult<TableRegistry>
245    where
246        T: TableSchema,
247    {
248        // get pages of the table registry from schema registry
249        let registry_pages = SCHEMA_REGISTRY
250            .with_borrow(|schema| schema.table_registry_page::<T>())
251            .ok_or(IcDbmsError::Table(TableError::TableNotFound))?;
252
253        TableRegistry::load(registry_pages).map_err(IcDbmsError::from)
254    }
255
256    /// Sorts the query results based on the specified column and order direction.
257    ///
258    /// We only sort values which have [`ValuesSource::This`].
259    #[allow(clippy::type_complexity)]
260    fn sort_query_results(
261        &self,
262        results: &mut [Vec<(ValuesSource, Vec<(ColumnDef, Value)>)>],
263        column: &str,
264        direction: OrderDirection,
265    ) {
266        results.sort_by(|a, b| {
267            let a_value = a
268                .iter()
269                .find(|(source, _)| *source == ValuesSource::This)
270                .and_then(|(_, cols)| {
271                    cols.iter()
272                        .find(|(col_def, _)| col_def.name == column)
273                        .map(|(_, value)| value)
274                });
275            let b_value = b
276                .iter()
277                .find(|(source, _)| *source == ValuesSource::This)
278                .and_then(|(_, cols)| {
279                    cols.iter()
280                        .find(|(col_def, _)| col_def.name == column)
281                        .map(|(_, value)| value)
282                });
283
284            match (a_value, b_value) {
285                (Some(a_val), Some(b_val)) => match direction {
286                    OrderDirection::Ascending => a_val.cmp(b_val),
287                    OrderDirection::Descending => b_val.cmp(a_val),
288                },
289                (Some(_), None) => std::cmp::Ordering::Greater,
290                (None, Some(_)) => std::cmp::Ordering::Less,
291                (None, None) => std::cmp::Ordering::Equal,
292            }
293        });
294    }
295}
296
297impl Database for IcDbmsDatabase {
298    /// Executes a SELECT query and returns the results.
299    ///
300    /// # Arguments
301    ///
302    /// - `query` - The SELECT [`Query`] to be executed.
303    ///
304    /// # Returns
305    ///
306    /// The returned results are a vector of [`table::TableRecord`] matching the query.
307    fn select<T>(&self, query: Query<T>) -> IcDbmsResult<Vec<T::Record>>
308    where
309        T: TableSchema,
310    {
311        // load table registry
312        let table_registry = self.load_table_registry::<T>()?;
313        // read table
314        let table_reader = table_registry.read::<T>();
315        // get database overlay
316        let mut table_overlay = if self.transaction.is_some() {
317            self.overlay()?
318        } else {
319            DatabaseOverlay::default()
320        };
321        // overlay table reader
322        let mut table_reader = table_overlay.reader(table_reader);
323
324        // prepare results vector
325        let mut results = Vec::with_capacity(query.limit.unwrap_or(DEFAULT_SELECT_LIMIT));
326        // iter and select
327        let mut count = 0;
328
329        while let Some(values) = table_reader.try_next()? {
330            // check whether it matches the filter
331            if let Some(filter) = &query.filter {
332                if !self.record_matches_filter(&values, filter)? {
333                    continue;
334                }
335            }
336            // filter matched, check limit and offset
337            count += 1;
338            // check whether is before offset
339            if query.offset.is_some_and(|offset| count <= offset) {
340                continue;
341            }
342            // get queried fields
343            let values = self.select_queried_fields::<T>(values, &query)?;
344            // push to results
345            results.push(values);
346            // check whether reached limit
347            if query.limit.is_some_and(|limit| results.len() >= limit) {
348                break;
349            }
350        }
351
352        // sort results if needed and map to records
353        for (column, direction) in query.order_by {
354            self.sort_query_results(&mut results, &column, direction);
355        }
356
357        Ok(results.into_iter().map(T::Record::from_values).collect())
358    }
359
360    /// Executes an INSERT query.
361    ///
362    /// # Arguments
363    ///
364    /// - `record` - The INSERT record to be executed.
365    fn insert<T>(&self, record: T::Insert) -> IcDbmsResult<()>
366    where
367        T: TableSchema,
368        T::Insert: InsertRecord<Schema = T>,
369    {
370        // check whether the insert is valid
371        let record_values = record.clone().into_values();
372        self.schema
373            .validate_insert(self, T::table_name(), &record_values)?;
374
375        if self.transaction.is_some() {
376            // insert a new `insert` into the transaction
377            self.with_transaction_mut(|tx| tx.insert::<T>(record_values))?;
378        } else {
379            // insert directly into the database
380            let mut table_registry = self.load_table_registry::<T>()?;
381            table_registry.insert(record.into_record())?;
382        }
383
384        Ok(())
385    }
386
387    /// Executes an UPDATE query.
388    ///
389    /// # Arguments
390    ///
391    /// - `patch` - The UPDATE patch to be applied.
392    /// - `filter` - An optional [`Filter`] to specify which records to update.
393    ///
394    /// # Returns
395    ///
396    /// The number of rows updated.
397    fn update<T>(&self, patch: T::Update) -> IcDbmsResult<u64>
398    where
399        T: TableSchema,
400        T::Update: UpdateRecord<Schema = T>,
401    {
402        // get all records matching the filter
403        let query = Query::<T>::builder().filter(patch.where_clause()).build();
404        let records = self.select::<T>(query)?;
405        let count = records.len() as u64;
406
407        if self.transaction.is_some() {
408            let filter = patch.where_clause().clone();
409            let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
410            // insert a new `update` into the transaction
411            self.with_transaction_mut(|tx| tx.update::<T>(patch, filter, pks))?;
412
413            return Ok(count);
414        }
415
416        let patch = patch.update_values();
417        // convert updates to values
418        // for each record apply update; delete and insert
419        let res = self.atomic(|db| {
420            for record in records {
421                let mut record_values = record.to_values();
422                // apply patch
423                for (col_def, value) in &patch {
424                    if let Some((_, record_value)) = record_values
425                        .iter_mut()
426                        .find(|(record_col_def, _)| record_col_def.name == col_def.name)
427                    {
428                        *record_value = value.clone();
429                    }
430                }
431                // create insert record
432                let insert_record = T::Insert::from_values(&record_values)?;
433                // delete old record
434                let pk = record_values
435                    .iter()
436                    .find(|(col_def, _)| col_def.primary_key)
437                    .expect("primary key not found") // this can't fail.
438                    .1
439                    .clone();
440                db.delete::<T>(
441                    DeleteBehavior::Break, // we just want to delete the old record
442                    Some(Filter::eq(T::primary_key(), pk)),
443                )?;
444                // insert new record
445                db.insert::<T>(insert_record)?;
446            }
447            Ok(count)
448        });
449
450        Ok(res)
451    }
452
453    /// Executes a DELETE query.
454    ///
455    /// # Arguments
456    ///
457    /// - `behaviour` - The [`DeleteBehavior`] to apply for foreign key constraints.
458    /// - `filter` - An optional [`Filter`] to specify which records to delete.
459    ///
460    /// # Returns
461    ///
462    /// The number of rows deleted.
463    fn delete<T>(&self, behaviour: DeleteBehavior, filter: Option<Filter>) -> IcDbmsResult<u64>
464    where
465        T: TableSchema,
466    {
467        if self.transaction.is_some() {
468            let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
469            let count = pks.len() as u64;
470
471            self.with_transaction_mut(|tx| tx.delete::<T>(behaviour, filter, pks))?;
472
473            return Ok(count);
474        }
475
476        // delete must be atomic
477        let res = self.atomic(|db| {
478            // delete directly from the database
479            // select all records matching the filter
480            // read table
481            let mut table_registry = db.load_table_registry::<T>()?;
482            let mut records = vec![];
483            // iter all records
484            // FIXME: this may be huge, we should do better
485            {
486                let mut table_reader = table_registry.read::<T>();
487                while let Some(values) = table_reader.try_next()? {
488                    let record_values = values.record.clone().to_values();
489                    if let Some(filter) = &filter {
490                        if !db.record_matches_filter(&record_values, filter)? {
491                            continue;
492                        }
493                    }
494                    records.push((values, record_values));
495                }
496            }
497            // deleted records
498            let mut count = records.len() as u64;
499            for (record, record_values) in records {
500                // match delete behaviour
501                match behaviour {
502                    DeleteBehavior::Cascade => {
503                        // delete recursively foreign keys if cascade
504                        count += self.delete_foreign_keys_cascade::<T>(&record_values)?;
505                    }
506                    DeleteBehavior::Restrict => {
507                        if self.delete_foreign_keys_cascade::<T>(&record_values)? > 0 {
508                            // it's okay; we panic here because we are in an atomic closure
509                            return Err(IcDbmsError::Query(
510                                QueryError::ForeignKeyConstraintViolation {
511                                    referencing_table: T::table_name().to_string(),
512                                    field: T::primary_key().to_string(),
513                                },
514                            ));
515                        }
516                    }
517                    DeleteBehavior::Break => {
518                        // do nothing
519                    }
520                }
521                // eventually delete the record
522                table_registry.delete(record.record, record.page, record.offset)?;
523            }
524
525            Ok(count)
526        });
527
528        Ok(res)
529    }
530
531    /// Commits the current transaction.
532    ///
533    /// The transaction is consumed.
534    ///
535    /// Any error during commit will trap the canister to ensure consistency.
536    fn commit(&mut self) -> IcDbmsResult<()> {
537        // take transaction out of self and get the transaction out of the storage
538        // this also invalidates the overlay, so we won't have conflicts during validation
539        let Some(txid) = self.transaction.take() else {
540            return Err(IcDbmsError::Transaction(
541                TransactionError::NoActiveTransaction,
542            ));
543        };
544        let transaction = TRANSACTION_SESSION.with_borrow_mut(|ts| ts.take_transaction(&txid))?;
545
546        // iterate over operations and apply them;
547        // for each operation, first validate, then apply
548        // using `self.atomic` when applying to ensure consistency
549        for op in transaction.operations {
550            match op {
551                TransactionOp::Insert { table, values } => {
552                    // validate
553                    self.schema.validate_insert(self, table, &values)?;
554                    // insert
555                    self.atomic(|db| db.schema.insert(db, table, &values));
556                }
557                TransactionOp::Delete {
558                    table,
559                    behaviour,
560                    filter,
561                } => {
562                    self.atomic(|db| db.schema.delete(db, table, behaviour, filter));
563                }
564                TransactionOp::Update {
565                    table,
566                    patch,
567                    filter,
568                } => {
569                    self.atomic(|db| db.schema.update(db, table, &patch, filter));
570                }
571            }
572        }
573
574        Ok(())
575    }
576
577    /// Rolls back the current transaction.
578    ///
579    /// The transaction is consumed.
580    fn rollback(&mut self) -> IcDbmsResult<()> {
581        let Some(txid) = self.transaction.take() else {
582            return Err(IcDbmsError::Transaction(
583                TransactionError::NoActiveTransaction,
584            ));
585        };
586
587        TRANSACTION_SESSION.with_borrow_mut(|ts| ts.close_transaction(&txid));
588        Ok(())
589    }
590}
591
592#[cfg(test)]
593mod tests {
594
595    use candid::{Nat, Principal};
596    use ic_dbms_api::prelude::{Text, Uint32};
597
598    use super::*;
599    use crate::tests::{
600        Message, POSTS_FIXTURES, Post, TestDatabaseSchema, USERS_FIXTURES, User, UserInsertRequest,
601        UserUpdateRequest, load_fixtures,
602    };
603
604    #[test]
605    fn test_should_init_dbms() {
606        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
607        assert!(dbms.transaction.is_none());
608
609        let tx_dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, Nat::from(1u64));
610        assert!(tx_dbms.transaction.is_some());
611    }
612
613    #[test]
614    fn test_should_select_all_users() {
615        load_fixtures();
616        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
617        let query = Query::<User>::builder().all().build();
618        let users = dbms.select(query).expect("failed to select users");
619
620        assert_eq!(users.len(), USERS_FIXTURES.len());
621        // check if all users all loaded
622        for (i, user) in users.iter().enumerate() {
623            assert_eq!(user.id.expect("should have id").0 as usize, i);
624            assert_eq!(
625                user.name.as_ref().expect("should have name").0,
626                USERS_FIXTURES[i]
627            );
628        }
629    }
630
631    #[test]
632    fn test_should_select_user_in_overlay() {
633        load_fixtures();
634        // create a transaction
635        let transaction_id =
636            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
637        // insert
638        TRANSACTION_SESSION.with_borrow_mut(|ts| {
639            let tx = ts
640                .get_transaction_mut(&transaction_id)
641                .expect("should have tx");
642            tx.overlay_mut()
643                .insert::<User>(vec![
644                    (
645                        ColumnDef {
646                            name: "id",
647                            data_type: ic_dbms_api::prelude::DataTypeKind::Uint32,
648                            nullable: false,
649                            primary_key: true,
650                            foreign_key: None,
651                        },
652                        Value::Uint32(999.into()),
653                    ),
654                    (
655                        ColumnDef {
656                            name: "name",
657                            data_type: ic_dbms_api::prelude::DataTypeKind::Text,
658                            nullable: false,
659                            primary_key: false,
660                            foreign_key: None,
661                        },
662                        Value::Text("OverlayUser".to_string().into()),
663                    ),
664                ])
665                .expect("failed to insert");
666        });
667
668        // select by pk
669        let dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id);
670        let query = Query::<User>::builder()
671            .and_where(Filter::eq("id", Value::Uint32(999.into())))
672            .build();
673        let users = dbms.select(query).expect("failed to select users");
674
675        assert_eq!(users.len(), 1);
676        let user = &users[0];
677        assert_eq!(user.id.expect("should have id").0, 999);
678        assert_eq!(
679            user.name.as_ref().expect("should have name").0,
680            "OverlayUser"
681        );
682    }
683
684    #[test]
685    fn test_should_select_users_with_offset_and_limit() {
686        load_fixtures();
687        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
688        let query = Query::<User>::builder().offset(2).limit(3).build();
689        let users = dbms.select(query).expect("failed to select users");
690
691        assert_eq!(users.len(), 3);
692        // check if correct users are loaded
693        for (i, user) in users.iter().enumerate() {
694            let expected_index = i + 2;
695            assert_eq!(user.id.expect("should have id").0 as usize, expected_index);
696            assert_eq!(
697                user.name.as_ref().expect("should have name").0,
698                USERS_FIXTURES[expected_index]
699            );
700        }
701    }
702
703    #[test]
704    fn test_should_select_users_with_offset_and_filter() {
705        load_fixtures();
706        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
707        let query = Query::<User>::builder()
708            .offset(1)
709            .and_where(Filter::gt("id", Value::Uint32(4.into())))
710            .build();
711        let users = dbms.select(query).expect("failed to select users");
712
713        assert_eq!(users.len(), 4);
714        // check if correct users are loaded
715        for (i, user) in users.iter().enumerate() {
716            let expected_index = i + 6;
717            assert_eq!(user.id.expect("should have id").0 as usize, expected_index);
718            assert_eq!(
719                user.name.as_ref().expect("should have name").0,
720                USERS_FIXTURES[expected_index]
721            );
722        }
723    }
724
725    #[test]
726    fn test_should_select_post_with_relation() {
727        load_fixtures();
728        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
729        let query = Query::<Post>::builder()
730            .all()
731            .with(User::table_name())
732            .build();
733        let posts = dbms.select(query).expect("failed to select posts");
734        assert_eq!(posts.len(), POSTS_FIXTURES.len());
735
736        for (id, post) in posts.into_iter().enumerate() {
737            let (expected_title, expected_content, expected_user_id) = &POSTS_FIXTURES[id];
738            assert_eq!(post.id.expect("should have id").0 as usize, id);
739            assert_eq!(
740                post.title.as_ref().expect("should have title").0,
741                *expected_title
742            );
743            assert_eq!(
744                post.content.as_ref().expect("should have content").0,
745                *expected_content
746            );
747            let user_query = Query::<User>::builder()
748                .and_where(Filter::eq("id", Value::Uint32((*expected_user_id).into())))
749                .build();
750            let author = dbms
751                .select(user_query)
752                .expect("failed to load user")
753                .pop()
754                .expect("should have user");
755            assert_eq!(
756                post.user.expect("should have loaded user"),
757                Box::new(author)
758            );
759        }
760    }
761
762    #[test]
763    fn test_should_fail_loading_unexisting_column_on_select() {
764        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
765        let query = Query::<User>::builder().field("unexisting_column").build();
766        let result = dbms.select(query);
767        assert!(result.is_err());
768    }
769
770    #[test]
771    fn test_should_select_queried_fields() {
772        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
773
774        let record_values = User::columns()
775            .iter()
776            .cloned()
777            .zip(vec![
778                Value::Uint32(1.into()),
779                Value::Text("Alice".to_string().into()),
780            ])
781            .collect::<Vec<(ColumnDef, Value)>>();
782
783        let query: Query<User> = Query::builder().field("name").build();
784        let selected_fields = dbms
785            .select_queried_fields::<User>(record_values, &query)
786            .expect("failed to select queried fields");
787        let user_fields = selected_fields
788            .into_iter()
789            .find(|(table_name, _)| *table_name == ValuesSource::This)
790            .map(|(_, cols)| cols)
791            .unwrap_or_default();
792
793        assert_eq!(user_fields.len(), 1);
794        assert_eq!(user_fields[0].0.name, "name");
795        assert_eq!(user_fields[0].1, Value::Text("Alice".to_string().into()));
796    }
797
798    #[test]
799    fn test_should_select_queried_fields_with_relations() {
800        load_fixtures();
801        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
802
803        let record_values = Post::columns()
804            .iter()
805            .cloned()
806            .zip(vec![
807                Value::Uint32(1.into()),
808                Value::Text("Title".to_string().into()),
809                Value::Text("Content".to_string().into()),
810                Value::Uint32(2.into()), // author_id
811            ])
812            .collect::<Vec<(ColumnDef, Value)>>();
813
814        let query: Query<Post> = Query::builder()
815            .field("title")
816            .with(User::table_name())
817            .build();
818        let selected_fields = dbms
819            .select_queried_fields::<Post>(record_values, &query)
820            .expect("failed to select queried fields");
821
822        // check post fields
823        let post_fields = selected_fields
824            .iter()
825            .find(|(table_name, _)| *table_name == ValuesSource::This)
826            .map(|(_, cols)| cols)
827            .cloned()
828            .unwrap_or_default();
829        assert_eq!(post_fields.len(), 1);
830        assert_eq!(post_fields[0].0.name, "title");
831        assert_eq!(post_fields[0].1, Value::Text("Title".to_string().into()));
832
833        // check user fields
834        let user_fields = selected_fields
835            .iter()
836            .find(|(table_name, _)| {
837                *table_name
838                    == ValuesSource::Foreign {
839                        table: User::table_name().to_string(),
840                        column: "user".to_string(),
841                    }
842            })
843            .map(|(_, cols)| cols)
844            .cloned()
845            .unwrap_or_default();
846
847        let expected_user = USERS_FIXTURES[2]; // author_id = 2
848
849        assert_eq!(user_fields.len(), 2);
850        assert_eq!(user_fields[0].0.name, "id");
851        assert_eq!(user_fields[0].1, Value::Uint32(2.into()));
852        assert_eq!(user_fields[1].0.name, "name");
853        assert_eq!(
854            user_fields[1].1,
855            Value::Text(expected_user.to_string().into())
856        );
857    }
858
859    #[test]
860    fn test_should_select_with_two_fk_on_the_same_table() {
861        load_fixtures();
862
863        let query: Query<Message> = Query::builder()
864            .all()
865            .and_where(Filter::Eq("id".to_string(), Value::Uint32(0.into())))
866            .with("users")
867            .build();
868
869        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
870        let messages = dbms.select(query).expect("failed to select messages");
871        assert_eq!(messages.len(), 1);
872        let message = &messages[0];
873        assert_eq!(message.id.expect("should have id").0, 0);
874        assert_eq!(
875            message
876                .sender
877                .as_ref()
878                .expect("should have sender")
879                .name
880                .as_ref()
881                .unwrap()
882                .0,
883            "Alice"
884        );
885        assert_eq!(
886            message
887                .recipient
888                .as_ref()
889                .expect("should have recipient")
890                .name
891                .as_ref()
892                .unwrap()
893                .0,
894            "Bob"
895        );
896    }
897
898    #[test]
899    fn test_should_select_users_sorted_by_name_descending() {
900        load_fixtures();
901        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
902        let query = Query::<User>::builder().all().order_by_desc("name").build();
903        let users = dbms.select(query).expect("failed to select users");
904
905        let mut sorted_usernames = USERS_FIXTURES.to_vec();
906        sorted_usernames.sort_by(|a, b| b.cmp(a)); // descending
907
908        assert_eq!(users.len(), USERS_FIXTURES.len());
909        // check if all users all loaded in sorted order
910        for (i, user) in users.iter().enumerate() {
911            assert_eq!(
912                user.name.as_ref().expect("should have name").0,
913                sorted_usernames[i]
914            );
915        }
916    }
917
918    #[test]
919    fn test_should_fail_loading_unexisting_relation() {
920        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
921
922        let record_values = Post::columns()
923            .iter()
924            .cloned()
925            .zip(vec![
926                Value::Uint32(1.into()),
927                Value::Text("Title".to_string().into()),
928                Value::Text("Content".to_string().into()),
929                Value::Uint32(2.into()), // author_id
930            ])
931            .collect::<Vec<(ColumnDef, Value)>>();
932
933        let query: Query<Post> = Query::builder()
934            .field("title")
935            .with("unexisting_relation")
936            .build();
937        let result = dbms.select_queried_fields::<Post>(record_values, &query);
938        assert!(result.is_err());
939    }
940
941    #[test]
942    fn test_should_get_whether_record_matches_filter() {
943        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
944
945        let record_values = User::columns()
946            .iter()
947            .cloned()
948            .zip(vec![
949                Value::Uint32(1.into()),
950                Value::Text("Alice".to_string().into()),
951            ])
952            .collect::<Vec<(ColumnDef, Value)>>();
953        let filter = Filter::eq("name", Value::Text("Alice".to_string().into()));
954
955        let matches = dbms
956            .record_matches_filter(&record_values, &filter)
957            .expect("failed to match");
958        assert!(matches);
959
960        let non_matching_filter = Filter::eq("name", Value::Text("Bob".to_string().into()));
961        let non_matches = dbms
962            .record_matches_filter(&record_values, &non_matching_filter)
963            .expect("failed to match");
964        assert!(!non_matches);
965    }
966
967    #[test]
968    fn test_should_load_table_registry() {
969        init_user_table();
970
971        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
972        let table_registry = dbms.load_table_registry::<User>();
973        assert!(table_registry.is_ok());
974    }
975
976    #[test]
977    fn test_should_insert_record_without_transaction() {
978        load_fixtures();
979
980        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
981        let new_user = UserInsertRequest {
982            id: Uint32(100u32),
983            name: Text("NewUser".to_string()),
984        };
985
986        let result = dbms.insert::<User>(new_user);
987        assert!(result.is_ok());
988
989        // find user
990        let query = Query::<User>::builder()
991            .and_where(Filter::eq("id", Value::Uint32(100u32.into())))
992            .build();
993        let users = dbms.select(query).expect("failed to select users");
994        assert_eq!(users.len(), 1);
995        let user = &users[0];
996        assert_eq!(user.id.expect("should have id").0, 100);
997        assert_eq!(
998            user.name.as_ref().expect("should have name").0,
999            "NewUser".to_string()
1000        );
1001    }
1002
1003    #[test]
1004    fn test_should_validate_user_insert_conflict() {
1005        load_fixtures();
1006
1007        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1008        let new_user = UserInsertRequest {
1009            id: Uint32(1u32),
1010            name: Text("NewUser".to_string()),
1011        };
1012
1013        let result = dbms.insert::<User>(new_user);
1014        assert!(result.is_err());
1015    }
1016
1017    #[test]
1018    fn test_should_insert_within_a_transaction() {
1019        load_fixtures();
1020
1021        // create a transaction
1022        let transaction_id =
1023            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1024        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1025
1026        let new_user = UserInsertRequest {
1027            id: Uint32(200u32),
1028            name: Text("TxUser".to_string()),
1029        };
1030
1031        let result = dbms.insert::<User>(new_user);
1032        assert!(result.is_ok());
1033
1034        // user should not be visible outside the transaction
1035        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1036        let query = Query::<User>::builder()
1037            .and_where(Filter::eq("id", Value::Uint32(200u32.into())))
1038            .build();
1039        let users = oneshot_dbms
1040            .select(query.clone())
1041            .expect("failed to select users");
1042        assert_eq!(users.len(), 0);
1043
1044        // commit transaction
1045        let commit_result = dbms.commit();
1046        assert!(commit_result.is_ok());
1047
1048        // now user should be visible
1049        let users_after_commit = oneshot_dbms.select(query).expect("failed to select users");
1050        assert_eq!(users_after_commit.len(), 1);
1051
1052        let user = &users_after_commit[0];
1053        assert_eq!(user.id.expect("should have id").0, 200);
1054        assert_eq!(
1055            user.name.as_ref().expect("should have name").0,
1056            "TxUser".to_string()
1057        );
1058
1059        // transaction should have been removed
1060        TRANSACTION_SESSION.with_borrow(|ts| {
1061            let tx_res = ts.get_transaction(&transaction_id);
1062            assert!(tx_res.is_err());
1063        });
1064    }
1065
1066    #[test]
1067    fn test_should_rollback_transaction() {
1068        load_fixtures();
1069
1070        // create a transaction
1071        let transaction_id =
1072            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1073        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1074        let new_user = UserInsertRequest {
1075            id: Uint32(300u32),
1076            name: Text("RollbackUser".to_string()),
1077        };
1078        let result = dbms.insert::<User>(new_user);
1079        assert!(result.is_ok());
1080
1081        // rollback transaction
1082        let rollback_result = dbms.rollback();
1083        assert!(rollback_result.is_ok());
1084
1085        // user should not be visible
1086        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1087        let query = Query::<User>::builder()
1088            .and_where(Filter::eq("id", Value::Uint32(300u32.into())))
1089            .build();
1090        let users = oneshot_dbms.select(query).expect("failed to select users");
1091        assert_eq!(users.len(), 0);
1092
1093        // transaction should have been removed
1094        TRANSACTION_SESSION.with_borrow(|ts| {
1095            let tx_res = ts.get_transaction(&transaction_id);
1096            assert!(tx_res.is_err());
1097        });
1098    }
1099
1100    #[test]
1101    fn test_should_delete_one_shot() {
1102        load_fixtures();
1103
1104        // insert user with id 100
1105        let new_user = UserInsertRequest {
1106            id: Uint32(100u32),
1107            name: Text("DeleteUser".to_string()),
1108        };
1109        assert!(
1110            IcDbmsDatabase::oneshot(TestDatabaseSchema)
1111                .insert::<User>(new_user)
1112                .is_ok()
1113        );
1114
1115        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1116        let query = Query::<User>::builder()
1117            .and_where(Filter::eq("id", Value::Uint32(100u32.into())))
1118            .build();
1119        let delete_count = dbms
1120            .delete::<User>(
1121                DeleteBehavior::Restrict,
1122                Some(Filter::eq("id", Value::Uint32(100u32.into()))),
1123            )
1124            .expect("failed to delete user");
1125        assert_eq!(delete_count, 1);
1126
1127        // verify user is deleted
1128        let users = dbms.select(query).expect("failed to select users");
1129        assert_eq!(users.len(), 0);
1130    }
1131
1132    #[test]
1133    #[should_panic(expected = "Foreign key constraint violation")]
1134    fn test_should_not_delete_with_fk_restrict() {
1135        load_fixtures();
1136
1137        // user 1 has post and messages for sure.
1138        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1139        let res = dbms
1140            .delete::<User>(
1141                DeleteBehavior::Restrict,
1142                Some(Filter::eq("id", Value::Uint32(1u32.into()))),
1143            )
1144            .expect("failed to delete user");
1145        println!("Delete result (should not show): {}", res);
1146    }
1147
1148    #[test]
1149    fn test_should_delete_with_fk_cascade() {
1150        load_fixtures();
1151
1152        // user 1 has posts and messages for sure.
1153        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1154        let delete_count = dbms
1155            .delete::<User>(
1156                DeleteBehavior::Cascade,
1157                Some(Filter::eq("id", Value::Uint32(1u32.into()))),
1158            )
1159            .expect("failed to delete user");
1160        println!("Delete count: {}", delete_count);
1161        assert!(delete_count > 1); // at least user + posts + messages
1162
1163        // verify user is deleted
1164        let query = Query::<User>::builder()
1165            .and_where(Filter::eq("id", Value::Uint32(1u32.into())))
1166            .build();
1167        let users = dbms.select(query).expect("failed to select users");
1168        assert_eq!(users.len(), 0);
1169
1170        // check posts are deleted (post ID 2)
1171        let post_query = Query::<Post>::builder()
1172            .and_where(Filter::eq("user_id", Value::Uint32(1u32.into())))
1173            .build();
1174        let posts = dbms.select(post_query).expect("failed to select posts");
1175        assert_eq!(posts.len(), 0);
1176
1177        // check messages are deleted (message ID 1)
1178        let message_query = Query::<Message>::builder()
1179            .and_where(Filter::eq("sender_id", Value::Uint32(1u32.into())))
1180            .or_where(Filter::eq("recipient_id", Value::Uint32(1u32.into())))
1181            .build();
1182        let messages = dbms
1183            .select(message_query)
1184            .expect("failed to select messages");
1185        assert_eq!(messages.len(), 0);
1186    }
1187
1188    #[test]
1189    fn test_should_delete_within_transaction() {
1190        load_fixtures();
1191
1192        // create a transaction
1193        let transaction_id =
1194            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1195        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1196
1197        let delete_count = dbms
1198            .delete::<User>(
1199                DeleteBehavior::Cascade,
1200                Some(Filter::eq("id", Value::Uint32(2u32.into()))),
1201            )
1202            .expect("failed to delete user");
1203        assert!(delete_count > 0);
1204
1205        // user should not be visible outside the transaction
1206        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1207        let query = Query::<User>::builder()
1208            .and_where(Filter::eq("id", Value::Uint32(2u32.into())))
1209            .build();
1210        let users = oneshot_dbms
1211            .select(query.clone())
1212            .expect("failed to select users");
1213        assert_eq!(users.len(), 1);
1214
1215        // commit transaction
1216        let commit_result = dbms.commit();
1217        assert!(commit_result.is_ok());
1218
1219        // now user should be deleted
1220        let users_after_commit = oneshot_dbms.select(query).expect("failed to select users");
1221        assert_eq!(users_after_commit.len(), 0);
1222
1223        // check posts are deleted
1224        let post_query = Query::<Post>::builder()
1225            .and_where(Filter::eq("user_id", Value::Uint32(2u32.into())))
1226            .build();
1227        let posts = oneshot_dbms
1228            .select(post_query)
1229            .expect("failed to select posts");
1230        assert_eq!(posts.len(), 0);
1231
1232        // check messages are deleted
1233        let message_query = Query::<Message>::builder()
1234            .and_where(Filter::eq("sender_id", Value::Uint32(2u32.into())))
1235            .or_where(Filter::eq("recipient_id", Value::Uint32(2u32.into())))
1236            .build();
1237        let messages = oneshot_dbms
1238            .select(message_query)
1239            .expect("failed to select messages");
1240        assert_eq!(messages.len(), 0);
1241
1242        // transaction should have been removed
1243        TRANSACTION_SESSION.with_borrow(|ts| {
1244            let tx_res = ts.get_transaction(&transaction_id);
1245            assert!(tx_res.is_err());
1246        });
1247    }
1248
1249    #[test]
1250    fn test_should_update_one_shot() {
1251        load_fixtures();
1252
1253        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1254        let filter = Filter::eq("id", Value::Uint32(3u32.into()));
1255
1256        let patch = UserUpdateRequest {
1257            id: None,
1258            name: Some(Text("UpdatedName".to_string())),
1259            where_clause: Some(filter.clone()),
1260        };
1261
1262        let update_count = dbms.update::<User>(patch).expect("failed to update user");
1263        assert_eq!(update_count, 1);
1264
1265        // verify user is updated
1266        let query = Query::<User>::builder().and_where(filter).build();
1267        let users = dbms.select(query).expect("failed to select users");
1268        assert_eq!(users.len(), 1);
1269        let user = &users[0];
1270        assert_eq!(user.id.expect("should have id").0, 3);
1271        assert_eq!(
1272            user.name.as_ref().expect("should have name").0,
1273            "UpdatedName".to_string()
1274        );
1275    }
1276
1277    #[test]
1278    fn test_should_update_within_transaction() {
1279        load_fixtures();
1280
1281        // create a transaction
1282        let transaction_id =
1283            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1284        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1285
1286        let filter = Filter::eq("id", Value::Uint32(4u32.into()));
1287        let patch = UserUpdateRequest {
1288            id: None,
1289            name: Some(Text("TxUpdatedName".to_string())),
1290            where_clause: Some(filter.clone()),
1291        };
1292
1293        let update_count = dbms.update::<User>(patch).expect("failed to update user");
1294        assert_eq!(update_count, 1);
1295
1296        // user should not be visible outside the transaction
1297        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1298        let query = Query::<User>::builder().and_where(filter.clone()).build();
1299        let users = oneshot_dbms
1300            .select(query.clone())
1301            .expect("failed to select users");
1302        let user = &users[0];
1303        assert_eq!(
1304            user.name.as_ref().expect("should have name").0,
1305            USERS_FIXTURES[4]
1306        );
1307
1308        // commit transaction
1309        let commit_result = dbms.commit();
1310        assert!(commit_result.is_ok());
1311
1312        // now user should be updated
1313        let users_after_commit = oneshot_dbms.select(query).expect("failed to select users");
1314        assert_eq!(users_after_commit.len(), 1);
1315        let user = &users_after_commit[0];
1316        assert_eq!(
1317            user.name.as_ref().expect("should have name").0,
1318            "TxUpdatedName".to_string()
1319        );
1320
1321        // transaction should have been removed
1322        TRANSACTION_SESSION.with_borrow(|ts| {
1323            let tx_res = ts.get_transaction(&transaction_id);
1324            assert!(tx_res.is_err());
1325        });
1326    }
1327
1328    fn init_user_table() {
1329        SCHEMA_REGISTRY
1330            .with_borrow_mut(|sr| sr.register_table::<User>())
1331            .expect("failed to register `User` table");
1332    }
1333}