1extern crate proc_macro;
2use std::collections::HashSet;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use serde_json::Value;
7use syn::{parse_macro_input, LitStr};
8
9fn to_camel_case(s: &str) -> String {
11 s.split('_')
12 .map(|word| {
13 let mut c = word.chars();
14 match c.next() {
15 None => String::new(),
16 Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
17 }
18 })
19 .collect()
20}
21
22fn map_idl_type(arg_type: &Value, generated_types: &HashSet<String>) -> proc_macro2::TokenStream {
26 if let Some(s) = arg_type.as_str() {
27 match s {
28 "u8" => quote! { u8 },
29 "u16" => quote! { u16 },
30 "u64" => quote! { u64 },
31 "i64" => quote! { i64 },
32 "bool" => quote! { bool },
33 "pubkey" => quote! { Pubkey },
34 "string" => quote! { String },
35 _ => quote! { () }, }
37 } else if let Some(obj) = arg_type.as_object() {
38 if let Some(array_val) = obj.get("array") {
39 if let Some(arr) = array_val.as_array() {
40 if arr.len() == 2 {
41 let inner = map_idl_type(&arr[0], generated_types);
42 if let Some(len) = arr[1].as_u64() {
43 let len_literal =
44 syn::LitInt::new(&len.to_string(), proc_macro2::Span::call_site());
45 return quote! { [#inner; #len_literal] };
46 }
47 }
48 }
49 } else if let Some(defined) = obj.get("defined") {
50 if let Some(defined_obj) = defined.as_object() {
51 if let Some(name) = defined_obj.get("name").and_then(|n| n.as_str()) {
52 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
53 if generated_types.contains(name) {
56 return quote! { #type_ident };
57 } else {
58 return quote! { ::crate::#type_ident };
59 }
60 }
61 }
62 }
63 quote! { () }
64 } else {
65 quote! { () }
66 }
67}
68
69#[proc_macro_attribute]
82pub fn anchor_idl(attr: TokenStream, _item: TokenStream) -> TokenStream {
83 let relative_path = parse_macro_input!(attr as LitStr).value();
85
86 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
88 .expect("CARGO_MANIFEST_DIR environment variable not set");
89 let idl_path = std::path::Path::new(&manifest_dir)
90 .join(relative_path)
91 .canonicalize()
92 .unwrap_or_else(|e| panic!("Failed to resolve IDL path: {}", e));
93
94 let idl_json = std::fs::read_to_string(&idl_path)
96 .unwrap_or_else(|_| panic!("Unable to read IDL file at: {}", idl_path.display()));
97 let idl: Value = serde_json::from_str(&idl_json)
98 .unwrap_or_else(|_| panic!("Invalid JSON in IDL file: {}", idl_path.display()));
99
100 let generated_types: HashSet<String> =
102 if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
103 types
104 .iter()
105 .filter_map(|t| {
106 t.get("name")
107 .and_then(|v| v.as_str())
108 .map(|s| s.to_string())
109 })
110 .collect()
111 } else {
112 HashSet::new()
113 };
114
115 let mut struct_defs = Vec::new();
116
117 if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
119 for type_def in types {
120 if let (Some(name), Some(type_info)) = (
121 type_def.get("name").and_then(|v| v.as_str()),
122 type_def.get("type").and_then(|v| v.as_object()),
123 ) {
124 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
125
126 if let Some(kind) = type_info.get("kind").and_then(|v| v.as_str()) {
128 match kind {
129 "struct" => {
130 if let Some(fields) = type_info.get("fields").and_then(|v| v.as_array())
132 {
133 let mut field_defs = Vec::new();
134 for field in fields {
135 if let (Some(field_name), Some(field_type)) = (
136 field.get("name").and_then(|v| v.as_str()),
137 field.get("type"),
138 ) {
139 let field_ident = syn::Ident::new(
140 field_name,
141 proc_macro2::Span::call_site(),
142 );
143 let field_type = map_idl_type(field_type, &generated_types);
144 field_defs.push(quote! {
145 pub #field_ident: #field_type,
146 });
147 }
148 }
149 struct_defs.push(quote! {
150 #[derive(Debug, BorshSerialize, BorshDeserialize)]
151 pub struct #type_ident {
152 #( #field_defs )*
153 }
154 impl #type_ident {
155 pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
156 <Self as BorshDeserialize>::try_from_slice(data)
157 }
158 }
159 });
160 }
161 }
162 "enum" => {
163 if let Some(variants) =
165 type_info.get("variants").and_then(|v| v.as_array())
166 {
167 let mut variant_tokens = Vec::new();
168 for variant in variants {
169 if let Some(variant_name) =
170 variant.get("name").and_then(|v| v.as_str())
171 {
172 let variant_ident = syn::Ident::new(
173 variant_name,
174 proc_macro2::Span::call_site(),
175 );
176 variant_tokens.push(quote! {
177 #variant_ident,
178 });
179 }
180 }
181 struct_defs.push(quote! {
182 #[derive(Debug, BorshSerialize, BorshDeserialize)]
183 pub enum #type_ident {
184 #( #variant_tokens )*
185 }
186 impl #type_ident {
187 pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
188 <Self as BorshDeserialize>::try_from_slice(data)
189 }
190 }
191 });
192 }
193 }
194 _ => {}
195 }
196 }
197 }
198 }
199 }
200
201 let instructions = idl
202 .get("instructions")
203 .and_then(|v| v.as_array())
204 .expect("IDL JSON does not contain an 'instructions' array");
205
206 let mut enum_variants = Vec::new();
207 let mut match_arms = Vec::new();
208
209 for inst in instructions {
210 let name = inst.get("name").and_then(|v| v.as_str()).unwrap();
212 let discriminator = inst
213 .get("discriminator")
214 .and_then(|v| v.as_array())
215 .expect("Discriminator missing or not an array");
216 let args = inst
217 .get("args")
218 .and_then(|v| v.as_array())
219 .expect("Args missing or not an array");
220
221 let struct_name_str = to_camel_case(name);
223 let struct_name = syn::Ident::new(&struct_name_str, proc_macro2::Span::call_site());
224
225 let accounts_struct_name = syn::Ident::new(
227 &format!("{}Accounts", struct_name_str),
228 proc_macro2::Span::call_site(),
229 );
230
231 let disc_values: Vec<u8> = discriminator
233 .iter()
234 .map(|v| v.as_u64().unwrap() as u8)
235 .collect();
236 let disc_tokens = quote! { [ #( #disc_values ),* ] };
237
238 let mut account_consts = Vec::new();
240 let mut account_fields = Vec::new();
241 let mut account_indices = Vec::new();
242 let mut account_name_matches = Vec::new();
243 let mut account_tuples = Vec::new();
244 let mut account_index_matches = Vec::new();
245
246 if let Some(accounts) = inst.get("accounts").and_then(|v| v.as_array()) {
247 for (idx, account) in accounts.iter().enumerate() {
248 if let Some(account_name) = account.get("name").and_then(|v| v.as_str()) {
249 let const_name = account_name.to_uppercase();
250 let const_ident = syn::Ident::new(&const_name, proc_macro2::Span::call_site());
251 let idx_lit =
252 syn::LitInt::new(&idx.to_string(), proc_macro2::Span::call_site());
253
254 account_consts.push(quote! {
255 pub const #const_ident: usize = #idx_lit;
256 });
257
258 let field_ident = syn::Ident::new(account_name, proc_macro2::Span::call_site());
259 account_fields.push(quote! {
260 pub #field_ident: usize,
261 });
262
263 account_indices.push(quote! {
264 #field_ident: #idx_lit,
265 });
266
267 let account_name_str = account_name;
269 account_name_matches.push(quote! {
270 #idx_lit => Some(#account_name_str),
271 });
272
273 account_tuples.push(quote! {
275 (#account_name_str, Self::#const_ident)
276 });
277
278 account_index_matches.push(quote! {
280 #account_name_str => Some(Self::#const_ident),
281 });
282 }
283 }
284
285 struct_defs.push(quote! {
287 #[derive(Debug, Clone, Copy)]
288 pub struct #accounts_struct_name {
289 #( #account_fields )*
290 }
291
292 impl #accounts_struct_name {
293 #( #account_consts )*
294
295 pub const fn new() -> Self {
296 Self {
297 #( #account_indices )*
298 }
299 }
300
301 pub fn get_account_name(&self, index: usize) -> Option<&'static str> {
302 match index {
303 #( #account_name_matches )*
304 _ => None,
305 }
306 }
307
308 pub fn get_all_accounts(&self) -> &'static [(&'static str, usize)] {
309 &[
310 #( #account_tuples, )*
311 ]
312 }
313
314 pub fn get_account_index(&self, name: &str) -> Option<usize> {
315 match name {
316 #( #account_index_matches )*
317 _ => None,
318 }
319 }
320 }
321 });
322 }
323
324 if !args.is_empty() {
325 let mut fields = Vec::new();
327 for arg in args {
328 let arg_name = arg.get("name").and_then(|v| v.as_str()).unwrap();
329 let arg_type = arg.get("type").expect("Missing type in argument");
330 let field_ident = syn::Ident::new(arg_name, proc_macro2::Span::call_site());
331 let field_type = map_idl_type(arg_type, &generated_types);
332 fields.push(quote! {
333 pub #field_ident: #field_type,
334 });
335 }
336
337 struct_defs.push(quote! {
338 #[derive(Debug, BorshSerialize, BorshDeserialize)]
339 pub struct #struct_name {
340 #( #fields )*
341 }
342 impl #struct_name {
343 pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
344 pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
345
346 pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
347 let payload = &data[8..];
349 <Self as BorshDeserialize>::try_from_slice(payload)
350 }
351
352 pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
354 let mut result = std::collections::HashMap::new();
355 for (i, account) in accounts.iter().enumerate() {
356 if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
357 result.insert(name, account);
358 }
359 }
360 result
361 }
362 }
363 });
364
365 enum_variants.push(quote! {
366 #struct_name(#struct_name)
367 });
368 match_arms.push(quote! {
369 x if x == #struct_name::DISCRIMINATOR => {
370 return Some(DecodedInstruction::#struct_name(
371 #struct_name::decode(data).ok()?
372 ))
373 }
374 });
375 } else {
376 struct_defs.push(quote! {
378 #[derive(Debug)]
379 pub struct #struct_name;
380 impl #struct_name {
381 pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
382 pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
383
384 pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
386 let mut result = std::collections::HashMap::new();
387 for (i, account) in accounts.iter().enumerate() {
388 if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
389 result.insert(name, account);
390 }
391 }
392 result
393 }
394 }
395 });
396 enum_variants.push(quote! {
397 #struct_name
398 });
399 match_arms.push(quote! {
400 x if x == #struct_name::DISCRIMINATOR => {
401 return Some(DecodedInstruction::#struct_name)
402 }
403 });
404 }
405 }
406
407 let mut account_enum_variants = Vec::new();
409 let mut account_match_arms = Vec::new();
410 if let Some(accounts) = idl.get("accounts").and_then(|v| v.as_array()) {
411 for account in accounts {
412 let name = account.get("name").and_then(|v| v.as_str()).unwrap();
413 let discriminator = account
414 .get("discriminator")
415 .and_then(|v| v.as_array())
416 .expect("Discriminator missing or not an array in accounts");
417 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
418 let disc_values: Vec<u8> = discriminator
419 .iter()
420 .map(|v| v.as_u64().unwrap() as u8)
421 .collect();
422 let disc_tokens = quote! { [ #( #disc_values ),* ] };
423
424 account_enum_variants.push(quote! {
425 #type_ident(#type_ident)
426 });
427 account_match_arms.push(quote! {
428 x if x == #disc_tokens => {
429 return Some(DecodedAccount::#type_ident(
430 #type_ident::decode(&data[8..]).ok()?
431 ))
432 }
433 });
434 }
435 }
436
437 let mut event_enum_variants = Vec::new();
439 let mut event_match_arms = Vec::new();
440 if let Some(events) = idl.get("events").and_then(|v| v.as_array()) {
441 for event in events {
442 let name = event.get("name").and_then(|v| v.as_str()).unwrap();
443 let discriminator = event
444 .get("discriminator")
445 .and_then(|v| v.as_array())
446 .expect("Discriminator missing or not an array in events");
447 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
448 let disc_values: Vec<u8> = discriminator
449 .iter()
450 .map(|v| v.as_u64().unwrap() as u8)
451 .collect();
452 let disc_tokens = quote! { [ #( #disc_values ),* ] };
453
454 event_enum_variants.push(quote! {
455 #type_ident(#type_ident)
456 });
457 event_match_arms.push(quote! {
458 x if x == #disc_tokens => {
459 return Some(DecodedEvent::#type_ident(
460 #type_ident::decode(&data[8..]).ok()?
461 ))
462 }
463 });
464 }
465 }
466
467 let program_address = idl
468 .get("address")
469 .and_then(|v| v.as_str())
470 .expect("IDL missing program address");
471
472 let expanded = quote! {
473 use ::borsh::{BorshDeserialize, BorshSerialize};
474 use ::solana_sdk::pubkey::Pubkey;
475 use std::collections::HashMap;
476
477 pub const ID: Pubkey = ::solana_sdk::pubkey!(#program_address);
478
479 #( #struct_defs )*
480
481 #[derive(Debug)]
482 pub enum DecodedInstruction {
483 #( #enum_variants, )*
484 EmitCpi(DecodedEvent)
485 }
486
487 pub fn decode_instruction(data: &[u8]) -> Option<DecodedInstruction> {
488 if data.len() < 8 { return None; }
489 let disc = &data[..8];
490 match disc {
491 #( #match_arms, )*
492 _ => {
493 if disc == EMIT_CPI_INSTRUCTION_DISCRIMINATOR {
494 let payload = &data[8..];
495 decode_event(payload).map(|event| DecodedInstruction::EmitCpi(event))
496 } else {
497 None
498 }
499 },
500 }
501 }
502
503 #[derive(Debug)]
504 pub enum DecodedAccount {
505 #( #account_enum_variants, )*
506 }
507
508 pub fn decode_account(data: &[u8]) -> Option<DecodedAccount> {
509 if data.len() < 8 { return None; }
510 let disc = &data[..8];
511 match disc {
512 #( #account_match_arms, )*
513 _ => {
514 None
515 },
516 }
517 }
518
519 #[derive(Debug)]
520 pub enum DecodedEvent {
521 #( #event_enum_variants, )*
522 }
523
524 const EMIT_CPI_INSTRUCTION_DISCRIMINATOR: [u8; 8] = [228, 69, 165, 46, 81, 203, 154, 29];
529
530 pub fn decode_event(data: &[u8]) -> Option<DecodedEvent> {
531 if data.len() < 8 { return None; }
532 let disc = &data[..8];
533
534 match disc {
535 #( #event_match_arms, )*
536 _ => {
537 None
538 }
539 }
540 }
541 };
542
543 expanded.into()
544}