use core::marker::PhantomData;
use super::StorageMapper;
use crate::{
abi::{TypeAbi, TypeName},
api::{EndpointFinishApi, ManagedTypeApi, StorageMapperApi},
finish_all,
io::EndpointResult,
storage::{storage_get, storage_get_len, storage_set, StorageKey},
types::{ManagedAddress, ManagedType, ManagedVec, MultiResultVec},
};
const ADDRESS_TO_ID_SUFFIX: &[u8] = b"_address_to_id";
const ID_TO_ADDRESS_SUFFIX: &[u8] = b"_id_to_address";
const COUNT_SUFFIX: &[u8] = b"_count";
pub struct UserMapper<SA>
where
SA: StorageMapperApi,
{
_phantom_api: PhantomData<SA>,
base_key: StorageKey<SA>,
}
impl<SA> StorageMapper<SA> for UserMapper<SA>
where
SA: StorageMapperApi,
{
fn new(base_key: StorageKey<SA>) -> Self {
UserMapper {
_phantom_api: PhantomData,
base_key,
}
}
}
impl<SA> UserMapper<SA>
where
SA: StorageMapperApi,
{
fn get_user_id_key(&self, address: &ManagedAddress<SA>) -> StorageKey<SA> {
let mut user_id_key = self.base_key.clone();
user_id_key.append_bytes(ADDRESS_TO_ID_SUFFIX);
user_id_key.append_item(address);
user_id_key
}
fn get_user_address_key(&self, id: usize) -> StorageKey<SA> {
let mut user_address_key = self.base_key.clone();
user_address_key.append_bytes(ID_TO_ADDRESS_SUFFIX);
user_address_key.append_item(&id);
user_address_key
}
fn get_user_count_key(&self) -> StorageKey<SA> {
let mut user_count_key = self.base_key.clone();
user_count_key.append_bytes(COUNT_SUFFIX);
user_count_key
}
pub fn get_user_id(&self, address: &ManagedAddress<SA>) -> usize {
storage_get(self.get_user_id_key(address).as_ref())
}
fn set_user_id(&self, address: &ManagedAddress<SA>, id: usize) {
storage_set(self.get_user_id_key(address).as_ref(), &id);
}
pub fn get_user_address(&self, id: usize) -> Option<ManagedAddress<SA>> {
let key = self.get_user_address_key(id);
if storage_get_len(key.as_ref()) > 0 {
Some(storage_get(key.as_ref()))
} else {
None
}
}
pub fn get_user_address_unchecked(&self, id: usize) -> ManagedAddress<SA> {
storage_get(self.get_user_address_key(id).as_ref())
}
pub fn get_user_address_or_zero(&self, id: usize) -> ManagedAddress<SA> {
let key = self.get_user_address_key(id);
if storage_get_len(key.as_ref()) > 0 {
storage_get(key.as_ref())
} else {
ManagedAddress::zero()
}
}
fn set_user_address(&self, id: usize, address: &ManagedAddress<SA>) {
storage_set(self.get_user_address_key(id).as_ref(), address);
}
pub fn get_user_count(&self) -> usize {
storage_get(self.get_user_count_key().as_ref())
}
fn set_user_count(&self, user_count: usize) {
storage_set(self.get_user_count_key().as_ref(), &user_count);
}
pub fn get_or_create_user(&self, address: &ManagedAddress<SA>) -> usize {
let mut user_id = self.get_user_id(address);
if user_id == 0 {
let mut user_count = self.get_user_count();
user_count += 1;
self.set_user_count(user_count);
user_id = user_count;
self.set_user_id(address, user_id);
self.set_user_address(user_id, address);
}
user_id
}
pub fn get_or_create_users<AddressIter, F>(
&self,
address_iter: AddressIter,
mut user_id_lambda: F,
) where
AddressIter: Iterator<Item = ManagedAddress<SA>>,
F: FnMut(usize, bool),
{
let mut user_count = self.get_user_count();
for address in address_iter {
let mut user_id = self.get_user_id(&address);
if user_id > 0 {
user_id_lambda(user_id, false);
} else {
user_count += 1;
user_id = user_count;
self.set_user_id(&address, user_id);
self.set_user_address(user_id, &address);
user_id_lambda(user_id, true);
}
}
self.set_user_count(user_count);
}
pub fn get_all_addresses(&self) -> ManagedVec<SA, ManagedAddress<SA>> {
let user_count = self.get_user_count();
let mut result = ManagedVec::new();
for i in 1..=user_count {
result.push(self.get_user_address_or_zero(i));
}
result
}
}
impl<SA> EndpointResult for UserMapper<SA>
where
SA: StorageMapperApi,
{
type DecodeAs = MultiResultVec<ManagedAddress<SA>>;
fn finish<FA>(&self)
where
FA: ManagedTypeApi + EndpointFinishApi,
{
let all_addresses = self.get_all_addresses();
finish_all::<FA, _, _>(all_addresses.into_iter());
}
}
impl<SA> TypeAbi for UserMapper<SA>
where
SA: StorageMapperApi,
{
fn type_name() -> TypeName {
crate::types::MultiResultVec::<ManagedAddress<SA>>::type_name()
}
fn is_multi_arg_or_result() -> bool {
true
}
}