ic_dbms_macros/
lib.rs

1#![crate_name = "ic_dbms_macros"]
2#![crate_type = "lib"]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![deny(clippy::print_stdout)]
5#![deny(clippy::print_stderr)]
6
7//! Macros and derive for ic-dbms-canister
8//!
9//! This crate provides procedural macros to automatically implement traits
10//! required by the `ic-dbms-canister`.
11//!
12//! ## Provided Derive Macros
13//!
14//! - `Encode`: Automatically implements the `Encode` trait for structs.
15//! - `Table`: Automatically implements the `TableSchema` trait and associated types.
16//! - `DbmsCanister`: Automatically implements the API for the ic-dbms-canister.
17//!
18
19#![doc(html_playground_url = "https://play.rust-lang.org")]
20#![doc(
21    html_favicon_url = "https://raw.githubusercontent.com/veeso/ic-dbms/main/assets/images/cargo/logo-128.png"
22)]
23#![doc(
24    html_logo_url = "https://raw.githubusercontent.com/veeso/ic-dbms/main/assets/images/cargo/logo-512.png"
25)]
26
27use proc_macro::TokenStream;
28use syn::{DeriveInput, parse_macro_input};
29
30mod dbms_canister;
31mod encode;
32mod table;
33mod utils;
34
35/// Automatically implements the `Encode`` trait for a struct.
36///
37/// This derive macro generates two methods required by the `Encode` trait:
38///
39/// - `fn data_size() -> DataSize`  
40///   Computes the static size of the encoded type.  
41///   If all fields implement `Encode::data_size()` returning  
42///   `DataSize::Fixed(n)`, then the type is also considered fixed-size.  
43///   Otherwise, the type is `DataSize::Dynamic`.
44///
45/// - `fn size(&self) -> MSize`  
46///   Computes the runtime-encoding size of the value by summing the
47///   sizes of all fields.
48///
49/// # What the macro generates
50///
51/// Given a struct like:
52///
53/// ```rust,ignore
54/// #[derive(Encode)]
55/// struct User {
56///     id: Uint32,
57///     name: Text,
58/// }
59/// ```
60///
61/// The macro expands into:
62///
63/// ```rust,ignore
64/// impl Encode for User {
65///     const DATA_SIZE: DataSize = DataSize::Dynamic; // or DataSize::Fixed(n) if applicable
66///
67///     fn size(&self) -> MSize {
68///         self.id.size() + self.name.size()
69///     }
70///
71///     fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
72///         let mut encoded = Vec::with_capacity(self.size() as usize);
73///         encoded.extend_from_slice(&self.id.encode());
74///         encoded.extend_from_slice(&self.name.encode());
75///         std::borrow::Cow::Owned(encoded)
76///     }
77///
78///     fn decode(data: std::borrow::Cow<[u8]>) -> ::ic_dbms_api::prelude::MemoryResult<Self> {
79///         let mut offset = 0;
80///         let id = Uint32::decode(std::borrow::Borrowed(&data[offset..]))?;
81///         offset += id.size() as usize;
82///         let name = Text::decode(std::borrow::Borrowed(&data[offset..]))?;
83///         offset += name.size() as usize;
84///         Ok(Self { id, name })
85///     }
86/// }
87/// ```
88/// # Requirements
89///
90/// - Each field type must implement `Encode`.
91/// - Only works on `struct`s; enums and unions are not supported.
92/// - All field identifiers must be valid Rust identifiers (no tuple structs).
93///
94/// # Notes
95///
96/// - It is intended for internal use within the `ic-dbms-canister` DBMS memory
97///   system.
98///
99/// # Errors
100///
101/// The macro will fail to expand if:
102///
103/// - The struct has unnamed fields (tuple struct)
104/// - A field type does not implement `Encode`
105/// - The macro is applied to a non-struct item.
106///
107/// # Example
108///
109/// ```rust,ignore
110/// #[derive(Encode, Debug, PartialEq, Eq)]
111/// struct Position {
112///     x: Int32,
113///     y: Int32,
114/// }
115///
116/// let pos = Position { x: 10.into(), y: 20.into() };
117/// assert_eq!(Position::data_size(), DataSize::Fixed(8));
118/// assert_eq!(pos.size(), 8);
119/// let encoded = pos.encode();
120/// let decoded = Position::decode(encoded).unwrap();
121/// assert_eq!(pos, decoded);
122/// ```
123#[proc_macro_derive(Encode)]
124pub fn derive_encode(input: TokenStream) -> TokenStream {
125    let input = parse_macro_input!(input as DeriveInput);
126    self::encode::encode(input, None)
127        .expect("Failed to derive `Encode`")
128        .into()
129}
130
131/// Given a struct representing a database table, automatically implements
132/// the `TableSchema` trait with all the necessary types to work with the ic-dbms-canister.
133/// So given this struct:
134///
135/// ```rust,ignore
136/// #[derive(Table, Encode)]
137/// #[table = "posts"]
138/// struct Post {
139///     #[primary_key]
140///     id: Uint32,
141///     title: Text,
142///     content: Text,
143///     #[foreign_key(entity = "User", table = "users", column = "id")]
144///     author_id: Uint32,
145/// }
146/// ```
147///
148/// What we expect as output is:
149///
150/// - To implement the `TableSchema` trait for the struct as follows:
151///
152///     ```rust,ignore
153///     impl TableSchema for Post {
154///         type Insert = PostInsertRequest;
155///         type Record = PostRecord;
156///         type Update = PostUpdateRequest;
157///         type ForeignFetcher = PostForeignFetcher;
158///
159///         fn columns() -> &'static [ColumnDef] {
160///             &[
161///                 ColumnDef {
162///                     name: "id",
163///                     data_type: DataTypeKind::Uint32,
164///                     nullable: false,
165///                     primary_key: true,
166///                     foreign_key: None,
167///                 },
168///                 ColumnDef {
169///                     name: "title",
170///                     data_type: DataTypeKind::Text,
171///                     nullable: false,
172///                     primary_key: false,
173///                     foreign_key: None,
174///                 },
175///                 ColumnDef {
176///                     name: "content",
177///                     data_type: DataTypeKind::Text,
178///                     nullable: false,
179///                     primary_key: false,
180///                     foreign_key: None,
181///                 },
182///                 ColumnDef {
183///                     name: "user_id",
184///                     data_type: DataTypeKind::Uint32,
185///                     nullable: false,
186///                     primary_key: false,
187///                     foreign_key: Some(ForeignKeyDef {
188///                         local_column: "user_id",
189///                         foreign_table: "users",
190///                         foreign_column: "id",
191///                     }),
192///                 },
193///             ]
194///         }
195///
196///         fn table_name() -> &'static str {
197///             "posts"
198///         }
199///
200///         fn primary_key() -> &'static str {
201///             "id"
202///         }
203///
204///         fn to_values(self) -> Vec<(ColumnDef, Value)> {
205///             vec![
206///                 (Self::columns()[0], Value::Uint32(self.id)),
207///                 (Self::columns()[1], Value::Text(self.title)),
208///                 (Self::columns()[2], Value::Text(self.content)),
209///                 (Self::columns()[3], Value::Uint32(self.user_id)),
210///             ]
211///         }
212///     }
213///     ```
214///
215/// - Implement the associated `Record` type
216///
217///     ```rust,ignore
218///     pub struct PostRecord {
219///         pub id: Option<Uint32>,
220///         pub title: Option<Text>,
221///         pub content: Option<Text>,
222///         pub user: Option<UserRecord>,
223///     }
224///
225///     impl TableRecord for PostRecord {
226///         type Schema = Post;
227///     
228///         fn from_values(values: TableColumns) -> Self {
229///             let mut id: Option<Uint32> = None;
230///             let mut title: Option<Text> = None;
231///             let mut content: Option<Text> = None;
232///     
233///             let post_values = values
234///                 .iter()
235///                 .find(|(table_name, _)| *table_name == ValuesSource::This)
236///                 .map(|(_, cols)| cols);
237///             
238///             for (column, value) in post_values.unwrap_or(&vec![]) {
239///                 match column.name {
240///                     "id" => {
241///                         if let Value::Uint32(v) = value {
242///                             id = Some(*v);
243///                         }
244///                     }
245///                     "title" => {
246///                         if let Value::Text(v) = value {
247///                             title = Some(v.clone());
248///                         }
249///                     }
250///                     "content" => {
251///                         if let Value::Text(v) = value {
252///                             content = Some(v.clone());
253///                         }
254///                     }
255///                     _ => { /* Ignore unknown columns */ }
256///                 }
257///             }
258///     
259///             let has_user = values.iter().any(|(source, _)| {
260///                 *source
261///                     == ValuesSource::Foreign {
262///                         table: User::table_name(),
263///                         column: "user_id",
264///                     }
265///             });
266///             let user = if has_user {
267///                 Some(UserRecord::from_values(self_reference_values(
268///                     &values,
269///                     User::table_name(),
270///                     "user_id",
271///                 )))
272///             } else {
273///                 None
274///             };
275///     
276///             Self {
277///                 id,
278///                 title,
279///                 content,
280///                 user,
281///             }
282///         }
283///     
284///         fn to_values(&self) -> Vec<(ColumnDef, Value)> {
285///             Self::Schema::columns()
286///                 .iter()
287///                 .zip(vec![
288///                     match self.id {
289///                         Some(v) => Value::Uint32(v),
290///                         None => Value::Null,
291///                     },
292///                     match &self.title {
293///                         Some(v) => Value::Text(v.clone()),
294///                         None => Value::Null,
295///                     },
296///                     match &self.content {
297///                         Some(v) => Value::Text(v.clone()),
298///                         None => Value::Null,
299///                     },
300///                 ])
301///                 .map(|(col_def, value)| (*col_def, value))
302///                 .collect()
303///         }
304///     }
305///     ```
306///
307/// - Implement the associated `InsertRecord` type
308///
309///     ```rust,ignore
310///     #[derive(Clone)]
311///     pub struct PostInsertRequest {
312///         pub id: Uint32,
313///         pub title: Text,
314///         pub content: Text,
315///         pub user_id: Uint32,
316///     }
317///
318///     impl InsertRecord for PostInsertRequest {
319///         type Record = PostRecord;
320///         type Schema = Post;
321///
322///         fn from_values(values: &[(ColumnDef, Value)]) -> ic_dbms_api::prelude::IcDbmsResult<Self> {
323///             let mut id: Option<Uint32> = None;
324///             let mut title: Option<Text> = None;
325///             let mut content: Option<Text> = None;
326///             let mut user_id: Option<Uint32> = None;
327///
328///             for (column, value) in values {
329///                 match column.name {
330///                     "id" => {
331///                         if let Value::Uint32(v) = value {
332///                             id = Some(*v);
333///                         }
334///                     }
335///                     "title" => {
336///                         if let Value::Text(v) = value {
337///                             title = Some(v.clone());
338///                         }
339///                     }
340///                     "content" => {
341///                         if let Value::Text(v) = value {
342///                             content = Some(v.clone());
343///                         }
344///                     }
345///                     "user_id" => {
346///                         if let Value::Uint32(v) = value {
347///                             user_id = Some(*v);
348///                         }
349///                     }
350///                     _ => { /* Ignore unknown columns */ }
351///                 }
352///             }
353///
354///             Ok(Self {
355///                 id: id.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
356///                     "id".to_string(),
357///                 )))?,
358///                 title: title.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
359///                     "title".to_string(),
360///                 )))?,
361///                 content: content.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
362///                     "content".to_string(),
363///                 )))?,
364///                 user_id: user_id.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
365///                     "user_id".to_string(),
366///                 )))?,
367///             })
368///         }
369///
370///         fn into_values(self) -> Vec<(ColumnDef, Value)> {
371///             vec![
372///                 (Self::Schema::columns()[0], Value::Uint32(self.id)),
373///                 (Self::Schema::columns()[1], Value::Text(self.title)),
374///                 (Self::Schema::columns()[2], Value::Text(self.content)),
375///                 (Self::Schema::columns()[3], Value::Uint32(self.user_id)),
376///             ]
377///         }
378///
379///         fn into_record(self) -> Self::Schema {
380///             Post {
381///                 id: self.id,
382///                 title: self.title,
383///                 content: self.content,
384///                 user_id: self.user_id,
385///             }
386///         }
387///     }
388///     ```
389///
390/// - Implement the associated `UpdateRecord` type
391///
392///     ```rust,ignore
393///     pub struct PostUpdateRequest {
394///         pub id: Option<Uint32>,
395///         pub title: Option<Text>,
396///         pub content: Option<Text>,
397///         pub user_id: Option<Uint32>,
398///         pub where_clause: Option<Filter>,
399///     }
400///
401///     impl UpdateRecord for PostUpdateRequest {
402///         type Record = PostRecord;
403///         type Schema = Post;
404///
405///         fn from_values(values: &[(ColumnDef, Value)], where_clause: Option<Filter>) -> Self {
406///             let mut id: Option<Uint32> = None;
407///             let mut title: Option<Text> = None;
408///             let mut content: Option<Text> = None;
409///             let mut user_id: Option<Uint32> = None;
410///
411///             for (column, value) in values {
412///                 match column.name {
413///                     "id" => {
414///                         if let Value::Uint32(v) = value {
415///                             id = Some(*v);
416///                         }
417///                     }
418///                     "title" => {
419///                         if let Value::Text(v) = value {
420///                             title = Some(v.clone());
421///                         }
422///                     }
423///                     "content" => {
424///                         if let Value::Text(v) = value {
425///                             content = Some(v.clone());
426///                         }
427///                     }
428///                     "user_id" => {
429///                         if let Value::Uint32(v) = value {
430///                             user_id = Some(*v);
431///                         }
432///                     }
433///                     _ => { /* Ignore unknown columns */ }
434///                 }
435///             }
436///
437///             Self {
438///                 id,
439///                 title,
440///                 content,
441///                 user_id,
442///                 where_clause,
443///             }
444///         }
445///
446///         fn update_values(&self) -> Vec<(ColumnDef, Value)> {
447///             let mut updates = Vec::new();
448///
449///             if let Some(id) = self.id {
450///                 updates.push((Self::Schema::columns()[0], Value::Uint32(id)));
451///             }
452///             if let Some(title) = &self.title {
453///                 updates.push((Self::Schema::columns()[1], Value::Text(title.clone())));
454///             }
455///             if let Some(content) = &self.content {
456///                 updates.push((Self::Schema::columns()[2], Value::Text(content.clone())));
457///             }
458///             if let Some(user_id) = self.user_id {
459///                 updates.push((Self::Schema::columns()[3], Value::Uint32(user_id)));
460///             }
461///
462///             updates
463///         }
464///
465///         fn where_clause(&self) -> Option<Filter> {
466///             self.where_clause.clone()
467///         }
468///     }
469///     ```
470///
471/// - If has foreign keys, implement the associated `ForeignFetched` (otherwise use `NoForeignFetcher`):
472///
473///     ```rust,ignore
474///     pub struct PostForeignFetcher;
475///
476///     impl ForeignFetcher for PostForeignFetcher {
477///         fn fetch(
478///             &self,
479///             database: &impl Database,
480///             table: &'static str,
481///             local_column: &'static str,
482///             pk_value: Value,
483///         ) -> ic_dbms_api::prelude::IcDbmsResult<TableColumns> {
484///             if table != User::table_name() {
485///                 return Err(IcDbmsError::Query(QueryError::InvalidQuery(format!(
486///                     "ForeignFetcher: unknown table '{table}' for {table_name} foreign fetcher",
487///                     table_name = Post::table_name()
488///                 ))));
489///             }
490///
491///             // query all records from the foreign table
492///             let mut users = database.select(
493///                 Query::<User>::builder()
494///                     .all()
495///                     .limit(1)
496///                     .and_where(Filter::Eq(User::primary_key(), pk_value.clone()))
497///                     .build(),
498///             )?;
499///             let user = match users.pop() {
500///                 Some(user) => user,
501///                 None => {
502///                     return Err(IcDbmsError::Query(QueryError::BrokenForeignKeyReference {
503///                         table: User::table_name(),
504///                         key: pk_value,
505///                     }));
506///                 }
507///             };
508///
509///             let values = user.to_values();
510///             Ok(vec![(
511///                 ValuesSource::Foreign {
512///                     table,
513///                     column: local_column,
514///                 },
515///                 values,
516///             )])
517///         }
518///     }
519///     ```
520///
521/// So for each struct deriving `Table`, we will generate the following type. Given `${StructName}`, we will generate:
522///
523/// - `${StructName}Record` - implementing `TableRecord`
524/// - `${StructName}InsertRequest` - implementing `InsertRecord`
525/// - `${StructName}UpdateRequest` - implementing `UpdateRecord`
526/// - `${StructName}ForeignFetcher` (only if foreign keys are present)
527///
528/// Also, we will implement the `TableSchema` trait for the struct itself and derive `Encode` for `${StructName}`.
529///
530/// ## Attributes
531///
532/// The `Table` derive macro supports the following attributes:
533///
534/// - `#[table = "table_name"]`: Specifies the name of the table in the database.
535/// - `#[alignment = N]`: (optional) Specifies the alignment for the table records. Use only if you know what you are doing.
536/// - `#[primary_key]`: Marks a field as the primary key of the table.
537/// - `#[foreign_key(entity = "EntityName", table = "table_name", column = "column_name")]`: Defines a foreign key relationship.
538/// - `#[sanitizer(SanitizerType)]`: Specifies a sanitize for the field.
539/// - `#[validate(ValidatorType)]`: Specifies a validator for the field.
540///
541#[proc_macro_derive(
542    Table,
543    attributes(alignment, table, primary_key, foreign_key, sanitizer, validate)
544)]
545pub fn derive_table(input: TokenStream) -> TokenStream {
546    let input = parse_macro_input!(input as DeriveInput);
547    self::table::table(input)
548        .expect("failed to derive `Table`")
549        .into()
550}
551
552/// Automatically implements the api for the ic-dbms-canister with all the required methods to interact with the ACL and
553/// the defined tables.
554#[proc_macro_derive(DbmsCanister, attributes(tables))]
555pub fn derive_dbms_canister(input: TokenStream) -> TokenStream {
556    let input = parse_macro_input!(input as DeriveInput);
557    self::dbms_canister::dbms_canister(input)
558        .expect("failed to derive `DbmsCanister`")
559        .into()
560}