1extern crate proc_macro;
2use std::collections::HashSet;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use serde_json::Value;
7use syn::{parse_macro_input, LitStr};
8
9fn to_camel_case(s: &str) -> String {
11 s.split('_')
12 .map(|word| {
13 let mut c = word.chars();
14 match c.next() {
15 None => String::new(),
16 Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
17 }
18 })
19 .collect()
20}
21
22fn map_idl_type(arg_type: &Value, generated_types: &HashSet<String>) -> proc_macro2::TokenStream {
26 if let Some(s) = arg_type.as_str() {
27 match s {
28 "u8" => quote! { u8 },
29 "u16" => quote! { u16 },
30 "u64" => quote! { u64 },
31 "i64" => quote! { i64 },
32 "bool" => quote! { bool },
33 "pubkey" => quote! { Pubkey },
34 "string" => quote! { String },
35 _ => quote! { () }, }
37 } else if let Some(obj) = arg_type.as_object() {
38 if let Some(array_val) = obj.get("array") {
39 if let Some(arr) = array_val.as_array() {
40 if arr.len() == 2 {
41 let inner = map_idl_type(&arr[0], generated_types);
42 if let Some(len) = arr[1].as_u64() {
43 let len_literal =
44 syn::LitInt::new(&len.to_string(), proc_macro2::Span::call_site());
45 return quote! { [#inner; #len_literal] };
46 }
47 }
48 }
49 } else if let Some(defined) = obj.get("defined") {
50 if let Some(defined_obj) = defined.as_object() {
51 if let Some(name) = defined_obj.get("name").and_then(|n| n.as_str()) {
52 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
53 if generated_types.contains(name) {
56 return quote! { #type_ident };
57 } else {
58 return quote! { ::crate::#type_ident };
59 }
60 }
61 }
62 }
63 quote! { () }
64 } else {
65 quote! { () }
66 }
67}
68
69#[proc_macro_attribute]
82pub fn anchor_idl(attr: TokenStream, _item: TokenStream) -> TokenStream {
83 let relative_path = parse_macro_input!(attr as LitStr).value();
85
86 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
88 .expect("CARGO_MANIFEST_DIR environment variable not set");
89 let idl_path = std::path::Path::new(&manifest_dir)
90 .join(relative_path)
91 .canonicalize()
92 .unwrap_or_else(|e| panic!("Failed to resolve IDL path: {}", e));
93
94 let idl_json = std::fs::read_to_string(&idl_path)
96 .unwrap_or_else(|_| panic!("Unable to read IDL file at: {}", idl_path.display()));
97 let idl: Value = serde_json::from_str(&idl_json)
98 .unwrap_or_else(|_| panic!("Invalid JSON in IDL file: {}", idl_path.display()));
99
100 let generated_types: HashSet<String> =
102 if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
103 types
104 .iter()
105 .filter_map(|t| {
106 t.get("name")
107 .and_then(|v| v.as_str())
108 .map(|s| s.to_string())
109 })
110 .collect()
111 } else {
112 HashSet::new()
113 };
114
115 let mut struct_defs = Vec::new();
116
117 if let Some(types) = idl.get("types").and_then(|v| v.as_array()) {
119 for type_def in types {
120 if let (Some(name), Some(type_info)) = (
121 type_def.get("name").and_then(|v| v.as_str()),
122 type_def.get("type").and_then(|v| v.as_object()),
123 ) {
124 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
125
126 if let Some(kind) = type_info.get("kind").and_then(|v| v.as_str()) {
128 match kind {
129 "struct" => {
130 let field_defs = if let Some(fields) = type_info.get("fields").and_then(|v| v.as_array()) {
132 let mut field_defs = Vec::new();
133 for field in fields {
134 if let (Some(field_name), Some(field_type)) = (
135 field.get("name").and_then(|v| v.as_str()),
136 field.get("type"),
137 ) {
138 let field_ident = syn::Ident::new(
139 field_name,
140 proc_macro2::Span::call_site(),
141 );
142 let field_type = map_idl_type(field_type, &generated_types);
143 field_defs.push(quote! {
144 pub #field_ident: #field_type,
145 });
146 }
147 }
148 field_defs
149 } else {
150 Vec::new()
152 };
153
154 struct_defs.push(quote! {
155 #[derive(Debug, BorshSerialize, BorshDeserialize)]
156 pub struct #type_ident {
157 #( #field_defs )*
158 }
159 impl #type_ident {
160 pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
161 <Self as BorshDeserialize>::try_from_slice(data)
162 }
163 }
164 });
165 }
166 "enum" => {
167 if let Some(variants) =
169 type_info.get("variants").and_then(|v| v.as_array())
170 {
171 let mut variant_tokens = Vec::new();
172 for variant in variants {
173 if let Some(variant_name) =
174 variant.get("name").and_then(|v| v.as_str())
175 {
176 let variant_ident = syn::Ident::new(
177 variant_name,
178 proc_macro2::Span::call_site(),
179 );
180 variant_tokens.push(quote! {
181 #variant_ident,
182 });
183 }
184 }
185 struct_defs.push(quote! {
186 #[derive(Debug, BorshSerialize, BorshDeserialize)]
187 pub enum #type_ident {
188 #( #variant_tokens )*
189 }
190 impl #type_ident {
191 pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
192 <Self as BorshDeserialize>::try_from_slice(data)
193 }
194 }
195 });
196 }
197 }
198 _ => {}
199 }
200 }
201 }
202 }
203 }
204
205 let instructions = idl
206 .get("instructions")
207 .and_then(|v| v.as_array())
208 .expect("IDL JSON does not contain an 'instructions' array");
209
210 let mut enum_variants = Vec::new();
211 let mut match_arms = Vec::new();
212
213 for inst in instructions {
214 let name = inst.get("name").and_then(|v| v.as_str()).unwrap();
216 let discriminator = inst
217 .get("discriminator")
218 .and_then(|v| v.as_array())
219 .expect("Discriminator missing or not an array");
220 let args = inst
221 .get("args")
222 .and_then(|v| v.as_array())
223 .expect("Args missing or not an array");
224
225 let struct_name_str = to_camel_case(name);
227 let struct_name = syn::Ident::new(&struct_name_str, proc_macro2::Span::call_site());
228
229 let accounts_struct_name = syn::Ident::new(
231 &format!("{}Accounts", struct_name_str),
232 proc_macro2::Span::call_site(),
233 );
234
235 let disc_values: Vec<u8> = discriminator
237 .iter()
238 .map(|v| v.as_u64().unwrap() as u8)
239 .collect();
240 let disc_tokens = quote! { [ #( #disc_values ),* ] };
241
242 let mut account_consts = Vec::new();
244 let mut account_fields = Vec::new();
245 let mut account_indices = Vec::new();
246 let mut account_name_matches = Vec::new();
247 let mut account_tuples = Vec::new();
248 let mut account_index_matches = Vec::new();
249
250 if let Some(accounts) = inst.get("accounts").and_then(|v| v.as_array()) {
251 for (idx, account) in accounts.iter().enumerate() {
252 if let Some(account_name) = account.get("name").and_then(|v| v.as_str()) {
253 let const_name = account_name.to_uppercase();
254 let const_ident = syn::Ident::new(&const_name, proc_macro2::Span::call_site());
255 let idx_lit =
256 syn::LitInt::new(&idx.to_string(), proc_macro2::Span::call_site());
257
258 account_consts.push(quote! {
259 pub const #const_ident: usize = #idx_lit;
260 });
261
262 let field_ident = syn::Ident::new(account_name, proc_macro2::Span::call_site());
263 account_fields.push(quote! {
264 pub #field_ident: usize,
265 });
266
267 account_indices.push(quote! {
268 #field_ident: #idx_lit,
269 });
270
271 let account_name_str = account_name;
273 account_name_matches.push(quote! {
274 #idx_lit => Some(#account_name_str),
275 });
276
277 account_tuples.push(quote! {
279 (#account_name_str, Self::#const_ident)
280 });
281
282 account_index_matches.push(quote! {
284 #account_name_str => Some(Self::#const_ident),
285 });
286 }
287 }
288
289 struct_defs.push(quote! {
291 #[derive(Debug, Clone, Copy)]
292 pub struct #accounts_struct_name {
293 #( #account_fields )*
294 }
295
296 impl #accounts_struct_name {
297 #( #account_consts )*
298
299 pub const fn new() -> Self {
300 Self {
301 #( #account_indices )*
302 }
303 }
304
305 pub fn get_account_name(&self, index: usize) -> Option<&'static str> {
306 match index {
307 #( #account_name_matches )*
308 _ => None,
309 }
310 }
311
312 pub fn get_all_accounts(&self) -> &'static [(&'static str, usize)] {
313 &[
314 #( #account_tuples, )*
315 ]
316 }
317
318 pub fn get_account_index(&self, name: &str) -> Option<usize> {
319 match name {
320 #( #account_index_matches )*
321 _ => None,
322 }
323 }
324 }
325 });
326 }
327
328 if !args.is_empty() {
329 let mut fields = Vec::new();
331 for arg in args {
332 let arg_name = arg.get("name").and_then(|v| v.as_str()).unwrap();
333 let arg_type = arg.get("type").expect("Missing type in argument");
334 let field_ident = syn::Ident::new(arg_name, proc_macro2::Span::call_site());
335 let field_type = map_idl_type(arg_type, &generated_types);
336 fields.push(quote! {
337 pub #field_ident: #field_type,
338 });
339 }
340
341 struct_defs.push(quote! {
342 #[derive(Debug, BorshSerialize, BorshDeserialize)]
343 pub struct #struct_name {
344 #( #fields )*
345 }
346 impl #struct_name {
347 pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
348 pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
349
350 pub fn decode(data: &[u8]) -> Result<Self, ::std::io::Error> {
351 let payload = &data[8..];
353 <Self as BorshDeserialize>::try_from_slice(payload)
354 }
355
356 pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
358 let mut result = std::collections::HashMap::new();
359 for (i, account) in accounts.iter().enumerate() {
360 if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
361 result.insert(name, account);
362 }
363 }
364 result
365 }
366 }
367 });
368
369 enum_variants.push(quote! {
370 #struct_name(#struct_name)
371 });
372 match_arms.push(quote! {
373 x if x == #struct_name::DISCRIMINATOR => {
374 return Some(DecodedInstruction::#struct_name(
375 #struct_name::decode(data).ok()?
376 ))
377 }
378 });
379 } else {
380 struct_defs.push(quote! {
382 #[derive(Debug)]
383 pub struct #struct_name;
384 impl #struct_name {
385 pub const DISCRIMINATOR: [u8; 8] = #disc_tokens;
386 pub const ACCOUNTS: #accounts_struct_name = #accounts_struct_name::new();
387
388 pub fn map_accounts<'a>(accounts: &'a [Pubkey]) -> std::collections::HashMap<&'static str, &'a Pubkey> {
390 let mut result = std::collections::HashMap::new();
391 for (i, account) in accounts.iter().enumerate() {
392 if let Some(name) = Self::ACCOUNTS.get_account_name(i) {
393 result.insert(name, account);
394 }
395 }
396 result
397 }
398 }
399 });
400 enum_variants.push(quote! {
401 #struct_name
402 });
403 match_arms.push(quote! {
404 x if x == #struct_name::DISCRIMINATOR => {
405 return Some(DecodedInstruction::#struct_name)
406 }
407 });
408 }
409 }
410
411 let mut account_enum_variants = Vec::new();
413 let mut account_match_arms = Vec::new();
414 if let Some(accounts) = idl.get("accounts").and_then(|v| v.as_array()) {
415 for account in accounts {
416 let name = account.get("name").and_then(|v| v.as_str()).unwrap();
417 let discriminator = account
418 .get("discriminator")
419 .and_then(|v| v.as_array())
420 .expect("Discriminator missing or not an array in accounts");
421 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
422 let disc_values: Vec<u8> = discriminator
423 .iter()
424 .map(|v| v.as_u64().unwrap() as u8)
425 .collect();
426 let disc_tokens = quote! { [ #( #disc_values ),* ] };
427
428 account_enum_variants.push(quote! {
429 #type_ident(#type_ident)
430 });
431 account_match_arms.push(quote! {
432 x if x == #disc_tokens => {
433 return Some(DecodedAccount::#type_ident(
434 #type_ident::decode(&data[8..]).ok()?
435 ))
436 }
437 });
438 }
439 }
440
441 let mut event_enum_variants = Vec::new();
443 let mut event_match_arms = Vec::new();
444 if let Some(events) = idl.get("events").and_then(|v| v.as_array()) {
445 for event in events {
446 let name = event.get("name").and_then(|v| v.as_str()).unwrap();
447 let discriminator = event
448 .get("discriminator")
449 .and_then(|v| v.as_array())
450 .expect("Discriminator missing or not an array in events");
451 let type_ident = syn::Ident::new(name, proc_macro2::Span::call_site());
452 let disc_values: Vec<u8> = discriminator
453 .iter()
454 .map(|v| v.as_u64().unwrap() as u8)
455 .collect();
456 let disc_tokens = quote! { [ #( #disc_values ),* ] };
457
458 event_enum_variants.push(quote! {
459 #type_ident(#type_ident)
460 });
461 event_match_arms.push(quote! {
462 x if x == #disc_tokens => {
463 return Some(DecodedEvent::#type_ident(
464 #type_ident::decode(&data[8..]).ok()?
465 ))
466 }
467 });
468 }
469 }
470
471 let program_address = idl
472 .get("address")
473 .and_then(|v| v.as_str())
474 .expect("IDL missing program address");
475
476 let expanded = quote! {
477 use ::borsh::{BorshDeserialize, BorshSerialize};
478 use ::solana_sdk::pubkey::Pubkey;
479 use std::collections::HashMap;
480
481 pub const ID: Pubkey = ::solana_sdk::pubkey!(#program_address);
482
483 #( #struct_defs )*
484
485 #[derive(Debug)]
486 pub enum DecodedInstruction {
487 #( #enum_variants, )*
488 EmitCpi(DecodedEvent)
489 }
490
491 pub fn decode_instruction(data: &[u8]) -> Option<DecodedInstruction> {
492 if data.len() < 8 { return None; }
493 let disc = &data[..8];
494 match disc {
495 #( #match_arms, )*
496 _ => {
497 if disc == EMIT_CPI_INSTRUCTION_DISCRIMINATOR {
498 let payload = &data[8..];
499 decode_event(payload).map(|event| DecodedInstruction::EmitCpi(event))
500 } else {
501 None
502 }
503 },
504 }
505 }
506
507 #[derive(Debug)]
508 pub enum DecodedAccount {
509 #( #account_enum_variants, )*
510 }
511
512 pub fn decode_account(data: &[u8]) -> Option<DecodedAccount> {
513 if data.len() < 8 { return None; }
514 let disc = &data[..8];
515 match disc {
516 #( #account_match_arms, )*
517 _ => {
518 None
519 },
520 }
521 }
522
523 #[derive(Debug)]
524 pub enum DecodedEvent {
525 #( #event_enum_variants, )*
526 }
527
528 const EMIT_CPI_INSTRUCTION_DISCRIMINATOR: [u8; 8] = [228, 69, 165, 46, 81, 203, 154, 29];
533
534 pub fn decode_event(data: &[u8]) -> Option<DecodedEvent> {
535 if data.len() < 8 { return None; }
536 let disc = &data[..8];
537
538 match disc {
539 #( #event_match_arms, )*
540 _ => {
541 None
542 }
543 }
544 }
545 };
546
547 expanded.into()
548}