use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
use serde::Serialize;
use solana_program::pubkey::Pubkey;
use crate::error::LidoError;
use crate::util::serialize_b58;
#[derive(
Clone, Default, Debug, Eq, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema, Serialize,
)]
pub struct PubkeyAndEntry<T> {
#[serde(serialize_with = "serialize_b58")]
pub pubkey: Pubkey,
pub entry: T,
}
#[derive(
Clone, Default, Debug, Eq, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema, Serialize,
)]
pub struct AccountMap<T> {
pub entries: Vec<PubkeyAndEntry<T>>,
pub maximum_entries: u32,
}
pub trait EntryConstantSize {
const SIZE: usize;
}
pub type AccountSet = AccountMap<()>;
impl<T: Default + EntryConstantSize> AccountMap<T> {
pub fn new_fill_default(maximum_entries: u32) -> Self {
let mut v = Vec::with_capacity(maximum_entries as usize);
for _ in 0..maximum_entries {
v.push(PubkeyAndEntry {
pubkey: Pubkey::default(),
entry: T::default(),
});
}
AccountMap {
entries: v,
maximum_entries,
}
}
pub fn new(maximum_entries: u32) -> Self {
AccountMap {
entries: Vec::new(),
maximum_entries,
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn add(&mut self, address: Pubkey, value: T) -> Result<(), LidoError> {
if self.len() == self.maximum_entries as usize {
return Err(LidoError::MaximumNumberOfAccountsExceeded);
}
if !self.entries.iter().any(|pe| pe.pubkey == address) {
self.entries.push(PubkeyAndEntry {
pubkey: address,
entry: value,
});
} else {
return Err(LidoError::DuplicatedEntry);
}
Ok(())
}
pub fn remove(&mut self, address: &Pubkey) -> Result<T, LidoError> {
let idx = self
.entries
.iter()
.position(|pe| &pe.pubkey == address)
.ok_or(LidoError::InvalidAccountMember)?;
Ok(self.entries.swap_remove(idx).entry)
}
pub fn get(&self, address: &Pubkey) -> Result<&PubkeyAndEntry<T>, LidoError> {
self.entries
.iter()
.find(|pe| &pe.pubkey == address)
.ok_or(LidoError::InvalidAccountMember)
}
pub fn get_mut(&mut self, address: &Pubkey) -> Result<&mut PubkeyAndEntry<T>, LidoError> {
self.entries
.iter_mut()
.find(|pe| &pe.pubkey == address)
.ok_or(LidoError::InvalidAccountMember)
}
pub fn required_bytes(max_entries: usize) -> usize {
let key_size = std::mem::size_of::<Pubkey>();
let value_size = T::SIZE;
let entry_size = key_size + value_size;
8 + entry_size * max_entries as usize
}
pub fn maximum_entries(buffer_size: usize) -> usize {
let key_size = std::mem::size_of::<Pubkey>();
let value_size = T::SIZE;
let entry_size = key_size + value_size;
buffer_size.saturating_sub(8) / entry_size
}
pub fn iter_entries(&self) -> IterEntries<T> {
IterEntries {
iter: self.entries.iter(),
}
}
pub fn iter_entries_mut(&mut self) -> IterEntriesMut<T> {
IterEntriesMut {
iter: self.entries.iter_mut(),
}
}
}
pub struct IterEntries<'a, T: 'a> {
iter: std::slice::Iter<'a, PubkeyAndEntry<T>>,
}
impl<'a, T: 'a> std::iter::Iterator for IterEntries<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
self.iter.next().map(|pubkey_entry| &pubkey_entry.entry)
}
}
pub struct IterEntriesMut<'a, T: 'a> {
iter: std::slice::IterMut<'a, PubkeyAndEntry<T>>,
}
impl<'a, T: 'a> std::iter::Iterator for IterEntriesMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<&'a mut T> {
self.iter.next().map(|pubkey_entry| &mut pubkey_entry.entry)
}
}
#[cfg(test)]
mod test {
use super::*;
impl EntryConstantSize for u32 {
const SIZE: usize = 4;
}
#[test]
fn test_account_map_limit() {
let mut map = AccountMap::new(1);
let result_0 = map.add(Pubkey::new_unique(), 0_u32);
let result_1 = map.add(Pubkey::new_unique(), 1_u32);
assert_eq!(result_0, Ok(()));
assert_eq!(result_1, Err(LidoError::MaximumNumberOfAccountsExceeded));
}
#[test]
fn test_account_map_duplicate() {
let mut map = AccountMap::new(2);
let key = Pubkey::new_unique();
let result_0 = map.add(key, 0_u32);
let result_1 = map.add(key, 1_u32);
assert_eq!(result_0, Ok(()));
assert_eq!(result_1, Err(LidoError::DuplicatedEntry));
}
#[test]
fn test_account_map_add_remove() {
let mut map = AccountMap::new(1);
let key = Pubkey::new_unique();
map.add(key, 0_u32).unwrap();
assert_eq!(map.get(&key).map(|pe| pe.entry), Ok(0));
assert_eq!(map.get_mut(&key).map(|pe| pe.entry), Ok(0));
assert_eq!(map.remove(&key), Ok(0));
assert_eq!(map.get(&key), Err(LidoError::InvalidAccountMember));
assert_eq!(map.get_mut(&key), Err(LidoError::InvalidAccountMember));
assert_eq!(map.remove(&key), Err(LidoError::InvalidAccountMember));
}
#[test]
fn test_account_map_iter_entries() {
let mut map: AccountMap<u32> = AccountMap::new(2);
map.add(Pubkey::new_unique(), 0).unwrap();
map.add(Pubkey::new_unique(), 1).unwrap();
for entry in map.iter_entries_mut() {
*entry = 2;
}
for entry in map.iter_entries() {
assert_eq!(*entry, 2);
}
}
}