anchor_decoder/
lib.rs

1extern crate proc_macro;
2use std::collections::HashSet;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use serde_json::Value;
7use syn::{parse_macro_input, LitStr};
8
9/// Helper to convert snake_case to CamelCase (e.g. "create_order" -> "CreateOrder")
10fn to_camel_case(s: &str) -> String {
11    s.split('_')
12        .map(|word| {
13            let mut c = word.chars();
14            match c.next() {
15                None => String::new(),
16                Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
17            }
18        })
19        .collect()
20}
21
22/// Maps an IDL type (which can be a string like "u8" or an object for arrays or defined types)
23/// into the corresponding Rust type as tokens. The `generated_types` set contains the names
24/// of custom types that will be generated by this macro.
25fn map_idl_type(arg_type: &Value, generated_types: &HashSet<String>) -> proc_macro2::TokenStream {
26    if let Some(s) = arg_type.as_str() {
27        match s {
28            "u8" => quote! { u8 },
29            "u16" => quote! { u16 },
30            "u64" => quote! { u64 },
31            "i64" => quote! { i64 },
32            "bool" => quote! { bool },
33            "pubkey" => quote! { Pubkey },
34            "string" => quote! { String },
35            _ => quote! { () }, // fallback for unsupported types
36        }
37    } else if let Some(obj) = arg_type.as_object() {
38        if let Some(array_val) = obj.get("array") {
39            if let Some(arr) = array_val.as_array() {
40                if arr.len() == 2 {
41                    let inner = map_idl_type(&arr[0], generated_types);
42                    if let Some(len) = arr[1].as_u64() {
43                        let len_literal =
44                            syn::LitInt::new(&len.to_string(), proc_macro2::Span::call_site());
45                        return quote! { [#inner; #len_literal] };
46                    }
47                }
48            }
49        } else if let Some(defined) = obj.get("defined") {
50            if let Some(defined_obj) = defined.as_object() {
51                if let Some(name) = defined_obj.get("name").and_then(|n| n.as_str()) {
52                    let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
53                    // If the type was generated by our macro, reference it directly.
54                    // Otherwise assume it's external (and qualify it).
55                    if generated_types.contains(name) {
56                        return quote! { #type_ident };
57                    } else {
58                        return quote! { ::crate::#type_ident };
59                    }
60                }
61            }
62        }
63        quote! { () }
64    } else {
65        quote! { () }
66    }
67}
68
69/// Procedural macro attribute that generates decoding code from an Anchor IDL JSON file.
70/// The macro reads the file at compile time
71/// 
72/// For each instruction:
73///  - It generates a struct for the instruction's arguments (if any), with a constant discriminator.
74///  - It creates an enum variant for the instruction.
75///  - It produces a helper function (`decode_instruction`) to match and decode incoming data.
76///
77/// For each account:
78///  - It assumes the account type is defined under "types" (by matching name).
79///  - It uses the provided discriminator to generate a match arm that decodes the account data,
80///    skipping the first 8 bytes.
81#[proc_macro_attribute]
82pub fn anchor_idl(attr: TokenStream, _item: TokenStream) -> TokenStream {
83    // Get the relative IDL file path from the attribute
84    let relative_path = parse_macro_input!(attr as LitStr).value();
85
86    // Resolve path relative to crate root
87    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
88        .expect("CARGO_MANIFEST_DIR environment variable not set");
89    let idl_path = std::path::Path::new(&manifest_dir)
90        .join(relative_path)
91        .canonicalize()
92        .unwrap_or_else(|e| panic!("Failed to resolve IDL path: {}", e));
93
94    // Read and parse the IDL JSON at compile time
95    let idl_json = std::fs::read_to_string(&idl_path)
96        .unwrap_or_else(|_| panic!("Unable to read IDL file at: {}", idl_path.display()));
97    let idl: Value = serde_json::from_str(&idl_json)
98        .unwrap_or_else(|_| panic!("Invalid JSON in IDL file: {}", idl_path.display()));
99
100    // Collect the names of all types defined in the IDL.
101    let generated_types: HashSet<String> =
102        if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
103            types
104                .iter()
105                .filter_map(|t| {
106                    t.get("name")
107                        .and_then(|v| v.as_str())
108                        .map(|s| s.to_string())
109                })
110                .collect()
111        } else {
112            HashSet::new()
113        };
114
115    let mut struct_defs = Vec::new();
116
117    // Process custom type definitions.
118    if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
119        for type_def in types {
120            if let (Some(name), Some(type_info)) = (
121                type_def.get("name").and_then(|v| v.as_str()),
122                type_def.get("type").and_then(|v| v.as_object()),
123            ) {
124                let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
125
126                // Check the kind of the type.
127                if let Some(kind) = type_info.get("kind").and_then(|v| v.as_str()) {
128                    match kind {
129                        "struct" => {
130                            // Process struct definitions.
131                            if let Some(fields) = type_info.get("fields").and_then(|v| v.as_array())
132                            {
133                                let mut field_defs = Vec::new();
134                                for field in fields {
135                                    if let (Some(field_name), Some(field_type)) = (
136                                        field.get("name").and_then(|v| v.as_str()),
137                                        field.get("type"),
138                                    ) {
139                                        let field_ident = syn::Ident::new(
140                                            field_name,
141                                            proc_macro2::Span::call_site(),
142                                        );
143                                        let field_type = map_idl_type(field_type, &generated_types);
144                                        field_defs.push(quote! {
145                                            pub #field_ident: #field_type,
146                                        });
147                                    }
148                                }
149                                struct_defs.push(quote! {
150                                    #[derive(Debug, BorshSerialize, BorshDeserialize)]
151                                    pub struct #type_ident {
152                                        #( #field_defs )*
153                                    }
154                                    impl #type_ident {
155                                        pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
156                                            <Self as BorshDeserialize>::try_from_slice(data)
157                                        }
158                                    }
159                                });
160                            }
161                        }
162                        "enum" => {
163                            // Process enum definitions.
164                            if let Some(variants) =
165                                type_info.get("variants").and_then(|v| v.as_array())
166                            {
167                                let mut variant_tokens = Vec::new();
168                                for variant in variants {
169                                    if let Some(variant_name) =
170                                        variant.get("name").and_then(|v| v.as_str())
171                                    {
172                                        let variant_ident = syn::Ident::new(
173                                            variant_name,
174                                            proc_macro2::Span::call_site(),
175                                        );
176                                        variant_tokens.push(quote! {
177                                            #variant_ident,
178                                        });
179                                    }
180                                }
181                                struct_defs.push(quote! {
182                                    #[derive(Debug, BorshSerialize, BorshDeserialize)]
183                                    pub enum #type_ident {
184                                        #( #variant_tokens )*
185                                    }
186                                    impl #type_ident {
187                                        pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
188                                            <Self as BorshDeserialize>::try_from_slice(data)
189                                        }
190                                    }
191                                });
192                            }
193                        }
194                        _ => {}
195                    }
196                }
197            }
198        }
199    }
200
201    let instructions = idl
202        .get("instructions")
203        .and_then(|v| v.as_array())
204        .expect("IDL JSON does not contain an 'instructions' array");
205
206    let mut enum_variants = Vec::new();
207    let mut match_arms = Vec::new();
208
209    for inst in instructions {
210        // Get instruction name, discriminator, and args.
211        let name = inst.get("name").and_then(|v| v.as_str()).unwrap();
212        let discriminator = inst
213            .get("discriminator")
214            .and_then(|v| v.as_array())
215            .expect("Discriminator missing or not an array");
216        let args = inst
217            .get("args")
218            .and_then(|v| v.as_array())
219            .expect("Args missing or not an array");
220
221        // Convert the instruction name to CamelCase for the generated struct.
222        let struct_name_str = to_camel_case(name);
223        let struct_name = syn::Ident::new(&struct_name_str, proc_macro2::Span::call_site());
224
225        // Generate account info struct name
226        let accounts_struct_name = syn::Ident::new(
227            &format!("{}Accounts", struct_name_str),
228            proc_macro2::Span::call_site(),
229        );
230
231        // Generate a constant for the discriminator.
232        let disc_values: Vec<u8> = discriminator
233            .iter()
234            .map(|v| v.as_u64().unwrap() as u8)
235            .collect();
236        let disc_tokens = quote! { [ #( #disc_values ),* ] };
237
238        // Process accounts for this instruction
239        let mut account_consts = Vec::new();
240        let mut account_fields = Vec::new();
241        let mut account_indices = Vec::new();
242        let mut account_name_matches = Vec::new();
243        let mut account_tuples = Vec::new();
244        let mut account_index_matches = Vec::new();
245
246        if let Some(accounts) = inst.get("accounts").and_then(|v| v.as_array()) {
247            for (idx, account) in accounts.iter().enumerate() {
248                if let Some(account_name) = account.get("name").and_then(|v| v.as_str()) {
249                    let const_name = account_name.to_uppercase();
250                    let const_ident = syn::Ident::new(&const_name, proc_macro2::Span::call_site());
251                    let idx_lit =
252                        syn::LitInt::new(&idx.to_string(), proc_macro2::Span::call_site());
253
254                    account_consts.push(quote! {
255                        pub const #const_ident: usize = #idx_lit;
256                    });
257
258                    let field_ident = syn::Ident::new(account_name, proc_macro2::Span::call_site());
259                    account_fields.push(quote! {
260                        pub #field_ident: usize,
261                    });
262
263                    account_indices.push(quote! {
264                        #field_ident: #idx_lit,
265                    });
266
267                    // Create match arm for get_account_name
268                    let account_name_str = account_name;
269                    account_name_matches.push(quote! {
270                        #idx_lit => Some(#account_name_str),
271                    });
272
273                    // Create tuple for get_all_accounts
274                    account_tuples.push(quote! {
275                        (#account_name_str, Self::#const_ident)
276                    });
277
278                    // Create match arm for get_account_index
279                    account_index_matches.push(quote! {
280                        #account_name_str => Some(Self::#const_ident),
281                    });
282                }
283            }
284
285            // Generate the accounts struct
286            struct_defs.push(quote! {
287                #[derive(Debug, Clone, Copy)]
288                pub struct #accounts_struct_name {
289                    #( #account_fields )*
290                }
291
292                impl #accounts_struct_name {
293                    #( #account_consts )*
294
295                    pub const fn new() -> Self {
296                        Self {
297                            #( #account_indices )*
298                        }
299                    }
300
301                    pub fn get_account_name(&self, index: usize) -> Option<&'static str> {
302                        match index {
303                            #( #account_name_matches )*
304                            _ => None,
305                        }
306                    }
307
308                    pub fn get_all_accounts(&self) -> &'static [(&'static str, usize)] {
309                        &[
310                            #( #account_tuples, )*
311                        ]
312                    }
313                    
314                    pub fn get_account_index(&self, name: &str) -> Option<usize> {
315                        match name {
316                            #( #account_index_matches )*
317                            _ => None,
318                        }
319                    }
320                }
321            });
322        }
323
324        if !args.is_empty() {
325            // Generate struct fields by mapping each argument's type.
326            let mut fields = Vec::new();
327            for arg in args {
328                let arg_name = arg.get("name").and_then(|v| v.as_str()).unwrap();
329                let arg_type = arg.get("type").expect("Missing type in argument");
330                let field_ident = syn::Ident::new(arg_name, proc_macro2::Span::call_site());
331                let field_type = map_idl_type(arg_type, &generated_types);
332                fields.push(quote! {
333                    pub #field_ident: #field_type,
334                });
335            }
336
337            struct_defs.push(quote! {
338                #[derive(Debug, BorshSerialize, BorshDeserialize)]
339                pub struct #struct_name {
340                    #( #fields )*
341                }
342                impl #struct_name {
343                    pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
344                    pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
345                    
346                    pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
347                        // Skip the first 8 bytes (discriminator)
348                        let payload = &data[8..];
349                        <Self as BorshDeserialize>::try_from_slice(payload)
350                    }
351                    
352                    /// Maps account indices to their semantic names
353                    pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
354                        let mut result = std::collections::HashMap::new();
355                        for (i, account) in accounts.iter().enumerate() {
356                            if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
357                                result.insert(name, account);
358                            }
359                        }
360                        result
361                    }
362                }
363            });
364
365            enum_variants.push(quote! {
366                #struct_name(#struct_name)
367            });
368            match_arms.push(quote! {
369                x if x == #struct_name::DISCRIMINATOR => {
370                    return Some(DecodedInstruction::#struct_name(
371                        #struct_name::decode(data).ok()?
372                    ))
373                }
374            });
375        } else {
376            // For instructions with no arguments, generate a unit struct.
377            struct_defs.push(quote! {
378                #[derive(Debug)]
379                pub struct #struct_name;
380                impl #struct_name {
381                    pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
382                    pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
383                    
384                    /// Maps account indices to their semantic names
385                    pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
386                        let mut result = std::collections::HashMap::new();
387                        for (i, account) in accounts.iter().enumerate() {
388                            if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
389                                result.insert(name, account);
390                            }
391                        }
392                        result
393                    }
394                }
395            });
396            enum_variants.push(quote! {
397                #struct_name
398            });
399            match_arms.push(quote! {
400                x if x == #struct_name::DISCRIMINATOR => {
401                    return Some(DecodedInstruction::#struct_name)
402                }
403            });
404        }
405    }
406
407    // Process accounts from the IDL.
408    let mut account_enum_variants = Vec::new();
409    let mut account_match_arms = Vec::new();
410    if let Some(accounts) = idl.get("accounts").and_then(|v| v.as_array()) {
411        for account in accounts {
412            let name = account.get("name").and_then(|v| v.as_str()).unwrap();
413            let discriminator = account
414                .get("discriminator")
415                .and_then(|v| v.as_array())
416                .expect("Discriminator missing or not an array in accounts");
417            let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
418            let disc_values: Vec<u8> = discriminator
419                .iter()
420                .map(|v| v.as_u64().unwrap() as u8)
421                .collect();
422            let disc_tokens = quote! { [ #( #disc_values ),* ] };
423
424            account_enum_variants.push(quote! {
425                #type_ident(#type_ident)
426            });
427            account_match_arms.push(quote! {
428                x if x == #disc_tokens => {
429                    return Some(DecodedAccount::#type_ident(
430                        #type_ident::decode(&data[8..]).ok()?
431                    ))
432                }
433            });
434        }
435    }
436
437    // Process events from the IDL.
438    let mut event_enum_variants = Vec::new();
439    let mut event_match_arms = Vec::new();
440    if let Some(events) = idl.get("events").and_then(|v| v.as_array()) {
441        for event in events {
442            let name = event.get("name").and_then(|v| v.as_str()).unwrap();
443            let discriminator = event
444                .get("discriminator")
445                .and_then(|v| v.as_array())
446                .expect("Discriminator missing or not an array in events");
447            let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
448            let disc_values: Vec<u8> = discriminator
449                .iter()
450                .map(|v| v.as_u64().unwrap() as u8)
451                .collect();
452            let disc_tokens = quote! { [ #( #disc_values ),* ] };
453
454            event_enum_variants.push(quote! {
455                #type_ident(#type_ident)
456            });
457            event_match_arms.push(quote! {
458                x if x == #disc_tokens => {
459                    return Some(DecodedEvent::#type_ident(
460                        #type_ident::decode(&data[8..]).ok()?
461                    ))
462                }
463            });
464        }
465    }
466
467    let program_address = idl
468        .get("address")
469        .and_then(|v| v.as_str())
470        .expect("IDL missing program address");
471
472    let expanded = quote! {
473        use ::borsh::{BorshDeserialize, BorshSerialize};
474        use ::solana_sdk::pubkey::Pubkey;
475        use std::collections::HashMap;
476
477        pub const ID: Pubkey = ::solana_sdk::pubkey!(#program_address);
478
479        #( #struct_defs )*
480
481        #[derive(Debug)]
482        pub enum DecodedInstruction {
483            #( #enum_variants, )*
484            EmitCpi(DecodedEvent)
485        }
486
487        pub fn decode_instruction(data: &[u8]) -> Option<DecodedInstruction> {
488            if data.len() < 8 { return None; }
489            let disc = &data[..8];
490            match disc {
491                #( #match_arms, )*
492                _ => {
493                    if disc == EMIT_CPI_INSTRUCTION_DISCRIMINATOR {
494                        let payload = &data[8..];
495                        decode_event(payload).map(|event| DecodedInstruction::EmitCpi(event))
496                    } else {
497                        None
498                    }
499                },
500            }
501        }
502
503        #[derive(Debug)]
504        pub enum DecodedAccount {
505            #( #account_enum_variants, )*
506        }
507
508        pub fn decode_account(data: &[u8]) -> Option<DecodedAccount> {
509            if data.len() < 8 { return None; }
510            let disc = &data[..8];
511            match disc {
512                #( #account_match_arms, )*
513                _ => {
514                    None
515                },
516            }
517        }
518
519        #[derive(Debug)]
520        pub enum DecodedEvent {
521            #( #event_enum_variants, )*
522        }
523
524        // Some programs might call anchor's emit_cpi instruction to emit events via self-cpi so that subscribed clients
525        // can see the events without risk of the RPC's truncating them (as with traditional event logging)
526        //
527        // Source: https://github.com/coral-xyz/anchor/blob/8b391aa278387b6f6ce3133453619a175544631e/lang/attribute/event/src/lib.rs#L111-L195
528        const EMIT_CPI_INSTRUCTION_DISCRIMINATOR: [u8; 8] = [228, 69, 165, 46, 81, 203, 154, 29];
529
530        pub fn decode_event(data: &[u8]) -> Option<DecodedEvent> {
531            if data.len() < 8 { return None; }
532            let disc = &data[..8];
533
534            match disc {
535                #( #event_match_arms, )*
536                _ => {
537                    None
538                }
539            }
540        }
541    };
542
543    expanded.into()
544}