use core::ops::Range;
use alloc::rc::Rc;
use casper_types::{
bytesrepr::{FromBytes, ToBytes},
CLTyped
};
use crate::{
mapping::Mapping,
module::{ModuleComponent, ModulePrimitive},
var::Var,
CollectionError, ContractEnv, UnwrapOrRevert
};
pub struct List<T> {
env: Rc<ContractEnv>,
index: u8,
values: Mapping<u32, T>,
current_index: Var<u32>
}
impl<T> List<T> {
pub fn env(&self) -> ContractEnv {
self.env.child(self.index)
}
}
impl<T> ModuleComponent for List<T> {
fn instance(env: Rc<ContractEnv>, index: u8) -> Self {
Self {
env: env.clone(),
index,
values: Mapping::instance(env.child(index).into(), 0),
current_index: Var::instance(env.child(index).into(), 1)
}
}
}
impl<T> ModulePrimitive for List<T> {}
impl<T> List<T> {
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u32 {
self.current_index.get_or_default()
}
}
impl<T: FromBytes + CLTyped> List<T> {
pub fn get(&self, index: u32) -> Option<T> {
self.values.get(&index)
}
}
impl<T: ToBytes + FromBytes + CLTyped> List<T> {
pub fn push(&mut self, value: T) {
let next_index = self.len();
self.values.set(&next_index, value);
self.current_index.set(next_index + 1);
}
pub fn replace(&mut self, index: u32, value: T) -> T {
let env = self.env();
if index >= self.len() {
env.revert(CollectionError::IndexOutOfBounds);
}
let prev_value = self.values.get(&index).unwrap_or_revert(&env);
self.values.set(&index, value);
prev_value
}
pub fn pop(&mut self) -> Option<T> {
let next_index = self.len();
if next_index == 0 {
return None;
}
let env = self.env();
let last = next_index - 1;
let value = self.values.get(&last).unwrap_or_revert(&env);
self.current_index.set(last);
Some(value)
}
pub fn iter(&self) -> ListIter<T> {
ListIter::new(self)
}
}
pub struct ListIter<'a, T> {
list: &'a List<T>,
range: Range<u32>
}
impl<'a, T> ListIter<'a, T> {
fn new(list: &'a List<T>) -> Self {
Self {
list,
range: Range {
start: 0,
end: list.len()
}
}
}
fn remaining(&self) -> usize {
(self.range.end - self.range.start) as usize
}
}
impl<'a, T> core::iter::Iterator for ListIter<'a, T>
where
T: ToBytes + FromBytes + CLTyped
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
<Self as Iterator>::nth(self, 0)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.remaining();
(remaining, Some(remaining))
}
fn count(self) -> usize {
self.remaining()
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
let index = self.range.nth(n)?;
self.list.get(index)
}
}
impl<'a, T> core::iter::ExactSizeIterator for ListIter<'a, T> where T: ToBytes + FromBytes + CLTyped {}
impl<'a, T> core::iter::FusedIterator for ListIter<'a, T> where T: ToBytes + FromBytes + CLTyped {}
impl<'a, T> core::iter::DoubleEndedIterator for ListIter<'a, T>
where
T: ToBytes + FromBytes + CLTyped
{
fn next_back(&mut self) -> Option<Self::Item> {
let index = self.range.nth_back(0)?;
self.list.get(index)
}
}
impl<T: ToBytes + FromBytes + CLTyped + Default> List<T> {
pub fn get_or_default(&self, index: u32) -> T {
self.get(index).unwrap_or_default()
}
}
#[cfg(all(feature = "mock-vm", test))]
mod tests {
use super::List;
use crate::{instance::StaticInstance, test_env};
use odra_types::{
casper_types::bytesrepr::{FromBytes, ToBytes},
CollectionError
};
#[test]
fn test_getting_items() {
let mut list = List::<u8>::default();
assert_eq!(list.len(), 0);
list.push(0u8);
assert_eq!(list.get(0).unwrap(), 0);
list.push(1u8);
list.push(3u8);
assert_eq!(list.get(1).unwrap(), 1);
assert_eq!(list.get(2).unwrap(), 3);
let result = list.get(100);
assert_eq!(result, None);
}
#[test]
fn test_replace() {
let mut list = List::<u8>::default();
for i in 0..5 {
list.push(i);
}
let result = list.replace(4, 10);
assert_eq!(result, 4);
assert_eq!(list.get(4).unwrap(), 10);
test_env::assert_exception(CollectionError::IndexOutOfBounds, || {
list.replace(100, 99);
});
}
#[test]
fn test_list_len() {
let mut list = List::<u8>::default();
assert_eq!(list.len(), 0);
list.push(0u8);
list.push(1u8);
list.push(3u8);
assert_eq!(list.len(), 3);
}
#[test]
fn test_list_is_empty() {
let mut list = List::<u8>::default();
assert!(list.is_empty());
list.push(9u8);
assert!(!list.is_empty());
}
#[test]
fn test_pop() {
let mut list = List::<u8>::default();
list.push(1u8);
list.push(2u8);
let result = list.pop();
assert_eq!(result, Some(2));
assert_eq!(list.len(), 1);
let result = list.pop();
assert_eq!(result, Some(1));
assert_eq!(list.len(), 0);
let result = list.pop();
assert_eq!(result, None);
}
#[test]
fn test_iter() {
let mut list = List::<u8>::default();
for i in 0..5 {
list.push(i);
}
let mut iter = list.iter();
assert_eq!(iter.next(), Some(0));
assert_eq!(iter.next(), Some(1));
assert_eq!(iter.next(), Some(2));
assert_eq!(iter.next(), Some(3));
assert_eq!(iter.next(), Some(4));
assert_eq!(iter.next(), None);
}
#[test]
fn test_fuse_iter() {
let mut list = List::<u8>::default();
for i in 0..3 {
list.push(i);
}
let iter = list.iter();
let mut iter = iter.fuse();
iter.next();
iter.next();
iter.next();
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);
}
#[test]
fn test_double_ended_iter() {
let mut list = List::<u8>::default();
for i in 0..10 {
list.push(i);
}
let mut iter = list.iter();
assert_eq!(iter.next(), Some(0));
assert_eq!(iter.next(), Some(1));
assert_eq!(iter.next_back(), Some(9));
assert_eq!(iter.next_back(), Some(8));
assert_eq!(iter.next(), Some(2));
assert_eq!(iter.count(), 5);
}
impl<T: ToBytes + FromBytes> Default for List<T> {
fn default() -> Self {
StaticInstance::instance(&["list_val", "list_idx"]).0
}
}
}