drift_idl_gen/
lib.rs

1use std::{
2    fs::{self},
3    io::Write,
4    path::Path,
5    process::{Command, Stdio},
6};
7
8use proc_macro2::TokenStream;
9use quote::quote;
10use serde::{Deserialize, Serialize};
11use sha2::Digest;
12use syn::{Ident, Type};
13
14#[derive(Debug, Serialize, Deserialize)]
15struct Idl {
16    version: String,
17    name: String,
18    instructions: Vec<Instruction>,
19    types: Vec<TypeDef>,
20    accounts: Vec<AccountDef>,
21    events: Vec<EventDef>,
22    errors: Vec<ErrorDef>,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct Instruction {
27    name: String,
28    accounts: Vec<Account>,
29    args: Vec<Arg>,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33struct Account {
34    name: String,
35    #[serde(rename = "isMut")]
36    is_mut: bool,
37    #[serde(rename = "isSigner")]
38    is_signer: bool,
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42struct Arg {
43    name: String,
44    #[serde(rename = "type")]
45    arg_type: ArgType,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49#[serde(untagged)]
50enum ArgType {
51    Simple(String),
52    Defined { defined: String },
53    Array { array: (Box<ArgType>, usize) },
54    Option { option: Box<ArgType> },
55    Vec { vec: Box<ArgType> },
56}
57
58impl ArgType {
59    fn to_rust_type(&self) -> String {
60        match self {
61            ArgType::Simple(t) => {
62                // special cases likely from manual edits to IDL
63                if t == "publicKey" {
64                    "Pubkey".to_string()
65                } else if t == "bytes" {
66                    "Vec<u8>".to_string()
67                } else if t == "string" {
68                    "String".to_string()
69                } else {
70                    t.clone()
71                }
72            }
73            ArgType::Defined { defined } => defined.clone(),
74            ArgType::Array { array: (t, len) } => {
75                let rust_type = t.to_rust_type();
76                // this is a common signature representation
77                if *len == 64_usize && rust_type == "u8" {
78                    // [u8; 64] does not have a Default impl
79                    "Signature".into()
80                } else {
81                    format!("[{}; {}]", t.to_rust_type(), len)
82                }
83            }
84            ArgType::Option { option } => format!("Option<{}>", option.to_rust_type()),
85            ArgType::Vec { vec } => format!("Vec<{}>", vec.to_rust_type()),
86        }
87    }
88}
89
90#[derive(Debug, Serialize, Deserialize)]
91struct TypeDef {
92    name: String,
93    #[serde(rename = "type")]
94    type_def: TypeData,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98#[serde(tag = "kind")]
99enum TypeData {
100    #[serde(rename = "struct")]
101    Struct { fields: Vec<StructField> },
102    #[serde(rename = "enum")]
103    Enum { variants: Vec<EnumVariant> },
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107struct StructField {
108    name: String,
109    #[serde(rename = "type")]
110    field_type: ArgType,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114#[serde(untagged)]
115enum EnumVariant {
116    // NB: this must come before `Simple` (harder match -> easiest match)
117    Complex {
118        name: String,
119        fields: Vec<StructField>,
120    },
121    Simple {
122        name: String,
123    },
124}
125
126#[derive(Debug, Serialize, Deserialize)]
127struct AccountDef {
128    name: String,
129    #[serde(rename = "type")]
130    account_type: AccountType,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134struct AccountType {
135    kind: String, // Typically "struct"
136    fields: Vec<StructField>,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140struct ErrorDef {
141    code: u32,
142    name: String,
143    msg: String,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147struct EventDef {
148    name: String,
149    fields: Vec<EventField>,
150}
151
152#[derive(Debug, Serialize, Deserialize)]
153struct EventField {
154    name: String,
155    #[serde(rename = "type")]
156    field_type: ArgType,
157    index: bool,
158}
159
160fn generate_idl_types(idl: &Idl) -> String {
161    let mut instructions_tokens = quote! {};
162    let mut types_tokens = quote! {};
163    let mut accounts_tokens = quote! {};
164    let mut errors_tokens = quote! {};
165    let mut events_tokens = quote! {};
166    let idl_version = syn::LitStr::new(&idl.version, proc_macro2::Span::call_site());
167
168    // Generate enums and structs from the types section
169    for type_def in &idl.types {
170        let type_name = Ident::new(
171            &capitalize_first_letter(&type_def.name),
172            proc_macro2::Span::call_site(),
173        );
174        let type_tokens = match &type_def.type_def {
175            TypeData::Enum { variants } => {
176                let has_complex_variant = variants.iter().any(|v| match v {
177                    EnumVariant::Complex { .. } => true,
178                    _ => false,
179                });
180
181                let variant_tokens =
182                    variants
183                        .iter()
184                        .enumerate()
185                        .map(|(i, variant)| match variant {
186                            EnumVariant::Simple { name } => {
187                                let variant_name = Ident::new(name, proc_macro2::Span::call_site());
188                                if i == 0 {
189                                    quote! {
190                                        #[default]
191                                        #variant_name,
192                                    }
193                                } else {
194                                    quote! {
195                                        #variant_name,
196                                    }
197                                }
198                            }
199                            EnumVariant::Complex { name, fields } => {
200                                let variant_name = Ident::new(name, proc_macro2::Span::call_site());
201                                let field_tokens = fields.iter().map(|field| {
202                                    let field_name = Ident::new(
203                                        &to_snake_case(&field.name),
204                                        proc_macro2::Span::call_site(),
205                                    );
206                                    let field_type: Type =
207                                        syn::parse_str(&field.field_type.to_rust_type()).unwrap();
208                                    quote! {
209                                        #field_name: #field_type,
210                                    }
211                                });
212                                quote! {
213                                    #variant_name {
214                                        #(#field_tokens)*
215                                    },
216                                }
217                            }
218                        });
219
220                if has_complex_variant {
221                    quote! {
222                        #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Debug, PartialEq)]
223                        pub enum #type_name {
224                            #(#variant_tokens)*
225                        }
226                    }
227                } else {
228                    // TODO: need more work to derive 'Default' on complex enums, not currently required
229                    quote! {
230                        #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Default, Debug, PartialEq)]
231                        pub enum #type_name {
232                            #(#variant_tokens)*
233                        }
234                    }
235                }
236            }
237            TypeData::Struct { fields } => {
238                let struct_name =
239                    Ident::new(type_def.name.as_str(), proc_macro2::Span::call_site());
240                let struct_fields = fields.iter().map(|field| {
241                    let field_name =
242                        Ident::new(&to_snake_case(&field.name), proc_macro2::Span::call_site());
243                    let field_type: syn::Type =
244                        syn::parse_str(&field.field_type.to_rust_type()).unwrap();
245                    quote! {
246                        pub #field_name: #field_type,
247                    }
248                });
249
250                quote! {
251                    #[repr(C)]
252                    #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Default, Debug, PartialEq)]
253                    pub struct #struct_name {
254                        #(#struct_fields)*
255                    }
256                }
257            }
258        };
259
260        types_tokens = quote! {
261            #types_tokens
262            #type_tokens
263        };
264    }
265
266    // Generate structs for accounts section
267    for account in &idl.accounts {
268        let struct_name = Ident::new(&account.name, proc_macro2::Span::call_site());
269
270        let mut has_vec_field = false;
271        let struct_fields: Vec<TokenStream> = account
272            .account_type
273            .fields
274            .iter()
275            .map(|field| {
276                let field_name =
277                    Ident::new(&to_snake_case(&field.name), proc_macro2::Span::call_site());
278                if let ArgType::Vec { .. } = field.field_type {
279                    has_vec_field = true;
280                }
281                let mut serde_decorator = TokenStream::new();
282                let mut field_type: Type =
283                    syn::parse_str(&field.field_type.to_rust_type()).unwrap();
284                // workaround for padding types preventing outertype from deriving 'Default'
285                if field_name == "padding" {
286                    if let ArgType::Array { array: (_t, len) } = &field.field_type {
287                        field_type = syn::parse_str(&format!("Padding<{len}>")).unwrap();
288                        serde_decorator = quote! {
289                            #[serde(skip)]
290                        };
291                    }
292                }
293
294                quote! {
295                    #serde_decorator
296                    pub #field_name: #field_type,
297                }
298            })
299            .collect();
300
301        let derive_tokens = if !has_vec_field {
302            quote! {
303                #[derive(AnchorSerialize, AnchorDeserialize, InitSpace, Serialize, Deserialize, Copy, Clone, Default, Debug, PartialEq)]
304            }
305        } else {
306            // can't derive `Copy` on accounts with `Vec` field
307            // `InitSpace` requires a 'max_len' but no point enforcing here if unset on program side
308            quote! {
309                #[derive(AnchorSerialize, AnchorDeserialize, Serialize, Deserialize, Clone, Default, Debug, PartialEq)]
310            }
311        };
312
313        let zc_tokens = if !has_vec_field {
314            // without copy can't derive the ZeroCopy trait
315            quote! {
316                #[automatically_derived]
317                unsafe impl anchor_lang::__private::bytemuck::Pod for #struct_name {}
318                #[automatically_derived]
319                unsafe impl anchor_lang::__private::bytemuck::Zeroable for #struct_name {}
320                #[automatically_derived]
321                impl anchor_lang::ZeroCopy for #struct_name {}
322            }
323        } else {
324            Default::default()
325        };
326
327        let discriminator: TokenStream = format!("{:?}", sighash("account", &account.name))
328            .parse()
329            .unwrap();
330        let struct_def = quote! {
331            #[repr(C)]
332            #derive_tokens
333            pub struct #struct_name {
334                #(#struct_fields)*
335            }
336            #[automatically_derived]
337            impl anchor_lang::Discriminator for #struct_name {
338                const DISCRIMINATOR: &[u8] = &#discriminator;
339            }
340            #zc_tokens
341            #[automatically_derived]
342            impl anchor_lang::AccountSerialize for #struct_name {
343                fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
344                    if writer.write_all(Self::DISCRIMINATOR).is_err() {
345                        return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
346                    }
347
348                    if AnchorSerialize::serialize(self, writer).is_err() {
349                        return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
350                    }
351
352                    Ok(())
353                }
354            }
355            #[automatically_derived]
356            impl anchor_lang::AccountDeserialize for #struct_name {
357                fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
358                    let given_disc = &buf[..8];
359                    if Self::DISCRIMINATOR != given_disc {
360                        return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch));
361                    }
362                    Self::try_deserialize_unchecked(buf)
363                }
364
365                fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
366                    let mut data: &[u8] = &buf[8..];
367                    AnchorDeserialize::deserialize(&mut data)
368                        .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
369                }
370            }
371        };
372
373        accounts_tokens = quote! {
374            #accounts_tokens
375            #struct_def
376        };
377    }
378
379    // Generate structs for instructions
380    for instr in &idl.instructions {
381        let name = capitalize_first_letter(&instr.name);
382        let fn_name = to_snake_case(&instr.name);
383        let struct_name = Ident::new(&name, proc_macro2::Span::call_site());
384        let fields = instr.args.iter().map(|arg| {
385            let field_name = Ident::new(&to_snake_case(&arg.name), proc_macro2::Span::call_site());
386            let field_type: Type = syn::parse_str(&arg.arg_type.to_rust_type()).unwrap();
387            quote! {
388                pub #field_name: #field_type,
389            }
390        });
391        // https://github.com/coral-xyz/anchor/blob/e48e7e60a64de77d878cdb063965cf125bec741a/lang/syn/src/codegen/program/instruction.rs#L32
392        let discriminator: TokenStream = format!("{:?}", sighash("global", &fn_name))
393            .parse()
394            .unwrap();
395        let struct_def = quote! {
396            #[derive(AnchorSerialize, AnchorDeserialize, Clone, Default)]
397            pub struct #struct_name {
398                #(#fields)*
399            }
400            #[automatically_derived]
401            impl anchor_lang::Discriminator for #struct_name {
402                const DISCRIMINATOR: &[u8] = &#discriminator;
403            }
404            #[automatically_derived]
405            impl anchor_lang::InstructionData for #struct_name {}
406        };
407
408        instructions_tokens = quote! {
409            #instructions_tokens
410            #struct_def
411        };
412
413        let accounts = instr.accounts.iter().map(|acc| {
414            let account_name =
415                Ident::new(&to_snake_case(&acc.name), proc_macro2::Span::call_site());
416            quote! {
417                pub #account_name: Pubkey,
418            }
419        });
420
421        let to_account_metas = instr.accounts.iter().map(|acc| {
422            let account_name_str = to_snake_case(&acc.name);
423            let account_name =
424                Ident::new(&account_name_str, proc_macro2::Span::call_site());
425            let is_mut: TokenStream = acc.is_mut.to_string().parse().unwrap();
426            let is_signer: TokenStream = acc.is_signer.to_string().parse().unwrap();
427            quote! {
428                AccountMeta { pubkey: self.#account_name, is_signer: #is_signer, is_writable: #is_mut },
429            }
430        });
431
432        let discriminator: TokenStream =
433            format!("{:?}", sighash("account", &name)).parse().unwrap();
434        let account_struct_def = quote! {
435            #[repr(C)]
436            #[derive(Copy, Clone, Default, AnchorSerialize, AnchorDeserialize, Serialize, Deserialize)]
437            pub struct #struct_name {
438                #(#accounts)*
439            }
440            #[automatically_derived]
441            impl anchor_lang::Discriminator for #struct_name {
442                const DISCRIMINATOR: &[u8] = &#discriminator;
443            }
444            #[automatically_derived]
445            unsafe impl anchor_lang::__private::bytemuck::Pod for #struct_name {}
446            #[automatically_derived]
447            unsafe impl anchor_lang::__private::bytemuck::Zeroable for #struct_name {}
448            #[automatically_derived]
449            impl anchor_lang::ZeroCopy for #struct_name {}
450            #[automatically_derived]
451            impl anchor_lang::InstructionData for #struct_name {}
452            #[automatically_derived]
453            impl ToAccountMetas for #struct_name {
454                fn to_account_metas(
455                    &self,
456                ) -> Vec<AccountMeta> {
457                   vec![
458                        #(#to_account_metas)*
459                    ]
460                }
461            }
462            #[automatically_derived]
463            impl anchor_lang::AccountSerialize for #struct_name {
464                fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
465                    if writer.write_all(Self::DISCRIMINATOR).is_err() {
466                        return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
467                    }
468
469                    if AnchorSerialize::serialize(self, writer).is_err() {
470                        return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
471                    }
472
473                    Ok(())
474                }
475            }
476            #[automatically_derived]
477            impl anchor_lang::AccountDeserialize for #struct_name {
478                fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
479                    let given_disc = &buf[..8];
480                    if Self::DISCRIMINATOR != given_disc {
481                        return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch));
482                    }
483                    Self::try_deserialize_unchecked(buf)
484                }
485
486                fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
487                    let mut data: &[u8] = &buf[8..];
488                    AnchorDeserialize::deserialize(&mut data)
489                        .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
490                }
491            }
492        };
493
494        accounts_tokens = quote! {
495            #accounts_tokens
496            #account_struct_def
497        };
498    }
499
500    // Generate enum for errors
501    let error_variants = idl.errors.iter().map(|error| {
502        let variant_name = Ident::new(&error.name, proc_macro2::Span::call_site());
503        let error_msg = &error.msg;
504        quote! {
505            #[msg(#error_msg)]
506            #variant_name,
507        }
508    });
509
510    let error_enum = quote! {
511        #[derive(PartialEq)]
512        #[error_code]
513        pub enum ErrorCode {
514            #(#error_variants)*
515        }
516    };
517
518    errors_tokens = quote! {
519        #errors_tokens
520        #error_enum
521    };
522
523    // Generate event structs from the events section
524    for event in &idl.events {
525        let struct_name = Ident::new(&event.name, proc_macro2::Span::call_site());
526        let fields = event.fields.iter().map(|field| {
527            let field_name =
528                Ident::new(&to_snake_case(&field.name), proc_macro2::Span::call_site());
529            let field_type: Type = syn::parse_str(&field.field_type.to_rust_type()).unwrap();
530            quote! {
531                pub #field_name: #field_type,
532            }
533        });
534
535        let struct_def = quote! {
536            //#[derive(InitSpace)]
537            #[event]
538            pub struct #struct_name {
539                #(#fields)*
540            }
541        };
542
543        events_tokens = quote! {
544            #events_tokens
545            #struct_def
546        };
547    }
548
549    let custom_types: TokenStream = include_str!("custom_types.rs")
550        .parse()
551        .expect("custom_types valid rust");
552
553    // Wrap generated code in modules with necessary imports
554    let output = quote! {
555        #![allow(unused_imports)]
556        //!
557        //! Auto-generated IDL types, manual edits do not persist (see `crates/drift-idl-gen`)
558        //!
559        use anchor_lang::{prelude::{account, AnchorSerialize, AnchorDeserialize, InitSpace, event, error_code, msg, borsh::{self}}, Discriminator};
560        // use solana-sdk Pubkey, the vendored anchor-lang Pubkey maybe behind
561        use solana_sdk::{instruction::AccountMeta, pubkey::Pubkey};
562        use serde::{Serialize, Deserialize};
563
564        pub const IDL_VERSION: &str = #idl_version;
565
566        use self::traits::ToAccountMetas;
567        pub mod traits {
568            use solana_sdk::instruction::AccountMeta;
569
570            /// This is distinct from the anchor_lang version of the trait
571            /// reimplemented to ensure the types used are from `solana`` crates _not_ the anchor_lang vendored versions which may be lagging behind
572            pub trait ToAccountMetas {
573                fn to_account_metas(&self) -> Vec<AccountMeta>;
574            }
575        }
576
577        pub mod instructions {
578            //! IDL instruction types
579            use super::{*, types::*};
580
581            #instructions_tokens
582        }
583
584        pub mod types {
585            //! IDL types
586            use std::ops::Mul;
587
588            use super::*;
589            #custom_types
590
591            #types_tokens
592        }
593
594        pub mod accounts {
595            //! IDL Account types
596            use super::{*, types::*};
597
598            #accounts_tokens
599        }
600
601        pub mod errors {
602            //! IDL error types
603            use super::{*, types::*};
604
605            #errors_tokens
606        }
607
608        pub mod events {
609            //! IDL event types
610            use super::{*, types::*};
611            #events_tokens
612        }
613    };
614
615    output.to_string()
616}
617
618fn sighash(namespace: &str, name: &str) -> [u8; 8] {
619    let preimage = format!("{namespace}:{name}");
620    let mut hasher = sha2::Sha256::default();
621    let mut sighash = <[u8; 8]>::default();
622    hasher.update(preimage.as_bytes());
623    let digest = hasher.finalize();
624    sighash.copy_from_slice(&digest.as_slice()[..8]);
625
626    sighash
627}
628
629fn to_snake_case(s: &str) -> String {
630    let mut snake_case = String::new();
631    for (i, c) in s.chars().enumerate() {
632        if c.is_uppercase() {
633            if i != 0 {
634                snake_case.push('_');
635            }
636            snake_case.push(c.to_ascii_lowercase());
637        } else {
638            snake_case.push(c);
639        }
640    }
641    snake_case
642}
643
644fn capitalize_first_letter(s: &str) -> String {
645    let mut c = s.chars();
646    match c.next() {
647        None => String::new(),
648        Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
649    }
650}
651
652fn format_rust_code(code: &str) -> String {
653    let mut rustfmt = Command::new("rustfmt")
654        .stdin(Stdio::piped())
655        .stdout(Stdio::piped())
656        .spawn()
657        .expect("Failed to run rustfmt");
658    {
659        let stdin = rustfmt.stdin.as_mut().expect("Failed to open stdin");
660        stdin
661            .write_all(code.as_bytes())
662            .expect("Failed to write to stdin");
663    }
664
665    let output = rustfmt
666        .wait_with_output()
667        .expect("Failed to read rustfmt output");
668
669    String::from_utf8(output.stdout).expect("rustfmt output is not valid UTF-8")
670}
671
672/// Generate rust types from IDL json
673///
674/// Returns (IDL Version, IDL rs code)
675pub fn generate_rust_types(idl_path: &Path) -> Result<String, Box<dyn std::error::Error>> {
676    // Load the JSON file
677    let data = fs::read_to_string(idl_path)?;
678    let idl: Idl = serde_json::from_str(&data)?;
679
680    // Generate Rust structs organized into modules
681    let rust_idl_types = format_rust_code(&generate_idl_types(&idl));
682    Ok(rust_idl_types)
683}