use super::{chain::UserChainData, transaction::UserTransactionData};
use crate::error::UserStateError;
use ic_cdk::export::{candid::CandidType, serde::Deserialize};
use std::collections::HashMap;
#[derive(Debug, CandidType, Deserialize, Clone)]
pub struct UserAccountData {
pub name: String,
pub hidden: bool,
pub disabled: bool,
pub public_key: Vec<u8>,
pub chain_data: HashMap<u64, UserChainData>,
}
impl UserAccountData {
pub fn new(public_key: Vec<u8>, name: String) -> Self {
Self {
name,
public_key,
hidden: false,
disabled: false,
chain_data: HashMap::default(),
}
}
pub fn update(&mut self, args: UserAccountArgs) -> Result<UserAccountArgs, UserStateError> {
if args.public_key != self.public_key {
return Err(UserStateError::PublicKeyMismatch);
}
if let Some(name) = args.name {
self.name = name;
}
if let Some(hidden) = args.hidden {
self.hidden = hidden;
}
if let Some(disabled) = args.disabled {
self.disabled = disabled;
}
Ok(UserAccountArgs {
name: Some(self.name.clone()),
public_key: self.public_key.clone(),
hidden: Some(self.hidden),
disabled: Some(self.disabled),
})
}
pub fn get_chain(&self, chain_id: u64) -> Result<&UserChainData, UserStateError> {
self.chain_data
.get(&chain_id)
.ok_or(UserStateError::ChainNotFound)
}
pub fn add_transaction(
&mut self,
chain_id: u64,
nonce: u64,
transaction: UserTransactionData,
) -> Result<&UserChainData, UserStateError> {
if let Some(chain_data) = self.chain_data.get_mut(&chain_id) {
chain_data.add(nonce, transaction);
Ok(chain_data)
} else {
Err(UserStateError::ChainNotFound)
}
}
pub fn get_transactions(&self, chain_id: u64) -> Result<&UserChainData, UserStateError> {
self.chain_data
.get(&chain_id)
.ok_or(UserStateError::ChainNotFound)
}
pub fn clear_transactions(&mut self, chain_id: u64) -> Result<&UserChainData, UserStateError> {
if let Some(chain_data) = self.chain_data.get_mut(&chain_id) {
chain_data.transactions.clear();
Ok(chain_data)
} else {
Err(UserStateError::ChainNotFound)
}
}
pub fn add_chain(
&mut self,
chain_id: u64,
chain_data: UserChainData,
) -> Result<&UserChainData, UserStateError> {
if self.chain_data.contains_key(&chain_id) {
Err(UserStateError::ChainAlreadyExists)
} else {
self.chain_data.insert(chain_id, chain_data);
if let Some(chain_data) = self.chain_data.get(&chain_id) {
Ok(chain_data)
} else {
Err(UserStateError::ChainNotFound)
}
}
}
pub fn remove_chain(&mut self, chain_id: u64) -> Result<UserChainData, UserStateError> {
if let Some(chain_data) = self.chain_data.remove(&chain_id) {
Ok(chain_data)
} else {
Err(UserStateError::ChainNotFound)
}
}
pub fn chain_count(&self) -> usize {
self.chain_data.len()
}
pub fn transaction_count(&self) -> usize {
self.chain_data
.values()
.map(|chain_data| chain_data.transactions.len())
.sum()
}
pub fn chain_transaction_count(&self, chain_id: u64) -> usize {
if let Some(chain_data) = self.chain_data.get(&chain_id) {
chain_data.transactions.len()
} else {
0
}
}
pub fn set_nonce(
&mut self,
chain_id: u64,
nonce: u64,
) -> Result<&UserChainData, UserStateError> {
if let Some(chain_data) = self.chain_data.get_mut(&chain_id) {
chain_data.nonce = nonce;
Ok(chain_data)
} else {
Err(UserStateError::ChainNotFound)
}
}
pub fn get_nonce(&self, chain_id: u64) -> Result<u64, UserStateError> {
if let Some(chain_data) = self.chain_data.get(&chain_id) {
Ok(chain_data.nonce)
} else {
Err(UserStateError::ChainNotFound)
}
}
pub fn get_transaction(
&self,
chain_id: u64,
index: usize,
) -> Result<&UserTransactionData, UserStateError> {
let chain_data = self.get_chain(chain_id)?;
chain_data.get_transaction(index)
}
}
#[derive(Clone, Debug, CandidType, Default, Deserialize)]
pub struct UserAccountArgs {
pub public_key: Vec<u8>,
pub name: Option<String>,
pub hidden: Option<bool>,
pub disabled: Option<bool>,
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn test_add_and_clear_transactions(
public_key: Vec<u8>,
name: String,
chain_id: u64,
nonce: u64,
transactions: Vec<UserTransactionData>,
) {
let mut account_data = UserAccountData::new(public_key, name);
account_data.add_chain(chain_id, UserChainData::default()).unwrap();
for (index, transaction) in transactions.iter().enumerate() {
account_data.add_transaction(chain_id, nonce + index as u64, transaction.clone()).unwrap();
}
let chain_data = account_data.get_transactions(chain_id).unwrap();
assert_eq!(chain_data.transactions.len(), transactions.len());
account_data.clear_transactions(chain_id).unwrap();
let chain_data = account_data.get_transactions(chain_id).unwrap();
assert_eq!(chain_data.transactions.len(), 0);
}
#[test]
fn test_add_chain_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
chain_data: UserChainData,
) {
let mut account_data = UserAccountData::new(public_key, name);
account_data.add_chain(chain_id, chain_data.clone()).unwrap();
let result = account_data.add_chain(chain_id, chain_data);
match result {
Err(UserStateError::ChainAlreadyExists) => assert!(true),
_ => panic!("Expected ChainAlreadyExists error"),
}
}
#[test]
fn test_update_account(
name: String,
hidden: bool,
disabled: bool,
args: UserAccountArgs,
) {
let mut account_data = UserAccountData::new(args.public_key.clone(), name.clone());
account_data.hidden = hidden;
account_data.disabled = disabled;
account_data.update(args.clone()).unwrap();
assert_eq!(account_data.public_key, args.public_key);
assert_eq!(account_data.name, args.name.unwrap_or(name));
assert_eq!(account_data.hidden, args.hidden.unwrap_or(hidden));
assert_eq!(account_data.disabled, args.disabled.unwrap_or(disabled));
}
#[test]
fn test_add_transaction_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
nonce: u64,
transaction: UserTransactionData,
) {
let mut account_data = UserAccountData::new(public_key, name);
let result = account_data.add_transaction(chain_id, nonce, transaction);
match result {
Err(UserStateError::ChainNotFound) => assert!(true),
_ => panic!("Expected ChainNotFound error"),
}
}
#[test]
fn test_clear_transactions_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
) {
let mut account_data = UserAccountData::new(public_key, name);
let result = account_data.clear_transactions(chain_id);
match result {
Err(UserStateError::ChainNotFound) => assert!(true),
_ => panic!("Expected ChainNotFound error"),
}
}
#[test]
fn test_get_transaction(
public_key: Vec<u8>,
name: String,
chain_id: u64,
nonce: u64,
transactions: Vec<UserTransactionData>,
) {
let mut account_data = UserAccountData::new(public_key, name);
account_data.add_chain(chain_id, UserChainData::default()).unwrap();
for (index, transaction) in transactions.iter().enumerate() {
account_data.add_transaction(chain_id, nonce + index as u64, transaction.clone()).unwrap();
}
for (index, transaction) in transactions.iter().enumerate() {
let result = account_data.get_transaction(chain_id, index).unwrap();
assert_eq!(result, transaction);
}
}
#[test]
fn test_get_transaction_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
index: usize,
) {
let account_data = UserAccountData::new(public_key, name);
let result = account_data.get_transaction(chain_id, index);
match result {
Err(UserStateError::ChainNotFound) => assert!(true),
_ => panic!("Expected ChainNotFound error"),
}
}
#[test]
fn test_get_transactions(
public_key: Vec<u8>,
name: String,
chain_id: u64,
nonce: u64,
transactions: Vec<UserTransactionData>,
) {
let mut account_data = UserAccountData::new(public_key, name);
account_data.add_chain(chain_id, UserChainData::default()).unwrap();
for (index, transaction) in transactions.iter().enumerate() {
account_data.add_transaction(chain_id, nonce + index as u64, transaction.clone()).unwrap();
}
let chain_data = account_data.get_transactions(chain_id).unwrap();
assert_eq!(chain_data.transactions.len(), transactions.len());
}
#[test]
fn test_get_transactions_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
) {
let account_data = UserAccountData::new(public_key, name);
let result = account_data.get_transactions(chain_id);
match result {
Err(UserStateError::ChainNotFound) => assert!(true),
_ => panic!("Expected ChainNotFound error"),
}
}
#[test]
fn test_get_chain(
public_key: Vec<u8>,
name: String,
chain_id: u64,
chain_data: UserChainData,
) {
let mut account_data = UserAccountData::new(public_key, name);
account_data.add_chain(chain_id, chain_data.clone()).unwrap();
let result = account_data.get_chain(chain_id).unwrap();
assert_eq!(result.clone(), chain_data);
}
#[test]
fn test_get_chain_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
) {
let account_data = UserAccountData::new(public_key, name);
let result = account_data.get_chain(chain_id);
match result {
Err(UserStateError::ChainNotFound) => assert!(true),
_ => panic!("Expected ChainNotFound error"),
}
}
#[test]
fn test_get_nonce(
public_key: Vec<u8>,
name: String,
chain_id: u64,
nonce: u64,
) {
let mut account_data = UserAccountData::new(public_key, name);
account_data.add_chain(chain_id, UserChainData::default()).unwrap();
account_data.set_nonce(chain_id, nonce).unwrap();
let result = account_data.get_nonce(chain_id).unwrap();
assert_eq!(result, nonce);
}
#[test]
fn test_get_nonce_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
) {
let account_data = UserAccountData::new(public_key, name);
let result = account_data.get_nonce(chain_id);
match result {
Err(UserStateError::ChainNotFound) => assert!(true),
_ => panic!("Expected ChainNotFound error"),
}
}
#[test]
fn test_set_nonce(
public_key: Vec<u8>,
name: String,
chain_id: u64,
nonce: u64,
) {
let mut account_data = UserAccountData::new(public_key, name);
account_data.add_chain(chain_id, UserChainData::default()).unwrap();
let result = account_data.set_nonce(chain_id, nonce);
assert!(result.is_ok());
}
#[test]
fn test_set_nonce_error(
public_key: Vec<u8>,
name: String,
chain_id: u64,
nonce: u64,
) {
let mut account_data = UserAccountData::new(public_key, name);
let result = account_data.set_nonce(chain_id, nonce);
match result {
Err(UserStateError::ChainNotFound) => assert!(true),
_ => panic!("Expected ChainNotFound error"),
}
}
}
}