Skip to main content

pina_cli/parse/
mod.rs

1pub mod account_state;
2pub mod accounts_struct;
3pub mod discriminator;
4pub mod doc_comments;
5pub mod entrypoint;
6pub mod error_enum;
7pub mod instruction_data;
8pub mod program_id;
9pub mod seeds;
10pub mod types;
11pub mod validation;
12
13use std::collections::HashMap;
14use std::path::Path;
15
16use heck::ToSnakeCase;
17
18use crate::error::IdlError;
19use crate::ir::AccountIr;
20use crate::ir::DiscriminatorIr;
21use crate::ir::InstructionAccountIr;
22use crate::ir::InstructionIr;
23use crate::ir::PdaIr;
24use crate::ir::ProgramIr;
25
26/// Parse a program crate directory and assemble a `ProgramIr`.
27pub fn parse_program(
28	program_path: &Path,
29	name_override: Option<&str>,
30) -> Result<ProgramIr, IdlError> {
31	let cargo_toml = program_path.join("Cargo.toml");
32	let cargo_contents =
33		std::fs::read_to_string(&cargo_toml).map_err(|e| IdlError::io(&cargo_toml, e))?;
34
35	let package_name =
36		extract_package_name(&cargo_contents).unwrap_or_else(|| "unknown_program".to_owned());
37
38	let src_path = program_path.join("src/lib.rs");
39	let source = std::fs::read_to_string(&src_path).map_err(|e| IdlError::io(&src_path, e))?;
40	let file = syn::parse_file(&source).map_err(|e| IdlError::parse(&src_path, &e))?;
41
42	assemble_program_ir(&file, name_override.unwrap_or(&package_name))
43}
44
45/// Assemble a `ProgramIr` from a parsed syn `File`.
46pub fn assemble_program_ir(file: &syn::File, program_name: &str) -> Result<ProgramIr, IdlError> {
47	// Step 1: Extract all pieces.
48	let public_key = program_id::extract_program_id(file).ok_or(IdlError::NoProgramId)?;
49	let disc_enums = discriminator::extract_discriminator_enums(file);
50	let account_structs = account_state::extract_account_structs(file);
51	let instruction_structs = instruction_data::extract_instruction_structs(file);
52	let ix_accounts_structs = accounts_struct::extract_accounts_structs(file);
53	let errors = error_enum::extract_error_enums(file);
54	let dispatch = entrypoint::extract_dispatch_map(file);
55	let validation_props = validation::extract_validation_properties(file);
56	let seed_constants = seeds::extract_seed_constants(file);
57	let pdas_ir = seeds::extract_pda_from_seed_macros(file, &seed_constants);
58
59	// Step 2: Build accounts IR.
60	let accounts: Vec<AccountIr> = account_structs
61		.iter()
62		.map(|acct| {
63			let disc_value =
64				find_discriminator_value(&disc_enums, &acct.discriminator_enum, &acct.name);
65			AccountIr {
66				name: acct.name.clone(),
67				fields: acct.fields.clone(),
68				discriminator: disc_value,
69				docs: acct.docs.clone(),
70			}
71		})
72		.collect();
73
74	// Step 3: Build instructions IR by connecting dispatch, accounts structs,
75	// instruction data, and validation properties.
76	let instructions: Vec<InstructionIr> = dispatch
77		.iter()
78		.filter_map(|entry| {
79			// Find the instruction data struct for this variant.
80			let ix_struct = instruction_structs
81				.iter()
82				.find(|ix| ix.variant == entry.variant)?;
83
84			// Find the accounts struct.
85			let accts_struct = ix_accounts_structs
86				.iter()
87				.find(|a| a.name == entry.accounts_struct)?;
88
89			// Find validation properties for this accounts struct.
90			let val_props = validation_props.get(&entry.accounts_struct);
91
92			// Find discriminator value.
93			let disc_value = find_discriminator_value_by_variant(
94				&disc_enums,
95				&ix_struct.discriminator_enum,
96				&entry.variant,
97			);
98
99			// Build instruction accounts with merged validation properties.
100			let instruction_accounts: Vec<InstructionAccountIr> = accts_struct
101				.fields
102				.iter()
103				.map(|field| {
104					let props = val_props
105						.and_then(|m| m.get(&field.name))
106						.cloned()
107						.unwrap_or_default();
108
109					// Check if this field is a PDA from the pdas we found.
110					let pda_name = if props.is_pda {
111						infer_pda_name_for_field(&field.name, &pdas_ir)
112					} else {
113						None
114					};
115
116					InstructionAccountIr {
117						name: field.name.clone(),
118						is_writable: props.is_writable,
119						is_signer: props.is_signer,
120						is_optional: false,
121						default_value: props.default_value,
122						is_pda: props.is_pda,
123						pda_name,
124						docs: field.docs.clone(),
125					}
126				})
127				.collect();
128
129			// Use the variant name (snake_case) as the instruction name.
130			let instruction_name = entry.variant.to_snake_case();
131
132			Some(InstructionIr {
133				name: instruction_name,
134				accounts: instruction_accounts,
135				arguments: ix_struct.fields.clone(),
136				discriminator: disc_value,
137				docs: ix_struct.docs.clone(),
138			})
139		})
140		.collect();
141
142	Ok(ProgramIr {
143		name: program_name.to_owned(),
144		public_key,
145		accounts,
146		instructions,
147		errors,
148		pdas: pdas_ir,
149	})
150}
151
152/// Find the discriminator value for an account struct by matching the struct
153/// name to a variant in the discriminator enum.
154fn find_discriminator_value(
155	disc_enums: &[discriminator::DiscriminatorEnum],
156	enum_name: &str,
157	struct_name: &str,
158) -> DiscriminatorIr {
159	for disc in disc_enums {
160		if disc.name == enum_name {
161			for variant in &disc.variants {
162				if variant.name == struct_name {
163					return DiscriminatorIr {
164						value: variant.value,
165						repr_size: disc.repr_size,
166					};
167				}
168			}
169		}
170	}
171	// Fallback
172	DiscriminatorIr {
173		value: 0,
174		repr_size: 1,
175	}
176}
177
178/// Find the discriminator value by variant name.
179fn find_discriminator_value_by_variant(
180	disc_enums: &[discriminator::DiscriminatorEnum],
181	enum_name: &str,
182	variant_name: &str,
183) -> DiscriminatorIr {
184	for disc in disc_enums {
185		if disc.name == enum_name {
186			for variant in &disc.variants {
187				if variant.name == variant_name {
188					return DiscriminatorIr {
189						value: variant.value,
190						repr_size: disc.repr_size,
191					};
192				}
193			}
194		}
195	}
196	DiscriminatorIr {
197		value: 0,
198		repr_size: 1,
199	}
200}
201
202/// Simple Cargo.toml parser to extract `name = "..."` from `[package]`.
203fn extract_package_name(cargo_contents: &str) -> Option<String> {
204	let mut in_package = false;
205	for line in cargo_contents.lines() {
206		let trimmed = line.trim();
207		if trimmed == "[package]" {
208			in_package = true;
209			continue;
210		}
211		if trimmed.starts_with('[') {
212			in_package = false;
213			continue;
214		}
215		if in_package {
216			if let Some(rest) = trimmed.strip_prefix("name") {
217				let rest = rest.trim().strip_prefix('=')?;
218				let rest = rest.trim().trim_matches('"');
219				return Some(rest.to_owned());
220			}
221		}
222	}
223	None
224}
225
226/// Build a lookup from discriminator enum name + variant name → (value,
227/// `repr_size`).
228pub fn build_discriminator_map(
229	disc_enums: &[discriminator::DiscriminatorEnum],
230) -> HashMap<(String, String), DiscriminatorIr> {
231	let mut map = HashMap::new();
232	for disc in disc_enums {
233		for variant in &disc.variants {
234			map.insert(
235				(disc.name.clone(), variant.name.clone()),
236				DiscriminatorIr {
237					value: variant.value,
238					repr_size: disc.repr_size,
239				},
240			);
241		}
242	}
243	map
244}
245
246fn infer_pda_name_for_field(field_name: &str, pdas: &[PdaIr]) -> Option<String> {
247	let candidates = [
248		field_name.to_owned(),
249		field_name.trim_end_matches("_account").to_owned(),
250		field_name.trim_end_matches("_pda").to_owned(),
251	];
252
253	for candidate in candidates {
254		if candidate.is_empty() {
255			continue;
256		}
257
258		if let Some(pda) = pdas.iter().find(|p| p.name == candidate) {
259			return Some(pda.name.clone());
260		}
261	}
262
263	None
264}
265
266#[cfg(test)]
267mod tests {
268	use super::*;
269	use crate::ir::PdaSeedIr;
270
271	#[test]
272	fn infer_pda_name_for_field_matches_exact_name() {
273		let pdas = vec![
274			PdaIr {
275				name: "counter".to_owned(),
276				seeds: vec![PdaSeedIr::Constant {
277					value: b"counter".to_vec(),
278				}],
279			},
280			PdaIr {
281				name: "vault".to_owned(),
282				seeds: vec![PdaSeedIr::Constant {
283					value: b"vault".to_vec(),
284				}],
285			},
286		];
287
288		assert_eq!(
289			infer_pda_name_for_field("counter", &pdas),
290			Some("counter".to_owned())
291		);
292		assert_eq!(
293			infer_pda_name_for_field("vault_account", &pdas),
294			Some("vault".to_owned())
295		);
296		assert_eq!(infer_pda_name_for_field("unknown", &pdas), None);
297	}
298}