1use std::{convert::TryInto, marker::PhantomData};
2
3use kvdb::KeyValueDB;
4use kvdb_memorydb::InMemory as MemoryDatabase;
5#[cfg(feature = "web")]
6use kvdb_web::Database as WebDatabase;
7use libzeropool::{
8 constants,
9 fawkes_crypto::{
10 ff_uint::{Num, PrimeField},
11 BorshDeserialize, BorshSerialize,
12 },
13 native::{
14 account::{Account, Account as NativeAccount},
15 note::{Note, Note as NativeNote},
16 params::PoolParams,
17 },
18};
19
20use crate::{merkle::MerkleTree, sparse_array::SparseArray};
21
22pub type TxStorage<D, Fr> = SparseArray<D, Transaction<Fr>>;
23
24#[derive(BorshSerialize, BorshDeserialize, PartialEq, Debug)]
25pub enum Transaction<Fr: PrimeField> {
26 Account(NativeAccount<Fr>),
27 Note(NativeNote<Fr>),
28}
29
30pub struct State<D: KeyValueDB, P: PoolParams> {
31 pub tree: MerkleTree<D, P>,
32 pub(crate) txs: TxStorage<D, P::Fr>,
34 pub(crate) latest_account: Option<NativeAccount<P::Fr>>,
35 pub latest_account_index: Option<u64>,
36 pub latest_note_index: u64,
38 _params: PhantomData<P>,
39}
40
41#[cfg(feature = "web")]
42impl<P> State<WebDatabase, P>
43where
44 P: PoolParams,
45 P::Fr: 'static,
46{
47 pub async fn init_web(db_id: String, params: P) -> Self {
48 let merkle_db_name = format!("zeropool.{}.smt", &db_id);
49 let tx_db_name = format!("zeropool.{}.txs", &db_id);
50 let tree = MerkleTree::new_web(&merkle_db_name, params.clone()).await;
51 let txs = TxStorage::new_web(&tx_db_name).await;
52
53 Self::new(tree, txs)
54 }
55}
56
57impl<P> State<MemoryDatabase, P>
58where
59 P: PoolParams,
60 P::Fr: 'static,
61{
62 pub fn init_test(params: P) -> Self {
63 let tree = MerkleTree::new_test(params);
64 let txs = TxStorage::new_test();
65
66 Self::new(tree, txs)
67 }
68}
69
70impl<D, P> State<D, P>
71where
72 D: KeyValueDB,
73 P: PoolParams,
74 P::Fr: 'static,
75{
76 pub fn new(tree: MerkleTree<D, P>, txs: TxStorage<D, P::Fr>) -> Self {
77 let (latest_account_index, latest_note_index, latest_account) =
79 latest_indices::<D, P>(&txs);
80
81 State {
82 tree,
83 txs,
84 latest_account_index,
85 latest_note_index,
86 latest_account,
87 _params: Default::default(),
88 }
89 }
90
91 pub fn add_hashes(&mut self, at_index: u64, hashes: &[Num<P::Fr>]) {
93 assert_eq!(
95 at_index % (constants::OUT as u64 + 1),
96 0,
97 "index must be divisible by {}",
98 constants::OUT + 1
99 );
100
101 self.tree.add_hashes(at_index, hashes.iter().copied());
102 }
103
104 pub fn add_full_tx(
106 &mut self,
107 at_index: u64,
108 hashes: &[Num<P::Fr>],
109 account: Option<Account<P::Fr>>,
110 notes: &[(u64, Note<P::Fr>)],
111 ) {
112 self.add_hashes(at_index, hashes);
113
114 if let Some(acc) = account {
115 self.add_account(at_index, acc);
116 }
117
118 for (index, note) in notes {
120 self.add_note(*index, *note);
121 }
122 }
123
124 pub fn add_account(&mut self, at_index: u64, account: Account<P::Fr>) {
126 self.txs.set(at_index, &Transaction::Account(account));
128
129 if at_index >= self.latest_account_index.unwrap_or(0) {
130 self.latest_account_index = Some(at_index);
131 self.latest_account = Some(account);
132 }
133 }
134
135 pub fn add_note(&mut self, at_index: u64, note: Note<P::Fr>) {
137 if self.txs.get(at_index).is_some() {
138 return;
139 }
140
141 self.txs.set(at_index, &Transaction::Note(note));
142
143 if at_index > self.latest_note_index {
144 self.latest_note_index = at_index;
145 }
146 }
147
148 pub fn get_all_txs(&self) -> Vec<(u64, Transaction<P::Fr>)> {
149 self.txs.iter().collect()
150 }
151
152 pub fn get_usable_notes(&self) -> Vec<(u64, Note<P::Fr>)> {
153 let next_usable_index = self.earliest_usable_index();
154
155 self.txs
157 .iter_slice(next_usable_index..=self.latest_note_index)
158 .filter_map(|(index, tx)| match tx {
159 Transaction::Note(note) => Some((index, note)),
160 _ => None,
161 })
162 .collect()
163 }
164
165 pub fn earliest_usable_index(&self) -> u64 {
167 let latest_account_index = self
168 .latest_account
169 .map(|acc| acc.i.to_num())
170 .unwrap_or(Num::ZERO)
171 .try_into()
172 .unwrap();
173
174 self.txs
175 .iter_slice(latest_account_index..=self.latest_note_index)
176 .filter_map(|(index, tx)| match tx {
177 Transaction::Note(_) => Some(index),
178 _ => None,
179 })
180 .next()
181 .unwrap_or(latest_account_index)
182 }
183
184 pub fn earliest_usable_index_optimistic(
186 &self,
187 optimistic_accounts: &[(u64, Account<P::Fr>)],
188 optimistic_notes: &[(u64, Note<P::Fr>)],
189 ) -> u64 {
190 let latest_account_index = optimistic_accounts
191 .last()
192 .map(|indexed_acc| indexed_acc.1)
193 .or(self.latest_account)
194 .map(|acc| acc.i.to_num())
195 .unwrap_or(Num::ZERO)
196 .try_into()
197 .unwrap();
198
199 let latest_note_index_optimistic = optimistic_notes
200 .last()
201 .map(|indexed_note| indexed_note.0)
202 .unwrap_or(self.latest_note_index);
203
204 let optimistic_note_indices = optimistic_notes
205 .iter()
206 .map(|indexed_note| indexed_note.0)
207 .filter(move |index| {
208 (latest_account_index..=latest_note_index_optimistic).contains(index)
209 });
210
211 self.txs
212 .iter_slice(latest_account_index..=latest_note_index_optimistic)
213 .filter_map(|(index, tx)| match tx {
214 Transaction::Note(_) => Some(index),
215 _ => None,
216 })
217 .chain(optimistic_note_indices)
218 .next()
219 .unwrap_or(latest_account_index)
220 }
221
222 pub fn total_balance(&self) -> Num<P::Fr> {
224 self.account_balance() + self.note_balance()
225 }
226
227 pub fn account_balance(&self) -> Num<P::Fr> {
228 self.latest_account
229 .map(|acc| acc.b.to_num())
230 .unwrap_or(Num::ZERO)
231 }
232
233 pub fn note_balance(&self) -> Num<P::Fr> {
234 let starting_index = self
235 .latest_account
236 .map(|acc| acc.i.to_num().try_into().unwrap())
237 .unwrap_or(0);
238 let mut note_balance = Num::ZERO;
239 for (_, tx) in self.txs.iter_slice(starting_index..=self.latest_note_index) {
240 if let Transaction::Note(note) = tx {
241 note_balance += note.b.to_num();
242 }
243 }
244
245 note_balance
246 }
247
248 pub fn rollback(&mut self, to_index: u64) {
249 self.txs.remove_all_after(to_index);
250 self.tree.rollback(to_index);
251 let (latest_account_index, latest_note_index, latest_account) =
252 latest_indices::<D, P>(&self.txs);
253 self.latest_account_index = latest_account_index;
254 self.latest_note_index = latest_note_index;
255 self.latest_account = latest_account;
256 }
257}
258
259fn latest_indices<D, P>(
260 txs: &TxStorage<D, P::Fr>,
261) -> (Option<u64>, u64, Option<NativeAccount<P::Fr>>)
262where
263 D: KeyValueDB,
264 P: PoolParams,
265 P::Fr: 'static,
266{
267 let mut latest_account_index = None;
268 let mut latest_note_index = 0;
269 let mut latest_account = None;
270 for (index, tx) in txs.iter() {
271 match tx {
272 Transaction::Account(acc) => {
273 if index >= latest_account_index.unwrap_or(0) {
274 latest_account_index = Some(index);
275 latest_account = Some(acc);
276 }
277 }
278 Transaction::Note(_) => {
279 if index >= latest_note_index {
280 latest_note_index = index;
281 }
282 }
283 }
284 }
285
286 (latest_account_index, latest_note_index, latest_account)
287}