ic_dbms_macros/
lib.rs

1#![crate_name = "ic_dbms_macros"]
2#![crate_type = "lib"]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5//! Macros and derive for ic-dbms-canister
6//!
7//! This crate provides procedural macros to automatically implement traits
8//! required by the `ic-dbms-canister`.
9//!
10//! ## Provided Derive Macros
11//!
12//! - `Encode`: Automatically implements the `Encode` trait for structs.
13//! - `Table`: Automatically implements the `TableSchema` trait and associated types.
14//! - `DbmsCanister`: Automatically implements the API for the ic-dbms-canister.
15//!
16
17#![doc(html_playground_url = "https://play.rust-lang.org")]
18#![doc(
19    html_favicon_url = "https://raw.githubusercontent.com/veeso/ic-dbms/main/assets/images/cargo/logo-128.png"
20)]
21#![doc(
22    html_logo_url = "https://raw.githubusercontent.com/veeso/ic-dbms/main/assets/images/cargo/logo-512.png"
23)]
24
25use proc_macro::TokenStream;
26use syn::{DeriveInput, parse_macro_input};
27
28mod dbms_canister;
29mod encode;
30mod table;
31mod utils;
32
33/// Automatically implements the `Encode`` trait for a struct.
34///
35/// This derive macro generates two methods required by the `Encode` trait:
36///
37/// - `fn data_size() -> DataSize`  
38///   Computes the static size of the encoded type.  
39///   If all fields implement `Encode::data_size()` returning  
40///   `DataSize::Fixed(n)`, then the type is also considered fixed-size.  
41///   Otherwise, the type is `DataSize::Dynamic`.
42///
43/// - `fn size(&self) -> MSize`  
44///   Computes the runtime-encoding size of the value by summing the
45///   sizes of all fields.
46///
47/// # What the macro generates
48///
49/// Given a struct like:
50///
51/// ```rust,ignore
52/// #[derive(Encode)]
53/// struct User {
54///     id: Uint32,
55///     name: Text,
56/// }
57/// ```
58///
59/// The macro expands into:
60///
61/// ```rust,ignore
62/// impl Encode for User {
63///     const DATA_SIZE: DataSize = DataSize::Dynamic; // or DataSize::Fixed(n) if applicable
64///
65///     fn size(&self) -> MSize {
66///         self.id.size() + self.name.size()
67///     }
68///
69///     fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
70///         let mut encoded = Vec::with_capacity(self.size() as usize);
71///         encoded.extend_from_slice(&self.id.encode());
72///         encoded.extend_from_slice(&self.name.encode());
73///         std::borrow::Cow::Owned(encoded)
74///     }
75///
76///     fn decode(data: std::borrow::Cow<[u8]>) -> ::ic_dbms_api::prelude::MemoryResult<Self> {
77///         let mut offset = 0;
78///         let id = Uint32::decode(std::borrow::Borrowed(&data[offset..]))?;
79///         offset += id.size() as usize;
80///         let name = Text::decode(std::borrow::Borrowed(&data[offset..]))?;
81///         offset += name.size() as usize;
82///         Ok(Self { id, name })
83///     }
84/// }
85/// ```
86/// # Requirements
87///
88/// - Each field type must implement `Encode`.
89/// - Only works on `struct`s; enums and unions are not supported.
90/// - All field identifiers must be valid Rust identifiers (no tuple structs).
91///
92/// # Notes
93///
94/// - It is intended for internal use within the `ic-dbms-canister` DBMS memory
95///   system.
96///
97/// # Errors
98///
99/// The macro will fail to expand if:
100///
101/// - The struct has unnamed fields (tuple struct)
102/// - A field type does not implement `Encode`
103/// - The macro is applied to a non-struct item.
104///
105/// # Example
106///
107/// ```rust,ignore
108/// #[derive(Encode, Debug, PartialEq, Eq)]
109/// struct Position {
110///     x: Int32,
111///     y: Int32,
112/// }
113///
114/// let pos = Position { x: 10.into(), y: 20.into() };
115/// assert_eq!(Position::data_size(), DataSize::Fixed(8));
116/// assert_eq!(pos.size(), 8);
117/// let encoded = pos.encode();
118/// let decoded = Position::decode(encoded).unwrap();
119/// assert_eq!(pos, decoded);
120/// ```
121#[proc_macro_derive(Encode)]
122pub fn derive_encode(input: TokenStream) -> TokenStream {
123    let input = parse_macro_input!(input as DeriveInput);
124    self::encode::encode(input)
125        .expect("Failed to derive `Encode`")
126        .into()
127}
128
129/// Given a struct representing a database table, automatically implements
130/// the `TableSchema` trait with all the necessary types to work with the ic-dbms-canister.
131/// So given this struct:
132///
133/// ```rust,ignore
134/// #[derive(Table, Encode)]
135/// #[table = "posts"]
136/// struct Post {
137///     #[primary_key]
138///     id: Uint32,
139///     title: Text,
140///     content: Text,
141///     #[foreign_key(entity = "User", table = "users", column = "id")]
142///     author_id: Uint32,
143/// }
144/// ```
145///
146/// What we expect as output is:
147///
148/// - To implement the `TableSchema` trait for the struct as follows:
149///
150///     ```rust,ignore
151///     impl TableSchema for Post {
152///         type Insert = PostInsertRequest;
153///         type Record = PostRecord;
154///         type Update = PostUpdateRequest;
155///         type ForeignFetcher = PostForeignFetcher;
156///
157///         fn columns() -> &'static [ColumnDef] {
158///             &[
159///                 ColumnDef {
160///                     name: "id",
161///                     data_type: DataTypeKind::Uint32,
162///                     nullable: false,
163///                     primary_key: true,
164///                     foreign_key: None,
165///                 },
166///                 ColumnDef {
167///                     name: "title",
168///                     data_type: DataTypeKind::Text,
169///                     nullable: false,
170///                     primary_key: false,
171///                     foreign_key: None,
172///                 },
173///                 ColumnDef {
174///                     name: "content",
175///                     data_type: DataTypeKind::Text,
176///                     nullable: false,
177///                     primary_key: false,
178///                     foreign_key: None,
179///                 },
180///                 ColumnDef {
181///                     name: "user_id",
182///                     data_type: DataTypeKind::Uint32,
183///                     nullable: false,
184///                     primary_key: false,
185///                     foreign_key: Some(ForeignKeyDef {
186///                         local_column: "user_id",
187///                         foreign_table: "users",
188///                         foreign_column: "id",
189///                     }),
190///                 },
191///             ]
192///         }
193///
194///         fn table_name() -> &'static str {
195///             "posts"
196///         }
197///
198///         fn primary_key() -> &'static str {
199///             "id"
200///         }
201///
202///         fn to_values(self) -> Vec<(ColumnDef, Value)> {
203///             vec![
204///                 (Self::columns()[0], Value::Uint32(self.id)),
205///                 (Self::columns()[1], Value::Text(self.title)),
206///                 (Self::columns()[2], Value::Text(self.content)),
207///                 (Self::columns()[3], Value::Uint32(self.user_id)),
208///             ]
209///         }
210///     }
211///     ```
212///
213/// - Implement the associated `Record` type
214///
215///     ```rust,ignore
216///     pub struct PostRecord {
217///         pub id: Option<Uint32>,
218///         pub title: Option<Text>,
219///         pub content: Option<Text>,
220///         pub user: Option<UserRecord>,
221///     }
222///
223///     impl TableRecord for PostRecord {
224///         type Schema = Post;
225///     
226///         fn from_values(values: TableColumns) -> Self {
227///             let mut id: Option<Uint32> = None;
228///             let mut title: Option<Text> = None;
229///             let mut content: Option<Text> = None;
230///     
231///             let post_values = values
232///                 .iter()
233///                 .find(|(table_name, _)| *table_name == ValuesSource::This)
234///                 .map(|(_, cols)| cols);
235///             
236///             for (column, value) in post_values.unwrap_or(&vec![]) {
237///                 match column.name {
238///                     "id" => {
239///                         if let Value::Uint32(v) = value {
240///                             id = Some(*v);
241///                         }
242///                     }
243///                     "title" => {
244///                         if let Value::Text(v) = value {
245///                             title = Some(v.clone());
246///                         }
247///                     }
248///                     "content" => {
249///                         if let Value::Text(v) = value {
250///                             content = Some(v.clone());
251///                         }
252///                     }
253///                     _ => { /* Ignore unknown columns */ }
254///                 }
255///             }
256///     
257///             let has_user = values.iter().any(|(source, _)| {
258///                 *source
259///                     == ValuesSource::Foreign {
260///                         table: User::table_name(),
261///                         column: "user_id",
262///                     }
263///             });
264///             let user = if has_user {
265///                 Some(UserRecord::from_values(self_reference_values(
266///                     &values,
267///                     User::table_name(),
268///                     "user_id",
269///                 )))
270///             } else {
271///                 None
272///             };
273///     
274///             Self {
275///                 id,
276///                 title,
277///                 content,
278///                 user,
279///             }
280///         }
281///     
282///         fn to_values(&self) -> Vec<(ColumnDef, Value)> {
283///             Self::Schema::columns()
284///                 .iter()
285///                 .zip(vec![
286///                     match self.id {
287///                         Some(v) => Value::Uint32(v),
288///                         None => Value::Null,
289///                     },
290///                     match &self.title {
291///                         Some(v) => Value::Text(v.clone()),
292///                         None => Value::Null,
293///                     },
294///                     match &self.content {
295///                         Some(v) => Value::Text(v.clone()),
296///                         None => Value::Null,
297///                     },
298///                 ])
299///                 .map(|(col_def, value)| (*col_def, value))
300///                 .collect()
301///         }
302///     }
303///     ```
304///
305/// - Implement the associated `InsertRecord` type
306///
307///     ```rust,ignore
308///     #[derive(Clone)]
309///     pub struct PostInsertRequest {
310///         pub id: Uint32,
311///         pub title: Text,
312///         pub content: Text,
313///         pub user_id: Uint32,
314///     }
315///
316///     impl InsertRecord for PostInsertRequest {
317///         type Record = PostRecord;
318///         type Schema = Post;
319///
320///         fn from_values(values: &[(ColumnDef, Value)]) -> ic_dbms_api::prelude::IcDbmsResult<Self> {
321///             let mut id: Option<Uint32> = None;
322///             let mut title: Option<Text> = None;
323///             let mut content: Option<Text> = None;
324///             let mut user_id: Option<Uint32> = None;
325///
326///             for (column, value) in values {
327///                 match column.name {
328///                     "id" => {
329///                         if let Value::Uint32(v) = value {
330///                             id = Some(*v);
331///                         }
332///                     }
333///                     "title" => {
334///                         if let Value::Text(v) = value {
335///                             title = Some(v.clone());
336///                         }
337///                     }
338///                     "content" => {
339///                         if let Value::Text(v) = value {
340///                             content = Some(v.clone());
341///                         }
342///                     }
343///                     "user_id" => {
344///                         if let Value::Uint32(v) = value {
345///                             user_id = Some(*v);
346///                         }
347///                     }
348///                     _ => { /* Ignore unknown columns */ }
349///                 }
350///             }
351///
352///             Ok(Self {
353///                 id: id.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
354///                     "id".to_string(),
355///                 )))?,
356///                 title: title.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
357///                     "title".to_string(),
358///                 )))?,
359///                 content: content.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
360///                     "content".to_string(),
361///                 )))?,
362///                 user_id: user_id.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
363///                     "user_id".to_string(),
364///                 )))?,
365///             })
366///         }
367///
368///         fn into_values(self) -> Vec<(ColumnDef, Value)> {
369///             vec![
370///                 (Self::Schema::columns()[0], Value::Uint32(self.id)),
371///                 (Self::Schema::columns()[1], Value::Text(self.title)),
372///                 (Self::Schema::columns()[2], Value::Text(self.content)),
373///                 (Self::Schema::columns()[3], Value::Uint32(self.user_id)),
374///             ]
375///         }
376///
377///         fn into_record(self) -> Self::Schema {
378///             Post {
379///                 id: self.id,
380///                 title: self.title,
381///                 content: self.content,
382///                 user_id: self.user_id,
383///             }
384///         }
385///     }
386///     ```
387///
388/// - Implement the associated `UpdateRecord` type
389///
390///     ```rust,ignore
391///     pub struct PostUpdateRequest {
392///         pub id: Option<Uint32>,
393///         pub title: Option<Text>,
394///         pub content: Option<Text>,
395///         pub user_id: Option<Uint32>,
396///         pub where_clause: Option<Filter>,
397///     }
398///
399///     impl UpdateRecord for PostUpdateRequest {
400///         type Record = PostRecord;
401///         type Schema = Post;
402///
403///         fn from_values(values: &[(ColumnDef, Value)], where_clause: Option<Filter>) -> Self {
404///             let mut id: Option<Uint32> = None;
405///             let mut title: Option<Text> = None;
406///             let mut content: Option<Text> = None;
407///             let mut user_id: Option<Uint32> = None;
408///
409///             for (column, value) in values {
410///                 match column.name {
411///                     "id" => {
412///                         if let Value::Uint32(v) = value {
413///                             id = Some(*v);
414///                         }
415///                     }
416///                     "title" => {
417///                         if let Value::Text(v) = value {
418///                             title = Some(v.clone());
419///                         }
420///                     }
421///                     "content" => {
422///                         if let Value::Text(v) = value {
423///                             content = Some(v.clone());
424///                         }
425///                     }
426///                     "user_id" => {
427///                         if let Value::Uint32(v) = value {
428///                             user_id = Some(*v);
429///                         }
430///                     }
431///                     _ => { /* Ignore unknown columns */ }
432///                 }
433///             }
434///
435///             Self {
436///                 id,
437///                 title,
438///                 content,
439///                 user_id,
440///                 where_clause,
441///             }
442///         }
443///
444///         fn update_values(&self) -> Vec<(ColumnDef, Value)> {
445///             let mut updates = Vec::new();
446///
447///             if let Some(id) = self.id {
448///                 updates.push((Self::Schema::columns()[0], Value::Uint32(id)));
449///             }
450///             if let Some(title) = &self.title {
451///                 updates.push((Self::Schema::columns()[1], Value::Text(title.clone())));
452///             }
453///             if let Some(content) = &self.content {
454///                 updates.push((Self::Schema::columns()[2], Value::Text(content.clone())));
455///             }
456///             if let Some(user_id) = self.user_id {
457///                 updates.push((Self::Schema::columns()[3], Value::Uint32(user_id)));
458///             }
459///
460///             updates
461///         }
462///
463///         fn where_clause(&self) -> Option<Filter> {
464///             self.where_clause.clone()
465///         }
466///     }
467///     ```
468///
469/// - If has foreign keys, implement the associated `ForeignFetched` (otherwise use `NoForeignFetcher`):
470///
471///     ```rust,ignore
472///     pub struct PostForeignFetcher;
473///
474///     impl ForeignFetcher for PostForeignFetcher {
475///         fn fetch(
476///             &self,
477///             database: &impl Database,
478///             table: &'static str,
479///             local_column: &'static str,
480///             pk_value: Value,
481///         ) -> ic_dbms_api::prelude::IcDbmsResult<TableColumns> {
482///             if table != User::table_name() {
483///                 return Err(IcDbmsError::Query(QueryError::InvalidQuery(format!(
484///                     "ForeignFetcher: unknown table '{table}' for {table_name} foreign fetcher",
485///                     table_name = Post::table_name()
486///                 ))));
487///             }
488///
489///             // query all records from the foreign table
490///             let mut users = database.select(
491///                 Query::<User>::builder()
492///                     .all()
493///                     .limit(1)
494///                     .and_where(Filter::Eq(User::primary_key(), pk_value.clone()))
495///                     .build(),
496///             )?;
497///             let user = match users.pop() {
498///                 Some(user) => user,
499///                 None => {
500///                     return Err(IcDbmsError::Query(QueryError::BrokenForeignKeyReference {
501///                         table: User::table_name(),
502///                         key: pk_value,
503///                     }));
504///                 }
505///             };
506///
507///             let values = user.to_values();
508///             Ok(vec![(
509///                 ValuesSource::Foreign {
510///                     table,
511///                     column: local_column,
512///                 },
513///                 values,
514///             )])
515///         }
516///     }
517///     ```
518///
519/// So for each struct deriving `Table`, we will generate the following type. Given `${StructName}`, we will generate:
520///
521/// - `${StructName}Record` - implementing `TableRecord`
522/// - `${StructName}InsertRequest` - implementing `InsertRecord`
523/// - `${StructName}UpdateRequest` - implementing `UpdateRecord`
524/// - `${StructName}ForeignFetcher` (only if foreign keys are present)
525///
526/// Also, we will implement the `TableSchema` trait for the struct itself and derive `Encode` for `${StructName}`.
527#[proc_macro_derive(Table, attributes(table, primary_key, foreign_key))]
528pub fn derive_table(input: TokenStream) -> TokenStream {
529    let input = parse_macro_input!(input as DeriveInput);
530    self::table::table(input)
531        .expect("failed to derive `Table`")
532        .into()
533}
534
535/// Automatically implements the api for the ic-dbms-canister with all the required methods to interact with the ACL and
536/// the defined tables.
537#[proc_macro_derive(DbmsCanister, attributes(tables))]
538pub fn derive_dbms_canister(input: TokenStream) -> TokenStream {
539    let input = parse_macro_input!(input as DeriveInput);
540    self::dbms_canister::dbms_canister(input)
541        .expect("failed to derive `DbmsCanister`")
542        .into()
543}