1use crate::level_iter::LevelIter;
2use crate::update_map::UpdateMap;
3use crate::utils::{Length, updated_length};
4use crate::{
5 Cow, Error, Value,
6 interface_iter::{InterfaceIter, InterfaceIterCow},
7 iter::Iter,
8};
9use std::collections::BTreeMap;
10use std::marker::PhantomData;
11use tree_hash::Hash256;
12
13pub trait ImmList<T: Value> {
14 fn get(&self, idx: usize) -> Option<&T>;
15
16 fn len(&self) -> Length;
17
18 fn is_empty(&self) -> bool {
19 self.len().as_usize() == 0
20 }
21
22 fn iter_from(&self, index: usize) -> Iter<'_, T>;
23
24 fn level_iter_from(&self, index: usize) -> LevelIter<'_, T>;
25}
26
27pub trait MutList<T: Value>: ImmList<T> {
28 fn validate_push(current_len: usize) -> Result<(), Error>;
29 fn replace(&mut self, index: usize, value: T) -> Result<(), Error>;
30 fn update<U: UpdateMap<T>>(
31 &mut self,
32 updates: U,
33 hash_updates: Option<BTreeMap<(usize, usize), Hash256>>,
34 ) -> Result<(), Error>;
35}
36
37#[derive(Debug, PartialEq, Clone)]
38#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
39pub struct Interface<T, B, U>
40where
41 T: Value,
42 B: MutList<T>,
43 U: UpdateMap<T>,
44{
45 pub(crate) backing: B,
46 pub(crate) updates: U,
47 pub(crate) _phantom: PhantomData<T>,
48}
49
50impl<T, B, U> Interface<T, B, U>
51where
52 T: Value,
53 B: MutList<T>,
54 U: UpdateMap<T>,
55{
56 pub fn new(backing: B) -> Self {
57 Self {
58 backing,
59 updates: U::default(),
60 _phantom: PhantomData,
61 }
62 }
63
64 pub fn get(&self, idx: usize) -> Option<&T> {
65 self.updates.get(idx).or_else(|| self.backing.get(idx))
66 }
67
68 pub fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
69 self.updates
70 .get_mut_with(idx, |idx| self.backing.get(idx).cloned())
71 }
72
73 pub fn get_cow(&mut self, index: usize) -> Option<Cow<'_, T>> {
74 self.updates
75 .get_cow_with(index, |idx| self.backing.get(idx))
76 }
77
78 pub fn push(&mut self, value: T) -> Result<(), Error> {
79 let index = self.len();
80 B::validate_push(index)?;
81 self.updates.insert(index, value);
82
83 Ok(())
84 }
85
86 pub fn apply_updates(&mut self) -> Result<(), Error> {
87 if !self.updates.is_empty() {
88 let updates = std::mem::take(&mut self.updates);
89 self.backing.update(updates, None)
90 } else {
91 Ok(())
92 }
93 }
94
95 pub fn has_pending_updates(&self) -> bool {
96 !self.updates.is_empty()
97 }
98
99 pub fn iter(&self) -> InterfaceIter<'_, T, U> {
100 self.iter_from(0)
101 }
102
103 pub fn iter_from(&self, index: usize) -> InterfaceIter<'_, T, U> {
104 InterfaceIter {
105 tree_iter: self.backing.iter_from(index),
106 updates: &self.updates,
107 index,
108 length: self.len(),
109 }
110 }
111
112 pub fn iter_cow(&mut self) -> InterfaceIterCow<'_, T, U> {
113 let index = 0;
114 InterfaceIterCow {
115 tree_iter: self.backing.iter_from(index),
116 updates: &mut self.updates,
117 index,
118 }
119 }
120
121 pub fn level_iter_from(&self, index: usize) -> Result<LevelIter<'_, T>, Error> {
122 if self.has_pending_updates() {
123 Err(Error::LevelIterPendingUpdates)
124 } else {
125 Ok(self.backing.level_iter_from(index))
126 }
127 }
128
129 pub fn len(&self) -> usize {
130 updated_length(self.backing.len(), &self.updates).as_usize()
131 }
132
133 pub fn is_empty(&self) -> bool {
134 self.len() == 0
135 }
136
137 pub fn bulk_update(&mut self, updates: U) -> Result<(), Error> {
138 if !self.updates.is_empty() {
139 return Err(Error::BulkUpdateUnclean);
140 }
141 self.updates = updates;
142 Ok(())
143 }
144}
145
146#[cfg(test)]
147mod test {
148 use crate::List;
149 use typenum::U8;
150
151 #[test]
152 fn basic_mutation() {
153 let mut list = List::<u64, U8>::new(vec![1, 2, 3, 4]).unwrap();
154
155 let x = list.get_mut(0).unwrap();
156 assert_eq!(*x, 1);
157 *x = 11;
158
159 let y = list.get_mut(0).unwrap();
160 assert_eq!(*y, 11);
161
162 assert!(list.has_pending_updates());
164 list.apply_updates().unwrap();
165 assert!(!list.has_pending_updates());
166
167 list.apply_updates().unwrap();
169
170 assert_eq!(*list.get(0).unwrap(), 11);
171 }
172
173 #[test]
174 fn cow_mutate_twice() {
175 let mut list = List::<u64, U8>::new(vec![1, 2, 3]).unwrap();
176
177 let c1 = list.get_cow(0).unwrap();
178 assert_eq!(*c1, 1);
179 *c1.into_mut().unwrap() = 10;
180
181 assert_eq!(*list.get(0).unwrap(), 10);
182
183 let c2 = list.get_cow(0).unwrap();
184 assert_eq!(*c2, 10);
185 *c2.into_mut().unwrap() = 11;
186 assert_eq!(*list.get(0).unwrap(), 11);
187
188 assert_eq!(list.iter().cloned().collect::<Vec<_>>(), vec![11, 2, 3]);
189 }
190
191 #[test]
192 fn cow_iter() {
193 let mut list = List::<u64, U8>::new(vec![1, 2, 3]).unwrap();
194
195 let mut iter = list.iter_cow();
196 while let Some((index, v)) = iter.next_cow() {
197 *v.into_mut().unwrap() = index as u64;
198 }
199
200 assert_eq!(list.to_vec(), vec![0, 1, 2]);
201 }
202}