hyperstack-macros 0.6.4

Proc-macros for defining HyperStack streams
Documentation
use std::collections::{HashMap, HashSet};

use syn::Error;

use crate::diagnostic::suggestion_or_available_suffix;

use super::idl::IdlSpec;
use super::pdas::{ParsedSeedKind, PdasBlock, ProgramPdas};

pub struct PdaValidationContext<'a> {
    pub idls: &'a HashMap<String, IdlSpec>,
}

impl<'a> PdaValidationContext<'a> {
    pub fn new(idls: &'a HashMap<String, IdlSpec>) -> Self {
        Self { idls }
    }

    pub fn validate(&self, block: &PdasBlock) -> Result<(), Error> {
        for program in &block.programs {
            self.validate_program(program)?;
        }
        Ok(())
    }

    fn validate_program(&self, program: &ProgramPdas) -> Result<(), Error> {
        let idl = self.idls.get(&program.program_name).ok_or_else(|| {
            let available: Vec<String> = self.idls.keys().cloned().collect();
            Error::new(
                program.program_name_span,
                format!(
                    "unknown program '{}' in pdas! block{}",
                    program.program_name,
                    suggestion_or_available_suffix(
                        &program.program_name,
                        &available,
                        "Available programs",
                    )
                ),
            )
        })?;

        let mut seen_names = HashSet::new();
        for pda in &program.pdas {
            if !seen_names.insert(&pda.name) {
                return Err(Error::new(
                    pda.name_span,
                    format!(
                        "duplicate PDA name '{}' in program '{}'",
                        pda.name, program.program_name
                    ),
                ));
            }

            for seed in &pda.seeds {
                self.validate_seed(seed, idl, &program.program_name)?;
            }
        }

        Ok(())
    }

    fn validate_seed(
        &self,
        seed: &super::pdas::ParsedSeed,
        idl: &IdlSpec,
        program_name: &str,
    ) -> Result<(), Error> {
        match &seed.kind {
            ParsedSeedKind::Literal(_) => Ok(()),

            ParsedSeedKind::Account(account_name) => {
                let account_exists = idl.instructions.iter().any(|ix| {
                    ix.accounts
                        .iter()
                        .any(|acc| &acc.name == account_name || acc.name == *account_name)
                });

                if !account_exists {
                    let available = self.collect_account_names(idl);
                    return Err(Error::new(
                        seed.span,
                        format!(
                            "unknown account '{}' in program '{}'{}",
                            account_name,
                            program_name,
                            suggestion_or_available_suffix(
                                account_name,
                                &available,
                                "Available accounts",
                            )
                        ),
                    ));
                }
                Ok(())
            }

            ParsedSeedKind::Arg { name, arg_type } => {
                let found_arg = idl
                    .instructions
                    .iter()
                    .find_map(|ix| ix.args.iter().find(|arg| &arg.name == name));

                match found_arg {
                    None => {
                        let available = self.collect_arg_names(idl);
                        Err(Error::new(
                            seed.span,
                            format!(
                                "unknown instruction argument '{}' in program '{}'{}",
                                name,
                                program_name,
                                suggestion_or_available_suffix(
                                    name,
                                    &available,
                                    "Available instruction arguments",
                                )
                            ),
                        ))
                    }
                    Some(idl_arg) => {
                        let actual_type = crate::parse::idl::to_rust_type_string(&idl_arg.type_);
                        if arg_type != &actual_type {
                            Err(Error::new(
                                seed.span,
                                format!(
                                    "invalid instruction argument type for '{}'. Expected '{}', found '{}'.",
                                    name, actual_type, arg_type
                                ),
                            ))
                        } else {
                            Ok(())
                        }
                    }
                }
            }
        }
    }

    fn collect_account_names(&self, idl: &IdlSpec) -> Vec<String> {
        let mut names: HashSet<String> = HashSet::new();
        for ix in &idl.instructions {
            for acc in &ix.accounts {
                names.insert(acc.name.clone());
            }
        }
        let mut result: Vec<_> = names.into_iter().collect();
        result.sort();
        result
    }

