use super::{Analyzer, Finding, Severity, Certainty, Location};
use crate::models::Program;
use anyhow::Result;
use syn::{visit::Visit, ItemStruct, ItemFn, Attribute};
use quote::ToTokens;
use std::collections::HashMap;
pub struct InitializationFrontrunning;
impl Analyzer for InitializationFrontrunning {
fn name(&self) -> &'static str {
"Initialization Frontrunning"
}
fn description(&self) -> &'static str {
"Global accounts (singletons) initialized with static seeds should be protected by authority checks \
to prevent front-running attacks where an attacker initializes the account with their own values."
}
fn analyze(&self, program: &Program) -> Result<Vec<Finding>> {
let mut findings = Vec::new();
let mut global_vulnerable_structs = HashMap::new();
for (path, ast) in &program.asts {
let mut struct_visitor = SingletonStructVisitor {
vulnerable_structs: HashMap::new(),
file_path: path.to_string_lossy().to_string(),
};
syn::visit::visit_file(&mut struct_visitor, ast);
global_vulnerable_structs.extend(struct_visitor.vulnerable_structs);
}
for (path, ast) in &program.asts {
let mut fn_visitor = InitializationVisitor {
findings: &mut findings,
file_path: path.to_string_lossy().to_string(),
vulnerable_structs: &global_vulnerable_structs,
};
syn::visit::visit_file(&mut fn_visitor, ast);
}
Ok(findings)
}
}
#[derive(Clone)]
struct VulnerableStructInfo {
struct_name: String,
account_field: String,
location: Location,
}
struct SingletonStructVisitor {
vulnerable_structs: HashMap<String, VulnerableStructInfo>,
file_path: String,
}
impl<'ast> Visit<'ast> for SingletonStructVisitor {
fn visit_item_struct(&mut self, item_struct: &'ast ItemStruct) {
let is_accounts = item_struct.attrs.iter().any(|attr| {
let s = attr.to_token_stream().to_string();
s.contains("derive") && s.contains("Accounts")
});
if !is_accounts {
return;
}
let struct_name = item_struct.ident.to_string();
for field in &item_struct.fields {
let account_attr = field.attrs.iter().find(|attr| {
attr.path().is_ident("account")
});
if let Some(attr) = account_attr {
let attr_str = attr.to_token_stream().to_string();
if !attr_str.contains("init") {
continue;
}
if let Some(seeds) = parse_seeds_from_attr(attr) {
if is_static_seeds(&seeds) {
if !has_authority_constraint(&attr_str) {
let field_name = field.ident.as_ref().map(|i| i.to_string()).unwrap_or_default();
let span = field.ident.as_ref().map(|i| i.span()).unwrap_or_else(|| item_struct.ident.span());
self.vulnerable_structs.insert(struct_name.clone(), VulnerableStructInfo {
struct_name: struct_name.clone(),
account_field: field_name,
location: Location {
file: self.file_path.clone(),
line: span.start().line,
column: span.start().column,
},
});
}
}
}
}
}
}
}
struct InitializationVisitor<'a> {
findings: &'a mut Vec<Finding>,
file_path: String,
vulnerable_structs: &'a HashMap<String, VulnerableStructInfo>,
}
impl<'a, 'ast> Visit<'ast> for InitializationVisitor<'a> {
fn visit_item_fn(&mut self, func: &'ast ItemFn) {
let mut context_struct_name = None;
for input in &func.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
let ty_str = pat_type.ty.to_token_stream().to_string();
if ty_str.contains("Context") {
if let Some(start) = ty_str.find('<') {
if let Some(end) = ty_str.find('>') {
let name = ty_str[start+1..end].trim().to_string();
context_struct_name = Some(name);
}
}
}
}
}
if let Some(struct_name) = context_struct_name {
if let Some(info) = self.vulnerable_structs.get(&struct_name) {
let body_str = func.block.to_token_stream().to_string();
let has_manual_check = body_str.contains("require!") ||
(body_str.contains("if") && (body_str.contains("return Err") || body_str.contains("return err")));
let looks_like_auth_check = has_manual_check && (
body_str.contains("==") || body_str.contains("!=")
) && (
body_str.contains("authority") || body_str.contains("admin") || body_str.contains("program_data") || body_str.contains("owner")
);
if !looks_like_auth_check {
self.findings.push(Finding {
severity: Severity::Medium,
certainty: Certainty::Medium,
message: format!(
"Global account '{}' initialized in function '{}' using static seeds without apparent authority validation. \
Ensure only authorized users can initialize this singleton to prevent front-running.",
info.account_field,
func.sig.ident
),
location: info.location.clone(),
});
}
}
}
}
}
fn parse_seeds_from_attr(attr: &Attribute) -> Option<String> {
let attr_str = attr.to_token_stream().to_string();
if let Some(idx) = attr_str.find("seeds") {
if let Some(start) = attr_str[idx..].find('[') {
let start_pos = idx + start;
let mut depth = 0;
for (i, c) in attr_str[start_pos..].char_indices() {
if c == '[' { depth += 1; }
else if c == ']' {
depth -= 1;
if depth == 0 {
return Some(attr_str[start_pos..=start_pos+i].to_string());
}
}
}
}
}
None
}
fn is_static_seeds(seeds: &str) -> bool {
let inner = seeds.trim_start_matches('[').trim_end_matches(']');
for part in inner.split(',') {
let part = part.trim();
if part.is_empty() { continue; }
let is_byte_string = part.starts_with("b\"") || part.starts_with("b'");
let is_string = part.starts_with('"') || part.starts_with('\'');
let is_number = part.chars().all(|c| c.is_digit(10));
if !is_byte_string && !is_string && !is_number {
return false;
}
}
true
}
fn has_authority_constraint(attr_str: &str) -> bool {
if let Some(idx) = attr_str.find("constraint") {
let suffix = &attr_str[idx..];
if suffix.contains("program_data") ||
suffix.contains("upgrade_authority") ||
suffix.contains("admin") ||
(suffix.contains("signer") && (suffix.contains("==") || suffix.contains("!="))) ||
(suffix.contains("authority") && (suffix.contains("==") || suffix.contains("!="))) {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzers::test_utils::create_program;
#[test]
fn test_initialization_frontrunning_vulnerable() {
let code = r#"
#[derive(Accounts)]
pub struct InitializeInsecure<'info> {
#[account(mut)]
pub signer: Signer<'info>,
#[account(
init,
payer = signer,
space = 8 + 8,
seeds = [b"config"],
bump
)]
pub global_config: Account<'info, GlobalConfig>,
pub system_program: Program<'info, System>,
}
pub fn initialize(ctx: Context<InitializeInsecure>) -> Result<()> {
Ok(())
}
"#;
let program = create_program(code);
let analyzer = InitializationFrontrunning;
let findings = analyzer.analyze(&program).unwrap();
assert_eq!(findings.len(), 1);
assert!(findings[0].message.contains("Global account 'global_config' initialized in function 'initialize' using static seeds"));
}
#[test]
fn test_initialization_frontrunning_secure_constraint() {
let code = r#"
#[derive(Accounts)]
pub struct InitializeSecure<'info> {
#[account(mut)]
pub signer: Signer<'info>,
#[account(
init,
payer = signer,
space = 8 + 8,
seeds = [b"config"],
bump,
constraint = signer.key() == program_data.upgrade_authority_address.unwrap_or_default()
)]
pub global_config: Account<'info, GlobalConfig>,
#[account(seeds = [crate::ID.as_ref()], bump, seeds::program = bpf_loader_upgradeable::id())]
pub program_data: Account<'info, ProgramData>,
pub system_program: Program<'info, System>,
}
pub fn initialize(ctx: Context<InitializeSecure>) -> Result<()> {
Ok(())
}
"#;
let program = create_program(code);
let analyzer = InitializationFrontrunning;
let findings = analyzer.analyze(&program).unwrap();
assert_eq!(findings.len(), 0);
}
#[test]
fn test_initialization_frontrunning_dynamic_seeds() {
let code = r#"
#[derive(Accounts)]
pub struct InitializeDynamic<'info> {
#[account(mut)]
pub signer: Signer<'info>,
#[account(
init,
payer = signer,
space = 8 + 8,
seeds = [b"user", signer.key().as_ref()],
bump
)]
pub user_account: Account<'info, UserAccount>,
pub system_program: Program<'info, System>,
}
pub fn initialize(ctx: Context<InitializeDynamic>) -> Result<()> {
Ok(())
}
"#;
let program = create_program(code);
let analyzer = InitializationFrontrunning;
let findings = analyzer.analyze(&program).unwrap();
assert_eq!(findings.len(), 0);
}
#[test]
fn test_initialization_frontrunning_numeric_seeds() {
let code = r#"
#[derive(Accounts)]
pub struct InitializeNumeric<'info> {
#[account(mut)]
pub signer: Signer<'info>,
#[account(
init,
payer = signer,
space = 8 + 8,
seeds = [b"version", 1],
bump
)]
pub versioned_config: Account<'info, Config>,
pub system_program: Program<'info, System>,
}
pub fn initialize(ctx: Context<InitializeNumeric>) -> Result<()> {
Ok(())
}
"#;
let program = create_program(code);
let analyzer = InitializationFrontrunning;
let findings = analyzer.analyze(&program).unwrap();
assert_eq!(findings.len(), 1);
assert!(findings[0].message.contains("Global account 'versioned_config'"));
}
}