use crate::{account::AccountView, account_wrappers::Signer, error::ProgramError};
pub const MAX_REMAINING_ACCOUNTS: usize = 64;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RemainingError {
DuplicateAccount,
Overflow,
}
impl From<RemainingError> for ProgramError {
fn from(e: RemainingError) -> Self {
match e {
RemainingError::DuplicateAccount => ProgramError::InvalidAccountData,
RemainingError::Overflow => ProgramError::InvalidArgument,
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum RemainingMode {
Strict,
Passthrough,
}
pub struct RemainingAccounts<'a> {
declared: &'a [AccountView],
remaining: &'a [AccountView],
mode: RemainingMode,
}
impl<'a> RemainingAccounts<'a> {
#[inline(always)]
pub fn strict(declared: &'a [AccountView], remaining: &'a [AccountView]) -> Self {
Self {
declared,
remaining,
mode: RemainingMode::Strict,
}
}
#[inline(always)]
pub fn passthrough(declared: &'a [AccountView], remaining: &'a [AccountView]) -> Self {
Self {
declared,
remaining,
mode: RemainingMode::Passthrough,
}
}
#[inline(always)]
pub fn len(&self) -> usize {
self.remaining.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.remaining.is_empty()
}
#[inline(always)]
pub fn mode(&self) -> RemainingMode {
self.mode
}
#[inline(always)]
pub fn as_slice(&self) -> &'a [AccountView] {
self.remaining
}
pub fn get(&self, index: usize) -> Result<Option<&'a AccountView>, ProgramError> {
if index >= self.remaining.len() {
return Ok(None);
}
let candidate = &self.remaining[index];
match self.mode {
RemainingMode::Passthrough => Ok(Some(candidate)),
RemainingMode::Strict => {
if index >= MAX_REMAINING_ACCOUNTS {
return Err(RemainingError::Overflow.into());
}
for d in self.declared {
if d.address() == candidate.address() {
return Err(RemainingError::DuplicateAccount.into());
}
}
for r in &self.remaining[..index] {
if r.address() == candidate.address() {
return Err(RemainingError::DuplicateAccount.into());
}
}
Ok(Some(candidate))
}
}
}
pub fn account_views<const N: usize>(
&self,
) -> Result<RemainingAccountViews<'a, N>, ProgramError> {
if self.remaining.len() > N {
return Err(RemainingError::Overflow.into());
}
let mut items: [Option<&'a AccountView>; N] = [None; N];
let mut index = 0;
while index < self.remaining.len() {
let account = self.get(index)?.ok_or(ProgramError::NotEnoughAccountKeys)?;
items[index] = Some(account);
index += 1;
}
Ok(RemainingAccountViews { items, len: index })
}
pub fn signers<const N: usize>(&self) -> Result<RemainingSigners<'a, N>, ProgramError> {
if self.remaining.len() > N {
return Err(RemainingError::Overflow.into());
}
let mut items: [Option<Signer<'a>>; N] = [None; N];
let mut index = 0;
while index < self.remaining.len() {
let account = self.get(index)?.ok_or(ProgramError::NotEnoughAccountKeys)?;
items[index] = Some(Signer::try_new(account)?);
index += 1;
}
Ok(RemainingSigners { items, len: index })
}
#[inline(always)]
pub fn iter(&self) -> RemainingIter<'a> {
RemainingIter {
declared: self.declared,
remaining: self.remaining,
mode: self.mode,
index: 0,
}
}
}
pub struct RemainingIter<'a> {
declared: &'a [AccountView],
remaining: &'a [AccountView],
mode: RemainingMode,
index: usize,
}
impl<'a> Iterator for RemainingIter<'a> {
type Item = Result<&'a AccountView, ProgramError>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.remaining.len() {
return None;
}
if self.index >= MAX_REMAINING_ACCOUNTS {
self.index = self.remaining.len();
return Some(Err(RemainingError::Overflow.into()));
}
let candidate = &self.remaining[self.index];
let i = self.index;
self.index = self.index.wrapping_add(1);
if matches!(self.mode, RemainingMode::Strict) {
for d in self.declared {
if d.address() == candidate.address() {
return Some(Err(RemainingError::DuplicateAccount.into()));
}
}
for r in &self.remaining[..i] {
if r.address() == candidate.address() {
return Some(Err(RemainingError::DuplicateAccount.into()));
}
}
}
Some(Ok(candidate))
}
}
pub struct RemainingAccountViews<'a, const N: usize> {
items: [Option<&'a AccountView>; N],
len: usize,
}
impl<'a, const N: usize> RemainingAccountViews<'a, N> {
#[inline(always)]
pub const fn len(&self) -> usize {
self.len
}
#[inline(always)]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[inline(always)]
pub fn get(&self, index: usize) -> Option<&'a AccountView> {
if index >= self.len {
None
} else {
self.items[index]
}
}
#[inline(always)]
pub fn iter(&self) -> RemainingAccountViewIter<'_, 'a, N> {
RemainingAccountViewIter {
set: self,
index: 0,
}
}
}
pub struct RemainingAccountViewIter<'set, 'a, const N: usize> {
set: &'set RemainingAccountViews<'a, N>,
index: usize,
}
impl<'a, const N: usize> Iterator for RemainingAccountViewIter<'_, 'a, N> {
type Item = &'a AccountView;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.set.len {
return None;
}
let item = self.set.items[self.index];
self.index += 1;
item
}
}
pub struct RemainingSigners<'a, const N: usize> {
items: [Option<Signer<'a>>; N],
len: usize,
}
impl<'a, const N: usize> RemainingSigners<'a, N> {
#[inline(always)]
pub const fn len(&self) -> usize {
self.len
}
#[inline(always)]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[inline(always)]
pub fn get(&self, index: usize) -> Option<Signer<'a>> {
if index >= self.len {
None
} else {
self.items[index]
}
}
#[inline(always)]
pub fn iter(&self) -> RemainingSignerIter<'_, 'a, N> {
RemainingSignerIter {
set: self,
index: 0,
}
}
}
pub struct RemainingSignerIter<'set, 'a, const N: usize> {
set: &'set RemainingSigners<'a, N>,
index: usize,
}
impl<'a, const N: usize> Iterator for RemainingSignerIter<'_, 'a, N> {
type Item = Signer<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.set.len {
return None;
}
let item = self.set.items[self.index];
self.index += 1;
item
}
}
#[inline(always)]
pub fn strict<'a>(
declared: &'a [AccountView],
remaining: &'a [AccountView],
) -> RemainingAccounts<'a> {
RemainingAccounts::strict(declared, remaining)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_variants_surface_as_program_error() {
let dup: ProgramError = RemainingError::DuplicateAccount.into();
assert_eq!(dup, ProgramError::InvalidAccountData);
let ovf: ProgramError = RemainingError::Overflow.into();
assert_eq!(ovf, ProgramError::InvalidArgument);
}
#[test]
fn max_remaining_matches_quasar() {
assert_eq!(MAX_REMAINING_ACCOUNTS, 64);
}
}