tallytree/
tally.rs

1use crate::hash::hash_node;
2use crate::node::{is_null_node, is_null_node_ref, is_wrapper_node, Node, NodeRef};
3use crate::Validation;
4
5/// List of tallies at a given height.
6pub type TallyList = Vec<u32>;
7
8/// Add two tally lists together.
9///
10/// Example:
11/// ```
12/// use tallytree::tally::{tally_node_ref, combine_tally};
13/// use tallytree::generate::generate_tree;
14/// use tallytree::Validation;
15/// let tree = generate_tree(vec![
16///     ([0xaa; 32], vec![1, 0]),
17///     ([0xbb; 32], vec![0, 1]),
18/// ], false).unwrap().unwrap();
19/// assert_eq!(Ok([1, 1].to_vec()), combine_tally(
20///     &tally_node_ref(&tree.left, &Validation::Strict).unwrap(),
21///     &tally_node_ref(&tree.right, &Validation::Strict).unwrap(),
22/// ));
23/// ```
24pub fn combine_tally(
25    left: &Option<TallyList>,
26    right: &Option<TallyList>,
27) -> Result<TallyList, String> {
28    if left.is_none() {
29        // Left node should never be Ø-node and always have a tally.
30        return Err("Left node does not contain a tally".to_string());
31    }
32    if right.is_none() {
33        return Ok(left.as_ref().unwrap().clone());
34    }
35
36    let left = left.as_ref().unwrap();
37    let right = right.as_ref().unwrap();
38
39    if left.len() != right.len() {
40        return Err(format!(
41            "Left tally length is not equal to right tally \
42            length ({} != {})",
43            left.len(),
44            right.len()
45        ));
46    }
47    Ok(left.iter().zip(right.iter()).map(|(l, r)| l + r).collect())
48}
49
50/// Checks if tally has a single vote
51///
52/// Example:
53/// ```
54/// use tallytree::tally::has_one_vote;
55/// assert!(has_one_vote(&[1, 0, 0]));
56/// assert!(!has_one_vote(&[1, 0, 1]));
57/// ```
58pub fn has_one_vote(tally: &[u32]) -> bool {
59    tally.iter().sum::<u32>() == 1
60}
61
62/// Given a node in the merkle tally tree, prints the tree to stdout. Pass `0`
63/// for `level`argument.
64///
65/// Leaf nodes are prefixed with their `VoteReference`. Parent nodes are
66/// prefixed with their hash.
67///
68/// Example:
69/// ```
70/// use tallytree::generate::generate_tree;
71/// use tallytree::tally::pretty_print_tally;
72/// let tree = generate_tree(vec![
73///     ([0xaa; 32], vec![1, 0]),
74///     ([0xbb; 32], vec![0, 1]),
75/// ], false).unwrap();
76/// pretty_print_tally(&tree, 0);
77/// ```
78pub fn pretty_print_tally(node: &NodeRef, level: usize) {
79    match node {
80        Some(n) => {
81            if is_null_node(n) {
82                println!("{:indent$}Ø", "", indent = level * 4);
83                return;
84            }
85            let prefix = match n.vote {
86                Some((v, _)) => [v[0], v[1]],
87                None => {
88                    let h = hash_node(n, &Validation::Strict).unwrap().0;
89                    [h.as_slice()[0], h.as_slice()[1]]
90                }
91            };
92            if is_wrapper_node(node) {
93                print!(
94                    "{:indent$}{:x?}=>{:?} --> ",
95                    "",
96                    prefix,
97                    tally_node(n, &Validation::Strict).unwrap().unwrap(),
98                    indent = level * 4
99                );
100                pretty_print_tally(&n.left, 0);
101                assert!(n.right.is_none());
102                return;
103            }
104
105            pretty_print_tally(&n.right, level + 1);
106            println!(
107                "{:indent$}{:x?}=>{:?}",
108                "",
109                prefix,
110                tally_node(n, &Validation::Strict).unwrap().unwrap(),
111                indent = level * 4
112            );
113            pretty_print_tally(&n.left, level + 1);
114        }
115        None => {}
116    }
117}
118
119/// Tally all the votes in a tree.
120///
121/// Example:
122/// ```
123/// use tallytree::generate::generate_tree;
124/// use tallytree::tally::tally_node;
125/// use tallytree::Validation;
126/// let tree = generate_tree(vec![
127///     ([0xaa; 32], vec![1, 0]),
128///     ([0xbb; 32], vec![0, 1]),
129/// ], false).unwrap();
130/// assert_eq!(Some([1, 1].to_vec()), tally_node(&tree.unwrap(), &Validation::Strict).unwrap());
131/// ```
132pub fn tally_node(node: &Node, v: &Validation) -> Result<Option<TallyList>, String> {
133    if is_null_node(node) {
134        return Ok(None);
135    }
136    if let Some((_, tally)) = &node.vote {
137        if node.left.is_some() || node.right.is_some() {
138            return Err("Leaf has a child".to_string());
139        }
140        if !matches!(v, Validation::Relaxed) && !has_one_vote(tally) {
141            return Err("Leaf casts more than 1 vote".to_string());
142        }
143        return Ok(Some(tally.to_vec()));
144    }
145    if node.right.is_none() || is_null_node_ref(&node.right) {
146        return tally_node_ref(&node.left, v);
147    }
148    Ok(Some(combine_tally(
149        &tally_node_ref(&node.left, v)?,
150        &tally_node_ref(&node.right, v)?,
151    )?))
152}
153
154/// Tally all the votes in a tree.
155///
156/// Example:
157/// ```
158/// use tallytree::generate::generate_tree;
159/// use tallytree::tally::tally_node_ref;
160/// use tallytree::Validation;
161/// let tree = generate_tree(vec![
162///     ([0xaa; 32], vec![1, 0]),
163///     ([0xbb; 32], vec![0, 1]),
164/// ], false).unwrap();
165/// assert_eq!(Some([1, 1].to_vec()), tally_node_ref(&tree, &Validation::Strict).unwrap());
166/// ```
167pub fn tally_node_ref(node: &NodeRef, v: &Validation) -> Result<Option<TallyList>, String> {
168    if let Some(n) = node {
169        tally_node(n, v)
170    } else {
171        Ok(None)
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::generate::generate_tree;
179
180    #[test]
181    fn test_pretty_print() {
182        let head = generate_tree(
183            vec![
184                ([0x11; 32], vec![1, 0, 0]),
185                ([0x22; 32], vec![0, 0, 1]),
186                ([0x33; 32], vec![0, 0, 1]),
187                ([0x44; 32], vec![0, 0, 1]),
188                ([0x55; 32], vec![0, 0, 1]),
189            ],
190            false,
191        )
192        .unwrap();
193        pretty_print_tally(&head, 0);
194        assert_eq!(1, 1);
195    }
196
197    #[test]
198    fn test_combine_tally() {
199        assert_eq!(
200            Ok(vec![2, 0, 1]),
201            combine_tally(&Some(vec![1, 0, 0]), &Some(vec![1, 0, 1]))
202        );
203
204        assert_eq!(Ok(vec![2, 0]), combine_tally(&Some(vec![2, 0]), &None,));
205
206        // Left node must contain a tally.
207        assert!(combine_tally(&None, &Some(vec![2, 0])).is_err());
208
209        // Tally list must be equal length
210        assert!(combine_tally(&Some(vec![0, 0, 1]), &Some(vec![2, 0])).is_err());
211    }
212}