use super::{Analyzer, Finding, Severity, Certainty, Location};
use crate::models::Program;
use anyhow::Result;
use syn::{visit::Visit, ItemStruct, ItemFn, Type, ExprBinary, BinOp};
use quote::ToTokens;
use std::collections::HashMap;
pub struct DuplicateMutableAccounts;
impl Analyzer for DuplicateMutableAccounts {
fn name(&self) -> &'static str {
"Duplicate Mutable Accounts"
}
fn description(&self) -> &'static str {
"When there are two or more accounts with mutable data, a check must be in place to ensure \
mutation of each account is differentiated properly, to avoid unintended data modification of other accounts."
}
fn analyze(&self, program: &Program) -> Result<Vec<Finding>> {
let mut findings = Vec::new();
let mut account_structs = HashMap::new();
for (_path, ast) in &program.asts {
let mut collector = AccountStructCollector {
account_structs: &mut account_structs,
};
syn::visit::visit_file(&mut collector, ast);
}
for (path, ast) in &program.asts {
let mut visitor = DuplicateMutableAccountsVisitor {
checked_structs: HashMap::new(),
};
syn::visit::visit_file(&mut visitor, ast);
for (name, is_checked) in &visitor.checked_structs {
if !is_checked {
if let Some(item_struct) = account_structs.get(name) {
let span = item_struct.ident.span();
findings.push(Finding {
severity: Severity::Medium,
certainty: Certainty::Medium,
message: format!("Struct '{}' has multiple Account fields without constraints to prevent duplicate accounts", name),
location: Location {
file: path.to_string_lossy().to_string(),
line: span.start().line,
column: span.start().column,
},
});
}
}
}
}
Ok(findings)
}
}
struct AccountStructCollector<'a> {
account_structs: &'a mut HashMap<String, ItemStruct>,
}
impl<'a, 'ast> Visit<'ast> for AccountStructCollector<'a> {
fn visit_item_struct(&mut self, item_struct: &'ast ItemStruct) {
for attr in &item_struct.attrs {
if let Some(path) = attr.path().segments.first() {
if path.ident == "derive" {
let tokens = attr.to_token_stream().to_string();
if tokens.contains("Accounts") {
self.account_structs.insert(item_struct.ident.to_string(), item_struct.clone());
return;
}
}
}
}
}
}
struct DuplicateMutableAccountsVisitor {
checked_structs: HashMap<String, bool>,
}
impl<'ast> Visit<'ast> for DuplicateMutableAccountsVisitor {
fn visit_item_struct(&mut self, item_struct: &'ast ItemStruct) {
let has_accounts_derive = item_struct.attrs.iter().any(|attr| {
if let Some(path) = attr.path().segments.first() {
if path.ident == "derive" {
let tokens = attr.to_token_stream().to_string();
return tokens.contains("Accounts");
}
}
false
});
if has_accounts_derive {
let struct_name = item_struct.ident.to_string();
if let Some(true) = self.checked_structs.get(&struct_name) {
return;
}
let mut account_fields = Vec::new();
for field in &item_struct.fields {
if let Type::Path(type_path) = &field.ty {
if let Some(segment) = type_path.path.segments.first() {
if segment.ident == "Account" {
if let Some(name) = &field.ident {
account_fields.push(name.clone());
}
}
}
}
}
if account_fields.len() >= 2 {
let mut has_constraint = false;
for field in &item_struct.fields {
for attr in &field.attrs {
let attr_str = attr.to_token_stream().to_string();
if attr_str.contains("constraint") &&
(attr_str.contains("key()") || attr_str.contains("key ()")) &&
(attr_str.contains("!=") || attr_str.contains("==") ||
attr_str.contains("<") || attr_str.contains(">")) {
has_constraint = true;
break;
}
}
if has_constraint {
break;
}
}
if !has_constraint {
for attr in &item_struct.attrs {
let attr_str = attr.to_token_stream().to_string();
if attr_str.contains("constraint") &&
(attr_str.contains("key()") || attr_str.contains("key ()")) &&
(attr_str.contains("!=") || attr_str.contains("==") ||
attr_str.contains("<") || attr_str.contains(">")) {
has_constraint = true;
break;
}
}
}
if !has_constraint {
self.checked_structs.insert(struct_name, false);
} else {
self.checked_structs.insert(struct_name, true);
}
}
}
}
fn visit_expr_binary(&mut self, expr: &'ast ExprBinary) {
if matches!(expr.op, BinOp::Eq(_) | BinOp::Ne(_) | BinOp::Lt(_) | BinOp::Gt(_) | BinOp::Le(_) | BinOp::Ge(_)) {
let left_str = expr.left.to_token_stream().to_string();
let right_str = expr.right.to_token_stream().to_string();
if (left_str.contains(".key()") || left_str.contains(". key ()")) &&
(right_str.contains(".key()") || right_str.contains(". key ()")) {
if let Some(struct_name) = self.find_context_struct_for_expr(expr) {
self.checked_structs.insert(struct_name, true);
}
}
}
syn::visit::visit_expr_binary(self, expr);
}
fn visit_item_fn(&mut self, func: &'ast ItemFn) {
let mut context_struct = None;
for input in &func.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
let ty_str = pat_type.ty.to_token_stream().to_string();
if ty_str.contains("Context") {
if let Some(start) = ty_str.find('<') {
if let Some(end) = ty_str.find('>') {
let struct_name = ty_str[start+1..end].trim().to_string();
context_struct = Some(struct_name);
}
}
}
}
}
if let Some(struct_name) = &context_struct {
CURRENT_CONTEXT_STRUCT.with(|cell| {
*cell.borrow_mut() = Some(struct_name.clone());
});
}
syn::visit::visit_block(self, &func.block);
CURRENT_CONTEXT_STRUCT.with(|cell| {
*cell.borrow_mut() = None;
});
}
fn visit_file(&mut self, file: &'ast syn::File) {
for item in &file.items {
match item {
syn::Item::Fn(item_fn) => self.visit_item_fn(item_fn),
syn::Item::Impl(item_impl) => {
for impl_item in &item_impl.items {
if let syn::ImplItem::Fn(impl_fn) = impl_item {
let mut context_struct = None;
for input in &impl_fn.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
let ty_str = pat_type.ty.to_token_stream().to_string();
if ty_str.contains("Context") {
if let Some(start) = ty_str.find('<') {
if let Some(end) = ty_str.find('>') {
let struct_name = ty_str[start+1..end].trim().to_string();
context_struct = Some(struct_name);
}
}
}
}
}
if let Some(struct_name) = &context_struct {
CURRENT_CONTEXT_STRUCT.with(|cell| {
*cell.borrow_mut() = Some(struct_name.clone());
});
}
syn::visit::visit_block(self, &impl_fn.block);
CURRENT_CONTEXT_STRUCT.with(|cell| {
*cell.borrow_mut() = None;
});
}
}
},
_ => {}
}
}
for item in &file.items {
if let syn::Item::Struct(item_struct) = item {
self.visit_item_struct(item_struct);
}
}
}
}
thread_local! {
static CURRENT_CONTEXT_STRUCT: std::cell::RefCell<Option<String>> = std::cell::RefCell::new(None);
}
impl DuplicateMutableAccountsVisitor {
fn find_context_struct_for_expr(&self, _expr: &ExprBinary) -> Option<String> {
CURRENT_CONTEXT_STRUCT.with(|cell| {
cell.borrow().clone()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzers::test_utils::create_program;
#[test]
fn test_duplicate_mutable_accounts_vulnerable() {
let code = r#"
#[derive(Accounts)]
pub struct DuplicateMutableAccounts<'info> {
#[account(mut)]
pub account1: Account<'info, MyAccount>,
#[account(mut)]
pub account2: Account<'info, MyAccount>,
}
pub fn update(ctx: Context<DuplicateMutableAccounts>) -> Result<()> {
// No check for duplicate accounts
Ok(())
}
"#;
let program = create_program(code);
let analyzer = DuplicateMutableAccounts;
let findings = analyzer.analyze(&program).unwrap();
assert_eq!(findings.len(), 1);
assert!(findings[0].message.contains("Struct 'DuplicateMutableAccounts' has multiple Account fields without constraints"));
}
#[test]
fn test_duplicate_mutable_accounts_secure_constraint() {
let code = r#"
#[derive(Accounts)]
pub struct SecureAccountsConstraint<'info> {
#[account(mut, constraint = account1.key() != account2.key())]
pub account1: Account<'info, MyAccount>,
#[account(mut)]
pub account2: Account<'info, MyAccount>,
}
pub fn update(ctx: Context<SecureAccountsConstraint>) -> Result<()> {
Ok(())
}
"#;
let program = create_program(code);
let analyzer = DuplicateMutableAccounts;
let findings = analyzer.analyze(&program).unwrap();
assert_eq!(findings.len(), 0);
}
#[test]
fn test_duplicate_mutable_accounts_secure_manual() {
let code = r#"
#[derive(Accounts)]
pub struct SecureAccountsManual<'info> {
#[account(mut)]
pub account1: Account<'info, MyAccount>,
#[account(mut)]
pub account2: Account<'info, MyAccount>,
}
pub fn update(ctx: Context<SecureAccountsManual>) -> Result<()> {
if ctx.accounts.account1.key() == ctx.accounts.account2.key() {
return Err(ErrorCode::DuplicateAccounts.into());
}
Ok(())
}
"#;
let program = create_program(code);
let analyzer = DuplicateMutableAccounts;
let findings = analyzer.analyze(&program).unwrap();
assert_eq!(findings.len(), 0);
}
}