use crate::data::{
account::UserAccountArgs,
user::{UserData, UserDataArgs},
};
use crate::error::UserStateError;
use std::collections::{BTreeMap, HashMap};
#[cfg(test)]
use crate::mocks::ic_caller;
#[cfg(not(test))]
use ic_cdk::caller as ic_caller;
use ic_cdk::export::{
candid::{CandidType, Principal},
serde::Deserialize,
};
#[derive(Debug, CandidType, Deserialize, Default, Clone)]
pub struct UserStateConfig {
pub key_name: Option<String>,
pub min_cycles_required: u64,
pub max_cycles_per_user: u64,
pub general_settings: HashMap<String, String>,
}
#[derive(CandidType, Deserialize, Debug)]
pub struct UserState {
pub owner: Principal,
pub wallet_canister: Principal,
pub config: UserStateConfig,
pub users: BTreeMap<Principal, UserData>,
}
impl UserState {
pub fn default() -> Self {
Self {
users: BTreeMap::default(),
owner: ic_caller(),
config: UserStateConfig::default(),
wallet_canister: Principal::anonymous(),
}
}
pub fn init(&mut self, wallet_canister: Principal) {
self.wallet_canister = wallet_canister;
}
fn is_caller_wallet_canister(&self) -> bool {
ic_caller() == self.wallet_canister
}
fn is_caller_owner(&self) -> bool {
ic_caller() == self.owner
}
fn is_caller_user(&self, user: &Principal) -> bool {
ic_caller() == *user
}
pub fn validate_caller_owner(&self) -> Result<(), UserStateError> {
if !self.is_caller_owner() {
return Err(UserStateError::CallerIsNotOwner);
}
Ok(())
}
pub fn validate_caller_wallet_canister(&self) -> Result<(), UserStateError> {
if !self.is_caller_wallet_canister() {
return Err(UserStateError::CallerIsNotWalletCanister);
}
Ok(())
}
pub fn validate_caller_wallet_canister_or_user(
&self,
user: &Principal,
) -> Result<(), UserStateError> {
if !self.is_caller_wallet_canister() && !self.is_caller_user(user) {
return Err(UserStateError::CallerNotAuthorized);
}
Ok(())
}
pub fn change_owner(&mut self, new_owner: Principal) -> Result<Principal, UserStateError> {
self.validate_caller_owner()?;
self.owner = new_owner;
Ok(new_owner)
}
pub fn change_wallet_canister(
&mut self,
new_wallet_canister: Principal,
) -> Result<Principal, UserStateError> {
self.validate_caller_owner()?;
self.wallet_canister = new_wallet_canister;
Ok(new_wallet_canister)
}
pub fn create_user(
&mut self,
user: Principal,
user_args: UserDataArgs,
account_args: UserAccountArgs,
) -> Result<UserData, UserStateError> {
self.validate_caller_wallet_canister()?;
if self.users.contains_key(&user) {
return Err(UserStateError::UserAlreadyExists);
}
let user_data = UserData::new(user_args, account_args);
self.users.insert(user, user_data);
Ok(self.users.get(&user).unwrap().clone())
}
pub fn get_user_mut(&mut self, user: &Principal) -> Result<&mut UserData, UserStateError> {
self.validate_caller_wallet_canister()?;
self.users.get_mut(user).ok_or(UserStateError::UserNotFound)
}
pub fn get_user(&self, user: &Principal) -> Result<&UserData, UserStateError> {
self.validate_caller_wallet_canister_or_user(user)?;
self.users
.get(user)
.map(|user_data| user_data)
.ok_or(UserStateError::UserNotFound)
}
pub fn get_user_derivation_path(
&self,
user: &Principal,
key: u8,
) -> Result<Vec<u8>, UserStateError> {
self.validate_caller_wallet_canister_or_user(user)?;
let user_data = self.users.get(user);
match user_data {
Some(user_data) => {
let derivation_path = user_data.get_derivation_path(*user, key)?;
Ok(derivation_path)
}
None => Err(UserStateError::UserNotFound),
}
}
}
#[cfg(test)]
mod tests {
use crate::{mocks::*, state::UserState};
use super::*;
use proptest::prelude::*;
fn initialize_state() -> UserState {
owner_caller();
let mut state = UserState::default();
state.init(wallet_canister_principal());
state
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(5))]
#[test]
fn test_user_state_validate_caller_owner(principal in principal_strategy()) {
let mut user_state = initialize_state();
owner_caller();
user_state.validate_caller_owner().ok();
random_caller();
user_state.validate_caller_owner().err();
owner_caller();
let new_owner = user_state.change_owner(principal).unwrap();
assert_eq!(user_state.owner, principal);
assert_eq!(user_state.owner, new_owner);
}
#[test]
fn test_user_state_validate_caller_wallet_canister(_ in ".*") {
let mut user_state = initialize_state();
let principal = wallet_canister_principal();
random_caller();
user_state.init(principal);
user_state.validate_caller_wallet_canister().err();
wallet_canister_caller();
user_state.validate_caller_wallet_canister().ok();
let random_data = UserDataArgs::default();
let randdom_account = UserAccountArgs::default();
user_state.create_user(
principal,
random_data,
randdom_account,
).ok();
random_caller();
assert!(user_state.get_user_mut(&principal).is_err());
}
#[test]
fn test_user_state_validate_caller_wallet_canister_or_user(principal in principal_strategy()) {
let mut user_state = initialize_state();
set_caller(principal);
user_state.validate_caller_wallet_canister_or_user(&principal).ok();
let random_data = UserDataArgs::default();
let randdom_account = UserAccountArgs::default();
wallet_canister_caller();
user_state.create_user(
principal,
random_data,
randdom_account,
).ok();
user_state.get_user(&principal).ok();
random_caller();
user_state.get_user(&principal).err();
owner_caller();
user_state.get_user(&principal).err();
wallet_canister_caller();
user_state.get_user(&principal).ok();
}
#[test]
fn test_user_state_create_user(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
let mut user_state = initialize_state();
wallet_canister_caller();
let user_data = user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
assert_eq!(user_data.accounts[0].name, account_args.name.unwrap_or("Account 0".to_owned()));
}
#[test]
fn test_user_state_get_user(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
let mut user_state = initialize_state();
wallet_canister_caller();
user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
user_state.get_user(&principal).ok();
set_caller(principal);
let user_data = user_state.get_user(&principal).unwrap();
assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
assert_eq!(user_data.accounts[0].name, account_args.name.unwrap_or("Account 0".to_owned()));
}
#[test]
fn test_user_state_get_user_mut(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
let mut user_state = initialize_state();
wallet_canister_caller();
user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
set_caller(principal);
user_state.get_user_mut(&principal).err();
wallet_canister_caller();
let user_data = user_state.get_user_mut(&principal).unwrap();
assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
assert_eq!(user_data.accounts[0].name, account_args.name.unwrap_or("Account 0".to_owned()));
}
#[test]
fn test_user_state_get_user_derivation_path(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
let mut user_state = initialize_state();
wallet_canister_caller();
user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
user_state.get_user_derivation_path(&principal, 20).err();
let derivation_path = user_state.get_user_derivation_path(&principal, 0).unwrap();
let last = derivation_path.last().unwrap();
assert_eq!(last, &0);
let derivation_path_1 = user_state.get_user_derivation_path(&principal, 1).unwrap();
let last = derivation_path_1.last().unwrap();
assert_eq!(last, &1);
let mut _derivation_path = principal.as_slice().to_vec();
_derivation_path.push(0);
assert_eq!(derivation_path, _derivation_path);
let mut _derivation_path_1 = principal.as_slice().to_vec();
_derivation_path_1.push(1);
assert_eq!(derivation_path_1, _derivation_path_1);
}
}
}