1pub mod account_state;
2pub mod accounts_struct;
3pub mod discriminator;
4pub mod doc_comments;
5pub mod entrypoint;
6pub mod error_enum;
7pub mod instruction_data;
8pub mod program_id;
9pub mod seeds;
10pub mod types;
11pub mod validation;
12
13use std::collections::HashMap;
14use std::path::Path;
15
16use heck::ToSnakeCase;
17
18use crate::error::IdlError;
19use crate::ir::AccountIr;
20use crate::ir::DiscriminatorIr;
21use crate::ir::InstructionAccountIr;
22use crate::ir::InstructionIr;
23use crate::ir::PdaIr;
24use crate::ir::ProgramIr;
25
26pub fn parse_program(
28 program_path: &Path,
29 name_override: Option<&str>,
30) -> Result<ProgramIr, IdlError> {
31 let cargo_toml = program_path.join("Cargo.toml");
32 let cargo_contents =
33 std::fs::read_to_string(&cargo_toml).map_err(|e| IdlError::io(&cargo_toml, e))?;
34
35 let package_name =
36 extract_package_name(&cargo_contents).unwrap_or_else(|| "unknown_program".to_owned());
37
38 let src_path = program_path.join("src/lib.rs");
39 let source = std::fs::read_to_string(&src_path).map_err(|e| IdlError::io(&src_path, e))?;
40 let file = syn::parse_file(&source).map_err(|e| IdlError::parse(&src_path, &e))?;
41
42 assemble_program_ir(&file, name_override.unwrap_or(&package_name))
43}
44
45pub fn assemble_program_ir(file: &syn::File, program_name: &str) -> Result<ProgramIr, IdlError> {
47 let public_key = program_id::extract_program_id(file).ok_or(IdlError::NoProgramId)?;
49 let disc_enums = discriminator::extract_discriminator_enums(file);
50 let account_structs = account_state::extract_account_structs(file);
51 let instruction_structs = instruction_data::extract_instruction_structs(file);
52 let ix_accounts_structs = accounts_struct::extract_accounts_structs(file);
53 let errors = error_enum::extract_error_enums(file);
54 let dispatch = entrypoint::extract_dispatch_map(file);
55 let validation_props = validation::extract_validation_properties(file);
56 let seed_constants = seeds::extract_seed_constants(file);
57 let pdas_ir = seeds::extract_pda_from_seed_macros(file, &seed_constants);
58
59 let accounts: Vec<AccountIr> = account_structs
61 .iter()
62 .map(|acct| {
63 let disc_value =
64 find_discriminator_value(&disc_enums, &acct.discriminator_enum, &acct.name);
65 AccountIr {
66 name: acct.name.clone(),
67 fields: acct.fields.clone(),
68 discriminator: disc_value,
69 docs: acct.docs.clone(),
70 }
71 })
72 .collect();
73
74 let instructions: Vec<InstructionIr> = dispatch
77 .iter()
78 .filter_map(|entry| {
79 let ix_struct = instruction_structs
81 .iter()
82 .find(|ix| ix.variant == entry.variant)?;
83
84 let accts_struct = ix_accounts_structs
86 .iter()
87 .find(|a| a.name == entry.accounts_struct)?;
88
89 let val_props = validation_props.get(&entry.accounts_struct);
91
92 let disc_value = find_discriminator_value_by_variant(
94 &disc_enums,
95 &ix_struct.discriminator_enum,
96 &entry.variant,
97 );
98
99 let instruction_accounts: Vec<InstructionAccountIr> = accts_struct
101 .fields
102 .iter()
103 .map(|field| {
104 let props = val_props
105 .and_then(|m| m.get(&field.name))
106 .cloned()
107 .unwrap_or_default();
108
109 let pda_name = if props.is_pda {
111 infer_pda_name_for_field(&field.name, &pdas_ir)
112 } else {
113 None
114 };
115
116 InstructionAccountIr {
117 name: field.name.clone(),
118 is_writable: props.is_writable,
119 is_signer: props.is_signer,
120 is_optional: false,
121 default_value: props.default_value,
122 is_pda: props.is_pda,
123 pda_name,
124 docs: field.docs.clone(),
125 }
126 })
127 .collect();
128
129 let instruction_name = entry.variant.to_snake_case();
131
132 Some(InstructionIr {
133 name: instruction_name,
134 accounts: instruction_accounts,
135 arguments: ix_struct.fields.clone(),
136 discriminator: disc_value,
137 docs: ix_struct.docs.clone(),
138 })
139 })
140 .collect();
141
142 Ok(ProgramIr {
143 name: program_name.to_owned(),
144 public_key,
145 accounts,
146 instructions,
147 errors,
148 pdas: pdas_ir,
149 })
150}
151
152fn find_discriminator_value(
155 disc_enums: &[discriminator::DiscriminatorEnum],
156 enum_name: &str,
157 struct_name: &str,
158) -> DiscriminatorIr {
159 for disc in disc_enums {
160 if disc.name == enum_name {
161 for variant in &disc.variants {
162 if variant.name == struct_name {
163 return DiscriminatorIr {
164 value: variant.value,
165 repr_size: disc.repr_size,
166 };
167 }
168 }
169 }
170 }
171 DiscriminatorIr {
173 value: 0,
174 repr_size: 1,
175 }
176}
177
178fn find_discriminator_value_by_variant(
180 disc_enums: &[discriminator::DiscriminatorEnum],
181 enum_name: &str,
182 variant_name: &str,
183) -> DiscriminatorIr {
184 for disc in disc_enums {
185 if disc.name == enum_name {
186 for variant in &disc.variants {
187 if variant.name == variant_name {
188 return DiscriminatorIr {
189 value: variant.value,
190 repr_size: disc.repr_size,
191 };
192 }
193 }
194 }
195 }
196 DiscriminatorIr {
197 value: 0,
198 repr_size: 1,
199 }
200}
201
202fn extract_package_name(cargo_contents: &str) -> Option<String> {
204 let mut in_package = false;
205 for line in cargo_contents.lines() {
206 let trimmed = line.trim();
207 if trimmed == "[package]" {
208 in_package = true;
209 continue;
210 }
211 if trimmed.starts_with('[') {
212 in_package = false;
213 continue;
214 }
215 if in_package {
216 if let Some(rest) = trimmed.strip_prefix("name") {
217 let rest = rest.trim().strip_prefix('=')?;
218 let rest = rest.trim().trim_matches('"');
219 return Some(rest.to_owned());
220 }
221 }
222 }
223 None
224}
225
226pub fn build_discriminator_map(
229 disc_enums: &[discriminator::DiscriminatorEnum],
230) -> HashMap<(String, String), DiscriminatorIr> {
231 let mut map = HashMap::new();
232 for disc in disc_enums {
233 for variant in &disc.variants {
234 map.insert(
235 (disc.name.clone(), variant.name.clone()),
236 DiscriminatorIr {
237 value: variant.value,
238 repr_size: disc.repr_size,
239 },
240 );
241 }
242 }
243 map
244}
245
246fn infer_pda_name_for_field(field_name: &str, pdas: &[PdaIr]) -> Option<String> {
247 let candidates = [
248 field_name.to_owned(),
249 field_name.trim_end_matches("_account").to_owned(),
250 field_name.trim_end_matches("_pda").to_owned(),
251 ];
252
253 for candidate in candidates {
254 if candidate.is_empty() {
255 continue;
256 }
257
258 if let Some(pda) = pdas.iter().find(|p| p.name == candidate) {
259 return Some(pda.name.clone());
260 }
261 }
262
263 None
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::ir::PdaSeedIr;
270
271 #[test]
272 fn infer_pda_name_for_field_matches_exact_name() {
273 let pdas = vec![
274 PdaIr {
275 name: "counter".to_owned(),
276 seeds: vec![PdaSeedIr::Constant {
277 value: b"counter".to_vec(),
278 }],
279 },
280 PdaIr {
281 name: "vault".to_owned(),
282 seeds: vec![PdaSeedIr::Constant {
283 value: b"vault".to_vec(),
284 }],
285 },
286 ];
287
288 assert_eq!(
289 infer_pda_name_for_field("counter", &pdas),
290 Some("counter".to_owned())
291 );
292 assert_eq!(
293 infer_pda_name_for_field("vault_account", &pdas),
294 Some("vault".to_owned())
295 );
296 assert_eq!(infer_pda_name_for_field("unknown", &pdas), None);
297 }
298}