    fn collect_arg_names(&self, idl: &IdlSpec) -> Vec<String> {
        let mut names: HashSet<String> = HashSet::new();
        for ix in &idl.instructions {
            for arg in &ix.args {
                names.insert(arg.name.clone());
            }
        }
        let mut result: Vec<_> = names.into_iter().collect();
        result.sort();
        result
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_test_idl() -> IdlSpec {
        use crate::parse::idl::*;

        IdlSpec {
            name: Some("test".to_string()),
            address: Some("TestProgram111111111111111111111111111111".to_string()),
            version: Some("0.1.0".to_string()),
            metadata: None,
            accounts: vec![],
            constants: vec![],
            instructions: vec![IdlInstruction {
                name: "create".to_string(),
                discriminator: vec![0; 8],
                discriminant: None,
                docs: vec![],
                accounts: vec![
                    IdlAccountArg {
                        name: "authority".to_string(),
                        is_signer: true,
                        is_mut: false,
                        address: None,
                        pda: None,
                        optional: false,
                        docs: vec![],
                    },
                    IdlAccountArg {
                        name: "miner".to_string(),
                        is_signer: false,
                        is_mut: true,
                        address: None,
                        pda: None,
                        optional: false,
                        docs: vec![],
                    },
                ],
                args: vec![IdlField {
                    name: "roundId".to_string(),
                    type_: IdlType::Simple("u64".to_string()),
                }],
            }],
            types: vec![],
            events: vec![],
            errors: vec![],
        }
    }

    #[test]
    fn test_validate_valid_pda() {
        let idl = make_test_idl();
        let mut idls = HashMap::new();
        idls.insert("test".to_string(), idl);

        let ctx = PdaValidationContext::new(&idls);

        let tokens: proc_macro2::TokenStream = quote::quote! {
            test {
                miner_pda = [literal("miner"), account("authority")];
            }
        };
        let block: PdasBlock = syn::parse2(tokens).unwrap();

        assert!(ctx.validate(&block).is_ok());
    }

    #[test]
    fn test_validate_unknown_program() {
        let idls = HashMap::new();
        let ctx = PdaValidationContext::new(&idls);

        let tokens: proc_macro2::TokenStream = quote::quote! {
            unknown {
                pda = [literal("test")];
            }
        };
        let block: PdasBlock = syn::parse2(tokens).unwrap();

        let result = ctx.validate(&block);
        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("unknown program"));
    }

    #[test]
    fn test_validate_unknown_account() {
        let idl = make_test_idl();
        let mut idls = HashMap::new();
        idls.insert("test".to_string(), idl);

        let ctx = PdaValidationContext::new(&idls);

        let tokens: proc_macro2::TokenStream = quote::quote! {
            test {
                pda = [literal("test"), account("nonexistent")];
            }
        };
        let block: PdasBlock = syn::parse2(tokens).unwrap();

        let result = ctx.validate(&block);
        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("unknown account"));
    }

    #[test]
    fn test_validate_arg_type_mismatch() {
        let idl = make_test_idl();
        let mut idls = HashMap::new();
        idls.insert("test".to_string(), idl);

        let ctx = PdaValidationContext::new(&idls);

        let tokens: proc_macro2::TokenStream = quote::quote! {
            test {
                pda = [literal("round"), arg("roundId", u128)];
            }
        };
        let block: PdasBlock = syn::parse2(tokens).unwrap();

        let result = ctx.validate(&block);
        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .to_string()
            .contains("invalid instruction argument type"));
    }

    #[test]
    fn test_validate_duplicate_pda_name() {
        let idl = make_test_idl();
        let mut idls = HashMap::new();
        idls.insert("test".to_string(), idl);

        let ctx = PdaValidationContext::new(&idls);

        let tokens: proc_macro2::TokenStream = quote::quote! {
            test {
                miner = [literal("miner")];
                miner = [literal("other")];
            }
        };
        let block: PdasBlock = syn::parse2(tokens).unwrap();

        let result = ctx.validate(&block);
        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("duplicate"));
    }
}