1pub mod data;
2pub mod error;
3pub mod mocks;
4pub mod state;
5
6use data::account::UserAccountArgs;
7use data::user::{UserData, UserDataArgs};
8use error::UserStateError;
9use ic_cdk::export::candid::Principal;
10use state::UserState;
11use std::cell::RefCell;
12
13thread_local! {
14 static USER_STATE: RefCell<UserState> = RefCell::new(UserState::default());
15}
16
17pub fn initialize(wallet_canister_id: String) -> Result<(), UserStateError> {
18 let wallet_canister = Principal::from_text(wallet_canister_id).unwrap();
19
20 USER_STATE.with(|s| {
21 let mut mut_state = s.borrow_mut();
22
23 mut_state.init(wallet_canister);
24
25 Ok(())
26 })
27}
28
29pub fn get_wallet_canister() -> Result<Principal, UserStateError> {
30 USER_STATE.with(|s| {
31 let state = s.borrow();
32
33 Ok(state.wallet_canister)
34 })
35}
36
37pub fn get_owner() -> Result<Principal, UserStateError> {
38 USER_STATE.with(|s| {
39 let state = s.borrow();
40
41 Ok(state.owner)
42 })
43}
44
45pub fn change_owner(new_owner: Principal) -> Result<Principal, UserStateError> {
46 USER_STATE.with(|s| {
47 let mut mut_state = s.borrow_mut();
48
49 mut_state.change_owner(new_owner)
50 })
51}
52
53pub fn change_wallet_canister(wallet_canister: Principal) -> Result<Principal, UserStateError> {
54 USER_STATE.with(|s| {
55 let mut mut_state = s.borrow_mut();
56
57 mut_state.change_wallet_canister(wallet_canister)
58 })
59}
60
61pub fn add_user(
62 user: Principal,
63 user_args: UserDataArgs,
64 account_args: UserAccountArgs,
65) -> Result<UserData, UserStateError> {
66 USER_STATE.with(|s| {
67 let mut mut_state = s.borrow_mut();
68
69 mut_state.create_user(user, user_args, account_args)
70 })
71}
72
73pub fn get_user(user: &Principal) -> Result<UserData, UserStateError> {
74 USER_STATE.with(|s| {
75 let state = s.borrow();
76
77 state.get_user(user).map(|user_data| user_data.clone())
78 })
79}
80
81pub fn with_user_state<T, F>(user: &Principal, callback: F) -> Result<T, UserStateError>
82where
83 F: FnOnce(&UserData) -> T,
84{
85 USER_STATE.with(|state| {
86 let state = state.borrow();
87
88 state.get_user(user).map(callback)
89 })
90}
91
92pub fn with_user_state_mut<T, F>(user: &Principal, callback: F) -> Result<T, UserStateError>
93where
94 F: FnOnce(&mut UserData) -> T,
95{
96 USER_STATE.with(|state| {
97 let mut state = state.borrow_mut();
98
99 state.get_user_mut(user).map(callback)
100 })
101}
102
103pub fn pre_upgrade() -> Result<(), candid::Error> {
104 USER_STATE.with(|s| {
105 let state = s.borrow();
106
107 ic_cdk::storage::stable_save((state.owner, state.wallet_canister, state.users.clone()))
108 })
109}
110
111pub fn post_upgrade() -> Result<(), candid::Error> {
112 USER_STATE.with(|s| {
113 let mut state = s.borrow_mut();
114
115 (state.owner, state.wallet_canister, state.users) =
116 ic_cdk::storage::stable_restore().unwrap();
117
118 Ok(())
119 })
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::{mocks::*, state::UserState};
126
127 #[test]
128 fn test_initialize() {
129 let principal = random_principal();
130
131 assert!(initialize(principal.to_string()).is_ok());
132
133 USER_STATE.with(|s| {
134 let state = s.borrow();
135
136 assert_eq!(state.wallet_canister, principal);
137 });
138 }
139
140 #[test]
141 fn test_change_owner() {
142 let owner = random_principal();
143
144 assert!(change_owner(owner.clone()).is_ok());
145
146 USER_STATE.with(|s| {
147 let state = s.borrow();
148
149 assert_eq!(state.owner, owner);
150 });
151 }
152
153 #[test]
154 fn test_change_wallet_canister() {
155 owner_caller();
156
157 let wallet_canister = wallet_canister_principal();
158
159 assert!(change_wallet_canister(wallet_canister).is_ok());
160
161 USER_STATE.with(|s| {
162 let state = s.borrow();
163 assert_eq!(state.wallet_canister, wallet_canister);
164 });
165 }
166
167 #[test]
168 fn test_user_state_init() {
169 random_caller();
170
171 let mut user_state = UserState::default();
172
173 let principal = ic_caller();
174
175 random_caller();
176
177 user_state.init(principal);
178
179 assert_eq!(user_state.wallet_canister, principal);
180
181 let user = ic_caller();
182
183 user_state.init(principal);
184
185 random_caller();
186
187 assert!(user_state
188 .validate_caller_wallet_canister_or_user(&user)
189 .is_err());
190 }
191
192 #[test]
193 fn test_with_user_state() {
194 let principal = ic_caller();
195
196 initialize(principal.to_string()).unwrap();
197
198 let user_principal = random_principal();
199
200 let user_args = UserDataArgs {
201 balance: Some(100),
202 ..UserDataArgs::default()
203 };
204
205 let account_args = UserAccountArgs {
206 name: Some("Account 1".to_owned()),
207 ..UserAccountArgs::default()
208 };
209
210 set_caller(principal.clone());
211
212 let _ = add_user(
213 user_principal.clone(),
214 user_args.clone(),
215 account_args.clone(),
216 )
217 .unwrap();
218
219 set_caller(principal.clone());
220 let result = with_user_state(&user_principal, |user_data| user_data.balance).unwrap();
221 assert_eq!(result, user_args.balance.unwrap_or_default());
222 }
223
224 #[test]
225 fn test_with_user_state_mut() {
226 owner_caller();
227
228 let principal = wallet_canister_principal();
229
230 initialize(principal.to_string()).unwrap();
231
232 wallet_canister_caller();
233
234 let user_principal = random_principal();
235
236 let user_args = UserDataArgs {
237 balance: Some(100),
238 ..UserDataArgs::default()
239 };
240 let account_args = UserAccountArgs {
241 name: Some("Account 1".to_owned()),
242 ..UserAccountArgs::default()
243 };
244 let user_data = add_user(
245 user_principal.clone(),
246 user_args.clone(),
247 account_args.clone(),
248 )
249 .unwrap();
250
251 assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
252
253 let new_balance = 200;
254 let result = with_user_state_mut(&user_principal, |user_data| {
255 user_data.balance = new_balance;
256 user_data.balance
257 })
258 .unwrap();
259
260 assert_eq!(result, new_balance);
261
262 USER_STATE.with(|s| {
263 let state = s.borrow();
264 let stored_user_data = state.get_user(&user_principal).unwrap();
265
266 assert_eq!(stored_user_data.balance, new_balance);
267 });
268 }
269
270 #[test]
271 fn test_add_user() {
272 let principal = wallet_canister_principal();
273
274 owner_caller();
275 initialize(principal.to_string()).unwrap();
276
277 let user_principal = random_principal();
278
279 let user_args = UserDataArgs {
280 balance: Some(100),
281 ..UserDataArgs::default()
282 };
283
284 let account_args = UserAccountArgs {
285 name: Some("Account 1".to_owned()),
286 ..UserAccountArgs::default()
287 };
288
289 wallet_canister_caller();
290
291 let user_data = add_user(
292 user_principal.clone(),
293 user_args.clone(),
294 account_args.clone(),
295 )
296 .unwrap();
297
298 assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
299 assert_eq!(user_data.accounts.len(), 1);
300
301 assert_eq!(
302 user_data.accounts[0].name,
303 account_args.name.unwrap_or("Account 0".to_owned())
304 );
305
306 USER_STATE.with(|s| {
307 let state = s.borrow();
308 let stored_user_data = state.get_user(&user_principal).unwrap();
309
310 assert_eq!(
311 stored_user_data.balance,
312 user_args.balance.unwrap_or_default()
313 );
314 });
315 }
316
317 #[test]
318 fn test_get_user() {
319 let principal = wallet_canister_principal();
320
321 owner_caller();
322 initialize(principal.to_string()).unwrap();
323
324 let user_principal = random_principal();
325
326 let user_args = UserDataArgs {
327 balance: Some(100),
328 ..UserDataArgs::default()
329 };
330
331 let account_args = UserAccountArgs {
332 name: Some("Account 1".to_owned()),
333 ..UserAccountArgs::default()
334 };
335
336 wallet_canister_caller();
337
338 let _ = add_user(
339 user_principal.clone(),
340 user_args.clone(),
341 account_args.clone(),
342 )
343 .unwrap();
344
345 let user_data = get_user(&user_principal).unwrap();
346
347 assert_eq!(user_data.balance, user_args.balance.unwrap_or_default());
348 assert_eq!(user_data.accounts.len(), 1);
349
350 assert_eq!(
351 user_data.accounts[0].name,
352 account_args.name.unwrap_or("Account 0".to_owned())
353 );
354
355 USER_STATE.with(|s| {
356 let state = s.borrow();
357 let stored_user_data = state.get_user(&user_principal).unwrap();
358
359 assert_eq!(
360 stored_user_data.balance,
361 user_args.balance.unwrap_or_default()
362 );
363 });
364 }
365}