anchor_decoder/
lib.rs

1extern crate proc_macro;
2use std::collections::HashSet;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use serde_json::Value;
7use syn::{parse_macro_input, LitStr};
8
9/// Helper to convert snake_case to CamelCase (e.g. "create_order" -> "CreateOrder")
10fn to_camel_case(s: &str) -> String {
11    s.split('_')
12        .map(|word| {
13            let mut c = word.chars();
14            match c.next() {
15                None => String::new(),
16                Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
17            }
18        })
19        .collect()
20}
21
22/// Maps an IDL type (which can be a string like "u8" or an object for arrays or defined types)
23/// into the corresponding Rust type as tokens. The `generated_types` set contains the names
24/// of custom types that will be generated by this macro.
25fn map_idl_type(arg_type: &Value, generated_types: &HashSet<String>) -> proc_macro2::TokenStream {
26    if let Some(s) = arg_type.as_str() {
27        match s {
28            "u8" => quote! { u8 },
29            "u16" => quote! { u16 },
30            "u64" => quote! { u64 },
31            "i64" => quote! { i64 },
32            "bool" => quote! { bool },
33            "pubkey" => quote! { Pubkey },
34            "string" => quote! { String },
35            _ => quote! { () }, // fallback for unsupported types
36        }
37    } else if let Some(obj) = arg_type.as_object() {
38        if let Some(array_val) = obj.get("array") {
39            if let Some(arr) = array_val.as_array() {
40                if arr.len() == 2 {
41                    let inner = map_idl_type(&arr[0], generated_types);
42                    if let Some(len) = arr[1].as_u64() {
43                        let len_literal =
44                            syn::LitInt::new(&len.to_string(), proc_macro2::Span::call_site());
45                        return quote! { [#inner; #len_literal] };
46                    }
47                }
48            }
49        } else if let Some(defined) = obj.get("defined") {
50            if let Some(defined_obj) = defined.as_object() {
51                if let Some(name) = defined_obj.get("name").and_then(|n| n.as_str()) {
52                    let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
53                    // If the type was generated by our macro, reference it directly.
54                    // Otherwise assume it's external (and qualify it).
55                    if generated_types.contains(name) {
56                        return quote! { #type_ident };
57                    } else {
58                        return quote! { ::crate::#type_ident };
59                    }
60                }
61            }
62        }
63        quote! { () }
64    } else {
65        quote! { () }
66    }
67}
68
69/// Procedural macro attribute that generates decoding code from an Anchor IDL JSON file.
70/// The macro reads the file at compile time
71/// 
72/// For each instruction:
73///  - It generates a struct for the instruction's arguments (if any), with a constant discriminator.
74///  - It creates an enum variant for the instruction.
75///  - It produces a helper function (`decode_instruction`) to match and decode incoming data.
76///
77/// For each account:
78///  - It assumes the account type is defined under "types" (by matching name).
79///  - It uses the provided discriminator to generate a match arm that decodes the account data,
80///    skipping the first 8 bytes.
81#[proc_macro_attribute]
82pub fn anchor_idl(attr: TokenStream, _item: TokenStream) -> TokenStream {
83    // Get the relative IDL file path from the attribute
84    let relative_path = parse_macro_input!(attr as LitStr).value();
85
86    // Resolve path relative to crate root
87    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
88        .expect("CARGO_MANIFEST_DIR environment variable not set");
89    let idl_path = std::path::Path::new(&manifest_dir)
90        .join(relative_path)
91        .canonicalize()
92        .unwrap_or_else(|e| panic!("Failed to resolve IDL path: {}", e));
93
94    // Read and parse the IDL JSON at compile time
95    let idl_json = std::fs::read_to_string(&idl_path)
96        .unwrap_or_else(|_| panic!("Unable to read IDL file at: {}", idl_path.display()));
97    let idl: Value = serde_json::from_str(&idl_json)
98        .unwrap_or_else(|_| panic!("Invalid JSON in IDL file: {}", idl_path.display()));
99
100    // Collect the names of all types defined in the IDL.
101    let generated_types: HashSet<String> =
102        if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
103            types
104                .iter()
105                .filter_map(|t| {
106                    t.get("name")
107                        .and_then(|v| v.as_str())
108                        .map(|s| s.to_string())
109                })
110                .collect()
111        } else {
112            HashSet::new()
113        };
114
115    let mut struct_defs = Vec::new();
116
117    // Process custom type definitions.
118    if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
119        for type_def in types {
120            if let (Some(name), Some(type_info)) = (
121                type_def.get("name").and_then(|v| v.as_str()),
122                type_def.get("type").and_then(|v| v.as_object()),
123            ) {
124                let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
125
126                // Check the kind of the type.
127                if let Some(kind) = type_info.get("kind").and_then(|v| v.as_str()) {
128                    match kind {
129                        "struct" => {
130                            // Process struct definitions.
131                            let field_defs = if let Some(fields) = type_info.get("fields").and_then(|v| v.as_array()) {
132                                let mut field_defs = Vec::new();
133                                for field in fields {
134                                    if let (Some(field_name), Some(field_type)) = (
135                                        field.get("name").and_then(|v| v.as_str()),
136                                        field.get("type"),
137                                    ) {
138                                        let field_ident = syn::Ident::new(
139                                            field_name,
140                                            proc_macro2::Span::call_site(),
141                                        );
142                                        let field_type = map_idl_type(field_type, &generated_types);
143                                        field_defs.push(quote! {
144                                            pub #field_ident: #field_type,
145                                        });
146                                    }
147                                }
148                                field_defs
149                            } else {
150                                // Handle empty struct (no fields property)
151                                Vec::new()
152                            };
153                            
154                            struct_defs.push(quote! {
155                                #[derive(Debug, BorshSerialize, BorshDeserialize)]
156                                pub struct #type_ident {
157                                    #( #field_defs )*
158                                }
159                                impl #type_ident {
160                                    pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
161                                        <Self as BorshDeserialize>::try_from_slice(data)
162                                    }
163                                }
164                            });
165                        }
166                        "enum" => {
167                            // Process enum definitions.
168                            if let Some(variants) =
169                                type_info.get("variants").and_then(|v| v.as_array())
170                            {
171                                let mut variant_tokens = Vec::new();
172                                for variant in variants {
173                                    if let Some(variant_name) =
174                                        variant.get("name").and_then(|v| v.as_str())
175                                    {
176                                        let variant_ident = syn::Ident::new(
177                                            variant_name,
178                                            proc_macro2::Span::call_site(),
179                                        );
180                                        variant_tokens.push(quote! {
181                                            #variant_ident,
182                                        });
183                                    }
184                                }
185                                struct_defs.push(quote! {
186                                    #[derive(Debug, BorshSerialize, BorshDeserialize)]
187                                    pub enum #type_ident {
188                                        #( #variant_tokens )*
189                                    }
190                                    impl #type_ident {
191                                        pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
192                                            <Self as BorshDeserialize>::try_from_slice(data)
193                                        }
194                                    }
195                                });
196                            }
197                        }
198                        _ => {}
199                    }
200                }
201            }
202        }
203    }
204
205    let instructions = idl
206        .get("instructions")
207        .and_then(|v| v.as_array())
208        .expect("IDL JSON does not contain an 'instructions' array");
209
210    let mut enum_variants = Vec::new();
211    let mut match_arms = Vec::new();
212
213    for inst in instructions {
214        // Get instruction name, discriminator, and args.
215        let name = inst.get("name").and_then(|v| v.as_str()).unwrap();
216        let discriminator = inst
217            .get("discriminator")
218            .and_then(|v| v.as_array())
219            .expect("Discriminator missing or not an array");
220        let args = inst
221            .get("args")
222            .and_then(|v| v.as_array())
223            .expect("Args missing or not an array");
224
225        // Convert the instruction name to CamelCase for the generated struct.
226        let struct_name_str = to_camel_case(name);
227        let struct_name = syn::Ident::new(&struct_name_str, proc_macro2::Span::call_site());
228
229        // Generate account info struct name
230        let accounts_struct_name = syn::Ident::new(
231            &format!("{}Accounts", struct_name_str),
232            proc_macro2::Span::call_site(),
233        );
234
235        // Generate a constant for the discriminator.
236        let disc_values: Vec<u8> = discriminator
237            .iter()
238            .map(|v| v.as_u64().unwrap() as u8)
239            .collect();
240        let disc_tokens = quote! { [ #( #disc_values ),* ] };
241
242        // Process accounts for this instruction
243        let mut account_consts = Vec::new();
244        let mut account_fields = Vec::new();
245        let mut account_indices = Vec::new();
246        let mut account_name_matches = Vec::new();
247        let mut account_tuples = Vec::new();
248        let mut account_index_matches = Vec::new();
249
250        if let Some(accounts) = inst.get("accounts").and_then(|v| v.as_array()) {
251            for (idx, account) in accounts.iter().enumerate() {
252                if let Some(account_name) = account.get("name").and_then(|v| v.as_str()) {
253                    let const_name = account_name.to_uppercase();
254                    let const_ident = syn::Ident::new(&const_name, proc_macro2::Span::call_site());
255                    let idx_lit =
256                        syn::LitInt::new(&idx.to_string(), proc_macro2::Span::call_site());
257
258                    account_consts.push(quote! {
259                        pub const #const_ident: usize = #idx_lit;
260                    });
261
262                    let field_ident = syn::Ident::new(account_name, proc_macro2::Span::call_site());
263                    account_fields.push(quote! {
264                        pub #field_ident: usize,
265                    });
266
267                    account_indices.push(quote! {
268                        #field_ident: #idx_lit,
269                    });
270
271                    // Create match arm for get_account_name
272                    let account_name_str = account_name;
273                    account_name_matches.push(quote! {
274                        #idx_lit => Some(#account_name_str),
275                    });
276
277                    // Create tuple for get_all_accounts
278                    account_tuples.push(quote! {
279                        (#account_name_str, Self::#const_ident)
280                    });
281
282                    // Create match arm for get_account_index
283                    account_index_matches.push(quote! {
284                        #account_name_str => Some(Self::#const_ident),
285                    });
286                }
287            }
288
289            // Generate the accounts struct
290            struct_defs.push(quote! {
291                #[derive(Debug, Clone, Copy)]
292                pub struct #accounts_struct_name {
293                    #( #account_fields )*
294                }
295
296                impl #accounts_struct_name {
297                    #( #account_consts )*
298
299                    pub const fn new() -> Self {
300                        Self {
301                            #( #account_indices )*
302                        }
303                    }
304
305                    pub fn get_account_name(&self, index: usize) -> Option<&'static str> {
306                        match index {
307                            #( #account_name_matches )*
308                            _ => None,
309                        }
310                    }
311
312                    pub fn get_all_accounts(&self) -> &'static [(&'static str, usize)] {
313                        &[
314                            #( #account_tuples, )*
315                        ]
316                    }
317                    
318                    pub fn get_account_index(&self, name: &str) -> Option<usize> {
319                        match name {
320                            #( #account_index_matches )*
321                            _ => None,
322                        }
323                    }
324                }
325            });
326        }
327
328        if !args.is_empty() {
329            // Generate struct fields by mapping each argument's type.
330            let mut fields = Vec::new();
331            for arg in args {
332                let arg_name = arg.get("name").and_then(|v| v.as_str()).unwrap();
333                let arg_type = arg.get("type").expect("Missing type in argument");
334                let field_ident = syn::Ident::new(arg_name, proc_macro2::Span::call_site());
335                let field_type = map_idl_type(arg_type, &generated_types);
336                fields.push(quote! {
337                    pub #field_ident: #field_type,
338                });
339            }
340
341            struct_defs.push(quote! {
342                #[derive(Debug, BorshSerialize, BorshDeserialize)]
343                pub struct #struct_name {
344                    #( #fields )*
345                }
346                impl #struct_name {
347                    pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
348                    pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
349                    
350                    pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
351                        // Skip the first 8 bytes (discriminator)
352                        let payload = &data[8..];
353                        <Self as BorshDeserialize>::try_from_slice(payload)
354                    }
355                    
356                    /// Maps account indices to their semantic names
357                    pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
358                        let mut result = std::collections::HashMap::new();
359                        for (i, account) in accounts.iter().enumerate() {
360                            if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
361                                result.insert(name, account);
362                            }
363                        }
364                        result
365                    }
366                }
367            });
368
369            enum_variants.push(quote! {
370                #struct_name(#struct_name)
371            });
372            match_arms.push(quote! {
373                x if x == #struct_name::DISCRIMINATOR => {
374                    return Some(DecodedInstruction::#struct_name(
375                        #struct_name::decode(data).ok()?
376                    ))
377                }
378            });
379        } else {
380            // For instructions with no arguments, generate a unit struct.
381            struct_defs.push(quote! {
382                #[derive(Debug)]
383                pub struct #struct_name;
384                impl #struct_name {
385                    pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
386                    pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
387                    
388                    /// Maps account indices to their semantic names
389                    pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
390                        let mut result = std::collections::HashMap::new();
391                        for (i, account) in accounts.iter().enumerate() {
392                            if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
393                                result.insert(name, account);
394                            }
395                        }
396                        result
397                    }
398                }
399            });
400            enum_variants.push(quote! {
401                #struct_name
402            });
403            match_arms.push(quote! {
404                x if x == #struct_name::DISCRIMINATOR => {
405                    return Some(DecodedInstruction::#struct_name)
406                }
407            });
408        }
409    }
410
411    // Process accounts from the IDL.
412    let mut account_enum_variants = Vec::new();
413    let mut account_match_arms = Vec::new();
414    if let Some(accounts) = idl.get("accounts").and_then(|v| v.as_array()) {
415        for account in accounts {
416            let name = account.get("name").and_then(|v| v.as_str()).unwrap();
417            let discriminator = account
418                .get("discriminator")
419                .and_then(|v| v.as_array())
420                .expect("Discriminator missing or not an array in accounts");
421            let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
422            let disc_values: Vec<u8> = discriminator
423                .iter()
424                .map(|v| v.as_u64().unwrap() as u8)
425                .collect();
426            let disc_tokens = quote! { [ #( #disc_values ),* ] };
427
428            account_enum_variants.push(quote! {
429                #type_ident(#type_ident)
430            });
431            account_match_arms.push(quote! {
432                x if x == #disc_tokens => {
433                    return Some(DecodedAccount::#type_ident(
434                        #type_ident::decode(&data[8..]).ok()?
435                    ))
436                }
437            });
438        }
439    }
440
441    // Process events from the IDL.
442    let mut event_enum_variants = Vec::new();
443    let mut event_match_arms = Vec::new();
444    if let Some(events) = idl.get("events").and_then(|v| v.as_array()) {
445        for event in events {
446            let name = event.get("name").and_then(|v| v.as_str()).unwrap();
447            let discriminator = event
448                .get("discriminator")
449                .and_then(|v| v.as_array())
450                .expect("Discriminator missing or not an array in events");
451            let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
452            let disc_values: Vec<u8> = discriminator
453                .iter()
454                .map(|v| v.as_u64().unwrap() as u8)
455                .collect();
456            let disc_tokens = quote! { [ #( #disc_values ),* ] };
457
458            event_enum_variants.push(quote! {
459                #type_ident(#type_ident)
460            });
461            event_match_arms.push(quote! {
462                x if x == #disc_tokens => {
463                    return Some(DecodedEvent::#type_ident(
464                        #type_ident::decode(&data[8..]).ok()?
465                    ))
466                }
467            });
468        }
469    }
470
471    let program_address = idl
472        .get("address")
473        .and_then(|v| v.as_str())
474        .expect("IDL missing program address");
475
476    let expanded = quote! {
477        use ::borsh::{BorshDeserialize, BorshSerialize};
478        use ::solana_sdk::pubkey::Pubkey;
479        use std::collections::HashMap;
480
481        pub const ID: Pubkey = ::solana_sdk::pubkey!(#program_address);
482
483        #( #struct_defs )*
484
485        #[derive(Debug)]
486        pub enum DecodedInstruction {
487            #( #enum_variants, )*
488            EmitCpi(DecodedEvent)
489        }
490
491        pub fn decode_instruction(data: &[u8]) -> Option<DecodedInstruction> {
492            if data.len() < 8 { return None; }
493            let disc = &data[..8];
494            match disc {
495                #( #match_arms, )*
496                _ => {
497                    if disc == EMIT_CPI_INSTRUCTION_DISCRIMINATOR {
498                        let payload = &data[8..];
499                        decode_event(payload).map(|event| DecodedInstruction::EmitCpi(event))
500                    } else {
501                        None
502                    }
503                },
504            }
505        }
506
507        #[derive(Debug)]
508        pub enum DecodedAccount {
509            #( #account_enum_variants, )*
510        }
511
512        pub fn decode_account(data: &[u8]) -> Option<DecodedAccount> {
513            if data.len() < 8 { return None; }
514            let disc = &data[..8];
515            match disc {
516                #( #account_match_arms, )*
517                _ => {
518                    None
519                },
520            }
521        }
522
523        #[derive(Debug)]
524        pub enum DecodedEvent {
525            #( #event_enum_variants, )*
526        }
527
528        // Some programs might call anchor's emit_cpi instruction to emit events via self-cpi so that subscribed clients
529        // can see the events without risk of the RPC's truncating them (as with traditional event logging)
530        //
531        // Source: https://github.com/coral-xyz/anchor/blob/8b391aa278387b6f6ce3133453619a175544631e/lang/attribute/event/src/lib.rs#L111-L195
532        const EMIT_CPI_INSTRUCTION_DISCRIMINATOR: [u8; 8] = [228, 69, 165, 46, 81, 203, 154, 29];
533
534        pub fn decode_event(data: &[u8]) -> Option<DecodedEvent> {
535            if data.len() < 8 { return None; }
536            let disc = &data[..8];
537
538            match disc {
539                #( #event_match_arms, )*
540                _ => {
541                    None
542                }
543            }
544        }
545    };
546
547    expanded.into()
548}