1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright (c) Aptos
// SPDX-License-Identifier: Apache-2.0

use aptos_crypto::HashValue;
use aptos_logger::info;
use aptos_types::{
    state_store::{state_key::StateKey, state_value::StateValue},
    transaction::Version,
};
use executor_types::StateSnapshotDelta;
use scratchpad::SparseMerkleTree;
use std::{
    collections::VecDeque,
    sync::{mpsc, mpsc::TryRecvError, Arc},
    time,
};
use storage_interface::DbWriter;

const NUM_COMMITTED_SMTS_TO_CACHE: usize = 5;
const NUM_MIN_COMMITS_TO_BATCH: usize = 2;

pub struct StateCommitter {
    commit_receiver: mpsc::Receiver<StateSnapshotDelta>,
    db: Arc<dyn DbWriter>,

    // keep some recently committed SMTs in mem as a naive cache
    cache_queue: VecDeque<SparseMerkleTree<StateValue>>,
    version: Version,
    smt: SparseMerkleTree<StateValue>,
    committed_smt: SparseMerkleTree<StateValue>,
    committed_version: Option<Version>,
    updates: Vec<(HashValue, (HashValue, StateKey))>,
    num_pending_commits: usize,
}

impl StateCommitter {
    pub fn new(
        commit_receiver: mpsc::Receiver<StateSnapshotDelta>,
        db: Arc<dyn DbWriter>,
        committed_smt: SparseMerkleTree<StateValue>,
        committed_version: Option<Version>,
    ) -> Self {
        let mut cache_queue = VecDeque::new();
        cache_queue.push_back(committed_smt.clone());

        Self {
            commit_receiver,
            db,

            cache_queue,
            version: 0,
            smt: committed_smt.clone(),
            committed_smt,
            committed_version,
            updates: Vec::new(),
            num_pending_commits: 0,
        }
    }

    pub fn run(mut self) {
        loop {
            match self.commit_receiver.try_recv() {
                Ok(StateSnapshotDelta {
                    version,
                    smt,
                    jmt_updates,
                }) => {
                    self.version = version;
                    self.smt = smt;
                    self.updates.extend(jmt_updates.into_iter());
                    self.num_pending_commits += 1;
                }
                Err(TryRecvError::Empty) => {
                    if self.num_pending_commits < NUM_MIN_COMMITS_TO_BATCH {
                        std::thread::sleep(time::Duration::from_secs(1));
                    } else {
                        self.commit();
                    }
                }
                Err(TryRecvError::Disconnected) => {
                    println!("Final state commit...");
                    self.commit();
                    return;
                }
            }
        }
    }

    fn commit(&mut self) {
        // commit
        info!(
            num_pending_commits = self.num_pending_commits,
            version = self.version,
            "Committing state.",
        );
        let mut to_commit = Vec::new();
        std::mem::swap(&mut to_commit, &mut self.updates);
        let node_hashes = self
            .smt
            .clone()
            .freeze()
            .new_node_hashes_since(&self.committed_smt.clone().freeze());
        self.db
            .save_state_snapshot(
                to_commit,
                Some(&node_hashes),
                self.version,
                self.committed_version,
                self.smt.clone(),
            )
            .unwrap();
        info!("Committing state. Saved.");

        // reset pending updates
        self.num_pending_commits = 0;
        self.committed_smt = self.smt.clone();
        self.committed_version = Some(self.version);

        // cache maintenance
        if self.cache_queue.len() >= NUM_COMMITTED_SMTS_TO_CACHE {
            self.cache_queue.pop_front();
        }
        self.cache_queue.push_back(self.smt.clone());
    }
}