ic_dbms_canister/
dbms.rs

1//! This module exposes all the types related to the DBMS engine.
2
3pub mod integrity;
4pub mod referenced_tables;
5pub mod schema;
6pub mod transaction;
7
8use ic_dbms_api::prelude::{
9    ColumnDef, Database, DeleteBehavior, Filter, ForeignFetcher, IcDbmsError, IcDbmsResult,
10    InsertRecord, OrderDirection, Query, QueryError, TableColumns, TableError, TableRecord,
11    TableSchema, TransactionError, TransactionId, UpdateRecord, Value, ValuesSource,
12};
13
14use crate::dbms::transaction::{DatabaseOverlay, Transaction, TransactionOp};
15use crate::memory::{SCHEMA_REGISTRY, TableRegistry};
16use crate::prelude::{DatabaseSchema, TRANSACTION_SESSION};
17use crate::utils::trap;
18
19/// Default capacity limit for SELECT queries.
20const DEFAULT_SELECT_LIMIT: usize = 128;
21
22/// The main DBMS struct.
23///
24/// This struct serves as the entry point for interacting with the DBMS engine.
25///
26/// It provides methods for executing queries.
27///
28/// - [`Database::select`] - Execute a SELECT query.
29/// - [`Database::insert`] - Execute an INSERT query.
30/// - [`Database::update`] - Execute an UPDATE query.
31/// - [`Database::delete`] - Execute a DELETE query.
32/// - [`Database::commit`] - Commit the current transaction.
33/// - [`Database::rollback`] - Rollback the current transaction.
34///
35/// The `transaction` field indicates whether the instance is operating within a transaction context.
36/// The [`Database`] can be instantiated for one-shot, with [`Database::oneshot`] operations (no transaction),
37/// or within a transaction context with [`Database::from_transaction`].
38///
39/// If a transaction is active, all operations will be part of that transaction until it is committed or rolled back.
40pub struct IcDbmsDatabase {
41    /// Database schema to perform generic operations, without knowing the concrete table schema at compile time.
42    schema: Box<dyn DatabaseSchema>,
43    /// Id of the loaded transaction, if any.
44    transaction: Option<TransactionId>,
45}
46
47impl IcDbmsDatabase {
48    /// Load an instance of the [`Database`] for one-shot operations (no transaction).
49    pub fn oneshot(schema: impl DatabaseSchema + 'static) -> Self {
50        Self {
51            schema: Box::new(schema),
52            transaction: None,
53        }
54    }
55
56    /// Load an instance of the [`Database`] within a transaction context.
57    pub fn from_transaction(
58        schema: impl DatabaseSchema + 'static,
59        transaction_id: TransactionId,
60    ) -> Self {
61        Self {
62            schema: Box::new(schema),
63            transaction: Some(transaction_id),
64        }
65    }
66
67    /// Executes a closure with a mutable reference to the current [`Transaction`].
68    fn with_transaction_mut<F, R>(&self, f: F) -> IcDbmsResult<R>
69    where
70        F: FnOnce(&mut Transaction) -> IcDbmsResult<R>,
71    {
72        let txid = self.transaction.as_ref().ok_or(IcDbmsError::Transaction(
73            TransactionError::NoActiveTransaction,
74        ))?;
75
76        TRANSACTION_SESSION.with_borrow_mut(|ts| {
77            let tx = ts.get_transaction_mut(txid)?;
78            f(tx)
79        })
80    }
81
82    /// Executes a closure with a reference to the current [`Transaction`].
83    fn with_transaction<F, R>(&self, f: F) -> IcDbmsResult<R>
84    where
85        F: FnOnce(&Transaction) -> IcDbmsResult<R>,
86    {
87        let txid = self.transaction.as_ref().ok_or(IcDbmsError::Transaction(
88            TransactionError::NoActiveTransaction,
89        ))?;
90
91        TRANSACTION_SESSION.with_borrow_mut(|ts| {
92            let tx = ts.get_transaction_mut(txid)?;
93            f(tx)
94        })
95    }
96
97    /// Executes a closure atomically within the database context.
98    ///
99    /// If the closure returns an error, the changes are rolled back by trapping the canister.
100    fn atomic<F, R>(&self, f: F) -> R
101    where
102        F: FnOnce(&IcDbmsDatabase) -> IcDbmsResult<R>,
103    {
104        match f(self) {
105            Ok(res) => res,
106            Err(err) => trap(err.to_string()),
107        }
108    }
109
110    /// Deletes foreign key related records recursively if the delete behavior is [`DeleteBehavior::Cascade`].
111    fn delete_foreign_keys_cascade<T>(
112        &self,
113        record_values: &[(ColumnDef, Value)],
114    ) -> IcDbmsResult<u64>
115    where
116        T: TableSchema,
117    {
118        let mut count = 0;
119        // verify referenced tables for foreign key constraints
120        for (table, columns) in self.schema.referenced_tables(T::table_name()) {
121            for column in columns.iter() {
122                // prepare filter
123                let pk = record_values
124                    .iter()
125                    .find(|(col_def, _)| col_def.primary_key)
126                    .ok_or(IcDbmsError::Query(QueryError::UnknownColumn(
127                        column.to_string(),
128                    )))?
129                    .1
130                    .clone();
131                // make filter to find records in the referenced table
132                let filter = Filter::eq(column, pk);
133                let res = self
134                    .schema
135                    .delete(self, table, DeleteBehavior::Cascade, Some(filter))?;
136                count += res;
137            }
138        }
139        Ok(count)
140    }
141
142    /// Retrieves the current [`DatabaseOverlay`].
143    fn overlay(&self) -> IcDbmsResult<DatabaseOverlay> {
144        self.with_transaction(|tx| Ok(tx.overlay().clone()))
145    }
146
147    /// Returns whether the read given record matches the provided filter.
148    fn record_matches_filter(
149        &self,
150        record_values: &[(ColumnDef, Value)],
151        filter: &Filter,
152    ) -> IcDbmsResult<bool> {
153        filter.matches(record_values).map_err(IcDbmsError::from)
154    }
155
156    /// Select only the queried fields from the given record values.
157    ///
158    /// It also loads eager relations if any.
159    fn select_queried_fields<T>(
160        &self,
161        mut record_values: Vec<(ColumnDef, Value)>,
162        query: &Query<T>,
163    ) -> IcDbmsResult<TableColumns>
164    where
165        T: TableSchema,
166    {
167        let mut queried_fields = vec![];
168
169        // handle eager relations
170        // FIXME: currently we fetch the FK for each record, which is shit.
171        // In the future, we should batch fetch foreign keys for all records in the result set.
172        for relation in &query.eager_relations {
173            let mut fetched = false;
174            // iter all foreign key with that table
175            for (fk, fk_value) in record_values
176                .iter()
177                .filter(|(col_def, _)| {
178                    col_def
179                        .foreign_key
180                        .is_some_and(|fk| fk.foreign_table == *relation)
181                })
182                .map(|(col, value)| {
183                    (
184                        col.foreign_key.as_ref().expect("cannot be empty"),
185                        value.clone(),
186                    )
187                })
188            {
189                // get foreign values
190                queried_fields.extend(T::foreign_fetcher().fetch(
191                    self,
192                    relation,
193                    fk.local_column,
194                    fk_value,
195                )?);
196                fetched = true;
197            }
198
199            if !fetched {
200                return Err(IcDbmsError::Query(QueryError::InvalidQuery(format!(
201                    "Cannot load relation '{}' for table '{}': no foreign key found",
202                    relation,
203                    T::table_name()
204                ))));
205            }
206        }
207
208        // short-circuit if all selected
209        if query.all_selected() {
210            queried_fields.extend(vec![(ValuesSource::This, record_values)]);
211            return Ok(queried_fields);
212        }
213        record_values.retain(|(col_def, _)| query.columns().contains(&col_def.name.to_string()));
214        queried_fields.extend(vec![(ValuesSource::This, record_values)]);
215        Ok(queried_fields)
216    }
217
218    /// Retrieves existing primary keys for records matching the given filter.
219    fn existing_primary_keys_for_filter<T>(
220        &self,
221        filter: Option<Filter>,
222    ) -> IcDbmsResult<Vec<Value>>
223    where
224        T: TableSchema,
225    {
226        let pk = T::primary_key();
227        let fields = self.select(Query::<T>::builder().filter(filter).build())?;
228        let pks = fields
229            .into_iter()
230            .map(|record| {
231                record
232                    .to_values()
233                    .into_iter()
234                    .find(|(col_def, _value)| col_def.name == pk)
235                    .expect("primary key not found") // this can't fail.
236                    .1
237            })
238            .collect::<Vec<Value>>();
239
240        Ok(pks)
241    }
242
243    /// Load the table registry for the given table schema.
244    fn load_table_registry<T>(&self) -> IcDbmsResult<TableRegistry>
245    where
246        T: TableSchema,
247    {
248        // get pages of the table registry from schema registry
249        let registry_pages = SCHEMA_REGISTRY
250            .with_borrow(|schema| schema.table_registry_page::<T>())
251            .ok_or(IcDbmsError::Table(TableError::TableNotFound))?;
252
253        TableRegistry::load(registry_pages).map_err(IcDbmsError::from)
254    }
255
256    /// Sorts the query results based on the specified column and order direction.
257    ///
258    /// We only sort values which have [`ValuesSource::This`].
259    #[allow(clippy::type_complexity)]
260    fn sort_query_results(
261        &self,
262        results: &mut [Vec<(ValuesSource, Vec<(ColumnDef, Value)>)>],
263        column: &str,
264        direction: OrderDirection,
265    ) {
266        results.sort_by(|a, b| {
267            let a_value = a
268                .iter()
269                .find(|(source, _)| *source == ValuesSource::This)
270                .and_then(|(_, cols)| {
271                    cols.iter()
272                        .find(|(col_def, _)| col_def.name == column)
273                        .map(|(_, value)| value)
274                });
275            let b_value = b
276                .iter()
277                .find(|(source, _)| *source == ValuesSource::This)
278                .and_then(|(_, cols)| {
279                    cols.iter()
280                        .find(|(col_def, _)| col_def.name == column)
281                        .map(|(_, value)| value)
282                });
283
284            match (a_value, b_value) {
285                (Some(a_val), Some(b_val)) => match direction {
286                    OrderDirection::Ascending => a_val.cmp(b_val),
287                    OrderDirection::Descending => b_val.cmp(a_val),
288                },
289                (Some(_), None) => std::cmp::Ordering::Greater,
290                (None, Some(_)) => std::cmp::Ordering::Less,
291                (None, None) => std::cmp::Ordering::Equal,
292            }
293        });
294    }
295}
296
297impl Database for IcDbmsDatabase {
298    /// Executes a SELECT query and returns the results.
299    ///
300    /// # Arguments
301    ///
302    /// - `query` - The SELECT [`Query`] to be executed.
303    ///
304    /// # Returns
305    ///
306    /// The returned results are a vector of [`table::TableRecord`] matching the query.
307    fn select<T>(&self, query: Query<T>) -> IcDbmsResult<Vec<T::Record>>
308    where
309        T: TableSchema,
310    {
311        // load table registry
312        let table_registry = self.load_table_registry::<T>()?;
313        // read table
314        let table_reader = table_registry.read::<T>();
315        // get database overlay
316        let mut table_overlay = if self.transaction.is_some() {
317            self.overlay()?
318        } else {
319            DatabaseOverlay::default()
320        };
321        // overlay table reader
322        let mut table_reader = table_overlay.reader(table_reader);
323
324        // prepare results vector
325        let mut results = Vec::with_capacity(query.limit.unwrap_or(DEFAULT_SELECT_LIMIT));
326        // iter and select
327        let mut count = 0;
328
329        while let Some(values) = table_reader.try_next()? {
330            // check whether it matches the filter
331            if let Some(filter) = &query.filter {
332                if !self.record_matches_filter(&values, filter)? {
333                    continue;
334                }
335            }
336            // filter matched, check limit and offset
337            count += 1;
338            // check whether is before offset
339            if query.offset.is_some_and(|offset| count <= offset) {
340                continue;
341            }
342            // get queried fields
343            let values = self.select_queried_fields::<T>(values, &query)?;
344            // push to results
345            results.push(values);
346            // check whether reached limit
347            if query.limit.is_some_and(|limit| results.len() >= limit) {
348                break;
349            }
350        }
351
352        // sort results if needed and map to records
353        for (column, direction) in query.order_by {
354            self.sort_query_results(&mut results, &column, direction);
355        }
356
357        Ok(results.into_iter().map(T::Record::from_values).collect())
358    }
359
360    /// Executes an INSERT query.
361    ///
362    /// # Arguments
363    ///
364    /// - `record` - The INSERT record to be executed.
365    fn insert<T>(&self, record: T::Insert) -> IcDbmsResult<()>
366    where
367        T: TableSchema,
368        T::Insert: InsertRecord<Schema = T>,
369    {
370        // check whether the insert is valid
371        let record_values = record.clone().into_values();
372        self.schema
373            .validate_insert(self, T::table_name(), &record_values)?;
374
375        if self.transaction.is_some() {
376            // insert a new `insert` into the transaction
377            self.with_transaction_mut(|tx| tx.insert::<T>(record_values))?;
378        } else {
379            // insert directly into the database
380            let mut table_registry = self.load_table_registry::<T>()?;
381            table_registry.insert(record.into_record())?;
382        }
383
384        Ok(())
385    }
386
387    /// Executes an UPDATE query.
388    ///
389    /// # Arguments
390    ///
391    /// - `patch` - The UPDATE patch to be applied.
392    /// - `filter` - An optional [`Filter`] to specify which records to update.
393    ///
394    /// # Returns
395    ///
396    /// The number of rows updated.
397    fn update<T>(&self, patch: T::Update) -> IcDbmsResult<u64>
398    where
399        T: TableSchema,
400        T::Update: UpdateRecord<Schema = T>,
401    {
402        // get all records matching the filter
403        let query = Query::<T>::builder().filter(patch.where_clause()).build();
404        let records = self.select::<T>(query)?;
405        let count = records.len() as u64;
406
407        if self.transaction.is_some() {
408            let filter = patch.where_clause().clone();
409            let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
410            // insert a new `update` into the transaction
411            self.with_transaction_mut(|tx| tx.update::<T>(patch, filter, pks))?;
412
413            return Ok(count);
414        }
415
416        let patch = patch.update_values();
417        // convert updates to values
418        // for each record apply update; delete and insert
419        let res = self.atomic(|db| {
420            for record in records {
421                let mut record_values = record.to_values();
422                // apply patch
423                for (col_def, value) in &patch {
424                    if let Some((_, record_value)) = record_values
425                        .iter_mut()
426                        .find(|(record_col_def, _)| record_col_def.name == col_def.name)
427                    {
428                        *record_value = value.clone();
429                    }
430                }
431                // create insert record
432                let insert_record = T::Insert::from_values(&record_values)?;
433                // delete old record
434                let pk = record_values
435                    .iter()
436                    .find(|(col_def, _)| col_def.primary_key)
437                    .expect("primary key not found") // this can't fail.
438                    .1
439                    .clone();
440                db.delete::<T>(
441                    DeleteBehavior::Break, // we just want to delete the old record
442                    Some(Filter::eq(T::primary_key(), pk)),
443                )?;
444                // insert new record
445                db.insert::<T>(insert_record)?;
446            }
447            Ok(count)
448        });
449
450        Ok(res)
451    }
452
453    /// Executes a DELETE query.
454    ///
455    /// # Arguments
456    ///
457    /// - `behaviour` - The [`DeleteBehavior`] to apply for foreign key constraints.
458    /// - `filter` - An optional [`Filter`] to specify which records to delete.
459    ///
460    /// # Returns
461    ///
462    /// The number of rows deleted.
463    fn delete<T>(&self, behaviour: DeleteBehavior, filter: Option<Filter>) -> IcDbmsResult<u64>
464    where
465        T: TableSchema,
466    {
467        if self.transaction.is_some() {
468            let pks = self.existing_primary_keys_for_filter::<T>(filter.clone())?;
469            let count = pks.len() as u64;
470
471            self.with_transaction_mut(|tx| tx.delete::<T>(behaviour, filter, pks))?;
472
473            return Ok(count);
474        }
475
476        // delete must be atomic
477        let res = self.atomic(|db| {
478            // delete directly from the database
479            // select all records matching the filter
480            // read table
481            let mut table_registry = db.load_table_registry::<T>()?;
482            let mut records = vec![];
483            // iter all records
484            // FIXME: this may be huge, we should do better
485            {
486                let mut table_reader = table_registry.read::<T>();
487                while let Some(values) = table_reader.try_next()? {
488                    let record_values = values.record.clone().to_values();
489                    if let Some(filter) = &filter {
490                        if !db.record_matches_filter(&record_values, filter)? {
491                            continue;
492                        }
493                    }
494                    records.push((values, record_values));
495                }
496            }
497            // deleted records
498            let mut count = records.len() as u64;
499            for (record, record_values) in records {
500                // match delete behaviour
501                match behaviour {
502                    DeleteBehavior::Cascade => {
503                        // delete recursively foreign keys if cascade
504                        count += self.delete_foreign_keys_cascade::<T>(&record_values)?;
505                    }
506                    DeleteBehavior::Restrict => {
507                        if self.delete_foreign_keys_cascade::<T>(&record_values)? > 0 {
508                            // it's okay; we panic here because we are in an atomic closure
509                            return Err(IcDbmsError::Query(
510                                QueryError::ForeignKeyConstraintViolation {
511                                    referencing_table: T::table_name().to_string(),
512                                    field: T::primary_key().to_string(),
513                                },
514                            ));
515                        }
516                    }
517                    DeleteBehavior::Break => {
518                        // do nothing
519                    }
520                }
521                // eventually delete the record
522                table_registry.delete(record.record, record.page, record.offset)?;
523            }
524
525            Ok(count)
526        });
527
528        Ok(res)
529    }
530
531    /// Commits the current transaction.
532    ///
533    /// The transaction is consumed.
534    ///
535    /// Any error during commit will trap the canister to ensure consistency.
536    fn commit(&mut self) -> IcDbmsResult<()> {
537        // take transaction out of self and get the transaction out of the storage
538        // this also invalidates the overlay, so we won't have conflicts during validation
539        let Some(txid) = self.transaction.take() else {
540            return Err(IcDbmsError::Transaction(
541                TransactionError::NoActiveTransaction,
542            ));
543        };
544        let transaction = TRANSACTION_SESSION.with_borrow_mut(|ts| ts.take_transaction(&txid))?;
545
546        // iterate over operations and apply them;
547        // for each operation, first validate, then apply
548        // using `self.atomic` when applying to ensure consistency
549        for op in transaction.operations {
550            match op {
551                TransactionOp::Insert { table, values } => {
552                    // validate
553                    self.schema.validate_insert(self, table, &values)?;
554                    // insert
555                    self.atomic(|db| db.schema.insert(db, table, &values));
556                }
557                TransactionOp::Delete {
558                    table,
559                    behaviour,
560                    filter,
561                } => {
562                    self.atomic(|db| db.schema.delete(db, table, behaviour, filter));
563                }
564                TransactionOp::Update {
565                    table,
566                    patch,
567                    filter,
568                } => {
569                    self.atomic(|db| db.schema.update(db, table, &patch, filter));
570                }
571            }
572        }
573
574        Ok(())
575    }
576
577    /// Rolls back the current transaction.
578    ///
579    /// The transaction is consumed.
580    fn rollback(&mut self) -> IcDbmsResult<()> {
581        let Some(txid) = self.transaction.take() else {
582            return Err(IcDbmsError::Transaction(
583                TransactionError::NoActiveTransaction,
584            ));
585        };
586
587        TRANSACTION_SESSION.with_borrow_mut(|ts| ts.close_transaction(&txid));
588        Ok(())
589    }
590}
591
592#[cfg(test)]
593mod tests {
594
595    use candid::{Nat, Principal};
596    use ic_dbms_api::prelude::{Text, Uint32};
597
598    use super::*;
599    use crate::tests::{
600        Message, POSTS_FIXTURES, Post, TestDatabaseSchema, USERS_FIXTURES, User, UserInsertRequest,
601        UserUpdateRequest, load_fixtures,
602    };
603
604    #[test]
605    fn test_should_init_dbms() {
606        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
607        assert!(dbms.transaction.is_none());
608
609        let tx_dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, Nat::from(1u64));
610        assert!(tx_dbms.transaction.is_some());
611    }
612
613    #[test]
614    fn test_should_select_all_users() {
615        load_fixtures();
616        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
617        let query = Query::<User>::builder().all().build();
618        let users = dbms.select(query).expect("failed to select users");
619
620        assert_eq!(users.len(), USERS_FIXTURES.len());
621        // check if all users all loaded
622        for (i, user) in users.iter().enumerate() {
623            assert_eq!(user.id.expect("should have id").0 as usize, i);
624            assert_eq!(
625                user.name.as_ref().expect("should have name").0,
626                USERS_FIXTURES[i]
627            );
628        }
629    }
630
631    #[test]
632    fn test_should_select_user_in_overlay() {
633        load_fixtures();
634        // create a transaction
635        let transaction_id =
636            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
637        // insert
638        TRANSACTION_SESSION.with_borrow_mut(|ts| {
639            let tx = ts
640                .get_transaction_mut(&transaction_id)
641                .expect("should have tx");
642            tx.overlay_mut()
643                .insert::<User>(vec![
644                    (
645                        ColumnDef {
646                            name: "id",
647                            data_type: ic_dbms_api::prelude::DataTypeKind::Uint32,
648                            nullable: false,
649                            primary_key: true,
650                            foreign_key: None,
651                        },
652                        Value::Uint32(999.into()),
653                    ),
654                    (
655                        ColumnDef {
656                            name: "name",
657                            data_type: ic_dbms_api::prelude::DataTypeKind::Text,
658                            nullable: false,
659                            primary_key: false,
660                            foreign_key: None,
661                        },
662                        Value::Text("OverlayUser".to_string().into()),
663                    ),
664                ])
665                .expect("failed to insert");
666        });
667
668        // select by pk
669        let dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id);
670        let query = Query::<User>::builder()
671            .and_where(Filter::eq("id", Value::Uint32(999.into())))
672            .build();
673        let users = dbms.select(query).expect("failed to select users");
674
675        assert_eq!(users.len(), 1);
676        let user = &users[0];
677        assert_eq!(user.id.expect("should have id").0, 999);
678        assert_eq!(
679            user.name.as_ref().expect("should have name").0,
680            "OverlayUser"
681        );
682    }
683
684    #[test]
685    fn test_should_select_users_with_offset_and_limit() {
686        load_fixtures();
687        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
688        let query = Query::<User>::builder().offset(2).limit(3).build();
689        let users = dbms.select(query).expect("failed to select users");
690
691        assert_eq!(users.len(), 3);
692        // check if correct users are loaded
693        for (i, user) in users.iter().enumerate() {
694            let expected_index = i + 2;
695            assert_eq!(user.id.expect("should have id").0 as usize, expected_index);
696            assert_eq!(
697                user.name.as_ref().expect("should have name").0,
698                USERS_FIXTURES[expected_index]
699            );
700        }
701    }
702
703    #[test]
704    fn test_should_select_users_with_offset_and_filter() {
705        load_fixtures();
706        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
707        let query = Query::<User>::builder()
708            .offset(1)
709            .and_where(Filter::gt("id", Value::Uint32(4.into())))
710            .build();
711        let users = dbms.select(query).expect("failed to select users");
712
713        assert_eq!(users.len(), 4);
714        // check if correct users are loaded
715        for (i, user) in users.iter().enumerate() {
716            let expected_index = i + 6;
717            assert_eq!(user.id.expect("should have id").0 as usize, expected_index);
718            assert_eq!(
719                user.name.as_ref().expect("should have name").0,
720                USERS_FIXTURES[expected_index]
721            );
722        }
723    }
724
725    #[test]
726    fn test_should_select_post_with_relation() {
727        load_fixtures();
728        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
729        let query = Query::<Post>::builder()
730            .all()
731            .with(User::table_name())
732            .build();
733        let posts = dbms.select(query).expect("failed to select posts");
734        assert_eq!(posts.len(), POSTS_FIXTURES.len());
735
736        for (id, post) in posts.into_iter().enumerate() {
737            let (expected_title, expected_content, expected_user_id) = &POSTS_FIXTURES[id];
738            assert_eq!(post.id.expect("should have id").0 as usize, id);
739            assert_eq!(
740                post.title.as_ref().expect("should have title").0,
741                *expected_title
742            );
743            assert_eq!(
744                post.content.as_ref().expect("should have content").0,
745                *expected_content
746            );
747            let user_query = Query::<User>::builder()
748                .and_where(Filter::eq("id", Value::Uint32((*expected_user_id).into())))
749                .build();
750            let author = dbms
751                .select(user_query)
752                .expect("failed to load user")
753                .pop()
754                .expect("should have user");
755            assert_eq!(
756                post.user.expect("should have loaded user"),
757                Box::new(author)
758            );
759        }
760    }
761
762    #[test]
763    fn test_should_fail_loading_unexisting_column_on_select() {
764        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
765        let query = Query::<User>::builder().field("unexisting_column").build();
766        let result = dbms.select(query);
767        assert!(result.is_err());
768    }
769
770    #[test]
771    fn test_should_select_queried_fields() {
772        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
773
774        let record_values = User::columns()
775            .iter()
776            .cloned()
777            .zip(vec![
778                Value::Uint32(1.into()),
779                Value::Text("Alice".to_string().into()),
780            ])
781            .collect::<Vec<(ColumnDef, Value)>>();
782
783        let query: Query<User> = Query::builder().field("name").build();
784        let selected_fields = dbms
785            .select_queried_fields::<User>(record_values, &query)
786            .expect("failed to select queried fields");
787        let user_fields = selected_fields
788            .into_iter()
789            .find(|(table_name, _)| *table_name == ValuesSource::This)
790            .map(|(_, cols)| cols)
791            .unwrap_or_default();
792
793        assert_eq!(user_fields.len(), 1);
794        assert_eq!(user_fields[0].0.name, "name");
795        assert_eq!(user_fields[0].1, Value::Text("Alice".to_string().into()));
796    }
797
798    #[test]
799    fn test_should_select_queried_fields_with_relations() {
800        load_fixtures();
801        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
802
803        let record_values = Post::columns()
804            .iter()
805            .cloned()
806            .zip(vec![
807                Value::Uint32(1.into()),
808                Value::Text("Title".to_string().into()),
809                Value::Text("Content".to_string().into()),
810                Value::Uint32(2.into()), // author_id
811            ])
812            .collect::<Vec<(ColumnDef, Value)>>();
813
814        let query: Query<Post> = Query::builder()
815            .field("title")
816            .with(User::table_name())
817            .build();
818        let selected_fields = dbms
819            .select_queried_fields::<Post>(record_values, &query)
820            .expect("failed to select queried fields");
821
822        // check post fields
823        let post_fields = selected_fields
824            .iter()
825            .find(|(table_name, _)| *table_name == ValuesSource::This)
826            .map(|(_, cols)| cols)
827            .cloned()
828            .unwrap_or_default();
829        assert_eq!(post_fields.len(), 1);
830        assert_eq!(post_fields[0].0.name, "title");
831        assert_eq!(post_fields[0].1, Value::Text("Title".to_string().into()));
832
833        // check user fields
834        let user_fields = selected_fields
835            .iter()
836            .find(|(table_name, _)| {
837                *table_name
838                    == ValuesSource::Foreign {
839                        table: User::table_name().to_string(),
840                        column: "user".to_string(),
841                    }
842            })
843            .map(|(_, cols)| cols)
844            .cloned()
845            .unwrap_or_default();
846
847        let expected_user = USERS_FIXTURES[2]; // author_id = 2
848
849        assert_eq!(user_fields.len(), 3);
850        assert_eq!(user_fields[0].0.name, "id");
851        assert_eq!(user_fields[0].1, Value::Uint32(2.into()));
852        assert_eq!(user_fields[1].0.name, "name");
853        assert_eq!(user_fields[2].0.name, "email");
854        assert_eq!(
855            user_fields[1].1,
856            Value::Text(expected_user.to_string().into())
857        );
858    }
859
860    #[test]
861    fn test_should_select_with_two_fk_on_the_same_table() {
862        load_fixtures();
863
864        let query: Query<Message> = Query::builder()
865            .all()
866            .and_where(Filter::Eq("id".to_string(), Value::Uint32(0.into())))
867            .with("users")
868            .build();
869
870        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
871        let messages = dbms.select(query).expect("failed to select messages");
872        assert_eq!(messages.len(), 1);
873        let message = &messages[0];
874        assert_eq!(message.id.expect("should have id").0, 0);
875        assert_eq!(
876            message
877                .sender
878                .as_ref()
879                .expect("should have sender")
880                .name
881                .as_ref()
882                .unwrap()
883                .0,
884            "Alice"
885        );
886        assert_eq!(
887            message
888                .recipient
889                .as_ref()
890                .expect("should have recipient")
891                .name
892                .as_ref()
893                .unwrap()
894                .0,
895            "Bob"
896        );
897    }
898
899    #[test]
900    fn test_should_select_users_sorted_by_name_descending() {
901        load_fixtures();
902        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
903        let query = Query::<User>::builder().all().order_by_desc("name").build();
904        let users = dbms.select(query).expect("failed to select users");
905
906        let mut sorted_usernames = USERS_FIXTURES.to_vec();
907        sorted_usernames.sort_by(|a, b| b.cmp(a)); // descending
908
909        assert_eq!(users.len(), USERS_FIXTURES.len());
910        // check if all users all loaded in sorted order
911        for (i, user) in users.iter().enumerate() {
912            assert_eq!(
913                user.name.as_ref().expect("should have name").0,
914                sorted_usernames[i]
915            );
916        }
917    }
918
919    #[test]
920    fn test_should_fail_loading_unexisting_relation() {
921        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
922
923        let record_values = Post::columns()
924            .iter()
925            .cloned()
926            .zip(vec![
927                Value::Uint32(1.into()),
928                Value::Text("Title".to_string().into()),
929                Value::Text("Content".to_string().into()),
930                Value::Uint32(2.into()), // author_id
931            ])
932            .collect::<Vec<(ColumnDef, Value)>>();
933
934        let query: Query<Post> = Query::builder()
935            .field("title")
936            .with("unexisting_relation")
937            .build();
938        let result = dbms.select_queried_fields::<Post>(record_values, &query);
939        assert!(result.is_err());
940    }
941
942    #[test]
943    fn test_should_get_whether_record_matches_filter() {
944        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
945
946        let record_values = User::columns()
947            .iter()
948            .cloned()
949            .zip(vec![
950                Value::Uint32(1.into()),
951                Value::Text("Alice".to_string().into()),
952            ])
953            .collect::<Vec<(ColumnDef, Value)>>();
954        let filter = Filter::eq("name", Value::Text("Alice".to_string().into()));
955
956        let matches = dbms
957            .record_matches_filter(&record_values, &filter)
958            .expect("failed to match");
959        assert!(matches);
960
961        let non_matching_filter = Filter::eq("name", Value::Text("Bob".to_string().into()));
962        let non_matches = dbms
963            .record_matches_filter(&record_values, &non_matching_filter)
964            .expect("failed to match");
965        assert!(!non_matches);
966    }
967
968    #[test]
969    fn test_should_load_table_registry() {
970        init_user_table();
971
972        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
973        let table_registry = dbms.load_table_registry::<User>();
974        assert!(table_registry.is_ok());
975    }
976
977    #[test]
978    fn test_should_insert_record_without_transaction() {
979        load_fixtures();
980
981        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
982        let new_user = UserInsertRequest {
983            id: Uint32(100u32),
984            name: Text("NewUser".to_string()),
985            email: "new_user@example.com".into(),
986        };
987
988        let result = dbms.insert::<User>(new_user);
989        assert!(result.is_ok());
990
991        // find user
992        let query = Query::<User>::builder()
993            .and_where(Filter::eq("id", Value::Uint32(100u32.into())))
994            .build();
995        let users = dbms.select(query).expect("failed to select users");
996        assert_eq!(users.len(), 1);
997        let user = &users[0];
998        assert_eq!(user.id.expect("should have id").0, 100);
999        assert_eq!(
1000            user.name.as_ref().expect("should have name").0,
1001            "NewUser".to_string()
1002        );
1003    }
1004
1005    #[test]
1006    fn test_should_validate_user_insert_conflict() {
1007        load_fixtures();
1008
1009        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1010        let new_user = UserInsertRequest {
1011            id: Uint32(1u32),
1012            name: Text("NewUser".to_string()),
1013            email: "new_user@example.com".into(),
1014        };
1015
1016        let result = dbms.insert::<User>(new_user);
1017        assert!(result.is_err());
1018    }
1019
1020    #[test]
1021    fn test_should_insert_within_a_transaction() {
1022        load_fixtures();
1023
1024        // create a transaction
1025        let transaction_id =
1026            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1027        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1028
1029        let new_user = UserInsertRequest {
1030            id: Uint32(200u32),
1031            name: Text("TxUser".to_string()),
1032            email: "new_user@example.com".into(),
1033        };
1034
1035        let result = dbms.insert::<User>(new_user);
1036        assert!(result.is_ok());
1037
1038        // user should not be visible outside the transaction
1039        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1040        let query = Query::<User>::builder()
1041            .and_where(Filter::eq("id", Value::Uint32(200u32.into())))
1042            .build();
1043        let users = oneshot_dbms
1044            .select(query.clone())
1045            .expect("failed to select users");
1046        assert_eq!(users.len(), 0);
1047
1048        // commit transaction
1049        let commit_result = dbms.commit();
1050        assert!(commit_result.is_ok());
1051
1052        // now user should be visible
1053        let users_after_commit = oneshot_dbms.select(query).expect("failed to select users");
1054        assert_eq!(users_after_commit.len(), 1);
1055
1056        let user = &users_after_commit[0];
1057        assert_eq!(user.id.expect("should have id").0, 200);
1058        assert_eq!(
1059            user.name.as_ref().expect("should have name").0,
1060            "TxUser".to_string()
1061        );
1062
1063        // transaction should have been removed
1064        TRANSACTION_SESSION.with_borrow(|ts| {
1065            let tx_res = ts.get_transaction(&transaction_id);
1066            assert!(tx_res.is_err());
1067        });
1068    }
1069
1070    #[test]
1071    fn test_should_rollback_transaction() {
1072        load_fixtures();
1073
1074        // create a transaction
1075        let transaction_id =
1076            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1077        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1078        let new_user = UserInsertRequest {
1079            id: Uint32(300u32),
1080            name: Text("RollbackUser".to_string()),
1081            email: "new_user@example.com".into(),
1082        };
1083        let result = dbms.insert::<User>(new_user);
1084        assert!(result.is_ok());
1085
1086        // rollback transaction
1087        let rollback_result = dbms.rollback();
1088        assert!(rollback_result.is_ok());
1089
1090        // user should not be visible
1091        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1092        let query = Query::<User>::builder()
1093            .and_where(Filter::eq("id", Value::Uint32(300u32.into())))
1094            .build();
1095        let users = oneshot_dbms.select(query).expect("failed to select users");
1096        assert_eq!(users.len(), 0);
1097
1098        // transaction should have been removed
1099        TRANSACTION_SESSION.with_borrow(|ts| {
1100            let tx_res = ts.get_transaction(&transaction_id);
1101            assert!(tx_res.is_err());
1102        });
1103    }
1104
1105    #[test]
1106    fn test_should_delete_one_shot() {
1107        load_fixtures();
1108
1109        // insert user with id 100
1110        let new_user = UserInsertRequest {
1111            id: Uint32(100u32),
1112            name: Text("DeleteUser".to_string()),
1113            email: "new_user@example.com".into(),
1114        };
1115        assert!(
1116            IcDbmsDatabase::oneshot(TestDatabaseSchema)
1117                .insert::<User>(new_user)
1118                .is_ok()
1119        );
1120
1121        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1122        let query = Query::<User>::builder()
1123            .and_where(Filter::eq("id", Value::Uint32(100u32.into())))
1124            .build();
1125        let delete_count = dbms
1126            .delete::<User>(
1127                DeleteBehavior::Restrict,
1128                Some(Filter::eq("id", Value::Uint32(100u32.into()))),
1129            )
1130            .expect("failed to delete user");
1131        assert_eq!(delete_count, 1);
1132
1133        // verify user is deleted
1134        let users = dbms.select(query).expect("failed to select users");
1135        assert_eq!(users.len(), 0);
1136    }
1137
1138    #[test]
1139    #[should_panic(expected = "Foreign key constraint violation")]
1140    fn test_should_not_delete_with_fk_restrict() {
1141        load_fixtures();
1142
1143        // user 1 has post and messages for sure.
1144        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1145        let res = dbms
1146            .delete::<User>(
1147                DeleteBehavior::Restrict,
1148                Some(Filter::eq("id", Value::Uint32(1u32.into()))),
1149            )
1150            .expect("failed to delete user");
1151        println!("Delete result (should not show): {}", res);
1152    }
1153
1154    #[test]
1155    fn test_should_delete_with_fk_cascade() {
1156        load_fixtures();
1157
1158        // user 1 has posts and messages for sure.
1159        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1160        let delete_count = dbms
1161            .delete::<User>(
1162                DeleteBehavior::Cascade,
1163                Some(Filter::eq("id", Value::Uint32(1u32.into()))),
1164            )
1165            .expect("failed to delete user");
1166        println!("Delete count: {}", delete_count);
1167        assert!(delete_count > 1); // at least user + posts + messages
1168
1169        // verify user is deleted
1170        let query = Query::<User>::builder()
1171            .and_where(Filter::eq("id", Value::Uint32(1u32.into())))
1172            .build();
1173        let users = dbms.select(query).expect("failed to select users");
1174        assert_eq!(users.len(), 0);
1175
1176        // check posts are deleted (post ID 2)
1177        let post_query = Query::<Post>::builder()
1178            .and_where(Filter::eq("user_id", Value::Uint32(1u32.into())))
1179            .build();
1180        let posts = dbms.select(post_query).expect("failed to select posts");
1181        assert_eq!(posts.len(), 0);
1182
1183        // check messages are deleted (message ID 1)
1184        let message_query = Query::<Message>::builder()
1185            .and_where(Filter::eq("sender_id", Value::Uint32(1u32.into())))
1186            .or_where(Filter::eq("recipient_id", Value::Uint32(1u32.into())))
1187            .build();
1188        let messages = dbms
1189            .select(message_query)
1190            .expect("failed to select messages");
1191        assert_eq!(messages.len(), 0);
1192    }
1193
1194    #[test]
1195    fn test_should_delete_within_transaction() {
1196        load_fixtures();
1197
1198        // create a transaction
1199        let transaction_id =
1200            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1201        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1202
1203        let delete_count = dbms
1204            .delete::<User>(
1205                DeleteBehavior::Cascade,
1206                Some(Filter::eq("id", Value::Uint32(2u32.into()))),
1207            )
1208            .expect("failed to delete user");
1209        assert!(delete_count > 0);
1210
1211        // user should not be visible outside the transaction
1212        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1213        let query = Query::<User>::builder()
1214            .and_where(Filter::eq("id", Value::Uint32(2u32.into())))
1215            .build();
1216        let users = oneshot_dbms
1217            .select(query.clone())
1218            .expect("failed to select users");
1219        assert_eq!(users.len(), 1);
1220
1221        // commit transaction
1222        let commit_result = dbms.commit();
1223        assert!(commit_result.is_ok());
1224
1225        // now user should be deleted
1226        let users_after_commit = oneshot_dbms.select(query).expect("failed to select users");
1227        assert_eq!(users_after_commit.len(), 0);
1228
1229        // check posts are deleted
1230        let post_query = Query::<Post>::builder()
1231            .and_where(Filter::eq("user_id", Value::Uint32(2u32.into())))
1232            .build();
1233        let posts = oneshot_dbms
1234            .select(post_query)
1235            .expect("failed to select posts");
1236        assert_eq!(posts.len(), 0);
1237
1238        // check messages are deleted
1239        let message_query = Query::<Message>::builder()
1240            .and_where(Filter::eq("sender_id", Value::Uint32(2u32.into())))
1241            .or_where(Filter::eq("recipient_id", Value::Uint32(2u32.into())))
1242            .build();
1243        let messages = oneshot_dbms
1244            .select(message_query)
1245            .expect("failed to select messages");
1246        assert_eq!(messages.len(), 0);
1247
1248        // transaction should have been removed
1249        TRANSACTION_SESSION.with_borrow(|ts| {
1250            let tx_res = ts.get_transaction(&transaction_id);
1251            assert!(tx_res.is_err());
1252        });
1253    }
1254
1255    #[test]
1256    fn test_should_update_one_shot() {
1257        load_fixtures();
1258
1259        let dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1260        let filter = Filter::eq("id", Value::Uint32(3u32.into()));
1261
1262        let patch = UserUpdateRequest {
1263            id: None,
1264            name: Some(Text("UpdatedName".to_string())),
1265            email: None,
1266            where_clause: Some(filter.clone()),
1267        };
1268
1269        let update_count = dbms.update::<User>(patch).expect("failed to update user");
1270        assert_eq!(update_count, 1);
1271
1272        // verify user is updated
1273        let query = Query::<User>::builder().and_where(filter).build();
1274        let users = dbms.select(query).expect("failed to select users");
1275        assert_eq!(users.len(), 1);
1276        let user = &users[0];
1277        assert_eq!(user.id.expect("should have id").0, 3);
1278        assert_eq!(
1279            user.name.as_ref().expect("should have name").0,
1280            "UpdatedName".to_string()
1281        );
1282    }
1283
1284    #[test]
1285    fn test_should_update_within_transaction() {
1286        load_fixtures();
1287
1288        // create a transaction
1289        let transaction_id =
1290            TRANSACTION_SESSION.with_borrow_mut(|ts| ts.begin_transaction(Principal::anonymous()));
1291        let mut dbms = IcDbmsDatabase::from_transaction(TestDatabaseSchema, transaction_id.clone());
1292
1293        let filter = Filter::eq("id", Value::Uint32(4u32.into()));
1294        let patch = UserUpdateRequest {
1295            id: None,
1296            name: Some(Text("TxUpdatedName".to_string())),
1297            email: None,
1298            where_clause: Some(filter.clone()),
1299        };
1300
1301        let update_count = dbms.update::<User>(patch).expect("failed to update user");
1302        assert_eq!(update_count, 1);
1303
1304        // user should not be visible outside the transaction
1305        let oneshot_dbms = IcDbmsDatabase::oneshot(TestDatabaseSchema);
1306        let query = Query::<User>::builder().and_where(filter.clone()).build();
1307        let users = oneshot_dbms
1308            .select(query.clone())
1309            .expect("failed to select users");
1310        let user = &users[0];
1311        assert_eq!(
1312            user.name.as_ref().expect("should have name").0,
1313            USERS_FIXTURES[4]
1314        );
1315
1316        // commit transaction
1317        let commit_result = dbms.commit();
1318        assert!(commit_result.is_ok());
1319
1320        // now user should be updated
1321        let users_after_commit = oneshot_dbms.select(query).expect("failed to select users");
1322        assert_eq!(users_after_commit.len(), 1);
1323        let user = &users_after_commit[0];
1324        assert_eq!(
1325            user.name.as_ref().expect("should have name").0,
1326            "TxUpdatedName".to_string()
1327        );
1328
1329        // transaction should have been removed
1330        TRANSACTION_SESSION.with_borrow(|ts| {
1331            let tx_res = ts.get_transaction(&transaction_id);
1332            assert!(tx_res.is_err());
1333        });
1334    }
1335
1336    fn init_user_table() {
1337        SCHEMA_REGISTRY
1338            .with_borrow_mut(|sr| sr.register_table::<User>())
1339            .expect("failed to register `User` table");
1340    }
1341}