icarus_derive/
lib.rs

1// Copyright (c) 2025 Icarus Team. All Rights Reserved.
2// Licensed under BSL-1.1. See LICENSE and NOTICE files.
3// Signature verification and telemetry must remain intact.
4
5// #![warn(missing_docs)] // TODO: Enable after adding all documentation
6
7//! Procedural macros for the Icarus SDK
8//!
9//! This crate provides derive macros and attribute macros to reduce
10//! boilerplate when building MCP servers for ICP.
11
12use proc_macro::TokenStream;
13use proc_macro2::TokenStream as TokenStream2;
14use quote::quote;
15use syn::{parse_macro_input, DeriveInput};
16
17mod server;
18mod tools;
19mod validation;
20
21/// Derive macro for creating MCP tools
22#[proc_macro_derive(IcarusTool, attributes(icarus_tool))]
23pub fn derive_icarus_tool(input: TokenStream) -> TokenStream {
24    let input = parse_macro_input!(input as DeriveInput);
25
26    // Extract tool metadata from attributes
27    let mut name = None;
28    let mut description = None;
29
30    for attr in &input.attrs {
31        if attr.path().is_ident("icarus_tool") {
32            attr.parse_nested_meta(|meta| {
33                if meta.path.is_ident("name") {
34                    name = Some(meta.value()?.parse::<syn::LitStr>()?.value());
35                    Ok(())
36                } else if meta.path.is_ident("description") {
37                    description = Some(meta.value()?.parse::<syn::LitStr>()?.value());
38                    Ok(())
39                } else {
40                    Err(meta.error("unsupported icarus_tool attribute"))
41                }
42            })
43            .expect("Failed to parse icarus_tool attribute");
44        }
45    }
46
47    let struct_name = &input.ident;
48    let tool_name = name.unwrap_or_else(|| struct_name.to_string());
49    let tool_desc = description.unwrap_or_else(|| format!("{} tool", tool_name));
50
51    // Generate the implementation
52    let expanded = quote! {
53        #[async_trait::async_trait]
54        impl icarus_core::tool::IcarusTool for #struct_name {
55            fn info(&self) -> icarus_core::tool::ToolInfo {
56                icarus_core::tool::ToolInfo {
57                    name: #tool_name.to_string(),
58                    description: #tool_desc.to_string(),
59                    input_schema: serde_json::json!({
60                        "type": "object",
61                        "properties": {},
62                        "required": []
63                    }),
64                }
65            }
66
67            fn to_rmcp_tool(&self) -> rmcp::model::Tool {
68                use std::borrow::Cow;
69                use std::sync::Arc;
70
71                let schema = serde_json::json!({
72                    "type": "object",
73                    "properties": {},
74                    "required": []
75                });
76
77                rmcp::model::Tool {
78                    name: Cow::Borrowed(#tool_name),
79                    description: Some(Cow::Borrowed(#tool_desc)),
80                    input_schema: Arc::new(schema.as_object().unwrap().clone()),
81                    annotations: None,
82                }
83            }
84
85            async fn execute(&self, args: serde_json::Value) -> icarus_core::error::Result<serde_json::Value> {
86                // Default implementation - override in your tool
87                Ok(serde_json::json!({
88                    "error": "Tool execution not implemented"
89                }))
90            }
91        }
92    };
93
94    TokenStream::from(expanded)
95}
96
97/// Attribute macro for MCP server setup
98/// ```ignore
99/// pub struct MyServer {
100///     tools: Vec<Box<dyn IcarusTool>>,
101/// }
102/// ```
103#[proc_macro_attribute]
104pub fn icarus_server(args: TokenStream, input: TokenStream) -> TokenStream {
105    let args = TokenStream2::from(args);
106    let input = parse_macro_input!(input as DeriveInput);
107    server::expand_icarus_server(args, input).into()
108}
109
110/// Derive macro for common Icarus type patterns
111///
112/// This is a convenience macro that combines IcarusStorable with sensible defaults.
113/// You still need to derive the standard traits manually.
114///
115/// # Examples
116/// ```ignore
117/// #[derive(Debug, Clone, Serialize, Deserialize, CandidType, IcarusType)]
118/// struct MemoryEntry {
119///     id: String,
120///     content: String,
121///     created_at: u64,
122/// }
123/// ```
124///
125/// This is equivalent to:
126/// ```ignore
127/// #[derive(Debug, Clone, Serialize, Deserialize, CandidType, IcarusStorable)]
128/// #[icarus_storable(unbounded)]
129/// struct MemoryEntry { ... }
130/// ```
131#[proc_macro_derive(IcarusType, attributes(icarus_storable))]
132pub fn derive_icarus_type(input: TokenStream) -> TokenStream {
133    let input = parse_macro_input!(input as DeriveInput);
134    let struct_name = &input.ident;
135
136    // Extract generics if any
137    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
138
139    // Parse attributes for storage configuration
140    let mut unbounded = true; // Default to unbounded for convenience
141    let mut max_size_bytes = 1024 * 1024; // 1MB default
142
143    for attr in &input.attrs {
144        if attr.path().is_ident("icarus_storable") {
145            attr.parse_nested_meta(|meta| {
146                if meta.path.is_ident("unbounded") {
147                    unbounded = true;
148                    Ok(())
149                } else if meta.path.is_ident("bounded") {
150                    unbounded = false;
151                    Ok(())
152                } else if meta.path.is_ident("max_size") {
153                    let value = meta.value()?;
154                    let lit_str: syn::LitStr = value.parse()?;
155                    let size_str = lit_str.value();
156                    max_size_bytes = parse_size_string(&size_str);
157                    unbounded = false;
158                    Ok(())
159                } else {
160                    Ok(()) // Ignore other attributes
161                }
162            })
163            .unwrap_or(()); // Ignore parse errors
164        }
165    }
166
167    let bound = if unbounded {
168        quote! { ic_stable_structures::storable::Bound::Unbounded }
169    } else {
170        quote! {
171            ic_stable_structures::storable::Bound::Bounded {
172                max_size: #max_size_bytes,
173                is_fixed_size: false,
174            }
175        }
176    };
177
178    // Generate all the common trait implementations
179    let expanded = quote! {
180        // Note: We expect the user to add #[derive(Debug, Clone, Serialize, Deserialize, CandidType)]
181        // This macro just adds the IcarusStorable functionality
182
183        // Implement Storable for ICP
184        impl #impl_generics ic_stable_structures::Storable for #struct_name #ty_generics #where_clause {
185            fn to_bytes(&self) -> std::borrow::Cow<[u8]> {
186                std::borrow::Cow::Owned(
187                    candid::encode_one(self).expect("Failed to encode to Candid")
188                )
189            }
190
191            fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self {
192                candid::decode_one(&bytes).expect("Failed to decode from Candid")
193            }
194
195            const BOUND: ic_stable_structures::storable::Bound = #bound;
196        }
197    };
198
199    TokenStream::from(expanded)
200}
201
202/// Derive macro for simplified storage declaration
203///
204/// Generates stable storage declarations from a simple struct definition.
205/// Automatically assigns memory IDs and handles initialization.
206///
207/// # Examples
208/// ```ignore
209/// #[derive(IcarusStorage)]
210/// struct Storage {
211///     memories: StableBTreeMap<String, MemoryEntry>,
212///     counter: u64,
213///     users: StableBTreeMap<Principal, User>,
214/// }
215/// ```
216///
217/// This generates:
218/// - Thread-local storage declarations
219/// - Memory manager initialization  
220/// - Accessor methods for each field
221#[proc_macro_derive(IcarusStorage)]
222pub fn derive_icarus_storage(input: TokenStream) -> TokenStream {
223    let input = parse_macro_input!(input as DeriveInput);
224
225    if let syn::Data::Struct(data_struct) = &input.data {
226        if let syn::Fields::Named(fields_named) = &data_struct.fields {
227            let struct_name = &input.ident;
228            let mut storage_declarations = vec![];
229            let mut accessor_methods = vec![];
230            let mut memory_id = 0u8;
231
232            for field in &fields_named.named {
233                if let Some(field_name) = &field.ident {
234                    let field_type = &field.ty;
235                    let field_name_upper =
236                        syn::Ident::new(&field_name.to_string().to_uppercase(), field_name.span());
237
238                    // Generate storage declaration based on field type
239                    let storage_decl = if is_stable_map_type(field_type) {
240                        quote! {
241                            #field_name_upper: #field_type =
242                                ::ic_stable_structures::StableBTreeMap::init(
243                                    MEMORY_MANAGER.with(|m| m.borrow().get(
244                                        ::ic_stable_structures::memory_manager::MemoryId::new(#memory_id)
245                                    ))
246                                );
247                        }
248                    } else if is_stable_cell_type(field_type) {
249                        quote! {
250                            #field_name_upper: ::ic_stable_structures::StableCell<#field_type, ::ic_stable_structures::memory_manager::VirtualMemory<::ic_stable_structures::DefaultMemoryImpl>> =
251                                ::ic_stable_structures::StableCell::init(
252                                    MEMORY_MANAGER.with(|m| m.borrow().get(
253                                        ::ic_stable_structures::memory_manager::MemoryId::new(#memory_id)
254                                    )),
255                                    Default::default()
256                                ).expect("Failed to initialize StableCell");
257                        }
258                    } else {
259                        // For simple types, wrap in StableCell
260                        quote! {
261                            #field_name_upper: ::ic_stable_structures::StableCell<#field_type, ::ic_stable_structures::memory_manager::VirtualMemory<::ic_stable_structures::DefaultMemoryImpl>> =
262                                ::ic_stable_structures::StableCell::init(
263                                    MEMORY_MANAGER.with(|m| m.borrow().get(
264                                        ::ic_stable_structures::memory_manager::MemoryId::new(#memory_id)
265                                    )),
266                                    Default::default()
267                                ).expect("Failed to initialize StableCell");
268                        }
269                    };
270
271                    storage_declarations.push(storage_decl);
272
273                    // Generate accessor method
274                    let accessor = if is_stable_map_type(field_type) {
275                        quote! {
276                            pub fn #field_name() -> impl std::ops::Deref<Target = #field_type> {
277                                #field_name_upper.with(|storage| storage.borrow())
278                            }
279                        }
280                    } else {
281                        let setter_name =
282                            syn::Ident::new(&format!("{}_set", field_name), field_name.span());
283
284                        quote! {
285                            pub fn #field_name() -> #field_type
286                            where
287                                #field_type: Clone + Default
288                            {
289                                #field_name_upper.with(|cell| cell.borrow().get().clone())
290                            }
291
292                            pub fn #setter_name(value: #field_type)
293                            where
294                                #field_type: Clone
295                            {
296                                #field_name_upper.with(|cell| {
297                                    cell.borrow_mut().set(value)
298                                        .expect("Failed to set value in StableCell");
299                                });
300                            }
301                        }
302                    };
303
304                    accessor_methods.push(accessor);
305                    memory_id += 1;
306                }
307            }
308
309            let expanded = quote! {
310                thread_local! {
311                    static MEMORY_MANAGER: ::std::cell::RefCell<
312                        ::ic_stable_structures::memory_manager::MemoryManager<
313                            ::ic_stable_structures::DefaultMemoryImpl
314                        >
315                    > = ::std::cell::RefCell::new(
316                        ::ic_stable_structures::memory_manager::MemoryManager::init(
317                            ::ic_stable_structures::DefaultMemoryImpl::default()
318                        )
319                    );
320
321                    #(static #storage_declarations)*
322                }
323
324                impl #struct_name {
325                    #(#accessor_methods)*
326                }
327            };
328
329            TokenStream::from(expanded)
330        } else {
331            syn::Error::new_spanned(
332                &input,
333                "IcarusStorage can only be used on structs with named fields",
334            )
335            .to_compile_error()
336            .into()
337        }
338    } else {
339        syn::Error::new_spanned(&input, "IcarusStorage can only be used on structs")
340            .to_compile_error()
341            .into()
342    }
343}
344
345/// Derive macro for ICP storable types
346///
347/// # Examples
348/// ```ignore
349/// #[derive(IcarusStorable)]
350/// struct MyData { ... } // Uses default 1MB bound
351///
352/// #[derive(IcarusStorable)]
353/// #[icarus_storable(unbounded)]
354/// struct LargeData { ... } // Uses unbounded storage
355///
356/// #[derive(IcarusStorable)]
357/// #[icarus_storable(max_size = "2MB")]
358/// struct CustomData { ... } // Uses custom 2MB bound
359/// ```
360#[proc_macro_derive(IcarusStorable, attributes(icarus_storable))]
361pub fn derive_icarus_storable(input: TokenStream) -> TokenStream {
362    let input = parse_macro_input!(input as DeriveInput);
363    let struct_name = &input.ident;
364
365    // Extract generics if any
366    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
367
368    // Parse attributes
369    let mut unbounded = false;
370    let mut max_size_bytes = 1024 * 1024; // 1MB default
371
372    for attr in &input.attrs {
373        if attr.path().is_ident("icarus_storable") {
374            attr.parse_nested_meta(|meta| {
375                if meta.path.is_ident("unbounded") {
376                    unbounded = true;
377                    Ok(())
378                } else if meta.path.is_ident("max_size") {
379                    let value = meta.value()?;
380                    let lit_str: syn::LitStr = value.parse()?;
381                    let size_str = lit_str.value();
382                    max_size_bytes = parse_size_string(&size_str);
383                    Ok(())
384                } else {
385                    Err(meta.error("unsupported icarus_storable attribute"))
386                }
387            })
388            .unwrap_or_else(|e| panic!("Failed to parse icarus_storable attribute: {}", e));
389        }
390    }
391
392    let bound = if unbounded {
393        quote! { ic_stable_structures::storable::Bound::Unbounded }
394    } else {
395        quote! {
396            ic_stable_structures::storable::Bound::Bounded {
397                max_size: #max_size_bytes,
398                is_fixed_size: false,
399            }
400        }
401    };
402
403    // Generate implementation
404    let expanded = quote! {
405        impl #impl_generics ic_stable_structures::Storable for #struct_name #ty_generics #where_clause {
406            fn to_bytes(&self) -> std::borrow::Cow<[u8]> {
407                std::borrow::Cow::Owned(
408                    candid::encode_one(self).expect("Failed to encode to Candid")
409                )
410            }
411
412            fn into_bytes(self) -> std::vec::Vec<u8> {
413                candid::encode_one(&self).expect("Failed to encode to Candid")
414            }
415
416            fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self {
417                candid::decode_one(&bytes).expect("Failed to decode from Candid")
418            }
419
420            const BOUND: ic_stable_structures::storable::Bound = #bound;
421        }
422
423    };
424
425    TokenStream::from(expanded)
426}
427
428/// Attribute macro for marking impl blocks that contain tool methods
429#[proc_macro_attribute]
430pub fn icarus_tools(attr: TokenStream, item: TokenStream) -> TokenStream {
431    let attr = TokenStream2::from(attr);
432    let input = parse_macro_input!(item as syn::ItemImpl);
433    tools::expand_icarus_tools(attr, input).into()
434}
435
436/// Attribute macro for individual tool methods
437/// Usage: #[icarus_tool("Tool description")]
438///
439/// This attribute marks functions as tools and stores their description.
440/// The icarus_module macro will collect these to generate metadata.
441#[proc_macro_attribute]
442pub fn icarus_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
443    let input_fn = parse_macro_input!(item as syn::ItemFn);
444
445    // Parse the description from the attribute
446    let description = if attr.is_empty() {
447        format!("{} tool", input_fn.sig.ident)
448    } else {
449        let lit_str = parse_macro_input!(attr as syn::LitStr);
450        lit_str.value()
451    };
452
453    // Preserve the function with the description as a doc comment
454    // The module macro will look for this pattern
455    let expanded = quote! {
456        #[doc = #description]
457        #input_fn
458    };
459
460    TokenStream::from(expanded)
461}
462
463/// Module-level attribute macro that collects all icarus_tool functions
464/// and generates the list_tools query function automatically.
465///
466/// Usage:
467/// ```ignore
468/// #[icarus_module]
469/// mod my_module {
470///     #[update]
471///     #[icarus_tool("Store data")]
472///     pub fn store(data: String) -> Result<(), String> { ... }
473/// }
474/// ```
475///
476/// The name and version are automatically taken from Cargo.toml
477#[proc_macro_attribute]
478pub fn icarus_module(attr: TokenStream, item: TokenStream) -> TokenStream {
479    let input = parse_macro_input!(item as syn::ItemMod);
480
481    // Parse attributes
482    let module_config = if attr.is_empty() {
483        tools::ModuleConfig::default()
484    } else {
485        parse_macro_input!(attr as tools::ModuleConfig)
486    };
487
488    // Process the module to collect tools and generate metadata
489    let expanded = tools::expand_icarus_module(input, module_config);
490    TokenStream::from(expanded)
491}
492
493/// Crate-level attribute macro that scans for all icarus_tool functions
494/// and generates the list_tools query function automatically.
495///
496/// Usage:
497/// ```ignore
498/// #![icarus_canister(name = "my-server", version = "1.0.0")]
499///
500/// #[update]
501/// #[icarus_tool("Store data")]
502/// pub fn store(data: String) -> Result<(), String> { ... }
503/// ```
504#[proc_macro_attribute]
505pub fn icarus_canister(_attr: TokenStream, item: TokenStream) -> TokenStream {
506    // Parse the crate content
507    let input = parse_macro_input!(item as syn::File);
508
509    // Process the crate to collect tools and generate metadata
510    let expanded = tools::expand_icarus_canister(input);
511    TokenStream::from(expanded)
512}
513
514// Helper function to extract type as string
515#[allow(dead_code)]
516fn extract_type_string(ty: &syn::Type) -> String {
517    quote!(#ty).to_string()
518}
519
520// Helper function to convert Rust types to JSON schema types
521#[allow(dead_code)]
522fn rust_type_to_json_type(rust_type: &str) -> &'static str {
523    match rust_type {
524        s if s.contains("String") || s.contains("&str") => "string",
525        s if s.contains("i32")
526            || s.contains("i64")
527            || s.contains("u32")
528            || s.contains("u64")
529            || s.contains("usize") =>
530        {
531            "integer"
532        }
533        s if s.contains("f32") || s.contains("f64") => "number",
534        s if s.contains("bool") => "boolean",
535        s if s.contains("Vec<") => "array",
536        _ => "string", // Default to string for unknown types
537    }
538}
539
540// Helper function to check if a type is StableBTreeMap
541fn is_stable_map_type(ty: &syn::Type) -> bool {
542    let type_string = quote!(#ty).to_string();
543    type_string.contains("StableBTreeMap")
544}
545
546// Helper function to check if a type is StableCell
547fn is_stable_cell_type(ty: &syn::Type) -> bool {
548    let type_string = quote!(#ty).to_string();
549    type_string.contains("StableCell")
550}
551
552// Helper function to parse size strings like "1MB", "2KB", etc.
553fn parse_size_string(size: &str) -> u32 {
554    let size = size.trim();
555    if let Some(num_str) = size.strip_suffix("MB") {
556        num_str.trim().parse::<u32>().unwrap_or(1) * 1024 * 1024
557    } else if let Some(num_str) = size.strip_suffix("KB") {
558        num_str.trim().parse::<u32>().unwrap_or(1) * 1024
559    } else if let Some(num_str) = size.strip_suffix("B") {
560        num_str.trim().parse::<u32>().unwrap_or(1024)
561    } else {
562        // Try to parse as raw bytes
563        size.parse::<u32>().unwrap_or(1024 * 1024)
564    }
565}