1use crate::data::{
2 account::UserAccountArgs,
3 user::{UserData, UserDataArgs},
4};
5use crate::error::UserStateError;
6use std::collections::{BTreeMap, HashMap};
7
8#[cfg(test)]
9use crate::mocks::ic_caller;
10
11#[cfg(not(test))]
12use ic_cdk::caller as ic_caller;
13
14use ic_cdk::export::{
15 candid::{CandidType, Principal},
16 serde::Deserialize,
17};
18
19#[derive(Debug, CandidType, Deserialize, Default, Clone)]
20pub struct UserStateConfig {
21 pub key_name: Option<String>,
22 pub min_cycles_required: u64,
23 pub max_cycles_per_user: u64,
24 pub general_settings: HashMap<String, String>,
25}
26
27#[derive(CandidType, Deserialize, Debug)]
29pub struct UserState {
30 pub owner: Principal,
31 pub wallet_canister: Principal,
32 pub config: UserStateConfig,
33 pub users: BTreeMap<Principal, UserData>,
34}
35
36impl UserState {
37 pub fn default() -> Self {
39 Self {
40 users: BTreeMap::default(),
41 owner: ic_caller(),
42 config: UserStateConfig::default(),
43 wallet_canister: Principal::anonymous(),
44 }
45 }
46
47 pub fn init(&mut self, wallet_canister: Principal) {
50 self.wallet_canister = wallet_canister;
51 }
52
53 fn is_caller_wallet_canister(&self) -> bool {
55 ic_caller() == self.wallet_canister
56 }
57
58 fn is_caller_owner(&self) -> bool {
60 ic_caller() == self.owner
61 }
62
63 fn is_caller_user(&self, user: &Principal) -> bool {
65 ic_caller() == *user
66 }
67
68 pub fn validate_caller_owner(&self) -> Result<(), UserStateError> {
70 if !self.is_caller_owner() {
71 return Err(UserStateError::CallerIsNotOwner);
72 }
73
74 Ok(())
75 }
76
77 pub fn validate_caller_wallet_canister(&self) -> Result<(), UserStateError> {
79 if !self.is_caller_wallet_canister() {
80 return Err(UserStateError::CallerIsNotWalletCanister);
81 }
82
83 Ok(())
84 }
85
86 pub fn validate_caller_wallet_canister_or_user(
88 &self,
89 user: &Principal,
90 ) -> Result<(), UserStateError> {
91 if !self.is_caller_wallet_canister() && !self.is_caller_user(user) {
92 return Err(UserStateError::CallerNotAuthorized);
93 }
94
95 Ok(())
96 }
97
98 pub fn change_owner(&mut self, new_owner: Principal) -> Result<Principal, UserStateError> {
101 self.validate_caller_owner()?;
102
103 self.owner = new_owner;
104
105 Ok(new_owner)
106 }
107
108 pub fn change_wallet_canister(
111 &mut self,
112 new_wallet_canister: Principal,
113 ) -> Result<Principal, UserStateError> {
114 self.validate_caller_owner()?;
115
116 self.wallet_canister = new_wallet_canister;
117
118 Ok(new_wallet_canister)
119 }
120
121 pub fn create_user(
124 &mut self,
125 user: Principal,
126 user_args: UserDataArgs,
127 account_args: UserAccountArgs,
128 ) -> Result<UserData, UserStateError> {
129 self.validate_caller_wallet_canister()?;
130
131 if self.users.contains_key(&user) {
132 return Err(UserStateError::UserAlreadyExists);
133 }
134
135 let user_data = UserData::new(user_args, account_args);
136 self.users.insert(user, user_data);
137
138 Ok(self.users.get(&user).unwrap().clone())
139 }
140
141 pub fn get_user_mut(&mut self, user: &Principal) -> Result<&mut UserData, UserStateError> {
144 self.validate_caller_wallet_canister()?;
145
146 self.users.get_mut(user).ok_or(UserStateError::UserNotFound)
147 }
148
149 pub fn get_user(&self, user: &Principal) -> Result<&UserData, UserStateError> {
152 self.validate_caller_wallet_canister_or_user(user)?;
153
154 self.users
155 .get(user)
156 .map(|user_data| user_data)
157 .ok_or(UserStateError::UserNotFound)
158 }
159
160 pub fn get_user_derivation_path(
162 &self,
163 user: &Principal,
164 key: u8,
165 ) -> Result<Vec<u8>, UserStateError> {
166 self.validate_caller_wallet_canister_or_user(user)?;
167
168 let user_data = self.users.get(user);
169
170 match user_data {
171 Some(user_data) => {
172 let derivation_path = user_data.get_derivation_path(*user, key)?;
173
174 Ok(derivation_path)
175 }
176 None => Err(UserStateError::UserNotFound),
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use crate::{mocks::*, state::UserState};
184
185 use super::*;
186 use proptest::prelude::*;
187
188 fn initialize_state() -> UserState {
189 owner_caller();
190
191 let mut state = UserState::default();
192 state.init(wallet_canister_principal());
193
194 state
195 }
196
197 proptest! {
198 #![proptest_config(ProptestConfig::with_cases(5))]
199
200 #[test]
201 fn test_user_state_validate_caller_owner(principal in principal_strategy()) {
202 let mut user_state = initialize_state();
203
204 owner_caller();
205
206 user_state.validate_caller_owner().ok();
207
208 random_caller();
209
210 user_state.validate_caller_owner().err();
211
212 owner_caller();
213
214 let new_owner = user_state.change_owner(principal).unwrap();
215
216 assert_eq!(user_state.owner, principal);
217 assert_eq!(user_state.owner, new_owner);
218 }
219
220 #[test]
221 fn test_user_state_validate_caller_wallet_canister(_ in ".*") {
222 let mut user_state = initialize_state();
223
224 let principal = wallet_canister_principal();
225
226 random_caller();
227
228 user_state.init(principal);
229
230 user_state.validate_caller_wallet_canister().err();
231
232 wallet_canister_caller();
233
234 user_state.validate_caller_wallet_canister().ok();
235
236 let random_data = UserDataArgs::default();
237
238 let randdom_account = UserAccountArgs::default();
239
240 user_state.create_user(
241 principal,
242 random_data,
243 randdom_account,
244 ).ok();
245
246 random_caller();
247
248 assert!(user_state.get_user_mut(&principal).is_err());
249 }
250
251 #[test]
252 fn test_user_state_validate_caller_wallet_canister_or_user(principal in principal_strategy()) {
253 let mut user_state = initialize_state();
254
255 set_caller(principal);
256
257 user_state.validate_caller_wallet_canister_or_user(&principal).ok();
258
259 let random_data = UserDataArgs::default();
260
261 let randdom_account = UserAccountArgs::default();
262
263 wallet_canister_caller();
264
265 user_state.create_user(
266 principal,
267 random_data,
268 randdom_account,
269 ).ok();
270
271 user_state.get_user(&principal).ok();
272
273 random_caller();
274
275 user_state.get_user(&principal).err();
276
277 owner_caller();
278
279 user_state.get_user(&principal).err();
280
281 wallet_canister_caller();
282
283 user_state.get_user(&principal).ok();
284 }
285
286 #[test]
287 fn test_user_state_create_user(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
288 let mut user_state = initialize_state();
289
290 wallet_canister_caller();
291
292 let user_data = user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
293
294 assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
295 assert_eq!(user_data.accounts[0].name, account_args.name.unwrap_or("Account 0".to_owned()));
296 }
297
298 #[test]
299 fn test_user_state_get_user(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
300 let mut user_state = initialize_state();
301
302 wallet_canister_caller();
303
304 user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
305
306 user_state.get_user(&principal).ok();
307
308 set_caller(principal);
309
310 let user_data = user_state.get_user(&principal).unwrap();
311
312 assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
313 assert_eq!(user_data.accounts[0].name, account_args.name.unwrap_or("Account 0".to_owned()));
314 }
315
316 #[test]
317 fn test_user_state_get_user_mut(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
318 let mut user_state = initialize_state();
319
320 wallet_canister_caller();
321
322 user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
323
324 set_caller(principal);
325
326 user_state.get_user_mut(&principal).err();
327
328 wallet_canister_caller();
329
330 let user_data = user_state.get_user_mut(&principal).unwrap();
331
332 assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
333 assert_eq!(user_data.accounts[0].name, account_args.name.unwrap_or("Account 0".to_owned()));
334 }
335
336 #[test]
337 fn test_user_state_get_user_derivation_path(principal in principal_strategy(), user_args: UserDataArgs, account_args: UserAccountArgs) {
338 let mut user_state = initialize_state();
339
340 wallet_canister_caller();
341
342 user_state.create_user(principal, user_args.clone(), account_args.clone()).unwrap();
343
344 user_state.get_user_derivation_path(&principal, 20).err();
345
346 let derivation_path = user_state.get_user_derivation_path(&principal, 0).unwrap();
347
348 let last = derivation_path.last().unwrap();
349
350 assert_eq!(last, &0);
351
352 let derivation_path_1 = user_state.get_user_derivation_path(&principal, 1).unwrap();
353
354 let last = derivation_path_1.last().unwrap();
355
356 assert_eq!(last, &1);
357
358 let mut _derivation_path = principal.as_slice().to_vec();
359
360 _derivation_path.push(0);
361
362 assert_eq!(derivation_path, _derivation_path);
363
364 let mut _derivation_path_1 = principal.as_slice().to_vec();
365
366 _derivation_path_1.push(1);
367
368 assert_eq!(derivation_path_1, _derivation_path_1);
369 }
370 }
371}