use crate::{account::AccountView, error::ProgramError};
pub const MAX_REMAINING_ACCOUNTS: usize = 64;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RemainingError {
DuplicateAccount,
Overflow,
}
impl From<RemainingError> for ProgramError {
fn from(e: RemainingError) -> Self {
match e {
RemainingError::DuplicateAccount => ProgramError::InvalidAccountData,
RemainingError::Overflow => ProgramError::InvalidArgument,
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum RemainingMode {
Strict,
Passthrough,
}
pub struct RemainingAccounts<'a> {
declared: &'a [&'a AccountView],
remaining: &'a [&'a AccountView],
mode: RemainingMode,
}
impl<'a> RemainingAccounts<'a> {
#[inline(always)]
pub fn strict(declared: &'a [&'a AccountView], remaining: &'a [&'a AccountView]) -> Self {
Self { declared, remaining, mode: RemainingMode::Strict }
}
#[inline(always)]
pub fn passthrough(
declared: &'a [&'a AccountView],
remaining: &'a [&'a AccountView],
) -> Self {
Self { declared, remaining, mode: RemainingMode::Passthrough }
}
#[inline(always)]
pub fn len(&self) -> usize {
self.remaining.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.remaining.is_empty()
}
#[inline(always)]
pub fn mode(&self) -> RemainingMode {
self.mode
}
pub fn get(&self, index: usize) -> Result<Option<&'a AccountView>, ProgramError> {
if index >= self.remaining.len() {
return Ok(None);
}
let candidate = self.remaining[index];
match self.mode {
RemainingMode::Passthrough => Ok(Some(candidate)),
RemainingMode::Strict => {
if index > MAX_REMAINING_ACCOUNTS {
return Err(RemainingError::Overflow.into());
}
for d in self.declared {
if d.address() == candidate.address() {
return Err(RemainingError::DuplicateAccount.into());
}
}
for r in &self.remaining[..index] {
if r.address() == candidate.address() {
return Err(RemainingError::DuplicateAccount.into());
}
}
Ok(Some(candidate))
}
}
}
#[inline(always)]
pub fn iter(&self) -> RemainingIter<'a> {
RemainingIter {
declared: self.declared,
remaining: self.remaining,
mode: self.mode,
index: 0,
}
}
}
pub struct RemainingIter<'a> {
declared: &'a [&'a AccountView],
remaining: &'a [&'a AccountView],
mode: RemainingMode,
index: usize,
}
impl<'a> Iterator for RemainingIter<'a> {
type Item = Result<&'a AccountView, ProgramError>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.remaining.len() {
return None;
}
if self.index >= MAX_REMAINING_ACCOUNTS {
self.index = self.remaining.len();
return Some(Err(RemainingError::Overflow.into()));
}
let candidate = self.remaining[self.index];
let i = self.index;
self.index = self.index.wrapping_add(1);
if matches!(self.mode, RemainingMode::Strict) {
for d in self.declared {
if d.address() == candidate.address() {
return Some(Err(RemainingError::DuplicateAccount.into()));
}
}
for r in &self.remaining[..i] {
if r.address() == candidate.address() {
return Some(Err(RemainingError::DuplicateAccount.into()));
}
}
}
Some(Ok(candidate))
}
}
#[inline(always)]
pub fn strict<'a>(
declared: &'a [&'a AccountView],
remaining: &'a [&'a AccountView],
) -> RemainingAccounts<'a> {
RemainingAccounts::strict(declared, remaining)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_variants_surface_as_program_error() {
let dup: ProgramError = RemainingError::DuplicateAccount.into();
assert_eq!(dup, ProgramError::InvalidAccountData);
let ovf: ProgramError = RemainingError::Overflow.into();
assert_eq!(ovf, ProgramError::InvalidArgument);
}
#[test]
fn max_remaining_matches_quasar() {
assert_eq!(MAX_REMAINING_ACCOUNTS, 64);
}
}