bat-cli 0.11.0

Blockchain Auditor Toolkit (BAT)
use error_stack::{IntoReport, Report, Result};
use lazy_regex::regex;
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserialize, Serialize};

use crate::batbelt::parser::solana_account_parser::SolanaAccountType;
use crate::batbelt::parser::{ParserError, ParserResult};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CAAccountAttributeInfo {
    pub is_pda: bool,
    pub is_init: bool,
    pub is_mut: bool,
    pub is_close: bool,
    pub rent_exemption_account: String,
    pub seeds: Vec<String>,
    pub validations: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CAAccountTypeInfo {
    pub content: String,
    pub solana_account_type: SolanaAccountType,
    pub account_struct_name: String,
    pub account_wrapper_name: String,
    pub lifetime_name: String,
    pub account_name: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CAAccountParser {
    pub content: String,
    pub solana_account_type: SolanaAccountType,
    pub account_struct_name: String,
    pub account_wrapper_name: String,
    pub lifetime_name: String,
    pub account_name: String,
    pub is_pda: bool,
    pub is_init: bool,
    pub is_mut: bool,
    pub is_close: bool,
    pub seeds: Vec<String>,
    pub rent_exemption_account: String,
    pub validations: Vec<String>,
    #[serde(default)]
    pub owner: Option<String>,
    #[serde(default)]
    pub token_mint: Option<String>,
    #[serde(default)]
    pub space: Option<String>,
    #[serde(default)]
    pub rent_exempt: bool,
    #[serde(default)]
    pub realloc: Option<String>,
    #[serde(default)]
    pub bump: Option<String>,
}

impl CAAccountParser {
    fn new(
        acc_type_info: CAAccountTypeInfo,
        acc_attribute: CAAccountAttributeInfo,
        content: &str,
    ) -> Self {
        Self {
            content: content.to_string(),
            solana_account_type: acc_type_info.solana_account_type,
            account_struct_name: acc_type_info.account_struct_name,
            account_wrapper_name: acc_type_info.account_wrapper_name,
            lifetime_name: acc_type_info.lifetime_name,
            account_name: acc_type_info.account_name,
            is_pda: acc_attribute.is_pda,
            is_init: acc_attribute.is_init,
            is_mut: acc_attribute.is_mut,
            is_close: acc_attribute.is_close,
            seeds: acc_attribute.seeds,
            rent_exemption_account: acc_attribute.rent_exemption_account,
            validations: acc_attribute.validations,
            owner: None,
            token_mint: None,
            space: None,
            rent_exempt: false,
            realloc: None,
            bump: None,
        }
    }

    pub fn new_from_context_account_content(
        context_account_content: &str,
    ) -> Result<Self, ParserError> {
        if !Self::get_context_account_lazy_regex().is_match(context_account_content) {
            return Err(Report::new(ParserError).attach_printable(format!(
                "Incorrect context account content\n{context_account_content}"
            )));
        }
        let account_attribute_info = Self::get_account_attribute_info(context_account_content)?;
        let account_type_info = Self::get_account_type_info(context_account_content)?;
        let new_parser = Self::new(
            account_type_info,
            account_attribute_info,
            context_account_content,
        );
        Ok(new_parser)
    }

    pub fn get_context_account_lazy_regex<'a>() -> &'a Lazy<Regex, fn() -> Regex> {
        regex!(
            r#"([ ]+#\[account\([\s\w,()?.= @:><!&{};\*\[\]+|]+\)\][\s]*)?[ ]+pub [\w]+: (\w+<)*([\w ,']+)(>)*"#
        )
    }

    pub fn get_account_type_info(
        context_account_content: &str,
    ) -> Result<CAAccountTypeInfo, ParserError> {
        let last_line = context_account_content
            .lines()
            .last()
            .unwrap()
            .trim()
            .trim_end_matches(',')
            .to_string();

        // Try syn-based extraction
        if let Some(type_info) =
            Self::try_syn_account_type_info(context_account_content, &last_line)?
        {
            return Ok(type_info);
        }

        // Fallback: string matching
        let account_name = last_line
            .trim()
            .trim_start_matches("pub ")
            .split(":")
            .next()
            .ok_or(ParserError)
            .into_report()?
            .to_string();

        let mut account_type_info = CAAccountTypeInfo {
            content: context_account_content.to_string(),
            solana_account_type: SolanaAccountType::from_context_account_content(
                context_account_content,
            )?,
            account_struct_name: "".to_string(),
            account_wrapper_name: "".to_string(),
            lifetime_name: "".to_string(),
            account_name: account_name.clone(),
        };

        account_type_info.account_wrapper_name = last_line
            .trim_start_matches(&format!("pub {}: ", account_name.clone()))
            .trim_start_matches("Box<")
            .split('<')
            .next()
            .unwrap()
            .to_string();

        let wrapper_content_regex = regex!(r"<[\w',_ ]+>");
        let wrapper_content = wrapper_content_regex
            .find(&last_line)
            .ok_or(ParserError)
            .into_report()?
            .as_str()
            .trim_start_matches("<")
            .trim_end_matches(">")
            .to_string();

        let (lifetime_name, account_struct_name) = if wrapper_content.contains(',') {
            let results = wrapper_content
                .split(',')
                .filter_map(|w_content| {
                    if w_content != "'_" {
                        Some(w_content.trim().to_string())
                    } else {
                        None
                    }
                })
                .collect::<Vec<_>>();
            (results[0].clone(), results[1].clone())
        } else {
            (
                wrapper_content.to_string(),
                account_type_info.account_wrapper_name.clone(),
            )
        };
        account_type_info.lifetime_name = lifetime_name;
        account_type_info.account_struct_name = account_struct_name;
        Ok(account_type_info)
    }

    /// Tries to extract account type info by parsing the field line with syn.
    fn try_syn_account_type_info(
        context_account_content: &str,
        last_line: &str,
    ) -> Result<Option<CAAccountTypeInfo>, ParserError> {
        // Wrap in a dummy struct so syn can parse it
        let field_str = format!("struct __Tmp {{ {} }}", last_line);
        let Ok(item_struct) = syn::parse_str::<syn::ItemStruct>(&field_str) else {
            return Ok(None);
        };
        let Some(field) = item_struct.fields.iter().next() else {
            return Ok(None);
        };
        let account_name = field
            .ident
            .as_ref()
            .map(|i| i.to_string())
            .unwrap_or_default();

        let solana_account_type =
            SolanaAccountType::from_context_account_content(context_account_content)?;

        // Navigate the type tree to extract wrapper, lifetime, and struct name
        let (wrapper_name, lifetime_name, struct_name) =
            Self::extract_type_parts(&field.ty, &account_name);

        let Some(wrapper_name) = wrapper_name else {
            return Ok(None);
        };

        Ok(Some(CAAccountTypeInfo {
            content: context_account_content.to_string(),
            solana_account_type,
            account_struct_name: struct_name.unwrap_or_else(|| wrapper_name.clone()),
            account_wrapper_name: wrapper_name,
            lifetime_name: lifetime_name.unwrap_or_default(),
            account_name,
        }))
    }

    /// Extracts (wrapper_name, lifetime, struct_name) from a syn::Type.
    /// E.g. `Account<'info, MyStruct>` → ("Account", "'info", "MyStruct")
    /// E.g. `Box<Account<'info, MyStruct>>` → ("Account", "'info", "MyStruct") (unwraps Box)
    /// E.g. `Signer<'info>` → ("Signer", "'info", None)
    fn extract_type_parts(
        ty: &syn::Type,
        _account_name: &str,
    ) -> (Option<String>, Option<String>, Option<String>) {
        use quote::ToTokens;
        if let syn::Type::Path(type_path) = ty {
            let segment = match type_path.path.segments.last() {
                Some(s) => s,
                None => return (None, None, None),
            };
            let ident = segment.ident.to_string();

            // If it's Box<...>, unwrap and recurse on inner type
            if ident == "Box" {
                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
                    for arg in &args.args {
                        if let syn::GenericArgument::Type(inner_ty) = arg {
                            return Self::extract_type_parts(inner_ty, _account_name);
                        }
                    }
                }
            }

            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
                let mut lifetime: Option<String> = None;
                let mut last_type: Option<String> = None;

                for arg in &args.args {
                    match arg {
                        syn::GenericArgument::Lifetime(lt) => {
                            let lt_str = lt.to_token_stream().to_string();
                            if lt_str != "'_" && lifetime.is_none() {
                                lifetime = Some(lt_str);
                            }
                        }
                        syn::GenericArgument::Type(inner_ty) => {
                            last_type = Some(
                                crate::batbelt::parser::function_parser::normalize_generic_type(
                                    &inner_ty.to_token_stream().to_string(),
                                ),
                            );
                        }
                        _ => {}
                    }
                }

                return (Some(ident), lifetime, last_type);
            }

            // No angle brackets — simple type like a bare ident
            return (Some(ident), None, None);
        }
        (None, None, None)
    }

    pub fn get_account_attribute_info(
        context_account_content: &str,
    ) -> Result<CAAccountAttributeInfo, ParserError> {
        let mut account_info = CAAccountAttributeInfo {
            is_pda: false,
            is_init: false,
            is_mut: false,
            is_close: false,
            rent_exemption_account: "".to_string(),
            seeds: vec![],
            validations: vec![],
        };
        if !context_account_content.contains("#[account(") {
            return Ok(account_info);
        }
        account_info.seeds = Self::get_seeds(context_account_content)?;
        account_info.is_pda = !account_info.seeds.is_empty();
        account_info.is_mut = Self::get_is_mut(context_account_content)?;
        account_info.is_init = Self::get_is_init(context_account_content)?;
        account_info.is_close = Self::get_is_close(context_account_content)?;
        account_info.rent_exemption_account =
            Self::get_rent_exemption_account(context_account_content)?;
        account_info.validations = Self::get_validations(context_account_content)?;

        Ok(account_info)
    }

    fn get_is_close(sonar_result_content: &str) -> ParserResult<bool> {
        let close_regex = regex!(r"(close = [\w_:]+)");
        Ok(close_regex.is_match(sonar_result_content))
    }

    fn get_is_mut(sonar_result_content: &str) -> ParserResult<bool> {
        let mut_regex_1 = regex!(r#"\(mut,"#);
        let mut_regex_2 = regex!(r#"\s+mut,"#);
        let mut_regex_3 = regex!(r#"\(mut\)"#);
        Ok(mut_regex_1.is_match(sonar_result_content)
            || mut_regex_2.is_match(sonar_result_content)
            || mut_regex_3.is_match(sonar_result_content))
    }

    fn get_seeds(sonar_result_content: &str) -> Result<Vec<String>, ParserError> {
        let seeds_array_regex = regex!(r"seeds = \[\s?[\w()._?,&:\s]+\s?\]");
        let seeds_separator_regex = regex!(r"\s*[\w()._?&:]+");
        if !seeds_array_regex.is_match(sonar_result_content) {
            return Ok(vec![]);
        };
        let seeds_array = seeds_array_regex
            .find(sonar_result_content)
            .ok_or(ParserError)
            .into_report()?
            .as_str()
            .replace("seeds = ", "")
            .to_string();
        let seeds = seeds_separator_regex
            .find_iter(&seeds_array)
            .map(|seed| seed.as_str().trim().to_string())
            .collect::<Vec<_>>();
        Ok(seeds)
    }

    fn get_is_init(sonar_result_content: &str) -> ParserResult<bool> {
        let init_regex = regex!(r#"\(?\s?init(_if_necessary)?,"#);
        Ok(init_regex.is_match(sonar_result_content))
    }

    fn get_validations(sonar_result_content: &str) -> ParserResult<Vec<String>> {
        // let validation_regex = Regex::new(r"constraint = [\sA-Za-z0-9()?._= @:><!&{}*]+[,\n]?")
        let validation_regex =
            regex!(r#"(constraint|has_one|address) = [\w()?.= @:><!&{}\*\s;|]+\n?"#);
        if validation_regex.is_match(sonar_result_content) {
            let matches = validation_regex
                .find_iter(sonar_result_content)
                .map(|reg_match| reg_match.as_str().trim_end_matches(')').trim().to_string())
                .collect::<Vec<_>>();
            log::debug!("validation_matches:\n{matches:#?}");
            return Ok(matches);
        }
        Ok(vec![])
    }

    fn get_rent_exemption_account(sonar_result_content: &str) -> ParserResult<String> {
        let rent_exemption_payer_regex = regex!(r#"payer = [A-Za-z0-9_.]+"#);
        if rent_exemption_payer_regex.is_match(sonar_result_content) {
            let payer_match = rent_exemption_payer_regex
                .find(sonar_result_content)
                .unwrap()
                .as_str()
                .trim()
                .to_string();
            return Ok(payer_match.split(" = ").last().unwrap().to_string());
        }

        let rent_exemption_close_regex = regex!(r#"close = [A-Za-z0-9_.]+"#);

        if rent_exemption_close_regex.is_match(sonar_result_content) {
            let close_match = rent_exemption_close_regex
                .find(sonar_result_content)
                .unwrap()
                .as_str()
                .trim()
                .to_string();
            return Ok(close_match.split(" = ").last().unwrap().to_string());
        }

        Ok("".to_string())
    }
}