ic_dbms_canister/memory/
schema_registry.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3
4use ic_dbms_api::prelude::{TableFingerprint, TableSchema};
5
6use crate::memory::{DataSize, Encode, MEMORY_MANAGER, MSize, MemoryError, MemoryResult, Page};
7
8thread_local! {
9    /// The global schema registry.
10    ///
11    /// We allow failing because on first initialization the schema registry might not be present yet.
12    pub static SCHEMA_REGISTRY: RefCell<SchemaRegistry> = RefCell::new(SchemaRegistry::load().unwrap_or_default());
13}
14
15/// Data regarding the table registry page.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct TableRegistryPage {
18    pub pages_list_page: Page,
19    pub free_segments_page: Page,
20}
21
22/// The schema registry takes care of storing and retrieving table schemas from memory.
23#[derive(Debug, Default, Clone, PartialEq, Eq)]
24pub struct SchemaRegistry {
25    tables: HashMap<TableFingerprint, TableRegistryPage>,
26}
27
28impl SchemaRegistry {
29    /// Load the schema registry from memory.
30    pub fn load() -> MemoryResult<Self> {
31        let page = MEMORY_MANAGER.with_borrow(|m| m.schema_page());
32        let registry: Self = MEMORY_MANAGER.with_borrow(|m| m.read_at(page, 0))?;
33        Ok(registry)
34    }
35
36    /// Registers a table and allocates it registry page.
37    ///
38    /// The [`TableSchema`] type parameter is used to get the [`TableSchema::fingerprint`] of the table schema.
39    pub fn register_table<TS>(&mut self) -> MemoryResult<TableRegistryPage>
40    where
41        TS: TableSchema,
42    {
43        // check if already registered
44        let fingerprint = TS::fingerprint();
45        if let Some(pages) = self.tables.get(&fingerprint) {
46            return Ok(*pages);
47        }
48
49        // allocate table registry page
50        let (pages_list_page, free_segments_page) = MEMORY_MANAGER.with_borrow_mut(|m| {
51            Ok::<(Page, Page), MemoryError>((m.allocate_page()?, m.allocate_page()?))
52        })?;
53
54        // insert into tables map
55        let pages = TableRegistryPage {
56            pages_list_page,
57            free_segments_page,
58        };
59        self.tables.insert(fingerprint, pages);
60
61        // get schema page
62        let page = MEMORY_MANAGER.with_borrow(|m| m.schema_page());
63        // write self to schema page
64        MEMORY_MANAGER.with_borrow_mut(|m| m.write_at(page, 0, self))?;
65
66        Ok(pages)
67    }
68
69    /// Returns the table registry page for a given table schema.
70    pub fn table_registry_page<TS>(&self) -> Option<TableRegistryPage>
71    where
72        TS: TableSchema,
73    {
74        self.tables.get(&TS::fingerprint()).copied()
75    }
76}
77
78impl Encode for SchemaRegistry {
79    const SIZE: DataSize = DataSize::Dynamic;
80
81    fn size(&self) -> MSize {
82        // 8 bytes for len + (8 + (4 * 2)) bytes for each entry
83        8 + (self.tables.len() as MSize * (4 * 2 + 8))
84    }
85
86    fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
87        // prepare buffer; size is 8 bytes for len + (8 + (4 * 2)) bytes for each entry
88        let mut buffer = Vec::with_capacity(self.size() as usize);
89        // write 8 bytes len of map
90        buffer.extend_from_slice(&(self.tables.len() as u64).to_le_bytes());
91        // write each entry
92        for (fingerprint, page) in &self.tables {
93            buffer.extend_from_slice(&fingerprint.to_le_bytes());
94            buffer.extend_from_slice(&page.pages_list_page.to_le_bytes());
95            buffer.extend_from_slice(&page.free_segments_page.to_le_bytes());
96        }
97        std::borrow::Cow::Owned(buffer)
98    }
99
100    fn decode(data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
101    where
102        Self: Sized,
103    {
104        let mut offset = 0;
105        // read len
106        let len = u64::from_le_bytes(
107            data[offset..offset + 8]
108                .try_into()
109                .expect("failed to read length"),
110        ) as usize;
111        offset += 8;
112        let mut tables = HashMap::with_capacity(len);
113        // read each entry
114        for _ in 0..len {
115            let fingerprint = u64::from_le_bytes(data[offset..offset + 8].try_into()?);
116            offset += 8;
117            let pages_list_page = Page::from_le_bytes(data[offset..offset + 4].try_into()?);
118            offset += 4;
119            let deleted_records_page = Page::from_le_bytes(data[offset..offset + 4].try_into()?);
120            offset += 4;
121            tables.insert(
122                fingerprint,
123                TableRegistryPage {
124                    pages_list_page,
125                    free_segments_page: deleted_records_page,
126                },
127            );
128        }
129        Ok(Self { tables })
130    }
131}
132
133#[cfg(test)]
134mod tests {
135
136    use candid::CandidType;
137    use ic_dbms_api::prelude::{
138        ColumnDef, IcDbmsResult, InsertRecord, NoForeignFetcher, TableColumns, TableRecord,
139        UpdateRecord,
140    };
141    use serde::{Deserialize, Serialize};
142
143    use super::*;
144    use crate::tests::User;
145
146    #[test]
147    fn test_should_encode_and_decode_schema_registry() {
148        // load
149        let mut registry = SchemaRegistry::load().expect("failed to load init schema registry");
150
151        // register table
152        let registry_page = registry
153            .register_table::<User>()
154            .expect("failed to register table");
155
156        // get table registry page
157        let fetched_page = registry
158            .table_registry_page::<User>()
159            .expect("failed to get table registry page");
160        assert_eq!(registry_page, fetched_page);
161
162        // encode
163        let encoded = registry.encode();
164        // decode
165        let decoded = SchemaRegistry::decode(encoded).expect("failed to decode");
166        assert_eq!(registry, decoded);
167
168        // try to actually add another
169        let another_registry_page = registry
170            .register_table::<AnotherTable>()
171            .expect("failed to register another table");
172        let another_fetched_page = registry
173            .table_registry_page::<AnotherTable>()
174            .expect("failed to get another table registry page");
175        assert_eq!(another_registry_page, another_fetched_page);
176
177        // re-init
178        let reloaded = SchemaRegistry::load().expect("failed to reload schema registry");
179        assert_eq!(registry, reloaded);
180        // should have two
181        assert_eq!(reloaded.tables.len(), 2);
182        assert_eq!(
183            reloaded
184                .table_registry_page::<User>()
185                .expect("failed to get first table registry page after reload"),
186            registry_page
187        );
188        assert_eq!(
189            reloaded
190                .table_registry_page::<AnotherTable>()
191                .expect("failed to get second table registry page after reload"),
192            another_registry_page
193        );
194    }
195
196    #[test]
197    fn test_should_not_register_same_table_twice() {
198        let mut registry = SchemaRegistry::default();
199
200        let first_page = registry
201            .register_table::<User>()
202            .expect("failed to register table first time");
203        let second_page = registry
204            .register_table::<User>()
205            .expect("failed to register table second time");
206
207        assert_eq!(first_page, second_page);
208        assert_eq!(registry.tables.len(), 1);
209    }
210
211    #[derive(Clone, CandidType)]
212    struct AnotherTable;
213
214    impl Encode for AnotherTable {
215        const SIZE: DataSize = DataSize::Dynamic;
216
217        fn size(&self) -> MSize {
218            0
219        }
220
221        fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
222            std::borrow::Cow::Owned(vec![])
223        }
224
225        fn decode(_data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
226        where
227            Self: Sized,
228        {
229            Ok(AnotherTable)
230        }
231    }
232
233    #[derive(Clone, CandidType, Deserialize)]
234    struct AnotherTableRecord;
235
236    impl TableRecord for AnotherTableRecord {
237        type Schema = AnotherTable;
238
239        fn from_values(_values: TableColumns) -> Self {
240            AnotherTableRecord
241        }
242
243        fn to_values(&self) -> Vec<(ColumnDef, ic_dbms_api::prelude::Value)> {
244            vec![]
245        }
246    }
247
248    #[derive(Clone, CandidType, Serialize)]
249    struct AnotherTableInsert;
250
251    impl InsertRecord for AnotherTableInsert {
252        type Record = AnotherTableRecord;
253        type Schema = AnotherTable;
254
255        fn from_values(_values: &[(ColumnDef, ic_dbms_api::prelude::Value)]) -> IcDbmsResult<Self> {
256            Ok(AnotherTableInsert)
257        }
258
259        fn into_values(self) -> Vec<(ColumnDef, ic_dbms_api::prelude::Value)> {
260            vec![]
261        }
262
263        fn into_record(self) -> Self::Schema {
264            AnotherTable
265        }
266    }
267
268    #[derive(CandidType, Serialize)]
269    struct AnotherTableUpdate;
270
271    impl UpdateRecord for AnotherTableUpdate {
272        type Record = AnotherTableRecord;
273        type Schema = AnotherTable;
274
275        fn from_values(
276            _values: &[(ColumnDef, ic_dbms_api::prelude::Value)],
277            _where_clause: Option<ic_dbms_api::prelude::Filter>,
278        ) -> Self {
279            AnotherTableUpdate
280        }
281
282        fn update_values(&self) -> Vec<(ColumnDef, ic_dbms_api::prelude::Value)> {
283            vec![]
284        }
285
286        fn where_clause(&self) -> Option<ic_dbms_api::prelude::Filter> {
287            None
288        }
289    }
290
291    impl TableSchema for AnotherTable {
292        type Record = AnotherTableRecord;
293        type Insert = AnotherTableInsert;
294        type Update = AnotherTableUpdate;
295        type ForeignFetcher = NoForeignFetcher;
296
297        fn table_name() -> &'static str {
298            "another_table"
299        }
300
301        fn columns() -> &'static [ic_dbms_api::prelude::ColumnDef] {
302            &[]
303        }
304
305        fn primary_key() -> &'static str {
306            ""
307        }
308
309        fn to_values(self) -> Vec<(ColumnDef, ic_dbms_api::prelude::Value)> {
310            vec![]
311        }
312    }
313}