b3_users/data/
account.rs

1// account.rs
2use super::{chain::UserChainData, transaction::UserTransactionData};
3use crate::error::UserStateError;
4
5use ic_cdk::export::{candid::CandidType, serde::Deserialize};
6use std::collections::HashMap;
7
8/// Represents the user account data structure, including the chain_data.
9#[derive(Debug, CandidType, Deserialize, Clone)]
10pub struct UserAccountData {
11    pub name: String,
12    pub hidden: bool,
13    pub disabled: bool,
14    pub public_key: Vec<u8>,
15    pub chain_data: HashMap<u64, UserChainData>,
16}
17
18impl UserAccountData {
19    /// Creates a new AccountData with the given name and public_key.
20    pub fn new(public_key: Vec<u8>, name: String) -> Self {
21        Self {
22            name,
23            public_key,
24            hidden: false,
25            disabled: false,
26            chain_data: HashMap::default(),
27        }
28    }
29
30    /// Update the account data with the given args.
31    /// Returns an error if the public_key does not match.
32    pub fn update(&mut self, args: UserAccountArgs) -> Result<UserAccountArgs, UserStateError> {
33        if args.public_key != self.public_key {
34            return Err(UserStateError::PublicKeyMismatch);
35        }
36
37        if let Some(name) = args.name {
38            self.name = name;
39        }
40
41        if let Some(hidden) = args.hidden {
42            self.hidden = hidden;
43        }
44
45        if let Some(disabled) = args.disabled {
46            self.disabled = disabled;
47        }
48
49        Ok(UserAccountArgs {
50            name: Some(self.name.clone()),
51            public_key: self.public_key.clone(),
52            hidden: Some(self.hidden),
53            disabled: Some(self.disabled),
54        })
55    }
56
57    /// Gets the ChainData for a specific chain_id.
58    /// Returns an error if the chain_id is not found.
59    pub fn get_chain(&self, chain_id: u64) -> Result<&UserChainData, UserStateError> {
60        self.chain_data
61            .get(&chain_id)
62            .ok_or(UserStateError::ChainNotFound)
63    }
64
65    /// Adds a transaction to the specified chain and updates the ChainData.
66    /// Returns an error if the chain_id is not found.
67    pub fn add_transaction(
68        &mut self,
69        chain_id: u64,
70        nonce: u64,
71        transaction: UserTransactionData,
72    ) -> Result<&UserChainData, UserStateError> {
73        if let Some(chain_data) = self.chain_data.get_mut(&chain_id) {
74            chain_data.add(nonce, transaction);
75
76            Ok(chain_data)
77        } else {
78            Err(UserStateError::ChainNotFound)
79        }
80    }
81
82    /// Gets the transactions for a specific chain_id.
83    /// Returns an error if the chain_id is not found.
84    pub fn get_transactions(&self, chain_id: u64) -> Result<&UserChainData, UserStateError> {
85        self.chain_data
86            .get(&chain_id)
87            .ok_or(UserStateError::ChainNotFound)
88    }
89
90    /// Clears the transactions vector for the specified chain.
91    /// Returns an error if the chain_id is not found.
92    pub fn clear_transactions(&mut self, chain_id: u64) -> Result<&UserChainData, UserStateError> {
93        if let Some(chain_data) = self.chain_data.get_mut(&chain_id) {
94            chain_data.transactions.clear();
95
96            Ok(chain_data)
97        } else {
98            Err(UserStateError::ChainNotFound)
99        }
100    }
101
102    /// Adds a new chain to the account data.
103    /// Returns an error if the chain_id is already in use.
104    pub fn add_chain(
105        &mut self,
106        chain_id: u64,
107        chain_data: UserChainData,
108    ) -> Result<&UserChainData, UserStateError> {
109        if self.chain_data.contains_key(&chain_id) {
110            Err(UserStateError::ChainAlreadyExists)
111        } else {
112            self.chain_data.insert(chain_id, chain_data);
113
114            if let Some(chain_data) = self.chain_data.get(&chain_id) {
115                Ok(chain_data)
116            } else {
117                Err(UserStateError::ChainNotFound)
118            }
119        }
120    }
121
122    /// Removes a chain from the account data.
123    /// Returns an error if the chain_id is not found.
124    pub fn remove_chain(&mut self, chain_id: u64) -> Result<UserChainData, UserStateError> {
125        if let Some(chain_data) = self.chain_data.remove(&chain_id) {
126            Ok(chain_data)
127        } else {
128            Err(UserStateError::ChainNotFound)
129        }
130    }
131
132    /// Returns the number of chains for the account.
133    /// This is used for the UI to determine if the account is empty.
134    pub fn chain_count(&self) -> usize {
135        self.chain_data.len()
136    }
137
138    /// Returns the number of transactions for the account.
139    /// This is used for the UI to determine if the account is empty.
140    pub fn transaction_count(&self) -> usize {
141        self.chain_data
142            .values()
143            .map(|chain_data| chain_data.transactions.len())
144            .sum()
145    }
146
147    /// Returns the number of transactions for the specified chain.
148    /// This is used for the UI to determine if the chain is empty.
149    pub fn chain_transaction_count(&self, chain_id: u64) -> usize {
150        if let Some(chain_data) = self.chain_data.get(&chain_id) {
151            chain_data.transactions.len()
152        } else {
153            0
154        }
155    }
156
157    /// Set Nonce for a specific chain_id.
158    /// Returns an error if the chain_id is not found.
159    pub fn set_nonce(
160        &mut self,
161        chain_id: u64,
162        nonce: u64,
163    ) -> Result<&UserChainData, UserStateError> {
164        if let Some(chain_data) = self.chain_data.get_mut(&chain_id) {
165            chain_data.nonce = nonce;
166
167            Ok(chain_data)
168        } else {
169            Err(UserStateError::ChainNotFound)
170        }
171    }
172
173    /// Get Nonce for a specific chain_id.
174    /// Returns an error if the chain_id is not found.
175    pub fn get_nonce(&self, chain_id: u64) -> Result<u64, UserStateError> {
176        if let Some(chain_data) = self.chain_data.get(&chain_id) {
177            Ok(chain_data.nonce)
178        } else {
179            Err(UserStateError::ChainNotFound)
180        }
181    }
182
183    /// Get Transaction for a specific chain_id and index.
184    /// Returns an error if the chain_id is not found.
185    pub fn get_transaction(
186        &self,
187        chain_id: u64,
188        index: usize,
189    ) -> Result<&UserTransactionData, UserStateError> {
190        let chain_data = self.get_chain(chain_id)?;
191        chain_data.get_transaction(index)
192    }
193}
194
195#[derive(Clone, Debug, CandidType, Default, Deserialize)]
196pub struct UserAccountArgs {
197    pub public_key: Vec<u8>,
198    pub name: Option<String>,
199    pub hidden: Option<bool>,
200    pub disabled: Option<bool>,
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use proptest::prelude::*;
207
208    proptest! {
209        #[test]
210        fn test_add_and_clear_transactions(
211            public_key: Vec<u8>,
212            name: String,
213            chain_id: u64,
214            nonce: u64,
215            transactions: Vec<UserTransactionData>,
216        ) {
217            let mut account_data = UserAccountData::new(public_key, name);
218            account_data.add_chain(chain_id, UserChainData::default()).unwrap();
219
220            for (index, transaction) in transactions.iter().enumerate() {
221                account_data.add_transaction(chain_id, nonce + index as u64, transaction.clone()).unwrap();
222            }
223
224            let chain_data = account_data.get_transactions(chain_id).unwrap();
225            assert_eq!(chain_data.transactions.len(), transactions.len());
226
227            account_data.clear_transactions(chain_id).unwrap();
228
229            let chain_data = account_data.get_transactions(chain_id).unwrap();
230            assert_eq!(chain_data.transactions.len(), 0);
231        }
232
233        #[test]
234        fn test_add_chain_error(
235            public_key: Vec<u8>,
236            name: String,
237            chain_id: u64,
238            chain_data: UserChainData,
239        ) {
240            let mut account_data = UserAccountData::new(public_key, name);
241            account_data.add_chain(chain_id, chain_data.clone()).unwrap();
242
243            let result = account_data.add_chain(chain_id, chain_data);
244
245            match result {
246                Err(UserStateError::ChainAlreadyExists) => assert!(true),
247                _ => panic!("Expected ChainAlreadyExists error"),
248            }
249        }
250
251        #[test]
252        fn test_update_account(
253            name: String,
254            hidden: bool,
255            disabled: bool,
256            args: UserAccountArgs,
257        ) {
258            let mut account_data = UserAccountData::new(args.public_key.clone(), name.clone());
259            account_data.hidden = hidden;
260            account_data.disabled = disabled;
261
262            account_data.update(args.clone()).unwrap();
263
264            assert_eq!(account_data.public_key, args.public_key);
265            assert_eq!(account_data.name, args.name.unwrap_or(name));
266            assert_eq!(account_data.hidden, args.hidden.unwrap_or(hidden));
267            assert_eq!(account_data.disabled, args.disabled.unwrap_or(disabled));
268        }
269
270        #[test]
271        fn test_add_transaction_error(
272            public_key: Vec<u8>,
273            name: String,
274            chain_id: u64,
275            nonce: u64,
276            transaction: UserTransactionData,
277        ) {
278            let mut account_data = UserAccountData::new(public_key, name);
279
280            let result = account_data.add_transaction(chain_id, nonce, transaction);
281
282            match result {
283                Err(UserStateError::ChainNotFound) => assert!(true),
284                _ => panic!("Expected ChainNotFound error"),
285            }
286        }
287
288        #[test]
289        fn test_clear_transactions_error(
290            public_key: Vec<u8>,
291            name: String,
292            chain_id: u64,
293        ) {
294            let mut account_data = UserAccountData::new(public_key, name);
295
296            let result = account_data.clear_transactions(chain_id);
297
298            match result {
299                Err(UserStateError::ChainNotFound) => assert!(true),
300                _ => panic!("Expected ChainNotFound error"),
301            }
302        }
303
304        #[test]
305        fn test_get_transaction(
306            public_key: Vec<u8>,
307            name: String,
308            chain_id: u64,
309            nonce: u64,
310            transactions: Vec<UserTransactionData>,
311        ) {
312            let mut account_data = UserAccountData::new(public_key, name);
313            account_data.add_chain(chain_id, UserChainData::default()).unwrap();
314
315            for (index, transaction) in transactions.iter().enumerate() {
316                account_data.add_transaction(chain_id, nonce + index as u64, transaction.clone()).unwrap();
317            }
318
319            for (index, transaction) in transactions.iter().enumerate() {
320                let result = account_data.get_transaction(chain_id, index).unwrap();
321                assert_eq!(result, transaction);
322            }
323        }
324
325        #[test]
326        fn test_get_transaction_error(
327            public_key: Vec<u8>,
328            name: String,
329            chain_id: u64,
330            index: usize,
331        ) {
332            let account_data = UserAccountData::new(public_key, name);
333
334            let result = account_data.get_transaction(chain_id, index);
335
336            match result {
337                Err(UserStateError::ChainNotFound) => assert!(true),
338                _ => panic!("Expected ChainNotFound error"),
339            }
340        }
341
342        #[test]
343        fn test_get_transactions(
344            public_key: Vec<u8>,
345            name: String,
346            chain_id: u64,
347            nonce: u64,
348            transactions: Vec<UserTransactionData>,
349        ) {
350            let mut account_data = UserAccountData::new(public_key, name);
351            account_data.add_chain(chain_id, UserChainData::default()).unwrap();
352
353            for (index, transaction) in transactions.iter().enumerate() {
354                account_data.add_transaction(chain_id, nonce + index as u64, transaction.clone()).unwrap();
355            }
356
357            let chain_data = account_data.get_transactions(chain_id).unwrap();
358            assert_eq!(chain_data.transactions.len(), transactions.len());
359        }
360
361        #[test]
362        fn test_get_transactions_error(
363            public_key: Vec<u8>,
364            name: String,
365            chain_id: u64,
366        ) {
367            let account_data = UserAccountData::new(public_key, name);
368
369            let result = account_data.get_transactions(chain_id);
370
371            match result {
372                Err(UserStateError::ChainNotFound) => assert!(true),
373                _ => panic!("Expected ChainNotFound error"),
374            }
375        }
376
377        #[test]
378        fn test_get_chain(
379            public_key: Vec<u8>,
380            name: String,
381            chain_id: u64,
382            chain_data: UserChainData,
383        ) {
384            let mut account_data = UserAccountData::new(public_key, name);
385            account_data.add_chain(chain_id, chain_data.clone()).unwrap();
386
387            let result = account_data.get_chain(chain_id).unwrap();
388
389            assert_eq!(result.clone(), chain_data);
390        }
391
392        #[test]
393        fn test_get_chain_error(
394            public_key: Vec<u8>,
395            name: String,
396            chain_id: u64,
397        ) {
398            let account_data = UserAccountData::new(public_key, name);
399
400            let result = account_data.get_chain(chain_id);
401
402            match result {
403                Err(UserStateError::ChainNotFound) => assert!(true),
404                _ => panic!("Expected ChainNotFound error"),
405            }
406        }
407
408        #[test]
409        fn test_get_nonce(
410            public_key: Vec<u8>,
411            name: String,
412            chain_id: u64,
413            nonce: u64,
414        ) {
415            let mut account_data = UserAccountData::new(public_key, name);
416            account_data.add_chain(chain_id, UserChainData::default()).unwrap();
417            account_data.set_nonce(chain_id, nonce).unwrap();
418
419            let result = account_data.get_nonce(chain_id).unwrap();
420            assert_eq!(result, nonce);
421        }
422
423        #[test]
424        fn test_get_nonce_error(
425            public_key: Vec<u8>,
426            name: String,
427            chain_id: u64,
428        ) {
429            let account_data = UserAccountData::new(public_key, name);
430
431            let result = account_data.get_nonce(chain_id);
432
433            match result {
434                Err(UserStateError::ChainNotFound) => assert!(true),
435                _ => panic!("Expected ChainNotFound error"),
436            }
437        }
438
439        #[test]
440        fn test_set_nonce(
441            public_key: Vec<u8>,
442            name: String,
443            chain_id: u64,
444            nonce: u64,
445        ) {
446            let mut account_data = UserAccountData::new(public_key, name);
447            account_data.add_chain(chain_id, UserChainData::default()).unwrap();
448
449            let result = account_data.set_nonce(chain_id, nonce);
450            assert!(result.is_ok());
451        }
452
453        #[test]
454        fn test_set_nonce_error(
455            public_key: Vec<u8>,
456            name: String,
457            chain_id: u64,
458            nonce: u64,
459        ) {
460            let mut account_data = UserAccountData::new(public_key, name);
461
462            let result = account_data.set_nonce(chain_id, nonce);
463
464            match result {
465                Err(UserStateError::ChainNotFound) => assert!(true),
466                _ => panic!("Expected ChainNotFound error"),
467            }
468        }
469
470    }
471}