cita_state/
state.rs

1use std::cell::RefCell;
2use std::sync::Arc;
3
4use cita_trie::DB;
5use cita_trie::{PatriciaTrie, Trie};
6use ethereum_types::{Address, H256, U256};
7use hashbrown::hash_map::Entry;
8use hashbrown::{HashMap, HashSet};
9
10use super::account::StateObject;
11use super::account_db::AccountDB;
12use super::err::Error;
13use super::hashlib;
14use super::object_entry::{ObjectStatus, StateObjectEntry};
15use log::debug;
16use rayon::prelude::{IntoParallelRefMutIterator, ParallelIterator};
17/// State is the one who managers all accounts and states in Ethereum's system.
18pub struct State<B> {
19    pub db: Arc<B>,
20    pub root: H256,
21    pub cache: RefCell<HashMap<Address, StateObjectEntry>>,
22    /// Checkpoints are used to revert to history.
23    pub checkpoints: RefCell<Vec<HashMap<Address, Option<StateObjectEntry>>>>,
24}
25
26impl<B: DB> State<B> {
27    /// Creates empty state for test.
28    pub fn new(db: Arc<B>) -> Result<State<B>, Error> {
29        let mut trie = PatriciaTrie::<_, cita_trie::Keccak256Hash>::new(Arc::clone(&db));
30        let root = trie.root()?;
31
32        Ok(State {
33            db,
34            root: From::from(&root[..]),
35            cache: RefCell::new(HashMap::new()),
36            checkpoints: RefCell::new(Vec::new()),
37        })
38    }
39
40    /// Creates new state with existing state root
41    pub fn from_existing(db: Arc<B>, root: H256) -> Result<State<B>, Error> {
42        if !db.contains(&root.0[..]).or_else(|e| Err(Error::DB(format!("{}", e))))? {
43            return Err(Error::NotFound);
44        }
45        Ok(State {
46            db,
47            root,
48            cache: RefCell::new(HashMap::new()),
49            checkpoints: RefCell::new(Vec::new()),
50        })
51    }
52
53    /// Create a contract account with code or not
54    /// Overwrite the code if the contract already exists
55    pub fn new_contract(&mut self, contract: &Address, balance: U256, nonce: U256, code: Vec<u8>) -> StateObject {
56        debug!(
57            "state.new_contract contract={:?} balance={:?}, nonce={:?} code={:?}",
58            contract, balance, nonce, code
59        );
60        let mut state_object = StateObject::new(balance, nonce);
61        state_object.init_code(code);
62
63        self.insert_cache(contract, StateObjectEntry::new_dirty(Some(state_object.clone_dirty())));
64        state_object
65    }
66
67    /// Kill a contract.
68    pub fn kill_contract(&mut self, contract: &Address) {
69        debug!("state.kill_contract contract={:?}", contract);
70        self.insert_cache(contract, StateObjectEntry::new_dirty(None));
71    }
72
73    /// Remove any touched empty or dust accounts.
74    pub fn kill_garbage(&mut self, inused: &HashSet<Address>) {
75        for a in inused {
76            if let Some(state_object_entry) = self.cache.borrow().get(a) {
77                if state_object_entry.state_object.is_none() {
78                    continue;
79                }
80            }
81            if self.is_empty(a).unwrap_or(false) {
82                self.kill_contract(a)
83            }
84        }
85    }
86
87    /// Clear cache
88    /// Note that the cache is just a HashMap, so memory explosion will be
89    /// happend if you never call `clear()`. You should decide for yourself
90    /// when to call this function.
91    pub fn clear(&mut self) {
92        assert!(self.checkpoints.borrow().is_empty());
93        self.cache.borrow_mut().clear();
94    }
95
96    /// Use a callback function to avoid clone data in caches.
97    fn call_with_cached<F, U>(&self, address: &Address, f: F) -> Result<U, Error>
98    where
99        F: Fn(Option<&StateObject>) -> U,
100    {
101        if let Some(state_object_entry) = self.cache.borrow().get(address) {
102            if let Some(state_object) = &state_object_entry.state_object {
103                return Ok(f(Some(state_object)));
104            } else {
105                return Ok(f(None));
106            }
107        }
108        let trie = PatriciaTrie::<_, cita_trie::Keccak256Hash>::from(Arc::clone(&self.db), &self.root.0)?;
109        match trie.get(hashlib::summary(&address[..]).as_slice())? {
110            Some(rlp) => {
111                let mut state_object = StateObject::from_rlp(&rlp)?;
112                state_object.read_code(self.db.clone())?;
113                self.insert_cache(address, StateObjectEntry::new_clean(Some(state_object.clone_clean())));
114                Ok(f(Some(&state_object)))
115            }
116            None => Ok(f(None)),
117        }
118    }
119
120    /// Get state object.
121    pub fn get_state_object(&self, address: &Address) -> Result<Option<StateObject>, Error> {
122        if let Some(state_object_entry) = self.cache.borrow().get(address) {
123            if let Some(state_object) = &state_object_entry.state_object {
124                return Ok(Some((*state_object).clone_dirty()));
125            }
126        }
127        let trie = PatriciaTrie::<_, cita_trie::Keccak256Hash>::from(Arc::clone(&self.db), &self.root.0)?;
128        match trie.get(hashlib::summary(&address[..]).as_slice())? {
129            Some(rlp) => {
130                let mut state_object = StateObject::from_rlp(&rlp)?;
131                state_object.read_code(self.db.clone())?;
132                self.insert_cache(address, StateObjectEntry::new_clean(Some(state_object.clone_clean())));
133                Ok(Some(state_object))
134            }
135            None => Ok(None),
136        }
137    }
138
139    /// Get state object. If not exists, create a fresh one.
140    pub fn get_state_object_or_default(&mut self, address: &Address) -> Result<StateObject, Error> {
141        match self.get_state_object(address)? {
142            Some(state_object) => Ok(state_object),
143            None => {
144                let state_object = self.new_contract(address, U256::zero(), U256::zero(), vec![]);
145                Ok(state_object)
146            }
147        }
148    }
149
150    /// Get the merkle proof for a given account.
151    pub fn get_account_proof(&self, address: &Address) -> Result<Vec<Vec<u8>>, Error> {
152        let trie = PatriciaTrie::<_, cita_trie::Keccak256Hash>::from(Arc::clone(&self.db), &self.root.0)?;
153        let proof = trie.get_proof(hashlib::summary(&address[..]).as_slice())?;
154        Ok(proof)
155    }
156
157    /// Get the storage proof for given account and key.
158    pub fn get_storage_proof(&self, address: &Address, key: &H256) -> Result<Vec<Vec<u8>>, Error> {
159        self.call_with_cached(address, |a| match a {
160            Some(data) => {
161                let accdb = Arc::new(AccountDB::new(*address, self.db.clone()));
162                data.get_storage_proof(accdb, key)
163            }
164            None => Ok(vec![]),
165        })?
166    }
167
168    /// Check if an account exists.
169    pub fn exist(&mut self, address: &Address) -> Result<bool, Error> {
170        self.call_with_cached(address, |a| Ok(a.is_some()))?
171    }
172
173    /// Check if an account is empty. Empty is defined according to
174    /// EIP161 (balance = nonce = code = 0).
175    #[allow(clippy::wrong_self_convention)]
176    pub fn is_empty(&mut self, address: &Address) -> Result<bool, Error> {
177        self.call_with_cached(address, |a| match a {
178            Some(data) => Ok(data.is_empty()),
179            None => Ok(true),
180        })?
181    }
182
183    /// Set (key, value) in storage cache.
184    pub fn set_storage(&mut self, address: &Address, key: H256, value: H256) -> Result<(), Error> {
185        debug!(
186            "state.set_storage address={:?} key={:?} value={:?}",
187            address, key, value
188        );
189        let state_object = self.get_state_object_or_default(address)?;
190        let accdb = Arc::new(AccountDB::new(*address, self.db.clone()));
191        if state_object.get_storage(accdb, &key)? == Some(value) {
192            return Ok(());
193        }
194
195        self.add_checkpoint(address);
196        if let Some(ref mut state_object_entry) = self.cache.borrow_mut().get_mut(address) {
197            match state_object_entry.state_object {
198                Some(ref mut state_object) => {
199                    state_object.set_storage(key, value);
200                    state_object_entry.status = ObjectStatus::Dirty;
201                }
202                None => panic!("state object always exist in cache."),
203            }
204        }
205        Ok(())
206    }
207
208    /// Set code for an account.
209    pub fn set_code(&mut self, address: &Address, code: Vec<u8>) -> Result<(), Error> {
210        debug!("state.set_code address={:?} code={:?}", address, code);
211        let mut state_object = self.get_state_object_or_default(address)?;
212        state_object.init_code(code);
213        self.insert_cache(address, StateObjectEntry::new_dirty(Some(state_object)));
214        Ok(())
215    }
216
217    /// Add balance by incr for an account.
218    pub fn add_balance(&mut self, address: &Address, incr: U256) -> Result<(), Error> {
219        debug!("state.add_balance a={:?} incr={:?}", address, incr);
220        if incr.is_zero() {
221            return Ok(());
222        }
223        let mut state_object = self.get_state_object_or_default(address)?;
224        if state_object.balance.overflowing_add(incr).1 {
225            return Err(Error::BalanceError);
226        }
227        state_object.add_balance(incr);
228        self.insert_cache(address, StateObjectEntry::new_dirty(Some(state_object)));
229        Ok(())
230    }
231
232    /// Sub balance by decr for an account.
233    pub fn sub_balance(&mut self, a: &Address, decr: U256) -> Result<(), Error> {
234        debug!("state.sub_balance a={:?} decr={:?}", a, decr);
235        if decr.is_zero() {
236            return Ok(());
237        }
238        let mut state_object = self.get_state_object_or_default(a)?;
239        if state_object.balance.overflowing_sub(decr).1 {
240            return Err(Error::BalanceError);
241        }
242        state_object.sub_balance(decr);
243        self.insert_cache(a, StateObjectEntry::new_dirty(Some(state_object)));
244        Ok(())
245    }
246
247    /// Transfer balance from `from` to `to` by `by`.
248    pub fn transfer_balance(&mut self, from: &Address, to: &Address, by: U256) -> Result<(), Error> {
249        self.sub_balance(from, by)?;
250        self.add_balance(to, by)?;
251        Ok(())
252    }
253
254    /// Increase nonce for an account.
255    pub fn inc_nonce(&mut self, address: &Address) -> Result<(), Error> {
256        debug!("state.inc_nonce a={:?}", address);
257        let mut state_object = self.get_state_object_or_default(address)?;
258        state_object.inc_nonce();
259        self.insert_cache(address, StateObjectEntry::new_dirty(Some(state_object)));
260        Ok(())
261    }
262
263    /// Insert a state object entry into cache.
264    fn insert_cache(&self, address: &Address, state_object_entry: StateObjectEntry) {
265        let is_dirty = state_object_entry.is_dirty();
266        let old_entry = self
267            .cache
268            .borrow_mut()
269            .insert(*address, state_object_entry.clone_dirty());
270
271        if is_dirty {
272            if let Some(checkpoint) = self.checkpoints.borrow_mut().last_mut() {
273                checkpoint.entry(*address).or_insert(old_entry);
274            }
275        }
276    }
277
278    /// Flush the data from cache to database.
279    pub fn commit(&mut self) -> Result<(), Error> {
280        assert!(self.checkpoints.borrow().is_empty());
281
282        // Firstly, update account storage tree
283        let db = Arc::clone(&self.db);
284        self.cache
285            .borrow_mut()
286            .par_iter_mut()
287            .map(|(address, entry)| {
288                if !entry.is_dirty() {
289                    return Ok(());
290                }
291
292                if let Some(ref mut state_object) = entry.state_object {
293                    let accdb = Arc::new(AccountDB::new(*address, Arc::clone(&db)));
294                    state_object.commit_storage(Arc::clone(&accdb))?;
295                    state_object.commit_code(Arc::clone(&db))?;
296                }
297                Ok(())
298            })
299            .collect::<Result<(), Error>>()?;
300
301        // Secondly, update the world state tree
302        let mut trie = PatriciaTrie::<_, cita_trie::Keccak256Hash>::from(Arc::clone(&self.db), &self.root.0)?;
303        let key_values = self
304            .cache
305            .borrow_mut()
306            .par_iter_mut()
307            .filter(|&(_, ref a)| a.is_dirty())
308            .map(|(address, entry)| {
309                entry.status = ObjectStatus::Committed;
310
311                match entry.state_object {
312                    Some(ref mut state_object) => {
313                        (hashlib::summary(&address[..]), rlp::encode(&state_object.account()))
314                    }
315                    None => (hashlib::summary(&address[..]), vec![]),
316                }
317            })
318            .collect::<Vec<(Vec<u8>, Vec<u8>)>>();
319
320        for (key, value) in key_values.into_iter() {
321            trie.insert(key, value)?;
322        }
323
324        self.root = From::from(&trie.root()?[..]);
325        self.db.flush().or_else(|e| Err(Error::DB(format!("{}", e))))
326    }
327
328    /// Create a recoverable checkpoint of this state. Return the checkpoint index.
329    pub fn checkpoint(&mut self) -> usize {
330        debug!("state.checkpoint");
331        let mut checkpoints = self.checkpoints.borrow_mut();
332        let index = checkpoints.len();
333        checkpoints.push(HashMap::new());
334        index
335    }
336
337    fn add_checkpoint(&self, address: &Address) {
338        if let Some(ref mut checkpoint) = self.checkpoints.borrow_mut().last_mut() {
339            checkpoint
340                .entry(*address)
341                .or_insert_with(|| self.cache.borrow().get(address).map(StateObjectEntry::clone_dirty));
342        }
343    }
344
345    /// Merge last checkpoint with previous.
346    pub fn discard_checkpoint(&mut self) {
347        debug!("state.discard_checkpoint");
348        let last = self.checkpoints.borrow_mut().pop();
349        if let Some(mut checkpoint) = last {
350            if let Some(prev) = self.checkpoints.borrow_mut().last_mut() {
351                if prev.is_empty() {
352                    *prev = checkpoint;
353                } else {
354                    for (k, v) in checkpoint.drain() {
355                        prev.entry(k).or_insert(v);
356                    }
357                }
358            }
359        }
360    }
361
362    /// Revert to the last checkpoint and discard it.
363    pub fn revert_checkpoint(&mut self) {
364        debug!("state.revert_checkpoint");
365        if let Some(mut last) = self.checkpoints.borrow_mut().pop() {
366            for (k, v) in last.drain() {
367                match v {
368                    Some(v) => match self.cache.borrow_mut().entry(k) {
369                        Entry::Occupied(mut e) => {
370                            // Merge checkpointed changes back into the main account
371                            // storage preserving the cache.
372                            e.get_mut().merge(v);
373                        }
374                        Entry::Vacant(e) => {
375                            e.insert(v);
376                        }
377                    },
378                    None => {
379                        if let Entry::Occupied(e) = self.cache.borrow_mut().entry(k) {
380                            if e.get().is_dirty() {
381                                e.remove();
382                            }
383                        }
384                    }
385                }
386            }
387        }
388    }
389}
390
391pub trait StateObjectInfo {
392    fn nonce(&mut self, a: &Address) -> Result<U256, Error>;
393
394    fn balance(&mut self, a: &Address) -> Result<U256, Error>;
395
396    fn get_storage(&mut self, a: &Address, key: &H256) -> Result<H256, Error>;
397
398    fn code(&mut self, a: &Address) -> Result<Vec<u8>, Error>;
399
400    fn code_hash(&mut self, a: &Address) -> Result<H256, Error>;
401
402    fn code_size(&mut self, a: &Address) -> Result<usize, Error>;
403}
404
405impl<B: DB> StateObjectInfo for State<B> {
406    fn nonce(&mut self, address: &Address) -> Result<U256, Error> {
407        self.call_with_cached(address, |a| Ok(a.map_or(U256::zero(), |e| e.nonce)))?
408    }
409
410    fn balance(&mut self, address: &Address) -> Result<U256, Error> {
411        self.call_with_cached(address, |a| Ok(a.map_or(U256::zero(), |e| e.balance)))?
412    }
413
414    fn get_storage(&mut self, address: &Address, key: &H256) -> Result<H256, Error> {
415        self.call_with_cached(address, |a| match a {
416            Some(state_object) => {
417                let accdb = Arc::new(AccountDB::new(*address, self.db.clone()));
418                match state_object.get_storage(accdb, key)? {
419                    Some(v) => Ok(v),
420                    None => Ok(H256::zero()),
421                }
422            }
423            None => Ok(H256::zero()),
424        })?
425    }
426
427    fn code(&mut self, address: &Address) -> Result<Vec<u8>, Error> {
428        self.call_with_cached(address, |a| Ok(a.map_or(vec![], |e| e.code.clone())))?
429    }
430
431    fn code_hash(&mut self, address: &Address) -> Result<H256, Error> {
432        self.call_with_cached(address, |a| Ok(a.map_or(H256::zero(), |e| e.code_hash)))?
433    }
434
435    fn code_size(&mut self, address: &Address) -> Result<usize, Error> {
436        self.call_with_cached(address, |a| Ok(a.map_or(0, |e| e.code_size)))?
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use cita_trie::MemoryDB;
444    use std::sync::Arc;
445
446    fn get_temp_state() -> State<MemoryDB> {
447        let db = Arc::new(MemoryDB::new(false));
448        State::new(db).unwrap()
449    }
450
451    #[test]
452    fn test_code_from_database() {
453        let a = Address::zero();
454        let (root, db) = {
455            let mut state = get_temp_state();
456            state.set_code(&a, vec![1, 2, 3]).unwrap();
457            assert_eq!(state.code(&a).unwrap(), vec![1, 2, 3]);
458            assert_eq!(
459                state.code_hash(&a).unwrap(),
460                "0xf1885eda54b7a053318cd41e2093220dab15d65381b1157a3633a83bfd5c9239".into()
461            );
462            assert_eq!(state.code_size(&a).unwrap(), 3);
463            state.commit().unwrap();
464            assert_eq!(state.code(&a).unwrap(), vec![1, 2, 3]);
465            assert_eq!(
466                state.code_hash(&a).unwrap(),
467                "0xf1885eda54b7a053318cd41e2093220dab15d65381b1157a3633a83bfd5c9239".into()
468            );
469            assert_eq!(state.code_size(&a).unwrap(), 3);
470            (state.root, state.db)
471        };
472
473        let mut state = State::from_existing(db, root).unwrap();
474        assert_eq!(state.code(&a).unwrap(), vec![1, 2, 3]);
475        assert_eq!(
476            state.code_hash(&a).unwrap(),
477            "0xf1885eda54b7a053318cd41e2093220dab15d65381b1157a3633a83bfd5c9239".into()
478        );
479        assert_eq!(state.code_size(&a).unwrap(), 3);
480    }
481
482    #[test]
483    fn get_storage_from_datebase() {
484        let a = Address::zero();
485        let (root, db) = {
486            let mut state = get_temp_state();
487            state
488                .set_storage(&a, H256::from(&U256::from(1u64)), H256::from(&U256::from(69u64)))
489                .unwrap();
490            state.commit().unwrap();
491            (state.root, state.db)
492        };
493
494        let mut state = State::from_existing(db, root).unwrap();
495        assert_eq!(
496            state.get_storage(&a, &H256::from(&U256::from(1u64))).unwrap(),
497            H256::from(&U256::from(69u64))
498        );
499    }
500
501    #[test]
502    fn get_from_database() {
503        let a = Address::zero();
504        let (root, db) = {
505            let mut state = get_temp_state();
506            state.inc_nonce(&a).unwrap();
507            state.add_balance(&a, U256::from(69u64)).unwrap();
508            state.commit().unwrap();
509            assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
510            assert_eq!(state.nonce(&a).unwrap(), U256::from(1u64));
511            (state.root, state.db)
512        };
513
514        let mut state = State::from_existing(db, root).unwrap();
515        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
516        assert_eq!(state.nonce(&a).unwrap(), U256::from(1u64));
517    }
518
519    #[test]
520    fn remove() {
521        let a = Address::zero();
522        let mut state = get_temp_state();
523        assert_eq!(state.exist(&a).unwrap(), false);
524        state.inc_nonce(&a).unwrap();
525        assert_eq!(state.exist(&a).unwrap(), true);
526        assert_eq!(state.nonce(&a).unwrap(), U256::from(1u64));
527        state.kill_contract(&a);
528        assert_eq!(state.exist(&a).unwrap(), false);
529        assert_eq!(state.nonce(&a).unwrap(), U256::from(0u64));
530    }
531
532    #[test]
533    fn remove_from_database() {
534        let a = Address::zero();
535        let (root, db) = {
536            let mut state = get_temp_state();
537            state.add_balance(&a, U256::from(69u64)).unwrap();
538            state.commit().unwrap();
539            (state.root, state.db)
540        };
541
542        let (root, db) = {
543            let mut state = State::from_existing(db, root).unwrap();
544            assert_eq!(state.exist(&a).unwrap(), true);
545            assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
546            state.kill_contract(&a);
547            state.commit().unwrap();
548            assert_eq!(state.exist(&a).unwrap(), false);
549            assert_eq!(state.balance(&a).unwrap(), U256::from(0u64));
550            (state.root, state.db)
551        };
552
553        let mut state = State::from_existing(db, root).unwrap();
554        assert_eq!(state.exist(&a).unwrap(), false);
555        assert_eq!(state.balance(&a).unwrap(), U256::from(0u64));
556    }
557
558    #[test]
559    fn alter_balance() {
560        let mut state = get_temp_state();
561        let a = Address::zero();
562        let b: Address = 1u64.into();
563
564        state.add_balance(&a, U256::from(69u64)).unwrap();
565        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
566        state.commit().unwrap();
567        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
568
569        state.sub_balance(&a, U256::from(42u64)).unwrap();
570        assert_eq!(state.balance(&a).unwrap(), U256::from(27u64));
571        state.commit().unwrap();
572        assert_eq!(state.balance(&a).unwrap(), U256::from(27u64));
573
574        state.transfer_balance(&a, &b, U256::from(18)).unwrap();
575        assert_eq!(state.balance(&a).unwrap(), U256::from(9u64));
576        assert_eq!(state.balance(&b).unwrap(), U256::from(18u64));
577        state.commit().unwrap();
578        assert_eq!(state.balance(&a).unwrap(), U256::from(9u64));
579        assert_eq!(state.balance(&b).unwrap(), U256::from(18u64));
580    }
581
582    #[test]
583    fn alter_nonce() {
584        let mut state = get_temp_state();
585        let a = Address::zero();
586        state.inc_nonce(&a).unwrap();
587        assert_eq!(state.nonce(&a).unwrap(), U256::from(1u64));
588        state.inc_nonce(&a).unwrap();
589        assert_eq!(state.nonce(&a).unwrap(), U256::from(2u64));
590        state.commit().unwrap();
591        assert_eq!(state.nonce(&a).unwrap(), U256::from(2u64));
592        state.inc_nonce(&a).unwrap();
593        assert_eq!(state.nonce(&a).unwrap(), U256::from(3u64));
594        state.commit().unwrap();
595        assert_eq!(state.nonce(&a).unwrap(), U256::from(3u64));
596    }
597
598    #[test]
599    fn balance_nonce() {
600        let mut state = get_temp_state();
601        let a = Address::zero();
602        assert_eq!(state.balance(&a).unwrap(), U256::from(0u64));
603        assert_eq!(state.nonce(&a).unwrap(), U256::from(0u64));
604        state.commit().unwrap();
605        assert_eq!(state.balance(&a).unwrap(), U256::from(0u64));
606        assert_eq!(state.nonce(&a).unwrap(), U256::from(0u64));
607    }
608
609    #[test]
610    fn ensure_cached() {
611        let mut state = get_temp_state();
612        let a = Address::zero();
613        state.new_contract(&a, U256::from(0u64), U256::from(0u64), vec![]);
614        state.commit().unwrap();
615        assert_eq!(
616            state.root,
617            "0ce23f3c809de377b008a4a3ee94a0834aac8bec1f86e28ffe4fdb5a15b0c785".into()
618        );
619    }
620
621    #[test]
622    fn checkpoint_basic() {
623        let mut state = get_temp_state();
624        let a = Address::zero();
625
626        state.checkpoint();
627        state.add_balance(&a, U256::from(69u64)).unwrap();
628        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
629        state.discard_checkpoint();
630        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
631
632        state.checkpoint();
633        state.add_balance(&a, U256::from(1u64)).unwrap();
634        assert_eq!(state.balance(&a).unwrap(), U256::from(70u64));
635        state.revert_checkpoint();
636        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
637    }
638
639    #[test]
640    fn checkpoint_nested() {
641        let mut state = get_temp_state();
642        let a = Address::zero();
643        state.checkpoint();
644        state.checkpoint();
645        state.add_balance(&a, U256::from(69u64)).unwrap();
646        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
647        state.discard_checkpoint();
648        assert_eq!(state.balance(&a).unwrap(), U256::from(69u64));
649        state.revert_checkpoint();
650        assert_eq!(state.balance(&a).unwrap(), U256::from(0));
651    }
652
653    #[test]
654    fn checkpoint_revert_to_get_storage() {
655        let mut state = get_temp_state();
656        let a = Address::zero();
657        let k = H256::from(U256::from(0));
658
659        state.checkpoint();
660        state.checkpoint();
661        state.set_storage(&a, k, H256::from(1u64)).unwrap();
662        assert_eq!(state.get_storage(&a, &k).unwrap(), H256::from(1u64));
663        state.revert_checkpoint();
664        assert!(state.get_storage(&a, &k).unwrap().is_zero());
665    }
666
667    #[test]
668    fn checkpoint_kill_account() {
669        let mut state = get_temp_state();
670        let a = Address::zero();
671        let k = H256::from(U256::from(0));
672        state.checkpoint();
673        state.set_storage(&a, k, H256::from(U256::from(1))).unwrap();
674        state.checkpoint();
675        state.kill_contract(&a);
676        assert!(state.get_storage(&a, &k).unwrap().is_zero());
677        state.revert_checkpoint();
678        assert_eq!(state.get_storage(&a, &k).unwrap(), H256::from(U256::from(1)));
679    }
680
681    #[test]
682    fn checkpoint_create_contract_fail() {
683        let mut state = get_temp_state();
684        let orig_root = state.root;
685        let a: Address = 1000.into();
686
687        state.checkpoint(); // c1
688        state.new_contract(&a, U256::zero(), U256::zero(), vec![]);
689        state.add_balance(&a, U256::from(1)).unwrap();
690        state.checkpoint(); // c2
691        state.add_balance(&a, U256::from(1)).unwrap();
692        state.discard_checkpoint(); // discard c2
693        state.revert_checkpoint(); // revert to c1
694        assert_eq!(state.exist(&a).unwrap(), false);
695        state.commit().unwrap();
696        assert_eq!(orig_root, state.root);
697    }
698
699    #[test]
700    fn create_contract_fail_previous_storage() {
701        let mut state = get_temp_state();
702        let a: Address = 1000.into();
703        let k = H256::from(U256::from(0));
704
705        state.set_storage(&a, k, H256::from(U256::from(0xffff))).unwrap();
706        state.commit().unwrap();
707        state.clear();
708
709        let orig_root = state.root;
710        assert_eq!(state.get_storage(&a, &k).unwrap(), H256::from(U256::from(0xffff)));
711        state.clear();
712
713        state.checkpoint(); // c1
714        state.new_contract(&a, U256::zero(), U256::zero(), vec![]);
715        state.checkpoint(); // c2
716        state.set_storage(&a, k, H256::from(U256::from(2))).unwrap();
717        state.revert_checkpoint(); // revert to c2
718        assert_eq!(state.get_storage(&a, &k).unwrap(), H256::from(U256::from(0)));
719        state.revert_checkpoint(); // revert to c1
720        assert_eq!(state.get_storage(&a, &k).unwrap(), H256::from(U256::from(0xffff)));
721
722        state.commit().unwrap();
723        assert_eq!(orig_root, state.root);
724    }
725
726    #[test]
727    fn checkpoint_chores() {
728        let mut state = get_temp_state();
729        let a: Address = 1000.into();
730        let b: Address = 2000.into();
731        state.new_contract(&a, 5.into(), 0.into(), vec![10u8, 20, 30, 40, 50]);
732        state.add_balance(&a, 5.into()).unwrap();
733        state.set_storage(&a, 10.into(), 10.into()).unwrap();
734        assert_eq!(state.code(&a).unwrap(), vec![10u8, 20, 30, 40, 50]);
735        assert_eq!(state.balance(&a).unwrap(), 10.into());
736        assert_eq!(state.get_storage(&a, &10.into()).unwrap(), 10.into());
737        state.commit().unwrap();
738        let orig_root = state.root;
739
740        // Top         => account_a: balance=8, nonce=0, code=[10, 20, 30, 40, 50],
741        //             |      stroage = { 10=15, 20=20 }
742        //             |  account_b: balance=30, nonce=0, code=[]
743        //             |      storage = { 55=55 }
744        //
745        //
746        // Checkpoint2 => account_a: balance=8, nonce=0, code=[10, 20, 30, 40, 50],
747        //             |      stroage = { 10=10, 20=20 }
748        //             |  account_b: None
749        //
750        // Checkpoint1 => account_a: balance=10, nonce=0, code=[10, 20, 30, 40, 50],
751        //             |      storage = { 10=10 }
752        //             |  account_b: None
753
754        state.checkpoint(); // c1
755        state.sub_balance(&a, 2.into()).unwrap();
756        state.set_storage(&a, 20.into(), 20.into()).unwrap();
757        assert_eq!(state.balance(&a).unwrap(), 8.into());
758        assert_eq!(state.get_storage(&a, &10.into()).unwrap(), 10.into());
759        assert_eq!(state.get_storage(&a, &20.into()).unwrap(), 20.into());
760
761        state.checkpoint(); // c2
762        state.new_contract(&b, 30.into(), 0.into(), vec![]);
763        state.set_storage(&a, 10.into(), 15.into()).unwrap();
764        assert_eq!(state.balance(&b).unwrap(), 30.into());
765        assert_eq!(state.code(&b).unwrap(), vec![]);
766
767        state.revert_checkpoint(); // revert c2
768        assert_eq!(state.balance(&a).unwrap(), 8.into());
769        assert_eq!(state.get_storage(&a, &10.into()).unwrap(), 10.into());
770        assert_eq!(state.get_storage(&a, &20.into()).unwrap(), 20.into());
771        assert_eq!(state.balance(&b).unwrap(), 0.into());
772        assert_eq!(state.code(&b).unwrap(), vec![]);
773        assert_eq!(state.exist(&b).unwrap(), false);
774
775        state.revert_checkpoint(); // revert c1
776        assert_eq!(state.code(&a).unwrap(), vec![10u8, 20, 30, 40, 50]);
777        assert_eq!(state.balance(&a).unwrap(), 10.into());
778        assert_eq!(state.get_storage(&a, &10.into()).unwrap(), 10.into());
779
780        state.commit().unwrap();
781        assert_eq!(orig_root, state.root);
782    }
783
784    #[test]
785    fn get_account_proof() {
786        let mut state = get_temp_state();
787        let a: Address = 1000.into();
788        let b: Address = 2000.into();
789        state.new_contract(&a, 5.into(), 0.into(), vec![10u8, 20, 30, 40, 50]);
790        state.commit().unwrap();
791
792        // The state only contains one account, should be a single leaf node, therefore the proof
793        // length is 1
794        let proof1 = state.get_account_proof(&a).unwrap();
795        assert_eq!(proof1.len(), 1);
796
797        // account not in state should also have non-empty proof, the proof is the longest common
798        // prefix node
799        let proof2 = state.get_account_proof(&b).unwrap();
800        assert_eq!(proof2.len(), 1);
801
802        assert_eq!(proof1, proof2);
803    }
804
805    #[test]
806    fn get_storage_proof() {
807        let mut state = get_temp_state();
808        let a: Address = 1000.into();
809        let b: Address = 2000.into();
810        let c: Address = 3000.into();
811        state.new_contract(&a, 5.into(), 0.into(), vec![10u8, 20, 30, 40, 50]);
812        state.set_storage(&a, 10.into(), 10.into()).unwrap();
813        state.new_contract(&b, 5.into(), 0.into(), vec![10u8, 20, 30, 40, 50]);
814        state.commit().unwrap();
815
816        // account not exist
817        let proof = state.get_storage_proof(&c, &10.into()).unwrap();
818        assert_eq!(proof.len(), 0);
819
820        // account who has empty storage trie
821        let proof = state.get_storage_proof(&b, &10.into()).unwrap();
822        assert_eq!(proof.len(), 0);
823
824        // account and storage key exists
825        let proof1 = state.get_storage_proof(&a, &10.into()).unwrap();
826        assert_eq!(proof1.len(), 1);
827
828        // account exists but storage key not exist
829        let proof2 = state.get_storage_proof(&a, &20.into()).unwrap();
830        assert_eq!(proof2.len(), 1);
831
832        assert_eq!(proof1, proof2);
833    }
834}