Skip to main content

ic_dbms_macros/
lib.rs

1#![crate_name = "ic_dbms_macros"]
2#![crate_type = "lib"]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4#![deny(clippy::print_stdout)]
5#![deny(clippy::print_stderr)]
6
7//! Macros and derive for ic-dbms-canister
8//!
9//! This crate provides procedural macros to automatically implement traits
10//! required by the `ic-dbms-canister`.
11//!
12//! ## Provided Derive Macros
13//!
14//! - `Encode`: Automatically implements the `Encode` trait for structs.
15//! - `Table`: Automatically implements the `TableSchema` trait and associated types.
16//! - `DatabaseSchema`: Generates `DatabaseSchema<M>` trait dispatch and `register_tables`.
17//! - `DbmsCanister`: Automatically implements the API for the ic-dbms-canister.
18//!
19
20#![doc(html_playground_url = "https://play.rust-lang.org")]
21#![doc(
22    html_favicon_url = "https://raw.githubusercontent.com/veeso/wasm-dbms/main/assets/images/cargo/logo-128.png"
23)]
24#![doc(
25    html_logo_url = "https://raw.githubusercontent.com/veeso/wasm-dbms/main/assets/images/cargo/logo-512.png"
26)]
27
28use proc_macro::TokenStream;
29use syn::{DeriveInput, parse_macro_input};
30
31mod custom_data_type;
32mod database_schema;
33mod dbms_canister;
34mod encode;
35mod table;
36mod utils;
37
38/// Automatically implements the `Encode` trait for a struct.
39///
40/// This derive macro generates two methods required by the `Encode` trait:
41///
42/// - `fn data_size() -> DataSize`  
43///   Computes the static size of the encoded type.  
44///   If all fields implement `Encode::data_size()` returning  
45///   `DataSize::Fixed(n)`, then the type is also considered fixed-size.  
46///   Otherwise, the type is `DataSize::Dynamic`.
47///
48/// - `fn size(&self) -> MSize`  
49///   Computes the runtime-encoding size of the value by summing the
50///   sizes of all fields.
51///
52/// # What the macro generates
53///
54/// Given a struct like:
55///
56/// ```rust,ignore
57/// #[derive(Encode)]
58/// struct User {
59///     id: Uint32,
60///     name: Text,
61/// }
62/// ```
63///
64/// The macro expands into:
65///
66/// ```rust,ignore
67/// impl Encode for User {
68///     const DATA_SIZE: DataSize = DataSize::Dynamic; // or DataSize::Fixed(n) if applicable
69///
70///     fn size(&self) -> MSize {
71///         self.id.size() + self.name.size()
72///     }
73///
74///     fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
75///         let mut encoded = Vec::with_capacity(self.size() as usize);
76///         encoded.extend_from_slice(&self.id.encode());
77///         encoded.extend_from_slice(&self.name.encode());
78///         std::borrow::Cow::Owned(encoded)
79///     }
80///
81///     fn decode(data: std::borrow::Cow<[u8]>) -> ::ic_dbms_api::prelude::MemoryResult<Self> {
82///         let mut offset = 0;
83///         let id = Uint32::decode(std::borrow::Borrowed(&data[offset..]))?;
84///         offset += id.size() as usize;
85///         let name = Text::decode(std::borrow::Borrowed(&data[offset..]))?;
86///         offset += name.size() as usize;
87///         Ok(Self { id, name })
88///     }
89/// }
90/// ```
91/// # Requirements
92///
93/// - Each field type must implement `Encode`.
94/// - Only works on `struct`s; enums and unions are not supported.
95/// - All field identifiers must be valid Rust identifiers (no tuple structs).
96///
97/// # Notes
98///
99/// - It is intended for internal use within the `ic-dbms-canister` DBMS memory
100///   system.
101///
102/// # Errors
103///
104/// The macro will fail to expand if:
105///
106/// - The struct has unnamed fields (tuple struct)
107/// - A field type does not implement `Encode`
108/// - The macro is applied to a non-struct item.
109///
110/// # Example
111///
112/// ```rust,ignore
113/// #[derive(Encode, Debug, PartialEq, Eq)]
114/// struct Position {
115///     x: Int32,
116///     y: Int32,
117/// }
118///
119/// let pos = Position { x: 10.into(), y: 20.into() };
120/// assert_eq!(Position::data_size(), DataSize::Fixed(8));
121/// assert_eq!(pos.size(), 8);
122/// let encoded = pos.encode();
123/// let decoded = Position::decode(encoded).unwrap();
124/// assert_eq!(pos, decoded);
125/// ```
126#[proc_macro_derive(Encode)]
127pub fn derive_encode(input: TokenStream) -> TokenStream {
128    let input = parse_macro_input!(input as DeriveInput);
129    self::encode::encode(input, None)
130        .expect("Failed to derive `Encode`")
131        .into()
132}
133
134/// Given a struct representing a database table, automatically implements
135/// the `TableSchema` trait with all the necessary types to work with the ic-dbms-canister.
136/// So given this struct:
137///
138/// ```rust,ignore
139/// #[derive(Table, Encode)]
140/// #[table = "posts"]
141/// struct Post {
142///     #[primary_key]
143///     id: Uint32,
144///     title: Text,
145///     content: Text,
146///     #[foreign_key(entity = "User", table = "users", column = "id")]
147///     author_id: Uint32,
148/// }
149/// ```
150///
151/// What we expect as output is:
152///
153/// - To implement the `TableSchema` trait for the struct as follows:
154///
155///     ```rust,ignore
156///     impl TableSchema for Post {
157///         type Insert = PostInsertRequest;
158///         type Record = PostRecord;
159///         type Update = PostUpdateRequest;
160///         type ForeignFetcher = PostForeignFetcher;
161///
162///         fn columns() -> &'static [ColumnDef] {
163///             &[
164///                 ColumnDef {
165///                     name: "id",
166///                     data_type: DataTypeKind::Uint32,
167///                     nullable: false,
168///                     primary_key: true,
169///                     foreign_key: None,
170///                 },
171///                 ColumnDef {
172///                     name: "title",
173///                     data_type: DataTypeKind::Text,
174///                     nullable: false,
175///                     primary_key: false,
176///                     foreign_key: None,
177///                 },
178///                 ColumnDef {
179///                     name: "content",
180///                     data_type: DataTypeKind::Text,
181///                     nullable: false,
182///                     primary_key: false,
183///                     foreign_key: None,
184///                 },
185///                 ColumnDef {
186///                     name: "user_id",
187///                     data_type: DataTypeKind::Uint32,
188///                     nullable: false,
189///                     primary_key: false,
190///                     foreign_key: Some(ForeignKeyDef {
191///                         local_column: "user_id",
192///                         foreign_table: "users",
193///                         foreign_column: "id",
194///                     }),
195///                 },
196///             ]
197///         }
198///
199///         fn table_name() -> &'static str {
200///             "posts"
201///         }
202///
203///         fn primary_key() -> &'static str {
204///             "id"
205///         }
206///
207///         fn to_values(self) -> Vec<(ColumnDef, Value)> {
208///             vec![
209///                 (Self::columns()[0], Value::Uint32(self.id)),
210///                 (Self::columns()[1], Value::Text(self.title)),
211///                 (Self::columns()[2], Value::Text(self.content)),
212///                 (Self::columns()[3], Value::Uint32(self.user_id)),
213///             ]
214///         }
215///     }
216///     ```
217///
218/// - Implement the associated `Record` type
219///
220///     ```rust,ignore
221///     pub struct PostRecord {
222///         pub id: Option<Uint32>,
223///         pub title: Option<Text>,
224///         pub content: Option<Text>,
225///         pub user: Option<UserRecord>,
226///     }
227///
228///     impl TableRecord for PostRecord {
229///         type Schema = Post;
230///     
231///         fn from_values(values: TableColumns) -> Self {
232///             let mut id: Option<Uint32> = None;
233///             let mut title: Option<Text> = None;
234///             let mut content: Option<Text> = None;
235///     
236///             let post_values = values
237///                 .iter()
238///                 .find(|(table_name, _)| *table_name == ValuesSource::This)
239///                 .map(|(_, cols)| cols);
240///             
241///             for (column, value) in post_values.unwrap_or(&vec![]) {
242///                 match column.name {
243///                     "id" => {
244///                         if let Value::Uint32(v) = value {
245///                             id = Some(*v);
246///                         }
247///                     }
248///                     "title" => {
249///                         if let Value::Text(v) = value {
250///                             title = Some(v.clone());
251///                         }
252///                     }
253///                     "content" => {
254///                         if let Value::Text(v) = value {
255///                             content = Some(v.clone());
256///                         }
257///                     }
258///                     _ => { /* Ignore unknown columns */ }
259///                 }
260///             }
261///     
262///             let has_user = values.iter().any(|(source, _)| {
263///                 *source
264///                     == ValuesSource::Foreign {
265///                         table: User::table_name(),
266///                         column: "user_id",
267///                     }
268///             });
269///             let user = if has_user {
270///                 Some(UserRecord::from_values(self_reference_values(
271///                     &values,
272///                     User::table_name(),
273///                     "user_id",
274///                 )))
275///             } else {
276///                 None
277///             };
278///     
279///             Self {
280///                 id,
281///                 title,
282///                 content,
283///                 user,
284///             }
285///         }
286///     
287///         fn to_values(&self) -> Vec<(ColumnDef, Value)> {
288///             Self::Schema::columns()
289///                 .iter()
290///                 .zip(vec![
291///                     match self.id {
292///                         Some(v) => Value::Uint32(v),
293///                         None => Value::Null,
294///                     },
295///                     match &self.title {
296///                         Some(v) => Value::Text(v.clone()),
297///                         None => Value::Null,
298///                     },
299///                     match &self.content {
300///                         Some(v) => Value::Text(v.clone()),
301///                         None => Value::Null,
302///                     },
303///                 ])
304///                 .map(|(col_def, value)| (*col_def, value))
305///                 .collect()
306///         }
307///     }
308///     ```
309///
310/// - Implement the associated `InsertRecord` type
311///
312///     ```rust,ignore
313///     #[derive(Clone)]
314///     pub struct PostInsertRequest {
315///         pub id: Uint32,
316///         pub title: Text,
317///         pub content: Text,
318///         pub user_id: Uint32,
319///     }
320///
321///     impl InsertRecord for PostInsertRequest {
322///         type Record = PostRecord;
323///         type Schema = Post;
324///
325///         fn from_values(values: &[(ColumnDef, Value)]) -> ic_dbms_api::prelude::IcDbmsResult<Self> {
326///             let mut id: Option<Uint32> = None;
327///             let mut title: Option<Text> = None;
328///             let mut content: Option<Text> = None;
329///             let mut user_id: Option<Uint32> = None;
330///
331///             for (column, value) in values {
332///                 match column.name {
333///                     "id" => {
334///                         if let Value::Uint32(v) = value {
335///                             id = Some(*v);
336///                         }
337///                     }
338///                     "title" => {
339///                         if let Value::Text(v) = value {
340///                             title = Some(v.clone());
341///                         }
342///                     }
343///                     "content" => {
344///                         if let Value::Text(v) = value {
345///                             content = Some(v.clone());
346///                         }
347///                     }
348///                     "user_id" => {
349///                         if let Value::Uint32(v) = value {
350///                             user_id = Some(*v);
351///                         }
352///                     }
353///                     _ => { /* Ignore unknown columns */ }
354///                 }
355///             }
356///
357///             Ok(Self {
358///                 id: id.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
359///                     "id".to_string(),
360///                 )))?,
361///                 title: title.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
362///                     "title".to_string(),
363///                 )))?,
364///                 content: content.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
365///                     "content".to_string(),
366///                 )))?,
367///                 user_id: user_id.ok_or(IcDbmsError::Query(QueryError::MissingNonNullableField(
368///                     "user_id".to_string(),
369///                 )))?,
370///             })
371///         }
372///
373///         fn into_values(self) -> Vec<(ColumnDef, Value)> {
374///             vec![
375///                 (Self::Schema::columns()[0], Value::Uint32(self.id)),
376///                 (Self::Schema::columns()[1], Value::Text(self.title)),
377///                 (Self::Schema::columns()[2], Value::Text(self.content)),
378///                 (Self::Schema::columns()[3], Value::Uint32(self.user_id)),
379///             ]
380///         }
381///
382///         fn into_record(self) -> Self::Schema {
383///             Post {
384///                 id: self.id,
385///                 title: self.title,
386///                 content: self.content,
387///                 user_id: self.user_id,
388///             }
389///         }
390///     }
391///     ```
392///
393/// - Implement the associated `UpdateRecord` type
394///
395///     ```rust,ignore
396///     pub struct PostUpdateRequest {
397///         pub id: Option<Uint32>,
398///         pub title: Option<Text>,
399///         pub content: Option<Text>,
400///         pub user_id: Option<Uint32>,
401///         pub where_clause: Option<Filter>,
402///     }
403///
404///     impl UpdateRecord for PostUpdateRequest {
405///         type Record = PostRecord;
406///         type Schema = Post;
407///
408///         fn from_values(values: &[(ColumnDef, Value)], where_clause: Option<Filter>) -> Self {
409///             let mut id: Option<Uint32> = None;
410///             let mut title: Option<Text> = None;
411///             let mut content: Option<Text> = None;
412///             let mut user_id: Option<Uint32> = None;
413///
414///             for (column, value) in values {
415///                 match column.name {
416///                     "id" => {
417///                         if let Value::Uint32(v) = value {
418///                             id = Some(*v);
419///                         }
420///                     }
421///                     "title" => {
422///                         if let Value::Text(v) = value {
423///                             title = Some(v.clone());
424///                         }
425///                     }
426///                     "content" => {
427///                         if let Value::Text(v) = value {
428///                             content = Some(v.clone());
429///                         }
430///                     }
431///                     "user_id" => {
432///                         if let Value::Uint32(v) = value {
433///                             user_id = Some(*v);
434///                         }
435///                     }
436///                     _ => { /* Ignore unknown columns */ }
437///                 }
438///             }
439///
440///             Self {
441///                 id,
442///                 title,
443///                 content,
444///                 user_id,
445///                 where_clause,
446///             }
447///         }
448///
449///         fn update_values(&self) -> Vec<(ColumnDef, Value)> {
450///             let mut updates = Vec::new();
451///
452///             if let Some(id) = self.id {
453///                 updates.push((Self::Schema::columns()[0], Value::Uint32(id)));
454///             }
455///             if let Some(title) = &self.title {
456///                 updates.push((Self::Schema::columns()[1], Value::Text(title.clone())));
457///             }
458///             if let Some(content) = &self.content {
459///                 updates.push((Self::Schema::columns()[2], Value::Text(content.clone())));
460///             }
461///             if let Some(user_id) = self.user_id {
462///                 updates.push((Self::Schema::columns()[3], Value::Uint32(user_id)));
463///             }
464///
465///             updates
466///         }
467///
468///         fn where_clause(&self) -> Option<Filter> {
469///             self.where_clause.clone()
470///         }
471///     }
472///     ```
473///
474/// - If has foreign keys, implement the associated `ForeignFetched` (otherwise use `NoForeignFetcher`):
475///
476///     ```rust,ignore
477///     pub struct PostForeignFetcher;
478///
479///     impl ForeignFetcher for PostForeignFetcher {
480///         fn fetch(
481///             &self,
482///             database: &impl Database,
483///             table: &'static str,
484///             local_column: &'static str,
485///             pk_value: Value,
486///         ) -> ic_dbms_api::prelude::IcDbmsResult<TableColumns> {
487///             if table != User::table_name() {
488///                 return Err(IcDbmsError::Query(QueryError::InvalidQuery(format!(
489///                     "ForeignFetcher: unknown table '{table}' for {table_name} foreign fetcher",
490///                     table_name = Post::table_name()
491///                 ))));
492///             }
493///
494///             // query all records from the foreign table
495///             let mut users = database.select(
496///                 Query::<User>::builder()
497///                     .all()
498///                     .limit(1)
499///                     .and_where(Filter::Eq(User::primary_key(), pk_value.clone()))
500///                     .build(),
501///             )?;
502///             let user = match users.pop() {
503///                 Some(user) => user,
504///                 None => {
505///                     return Err(IcDbmsError::Query(QueryError::BrokenForeignKeyReference {
506///                         table: User::table_name(),
507///                         key: pk_value,
508///                     }));
509///                 }
510///             };
511///
512///             let values = user.to_values();
513///             Ok(vec![(
514///                 ValuesSource::Foreign {
515///                     table,
516///                     column: local_column,
517///                 },
518///                 values,
519///             )])
520///         }
521///     }
522///     ```
523///
524/// So for each struct deriving `Table`, we will generate the following type. Given `${StructName}`, we will generate:
525///
526/// - `${StructName}Record` - implementing `TableRecord`
527/// - `${StructName}InsertRequest` - implementing `InsertRecord`
528/// - `${StructName}UpdateRequest` - implementing `UpdateRecord`
529/// - `${StructName}ForeignFetcher` (only if foreign keys are present)
530///
531/// Also, we will implement the `TableSchema` trait for the struct itself and derive `Encode` for `${StructName}`.
532///
533/// ## Attributes
534///
535/// The `Table` derive macro supports the following attributes:
536///
537/// - `#[table = "table_name"]`: Specifies the name of the table in the database.
538/// - `#[alignment = N]`: (optional) Specifies the alignment for the table records. Use only if you know what you are doing.
539/// - `#[primary_key]`: Marks a field as the primary key of the table.
540/// - `#[foreign_key(entity = "EntityName", table = "table_name", column = "column_name")]`: Defines a foreign key relationship.
541/// - `#[sanitizer(SanitizerType)]`: Specifies a sanitize for the field.
542/// - `#[validate(ValidatorType)]`: Specifies a validator for the field.
543///
544#[proc_macro_derive(
545    Table,
546    attributes(
547        alignment,
548        table,
549        primary_key,
550        foreign_key,
551        sanitizer,
552        validate,
553        custom_type
554    )
555)]
556pub fn derive_table(input: TokenStream) -> TokenStream {
557    let input = parse_macro_input!(input as DeriveInput);
558    self::table::table(input)
559        .expect("failed to derive `Table`")
560        .into()
561}
562
563/// Generates a [`DatabaseSchema`] implementation for IC canister crates.
564///
565/// This macro uses `::ic_dbms_canister::prelude::` and
566/// `::ic_dbms_api::prelude::` paths so the generated code resolves
567/// correctly in crates that depend on `ic-dbms-canister` without
568/// requiring direct `wasm-dbms` dependencies.
569///
570/// # Example
571///
572/// ```rust,ignore
573/// #[derive(DatabaseSchema, DbmsCanister)]
574/// #[tables(User = "users", Post = "posts")]
575/// pub struct MyCanister;
576/// ```
577#[proc_macro_derive(DatabaseSchema, attributes(tables))]
578pub fn derive_database_schema(input: TokenStream) -> TokenStream {
579    let input = parse_macro_input!(input as DeriveInput);
580    self::database_schema::database_schema(input)
581        .expect("failed to derive `DatabaseSchema`")
582        .into()
583}
584
585/// Automatically implements the api for the ic-dbms-canister with all the required methods to interact with the ACL and
586/// the defined tables.
587#[proc_macro_derive(DbmsCanister, attributes(tables))]
588pub fn derive_dbms_canister(input: TokenStream) -> TokenStream {
589    let input = parse_macro_input!(input as DeriveInput);
590    self::dbms_canister::dbms_canister(input)
591        .expect("failed to derive `DbmsCanister`")
592        .into()
593}
594
595/// Derives the [`CustomDataType`] trait and an `impl From<T> for Value` conversion
596/// for a user-defined enum or struct.
597///
598/// The type must also derive [`Encode`] (for binary serialization) and implement
599/// [`Display`](std::fmt::Display) (for the cached display string in [`CustomValue`]).
600///
601/// # Required attribute
602///
603/// - `#[type_tag = "..."]`: A unique string identifier for this custom data type.
604///
605/// # What the macro generates
606///
607/// Given a type like:
608///
609/// ```rust,ignore
610/// #[derive(Encode, CustomDataType)]
611/// #[type_tag = "status"]
612/// enum Status { Active, Inactive }
613/// ```
614///
615/// The macro expands into:
616///
617/// ```rust,ignore
618/// impl CustomDataType for Status {
619///     const TYPE_TAG: &'static str = "status";
620/// }
621///
622/// impl From<Status> for Value {
623///     fn from(val: Status) -> Value {
624///         Value::Custom(CustomValue {
625///             type_tag: "status".to_string(),
626///             encoded: Encode::encode(&val).into_owned(),
627///             display: val.to_string(),
628///         })
629///     }
630/// }
631/// ```
632///
633/// # Note
634///
635/// The user must also provide `Display`, `Default`, and `DataType` implementations
636/// for the type. This macro only bridges the custom type to the `Value` system.
637#[proc_macro_derive(CustomDataType, attributes(type_tag))]
638pub fn derive_custom_data_type(input: TokenStream) -> TokenStream {
639    let input = parse_macro_input!(input as DeriveInput);
640    custom_data_type::custom_data_type(&input)
641        .unwrap_or_else(|e| e.to_compile_error())
642        .into()
643}