1use std::{
2 fs::{self},
3 io::Write,
4 path::Path,
5 process::{Command, Stdio},
6};
7
8use proc_macro2::TokenStream;
9use quote::quote;
10use serde::{Deserialize, Serialize};
11use sha2::Digest;
12use syn::{Ident, Type};
13
14#[derive(Debug, Serialize, Deserialize)]
15struct Idl {
16 version: String,
17 name: String,
18 instructions: Vec<Instruction>,
19 types: Vec<TypeDef>,
20 accounts: Vec<AccountDef>,
21 events: Vec<EventDef>,
22 errors: Vec<ErrorDef>,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct Instruction {
27 name: String,
28 accounts: Vec<Account>,
29 args: Vec<Arg>,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33struct Account {
34 name: String,
35 #[serde(rename = "isMut")]
36 is_mut: bool,
37 #[serde(rename = "isSigner")]
38 is_signer: bool,
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42struct Arg {
43 name: String,
44 #[serde(rename = "type")]
45 arg_type: ArgType,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49#[serde(untagged)]
50enum ArgType {
51 Simple(String),
52 Defined { defined: String },
53 Array { array: (Box<ArgType>, usize) },
54 Option { option: Box<ArgType> },
55 Vec { vec: Box<ArgType> },
56}
57
58impl ArgType {
59 fn to_rust_type(&self) -> String {
60 match self {
61 ArgType::Simple(t) => {
62 if t == "publicKey" {
64 "Pubkey".to_string()
65 } else if t == "bytes" {
66 "Vec<u8>".to_string()
67 } else if t == "string" {
68 "String".to_string()
69 } else {
70 t.clone()
71 }
72 }
73 ArgType::Defined { defined } => defined.clone(),
74 ArgType::Array { array: (t, len) } => {
75 let rust_type = t.to_rust_type();
76 if *len == 64_usize && rust_type == "u8" {
78 "Signature".into()
80 } else {
81 format!("[{}; {}]", t.to_rust_type(), len)
82 }
83 }
84 ArgType::Option { option } => format!("Option<{}>", option.to_rust_type()),
85 ArgType::Vec { vec } => format!("Vec<{}>", vec.to_rust_type()),
86 }
87 }
88}
89
90#[derive(Debug, Serialize, Deserialize)]
91struct TypeDef {
92 name: String,
93 #[serde(rename = "type")]
94 type_def: TypeData,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98#[serde(tag = "kind")]
99enum TypeData {
100 #[serde(rename = "struct")]
101 Struct { fields: Vec<StructField> },
102 #[serde(rename = "enum")]
103 Enum { variants: Vec<EnumVariant> },
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107struct StructField {
108 name: String,
109 #[serde(rename = "type")]
110 field_type: ArgType,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114#[serde(untagged)]
115enum EnumVariant {
116 Complex {
118 name: String,
119 fields: Vec<StructField>,
120 },
121 Simple {
122 name: String,
123 },
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127struct AccountDef {
128 name: String,
129 #[serde(rename = "type")]
130 account_type: AccountType,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134struct AccountType {
135 kind: String, fields: Vec<StructField>,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140struct ErrorDef {
141 code: u32,
142 name: String,
143 msg: String,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147struct EventDef {
148 name: String,
149 fields: Vec<EventField>,
150}
151
152#[derive(Debug, Serialize, Deserialize)]
153struct EventField {
154 name: String,
155 #[serde(rename = "type")]
156 field_type: ArgType,
157 index: bool,
158}
159
160fn generate_idl_types(idl: &Idl) -> String {
161 let mut instructions_tokens = quote! {};
162 let mut types_tokens = quote! {};
163 let mut accounts_tokens = quote! {};
164 let mut errors_tokens = quote! {};
165 let mut events_tokens = quote! {};
166 let idl_version = syn::LitStr::new(&idl.version, proc_macro2::Span::call_site());
167
168 for type_def in &idl.types {
170 let type_name = Ident::new(
171 &capitalize_first_letter(&type_def.name),
172 proc_macro2::Span::call_site(),
173 );
174 let type_tokens = match &type_def.type_def {
175 TypeData::Enum { variants } => {
176 let has_complex_variant = variants.iter().any(|v| match v {
177 EnumVariant::Complex { .. } => true,
178 _ => false,
179 });
180
181 let variant_tokens =
182 variants
183 .iter()
184 .enumerate()
185 .map(|(i, variant)| match variant {
186 EnumVariant::Simple { name } => {
187 let variant_name = Ident::new(name, proc_macro2::Span::call_site());
188 if i == 0 {
189 quote! {
190 #[default]
191 #variant_name,
192 }
193 } else {
194 quote! {
195 #variant_name,
196 }
197 }
198 }
199 EnumVariant::Complex { name, fields } => {
200 let variant_name = Ident::new(name, proc_macro2::Span::call_site());
201 let field_tokens = fields.iter().map(|field| {
202 let field_name = Ident::new(
203 &to_snake_case(&field.name),
204 proc_macro2::Span::call_site(),
205 );
206 let field_type: Type =
207 syn::parse_str(&field.field_type.to_rust_type()).unwrap();
208 quote! {
209 #field_name: #field_type,
210 }
211 });
212 quote! {
213 #variant_name {
214 #(#field_tokens)*
215 },
216 }
217 }
218 });
219
220 if has_complex_variant {
221 quote! {
222 #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Debug, PartialEq)]
223 pub enum #type_name {
224 #(#variant_tokens)*
225 }
226 }
227 } else {
228 quote! {
230 #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Default, Debug, PartialEq)]
231 pub enum #type_name {
232 #(#variant_tokens)*
233 }
234 }
235 }
236 }
237 TypeData::Struct { fields } => {
238 let struct_name =
239 Ident::new(type_def.name.as_str(), proc_macro2::Span::call_site());
240 let struct_fields = fields.iter().map(|field| {
241 let field_name =
242 Ident::new(&to_snake_case(&field.name), proc_macro2::Span::call_site());
243 let field_type: syn::Type =
244 syn::parse_str(&field.field_type.to_rust_type()).unwrap();
245 quote! {
246 pub #field_name: #field_type,
247 }
248 });
249
250 quote! {
251 #[repr(C)]
252 #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Default, Debug, PartialEq)]
253 pub struct #struct_name {
254 #(#struct_fields)*
255 }
256 }
257 }
258 };
259
260 types_tokens = quote! {
261 #types_tokens
262 #type_tokens
263 };
264 }
265
266 for account in &idl.accounts {
268 let struct_name = Ident::new(&account.name, proc_macro2::Span::call_site());
269
270 let mut has_vec_field = false;
271 let struct_fields: Vec<TokenStream> = account
272 .account_type
273 .fields
274 .iter()
275 .map(|field| {
276 let field_name =
277 Ident::new(&to_snake_case(&field.name), proc_macro2::Span::call_site());
278 if let ArgType::Vec { .. } = field.field_type {
279 has_vec_field = true;
280 }
281 let mut serde_decorator = TokenStream::new();
282 let mut field_type: Type =
283 syn::parse_str(&field.field_type.to_rust_type()).unwrap();
284 if field_name == "padding" {
286 if let ArgType::Array { array: (_t, len) } = &field.field_type {
287 field_type = syn::parse_str(&format!("Padding<{len}>")).unwrap();
288 serde_decorator = quote! {
289 #[serde(skip)]
290 };
291 }
292 }
293
294 quote! {
295 #serde_decorator
296 pub #field_name: #field_type,
297 }
298 })
299 .collect();
300
301 let derive_tokens = if !has_vec_field {
302 quote! {
303 #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Default, Debug, PartialEq)]
304 }
305 } else {
306 quote! {
309 #[derive(AnchorSerialize, AnchorDeserialize, Serialize, Deserialize, Clone, Default, Debug, PartialEq)]
310 }
311 };
312
313 let zc_tokens = if !has_vec_field {
314 quote! {
316 #[automatically_derived]
317 unsafe impl anchor_lang::__private::bytemuck::Pod for #struct_name {}
318 #[automatically_derived]
319 unsafe impl anchor_lang::__private::bytemuck::Zeroable for #struct_name {}
320 #[automatically_derived]
321 impl anchor_lang::ZeroCopy for #struct_name {}
322 }
323 } else {
324 Default::default()
325 };
326
327 let discriminator: TokenStream = format!("{:?}", sighash("account", &account.name))
328 .parse()
329 .unwrap();
330 let struct_def = quote! {
331 #[repr(C)]
332 #derive_tokens
333 pub struct #struct_name {
334 #(#struct_fields)*
335 }
336 #[automatically_derived]
337 impl anchor_lang::Discriminator for #struct_name {
338 const DISCRIMINATOR: &[u8] = &#discriminator;
339 }
340 #zc_tokens
341 #[automatically_derived]
342 impl anchor_lang::AccountSerialize for #struct_name {
343 fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
344 if writer.write_all(Self::DISCRIMINATOR).is_err() {
345 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
346 }
347
348 if AnchorSerialize::serialize(self, writer).is_err() {
349 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
350 }
351
352 Ok(())
353 }
354 }
355 #[automatically_derived]
356 impl anchor_lang::AccountDeserialize for #struct_name {
357 fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
358 let given_disc = &buf[..8];
359 if Self::DISCRIMINATOR != given_disc {
360 return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch));
361 }
362 Self::try_deserialize_unchecked(buf)
363 }
364
365 fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
366 let mut data: &[u8] = &buf[8..];
367 AnchorDeserialize::deserialize(&mut data)
368 .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
369 }
370 }
371 };
372
373 accounts_tokens = quote! {
374 #accounts_tokens
375 #struct_def
376 };
377 }
378
379 for instr in &idl.instructions {
381 let name = capitalize_first_letter(&instr.name);
382 let fn_name = to_snake_case(&instr.name);
383 let struct_name = Ident::new(&name, proc_macro2::Span::call_site());
384 let fields = instr.args.iter().map(|arg| {
385 let field_name = Ident::new(&to_snake_case(&arg.name), proc_macro2::Span::call_site());
386 let field_type: Type = syn::parse_str(&arg.arg_type.to_rust_type()).unwrap();
387 quote! {
388 pub #field_name: #field_type,
389 }
390 });
391 let discriminator: TokenStream = format!("{:?}", sighash("global", &fn_name))
393 .parse()
394 .unwrap();
395 let struct_def = quote! {
396 #[derive(AnchorSerialize, AnchorDeserialize, Clone, Default)]
397 pub struct #struct_name {
398 #(#fields)*
399 }
400 #[automatically_derived]
401 impl anchor_lang::Discriminator for #struct_name {
402 const DISCRIMINATOR: &[u8] = &#discriminator;
403 }
404 #[automatically_derived]
405 impl anchor_lang::InstructionData for #struct_name {}
406 };
407
408 instructions_tokens = quote! {
409 #instructions_tokens
410 #struct_def
411 };
412
413 let accounts = instr.accounts.iter().map(|acc| {
414 let account_name =
415 Ident::new(&to_snake_case(&acc.name), proc_macro2::Span::call_site());
416 quote! {
417 pub #account_name: Pubkey,
418 }
419 });
420
421 let to_account_metas = instr.accounts.iter().map(|acc| {
422 let account_name_str = to_snake_case(&acc.name);
423 let account_name =
424 Ident::new(&account_name_str, proc_macro2::Span::call_site());
425 let is_mut: TokenStream = acc.is_mut.to_string().parse().unwrap();
426 let is_signer: TokenStream = acc.is_signer.to_string().parse().unwrap();
427 quote! {
428 AccountMeta { pubkey: self.#account_name, is_signer: #is_signer, is_writable: #is_mut },
429 }
430 });
431
432 let discriminator: TokenStream =
433 format!("{:?}", sighash("account", &name)).parse().unwrap();
434 let account_struct_def = quote! {
435 #[repr(C)]
436 #[derive(Copy, Clone, Default, AnchorSerialize, AnchorDeserialize, Serialize, Deserialize)]
437 pub struct #struct_name {
438 #(#accounts)*
439 }
440 #[automatically_derived]
441 impl anchor_lang::Discriminator for #struct_name {
442 const DISCRIMINATOR: &[u8] = &#discriminator;
443 }
444 #[automatically_derived]
445 unsafe impl anchor_lang::__private::bytemuck::Pod for #struct_name {}
446 #[automatically_derived]
447 unsafe impl anchor_lang::__private::bytemuck::Zeroable for #struct_name {}
448 #[automatically_derived]
449 impl anchor_lang::ZeroCopy for #struct_name {}
450 #[automatically_derived]
451 impl anchor_lang::InstructionData for #struct_name {}
452 #[automatically_derived]
453 impl ToAccountMetas for #struct_name {
454 fn to_account_metas(
455 &self,
456 ) -> Vec<AccountMeta> {
457 vec![
458 #(#to_account_metas)*
459 ]
460 }
461 }
462 #[automatically_derived]
463 impl anchor_lang::AccountSerialize for #struct_name {
464 fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
465 if writer.write_all(Self::DISCRIMINATOR).is_err() {
466 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
467 }
468
469 if AnchorSerialize::serialize(self, writer).is_err() {
470 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
471 }
472
473 Ok(())
474 }
475 }
476 #[automatically_derived]
477 impl anchor_lang::AccountDeserialize for #struct_name {
478 fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
479 let given_disc = &buf[..8];
480 if Self::DISCRIMINATOR != given_disc {
481 return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch));
482 }
483 Self::try_deserialize_unchecked(buf)
484 }
485
486 fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
487 let mut data: &[u8] = &buf[8..];
488 AnchorDeserialize::deserialize(&mut data)
489 .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
490 }
491 }
492 };
493
494 accounts_tokens = quote! {
495 #accounts_tokens
496 #account_struct_def
497 };
498 }
499
500 let error_variants = idl.errors.iter().map(|error| {
502 let variant_name = Ident::new(&error.name, proc_macro2::Span::call_site());
503 let error_msg = &error.msg;
504 quote! {
505 #[msg(#error_msg)]
506 #variant_name,
507 }
508 });
509
510 let error_enum = quote! {
511 #[derive(PartialEq)]
512 #[error_code]
513 pub enum ErrorCode {
514 #(#error_variants)*
515 }
516 };
517
518 errors_tokens = quote! {
519 #errors_tokens
520 #error_enum
521 };
522
523 for event in &idl.events {
525 let struct_name = Ident::new(&event.name, proc_macro2::Span::call_site());
526 let fields = event.fields.iter().map(|field| {
527 let field_name =
528 Ident::new(&to_snake_case(&field.name), proc_macro2::Span::call_site());
529 let field_type: Type = syn::parse_str(&field.field_type.to_rust_type()).unwrap();
530 quote! {
531 pub #field_name: #field_type,
532 }
533 });
534
535 let struct_def = quote! {
536 #[event]
538 pub struct #struct_name {
539 #(#fields)*
540 }
541 };
542
543 events_tokens = quote! {
544 #events_tokens
545 #struct_def
546 };
547 }
548
549 let custom_types: TokenStream = include_str!("custom_types.rs")
550 .parse()
551 .expect("custom_types valid rust");
552
553 let output = quote! {
555 #![allow(unused_imports)]
556 use anchor_lang::{prelude::{account, AnchorSerialize, AnchorDeserialize, InitSpace, event, error_code, msg, borsh::{self}}, Discriminator};
560 use solana_sdk::{instruction::AccountMeta, pubkey::Pubkey};
562 use serde::{Serialize, Deserialize};
563
564 pub const IDL_VERSION: &str = #idl_version;
565
566 use self::traits::ToAccountMetas;
567 pub mod traits {
568 use solana_sdk::instruction::AccountMeta;
569
570 pub trait ToAccountMetas {
573 fn to_account_metas(&self) -> Vec<AccountMeta>;
574 }
575 }
576
577 pub mod instructions {
578 use super::{*, types::*};
580
581 #instructions_tokens
582 }
583
584 pub mod types {
585 use std::ops::Mul;
587
588 use super::*;
589 #custom_types
590
591 #types_tokens
592 }
593
594 pub mod accounts {
595 use super::{*, types::*};
597
598 #accounts_tokens
599 }
600
601 pub mod errors {
602 use super::{*, types::*};
604
605 #errors_tokens
606 }
607
608 pub mod events {
609 use super::{*, types::*};
611 #events_tokens
612 }
613 };
614
615 output.to_string()
616}
617
618fn sighash(namespace: &str, name: &str) -> [u8; 8] {
619 let preimage = format!("{namespace}:{name}");
620 let mut hasher = sha2::Sha256::default();
621 let mut sighash = <[u8; 8]>::default();
622 hasher.update(preimage.as_bytes());
623 let digest = hasher.finalize();
624 sighash.copy_from_slice(&digest.as_slice()[..8]);
625
626 sighash
627}
628
629fn to_snake_case(s: &str) -> String {
630 let mut snake_case = String::new();
631 for (i, c) in s.chars().enumerate() {
632 if c.is_uppercase() {
633 if i != 0 {
634 snake_case.push('_');
635 }
636 snake_case.push(c.to_ascii_lowercase());
637 } else {
638 snake_case.push(c);
639 }
640 }
641 snake_case
642}
643
644fn capitalize_first_letter(s: &str) -> String {
645 let mut c = s.chars();
646 match c.next() {
647 None => String::new(),
648 Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
649 }
650}
651
652fn format_rust_code(code: &str) -> String {
653 let mut rustfmt = Command::new("rustfmt")
654 .stdin(Stdio::piped())
655 .stdout(Stdio::piped())
656 .spawn()
657 .expect("Failed to run rustfmt");
658 {
659 let stdin = rustfmt.stdin.as_mut().expect("Failed to open stdin");
660 stdin
661 .write_all(code.as_bytes())
662 .expect("Failed to write to stdin");
663 }
664
665 let output = rustfmt
666 .wait_with_output()
667 .expect("Failed to read rustfmt output");
668
669 String::from_utf8(output.stdout).expect("rustfmt output is not valid UTF-8")
670}
671
672pub fn generate_rust_types(idl_path: &Path) -> Result<String, Box<dyn std::error::Error>> {
676 let data = fs::read_to_string(idl_path)?;
678 let idl: Idl = serde_json::from_str(&data)?;
679
680 let rust_idl_types = format_rust_code(&generate_idl_types(&idl));
682 Ok(rust_idl_types)
683}