icarus_derive/
lib.rs

1// Copyright (c) 2025 Icarus Team. All Rights Reserved.
2// Licensed under BSL-1.1. See LICENSE and NOTICE files.
3// Signature verification and telemetry must remain intact.
4
5// #![warn(missing_docs)] // TODO: Enable after adding all documentation
6
7//! Procedural macros for the Icarus SDK
8//!
9//! This crate provides derive macros and attribute macros to reduce
10//! boilerplate when building MCP servers for ICP.
11
12use proc_macro::TokenStream;
13use proc_macro2::TokenStream as TokenStream2;
14use quote::quote;
15use syn::{parse_macro_input, DeriveInput};
16
17mod server;
18mod tools;
19
20/// Derive macro for creating MCP tools
21#[proc_macro_derive(IcarusTool, attributes(icarus_tool))]
22pub fn derive_icarus_tool(input: TokenStream) -> TokenStream {
23    let input = parse_macro_input!(input as DeriveInput);
24
25    // Extract tool metadata from attributes
26    let mut name = None;
27    let mut description = None;
28
29    for attr in &input.attrs {
30        if attr.path().is_ident("icarus_tool") {
31            attr.parse_nested_meta(|meta| {
32                if meta.path.is_ident("name") {
33                    name = Some(meta.value()?.parse::<syn::LitStr>()?.value());
34                    Ok(())
35                } else if meta.path.is_ident("description") {
36                    description = Some(meta.value()?.parse::<syn::LitStr>()?.value());
37                    Ok(())
38                } else {
39                    Err(meta.error("unsupported icarus_tool attribute"))
40                }
41            })
42            .expect("Failed to parse icarus_tool attribute");
43        }
44    }
45
46    let struct_name = &input.ident;
47    let tool_name = name.unwrap_or_else(|| struct_name.to_string());
48    let tool_desc = description.unwrap_or_else(|| format!("{} tool", tool_name));
49
50    // Generate the implementation
51    let expanded = quote! {
52        #[async_trait::async_trait]
53        impl icarus_core::tool::IcarusTool for #struct_name {
54            fn info(&self) -> icarus_core::tool::ToolInfo {
55                icarus_core::tool::ToolInfo {
56                    name: #tool_name.to_string(),
57                    description: #tool_desc.to_string(),
58                    input_schema: serde_json::json!({
59                        "type": "object",
60                        "properties": {},
61                        "required": []
62                    }),
63                }
64            }
65
66            fn to_rmcp_tool(&self) -> rmcp::model::Tool {
67                use std::borrow::Cow;
68                use std::sync::Arc;
69
70                let schema = serde_json::json!({
71                    "type": "object",
72                    "properties": {},
73                    "required": []
74                });
75
76                rmcp::model::Tool {
77                    name: Cow::Borrowed(#tool_name),
78                    description: Some(Cow::Borrowed(#tool_desc)),
79                    input_schema: Arc::new(schema.as_object().unwrap().clone()),
80                    annotations: None,
81                }
82            }
83
84            async fn execute(&self, args: serde_json::Value) -> icarus_core::error::Result<serde_json::Value> {
85                // Default implementation - override in your tool
86                Ok(serde_json::json!({
87                    "error": "Tool execution not implemented"
88                }))
89            }
90        }
91    };
92
93    TokenStream::from(expanded)
94}
95
96/// Attribute macro for MCP server setup
97/// pub struct MyServer {
98///     tools: Vec<Box<dyn IcarusTool>>,
99/// }
100/// ```
101#[proc_macro_attribute]
102pub fn icarus_server(args: TokenStream, input: TokenStream) -> TokenStream {
103    let args = TokenStream2::from(args);
104    let input = parse_macro_input!(input as DeriveInput);
105    server::expand_icarus_server(args, input).into()
106}
107
108/// Derive macro for common Icarus type patterns
109///
110/// This is a convenience macro that combines IcarusStorable with sensible defaults.
111/// You still need to derive the standard traits manually.
112///
113/// # Examples
114/// ```
115/// #[derive(Debug, Clone, Serialize, Deserialize, CandidType, IcarusType)]
116/// struct MemoryEntry {
117///     id: String,
118///     content: String,
119///     created_at: u64,
120/// }
121/// ```
122///
123/// This is equivalent to:
124/// ```
125/// #[derive(Debug, Clone, Serialize, Deserialize, CandidType, IcarusStorable)]
126/// #[icarus_storable(unbounded)]
127/// struct MemoryEntry { ... }
128/// ```
129#[proc_macro_derive(IcarusType, attributes(icarus_storable))]
130pub fn derive_icarus_type(input: TokenStream) -> TokenStream {
131    let input = parse_macro_input!(input as DeriveInput);
132    let struct_name = &input.ident;
133
134    // Extract generics if any
135    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
136
137    // Parse attributes for storage configuration
138    let mut unbounded = true; // Default to unbounded for convenience
139    let mut max_size_bytes = 1024 * 1024; // 1MB default
140
141    for attr in &input.attrs {
142        if attr.path().is_ident("icarus_storable") {
143            attr.parse_nested_meta(|meta| {
144                if meta.path.is_ident("unbounded") {
145                    unbounded = true;
146                    Ok(())
147                } else if meta.path.is_ident("bounded") {
148                    unbounded = false;
149                    Ok(())
150                } else if meta.path.is_ident("max_size") {
151                    let value = meta.value()?;
152                    let lit_str: syn::LitStr = value.parse()?;
153                    let size_str = lit_str.value();
154                    max_size_bytes = parse_size_string(&size_str);
155                    unbounded = false;
156                    Ok(())
157                } else {
158                    Ok(()) // Ignore other attributes
159                }
160            })
161            .unwrap_or(()); // Ignore parse errors
162        }
163    }
164
165    let bound = if unbounded {
166        quote! { ic_stable_structures::storable::Bound::Unbounded }
167    } else {
168        quote! {
169            ic_stable_structures::storable::Bound::Bounded {
170                max_size: #max_size_bytes,
171                is_fixed_size: false,
172            }
173        }
174    };
175
176    // Generate all the common trait implementations
177    let expanded = quote! {
178        // Note: We expect the user to add #[derive(Debug, Clone, Serialize, Deserialize, CandidType)]
179        // This macro just adds the IcarusStorable functionality
180
181        // Implement Storable for ICP
182        impl #impl_generics ic_stable_structures::Storable for #struct_name #ty_generics #where_clause {
183            fn to_bytes(&self) -> std::borrow::Cow<[u8]> {
184                std::borrow::Cow::Owned(
185                    candid::encode_one(self).expect("Failed to encode to Candid")
186                )
187            }
188
189            fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self {
190                candid::decode_one(&bytes).expect("Failed to decode from Candid")
191            }
192
193            const BOUND: ic_stable_structures::storable::Bound = #bound;
194        }
195    };
196
197    TokenStream::from(expanded)
198}
199
200/// Derive macro for simplified storage declaration
201///
202/// Generates stable storage declarations from a simple struct definition.
203/// Automatically assigns memory IDs and handles initialization.
204///
205/// # Examples
206/// ```
207/// #[derive(IcarusStorage)]
208/// struct Storage {
209///     memories: StableBTreeMap<String, MemoryEntry>,
210///     counter: u64,
211///     users: StableBTreeMap<Principal, User>,
212/// }
213/// ```
214///
215/// This generates:
216/// - Thread-local storage declarations
217/// - Memory manager initialization  
218/// - Accessor methods for each field
219#[proc_macro_derive(IcarusStorage)]
220pub fn derive_icarus_storage(input: TokenStream) -> TokenStream {
221    let input = parse_macro_input!(input as DeriveInput);
222
223    if let syn::Data::Struct(data_struct) = &input.data {
224        if let syn::Fields::Named(fields_named) = &data_struct.fields {
225            let struct_name = &input.ident;
226            let mut storage_declarations = vec![];
227            let mut accessor_methods = vec![];
228            let mut memory_id = 0u8;
229
230            for field in &fields_named.named {
231                if let Some(field_name) = &field.ident {
232                    let field_type = &field.ty;
233                    let field_name_upper =
234                        syn::Ident::new(&field_name.to_string().to_uppercase(), field_name.span());
235
236                    // Generate storage declaration based on field type
237                    let storage_decl = if is_stable_map_type(field_type) {
238                        quote! {
239                            #field_name_upper: #field_type =
240                                ::ic_stable_structures::StableBTreeMap::init(
241                                    MEMORY_MANAGER.with(|m| m.borrow().get(
242                                        ::ic_stable_structures::memory_manager::MemoryId::new(#memory_id)
243                                    ))
244                                );
245                        }
246                    } else if is_stable_cell_type(field_type) {
247                        quote! {
248                            #field_name_upper: ::ic_stable_structures::StableCell<#field_type, ::ic_stable_structures::memory_manager::VirtualMemory<::ic_stable_structures::DefaultMemoryImpl>> =
249                                ::ic_stable_structures::StableCell::init(
250                                    MEMORY_MANAGER.with(|m| m.borrow().get(
251                                        ::ic_stable_structures::memory_manager::MemoryId::new(#memory_id)
252                                    )),
253                                    Default::default()
254                                ).expect("Failed to initialize StableCell");
255                        }
256                    } else {
257                        // For simple types, wrap in StableCell
258                        quote! {
259                            #field_name_upper: ::ic_stable_structures::StableCell<#field_type, ::ic_stable_structures::memory_manager::VirtualMemory<::ic_stable_structures::DefaultMemoryImpl>> =
260                                ::ic_stable_structures::StableCell::init(
261                                    MEMORY_MANAGER.with(|m| m.borrow().get(
262                                        ::ic_stable_structures::memory_manager::MemoryId::new(#memory_id)
263                                    )),
264                                    Default::default()
265                                ).expect("Failed to initialize StableCell");
266                        }
267                    };
268
269                    storage_declarations.push(storage_decl);
270
271                    // Generate accessor method
272                    let accessor = if is_stable_map_type(field_type) {
273                        quote! {
274                            pub fn #field_name() -> impl std::ops::Deref<Target = #field_type> {
275                                #field_name_upper.with(|storage| storage.borrow())
276                            }
277                        }
278                    } else {
279                        let setter_name =
280                            syn::Ident::new(&format!("{}_set", field_name), field_name.span());
281
282                        quote! {
283                            pub fn #field_name() -> #field_type
284                            where
285                                #field_type: Clone + Default
286                            {
287                                #field_name_upper.with(|cell| cell.borrow().get().clone())
288                            }
289
290                            pub fn #setter_name(value: #field_type)
291                            where
292                                #field_type: Clone
293                            {
294                                #field_name_upper.with(|cell| {
295                                    cell.borrow_mut().set(value)
296                                        .expect("Failed to set value in StableCell");
297                                });
298                            }
299                        }
300                    };
301
302                    accessor_methods.push(accessor);
303                    memory_id += 1;
304                }
305            }
306
307            let expanded = quote! {
308                thread_local! {
309                    static MEMORY_MANAGER: ::std::cell::RefCell<
310                        ::ic_stable_structures::memory_manager::MemoryManager<
311                            ::ic_stable_structures::DefaultMemoryImpl
312                        >
313                    > = ::std::cell::RefCell::new(
314                        ::ic_stable_structures::memory_manager::MemoryManager::init(
315                            ::ic_stable_structures::DefaultMemoryImpl::default()
316                        )
317                    );
318
319                    #(static #storage_declarations)*
320                }
321
322                impl #struct_name {
323                    #(#accessor_methods)*
324                }
325            };
326
327            TokenStream::from(expanded)
328        } else {
329            syn::Error::new_spanned(
330                &input,
331                "IcarusStorage can only be used on structs with named fields",
332            )
333            .to_compile_error()
334            .into()
335        }
336    } else {
337        syn::Error::new_spanned(&input, "IcarusStorage can only be used on structs")
338            .to_compile_error()
339            .into()
340    }
341}
342
343/// Derive macro for ICP storable types
344///
345/// # Examples
346/// ```
347/// #[derive(IcarusStorable)]
348/// struct MyData { ... } // Uses default 1MB bound
349///
350/// #[derive(IcarusStorable)]
351/// #[icarus_storable(unbounded)]
352/// struct LargeData { ... } // Uses unbounded storage
353///
354/// #[derive(IcarusStorable)]
355/// #[icarus_storable(max_size = "2MB")]
356/// struct CustomData { ... } // Uses custom 2MB bound
357/// ```
358#[proc_macro_derive(IcarusStorable, attributes(icarus_storable))]
359pub fn derive_icarus_storable(input: TokenStream) -> TokenStream {
360    let input = parse_macro_input!(input as DeriveInput);
361    let struct_name = &input.ident;
362
363    // Extract generics if any
364    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
365
366    // Parse attributes
367    let mut unbounded = false;
368    let mut max_size_bytes = 1024 * 1024; // 1MB default
369
370    for attr in &input.attrs {
371        if attr.path().is_ident("icarus_storable") {
372            attr.parse_nested_meta(|meta| {
373                if meta.path.is_ident("unbounded") {
374                    unbounded = true;
375                    Ok(())
376                } else if meta.path.is_ident("max_size") {
377                    let value = meta.value()?;
378                    let lit_str: syn::LitStr = value.parse()?;
379                    let size_str = lit_str.value();
380                    max_size_bytes = parse_size_string(&size_str);
381                    Ok(())
382                } else {
383                    Err(meta.error("unsupported icarus_storable attribute"))
384                }
385            })
386            .unwrap_or_else(|e| panic!("Failed to parse icarus_storable attribute: {}", e));
387        }
388    }
389
390    let bound = if unbounded {
391        quote! { ic_stable_structures::storable::Bound::Unbounded }
392    } else {
393        quote! {
394            ic_stable_structures::storable::Bound::Bounded {
395                max_size: #max_size_bytes,
396                is_fixed_size: false,
397            }
398        }
399    };
400
401    // Generate implementation
402    let expanded = quote! {
403        impl #impl_generics ic_stable_structures::Storable for #struct_name #ty_generics #where_clause {
404            fn to_bytes(&self) -> std::borrow::Cow<[u8]> {
405                std::borrow::Cow::Owned(
406                    candid::encode_one(self).expect("Failed to encode to Candid")
407                )
408            }
409
410            fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self {
411                candid::decode_one(&bytes).expect("Failed to decode from Candid")
412            }
413
414            const BOUND: ic_stable_structures::storable::Bound = #bound;
415        }
416
417    };
418
419    TokenStream::from(expanded)
420}
421
422/// Attribute macro for marking impl blocks that contain tool methods
423#[proc_macro_attribute]
424pub fn icarus_tools(attr: TokenStream, item: TokenStream) -> TokenStream {
425    let attr = TokenStream2::from(attr);
426    let input = parse_macro_input!(item as syn::ItemImpl);
427    tools::expand_icarus_tools(attr, input).into()
428}
429
430/// Attribute macro for individual tool methods
431/// Usage: #[icarus_tool("Tool description")]
432///
433/// This attribute marks functions as tools and stores their description.
434/// The icarus_module macro will collect these to generate metadata.
435#[proc_macro_attribute]
436pub fn icarus_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
437    let input_fn = parse_macro_input!(item as syn::ItemFn);
438
439    // Parse the description from the attribute
440    let description = if attr.is_empty() {
441        format!("{} tool", input_fn.sig.ident)
442    } else {
443        let lit_str = parse_macro_input!(attr as syn::LitStr);
444        lit_str.value()
445    };
446
447    // Preserve the function with the description as a doc comment
448    // The module macro will look for this pattern
449    let expanded = quote! {
450        #[doc = #description]
451        #input_fn
452    };
453
454    TokenStream::from(expanded)
455}
456
457/// Module-level attribute macro that collects all icarus_tool functions
458/// and generates the get_metadata query function automatically.
459///
460/// Usage:
461/// ```
462/// #[icarus_module]
463/// mod my_module {
464///     #[update]
465///     #[icarus_tool("Store data")]
466///     pub fn store(data: String) -> Result<(), String> { ... }
467/// }
468/// ```
469///
470/// The name and version are automatically taken from Cargo.toml
471#[proc_macro_attribute]
472pub fn icarus_module(attr: TokenStream, item: TokenStream) -> TokenStream {
473    let input = parse_macro_input!(item as syn::ItemMod);
474
475    // Parse attributes
476    let module_config = if attr.is_empty() {
477        tools::ModuleConfig::default()
478    } else {
479        parse_macro_input!(attr as tools::ModuleConfig)
480    };
481
482    // Process the module to collect tools and generate metadata
483    let expanded = tools::expand_icarus_module(input, module_config);
484    TokenStream::from(expanded)
485}
486
487/// Crate-level attribute macro that scans for all icarus_tool functions
488/// and generates the get_metadata query function automatically.
489///
490/// Usage:
491/// ```
492/// #![icarus_canister(name = "my-server", version = "1.0.0")]
493///
494/// #[update]
495/// #[icarus_tool("Store data")]
496/// pub fn store(data: String) -> Result<(), String> { ... }
497/// ```
498#[proc_macro_attribute]
499pub fn icarus_canister(_attr: TokenStream, item: TokenStream) -> TokenStream {
500    // Parse the crate content
501    let input = parse_macro_input!(item as syn::File);
502
503    // Process the crate to collect tools and generate metadata
504    let expanded = tools::expand_icarus_canister(input);
505    TokenStream::from(expanded)
506}
507
508// Helper function to extract type as string
509#[allow(dead_code)]
510fn extract_type_string(ty: &syn::Type) -> String {
511    quote!(#ty).to_string()
512}
513
514// Helper function to convert Rust types to JSON schema types
515#[allow(dead_code)]
516fn rust_type_to_json_type(rust_type: &str) -> &'static str {
517    match rust_type {
518        s if s.contains("String") || s.contains("&str") => "string",
519        s if s.contains("i32")
520            || s.contains("i64")
521            || s.contains("u32")
522            || s.contains("u64")
523            || s.contains("usize") =>
524        {
525            "integer"
526        }
527        s if s.contains("f32") || s.contains("f64") => "number",
528        s if s.contains("bool") => "boolean",
529        s if s.contains("Vec<") => "array",
530        _ => "string", // Default to string for unknown types
531    }
532}
533
534// Helper function to check if a type is StableBTreeMap
535fn is_stable_map_type(ty: &syn::Type) -> bool {
536    let type_string = quote!(#ty).to_string();
537    type_string.contains("StableBTreeMap")
538}
539
540// Helper function to check if a type is StableCell
541fn is_stable_cell_type(ty: &syn::Type) -> bool {
542    let type_string = quote!(#ty).to_string();
543    type_string.contains("StableCell")
544}
545
546// Helper function to parse size strings like "1MB", "2KB", etc.
547fn parse_size_string(size: &str) -> u32 {
548    let size = size.trim();
549    if let Some(num_str) = size.strip_suffix("MB") {
550        num_str.trim().parse::<u32>().unwrap_or(1) * 1024 * 1024
551    } else if let Some(num_str) = size.strip_suffix("KB") {
552        num_str.trim().parse::<u32>().unwrap_or(1) * 1024
553    } else if let Some(num_str) = size.strip_suffix("B") {
554        num_str.trim().parse::<u32>().unwrap_or(1024)
555    } else {
556        // Try to parse as raw bytes
557        size.parse::<u32>().unwrap_or(1024 * 1024)
558    }
559}