use std::fmt::{self, Display};
use num_traits::ToPrimitive;
use crate::partial_trie::{Node, PartialTrie};
#[derive(Clone, Debug, Default)]
pub struct TrieStats {
name: Option<String>,
counts: NodeCounts,
depth_stats: DepthStats,
}
impl Display for TrieStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Trie Stats:")?;
match self.name.as_ref() {
Some(name) => writeln!(f, " ({})", name)?,
None => writeln!(f)?,
}
writeln!(f, "Counts:\n{}", self.counts)?;
writeln!(f, "Depth stats:\n{}", self.depth_stats)
}
}
impl TrieStats {
pub const fn compare(&self, other: &Self) -> TrieComparison {
TrieComparison {
node_comp: self.counts.compare(&other.counts),
depth_comp: self.depth_stats.compare(&other.depth_stats),
}
}
}
#[derive(Clone, Debug, Default, Hash)]
struct NodeCounts {
empty: usize,
hash: usize,
branch: usize,
extension: usize,
leaf: usize,
}
impl Display for NodeCounts {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let tot_nodes = self.total_nodes();
Self::write_node_count_stats(f, "Empty", self.empty, tot_nodes)?;
Self::write_node_count_stats(f, "Hash", self.hash, tot_nodes)?;
Self::write_node_count_stats(f, "Branch", self.branch, tot_nodes)?;
Self::write_node_count_stats(f, "Extension", self.extension, tot_nodes)?;
Self::write_node_count_stats(f, "Leaf", self.leaf, tot_nodes)
}
}
impl NodeCounts {
fn write_node_count_stats(
f: &mut fmt::Formatter<'_>,
node_t_name: &str,
count: usize,
tot_count: usize,
) -> fmt::Result {
let perc = (count as f32 / tot_count as f32) * 100.0;
writeln!(f, "{}: {} ({:.2}%)", node_t_name, count, perc)
}
}
impl NodeCounts {
const fn total_nodes(&self) -> usize {
self.empty + self.total_node_non_empty()
}
const fn total_node_non_empty(&self) -> usize {
self.branch + self.extension + self.hash_and_leaf_node_count()
}
const fn hash_and_leaf_node_count(&self) -> usize {
self.hash + self.leaf
}
const fn compare(&self, other: &Self) -> NodeComparison {
NodeComparison {
tot_node_rat: RatioStat::new(self.total_nodes(), other.total_nodes()),
non_empty_rat: RatioStat::new(
self.total_node_non_empty(),
other.total_node_non_empty(),
),
empty_rat: RatioStat::new(self.empty, other.empty),
hash_rat: RatioStat::new(self.hash, other.hash),
branch_rat: RatioStat::new(self.branch, other.branch),
extension_rat: RatioStat::new(self.extension, other.extension),
leaf_rat: RatioStat::new(self.leaf, other.leaf),
}
}
}
#[derive(Clone, Debug)]
pub struct TrieComparison {
node_comp: NodeComparison,
depth_comp: DepthComparison,
}
impl Display for TrieComparison {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Node comparison: {}", self.node_comp)?;
writeln!(f, "Depth comparison: {}", self.depth_comp)
}
}
#[derive(Clone, Debug, Hash)]
struct NodeComparison {
tot_node_rat: RatioStat<usize>,
non_empty_rat: RatioStat<usize>,
empty_rat: RatioStat<usize>,
hash_rat: RatioStat<usize>,
branch_rat: RatioStat<usize>,
extension_rat: RatioStat<usize>,
leaf_rat: RatioStat<usize>,
}
impl Display for NodeComparison {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Total nodes: {}", self.tot_node_rat)?;
writeln!(f, "Non-empty: {}", self.non_empty_rat)?;
writeln!(f, "Total empty: {}", self.empty_rat)?;
writeln!(f, "Total hash: {}", self.hash_rat)?;
writeln!(f, "Total branch: {}", self.branch_rat)?;
writeln!(f, "Total extension: {}", self.extension_rat)?;
writeln!(f, "Total leaf: {}", self.leaf_rat)
}
}
#[derive(Clone, Debug)]
struct DepthComparison {
lowest_depth_rat: RatioStat<usize>,
avg_leaf_depth_rat: RatioStat<f32>,
avg_hash_depth_rat: RatioStat<f32>,
}
impl Display for DepthComparison {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Lowest depth: {}", self.lowest_depth_rat)?;
writeln!(f, "Avg leaf depth: {}", self.avg_leaf_depth_rat)?;
writeln!(f, "Avg hash depth: {}", self.avg_hash_depth_rat)
}
}
#[derive(Clone, Debug, Hash)]
struct RatioStat<T> {
a: T,
b: T,
}
impl<T: Display + ToPrimitive> Display for RatioStat<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{:.3} / {:.3} ({:.3}%)",
self.a,
self.b,
self.get_a_over_b_perc().map_err(|_| fmt::Error)?
)
}
}
impl<T: ToPrimitive> RatioStat<T> {
const fn new(a: T, b: T) -> Self {
Self { a, b }
}
fn get_a_over_b_perc(&self) -> Result<f32, ()> {
Ok((Self::try_to_f32(&self.a)? / Self::try_to_f32(&self.b)?) * 100.0)
}
fn try_to_f32(v: &T) -> Result<f32, ()> {
v.to_f32().ok_or(())
}
}
#[derive(Debug, Default)]
struct CurrTrackingState {
counts: NodeCounts,
leaf_depth_sum: u64,
hash_depth_sum: u64,
lowest_depth: usize,
}
impl CurrTrackingState {
fn update_lowest_depth_if_larger(&mut self, curr_depth: usize) {
if self.lowest_depth < curr_depth {
self.lowest_depth = curr_depth;
}
}
}
#[derive(Clone, Debug, Default)]
struct DepthStats {
lowest_depth: usize,
avg_leaf_depth: f32,
avg_hash_depth: f32,
}
impl Display for DepthStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Lowest depth: {}", self.lowest_depth)?;
writeln!(f, "Average leaf depth: {:.3}", self.avg_leaf_depth)?;
writeln!(f, "Average hash depth: {:.3}", self.avg_hash_depth)
}
}
impl DepthStats {
const fn compare(&self, other: &Self) -> DepthComparison {
DepthComparison {
lowest_depth_rat: RatioStat::new(self.lowest_depth, other.lowest_depth),
avg_leaf_depth_rat: RatioStat::new(self.avg_leaf_depth, other.avg_leaf_depth),
avg_hash_depth_rat: RatioStat::new(self.avg_hash_depth, other.avg_hash_depth),
}
}
}
pub fn get_trie_stats<T: PartialTrie>(trie: &T) -> TrieStats {
get_trie_stats_common(trie, None)
}
pub fn get_trie_stats_with_name<T: PartialTrie>(trie: &T, name: String) -> TrieStats {
get_trie_stats_common(trie, Some(name))
}
fn get_trie_stats_common<T: PartialTrie>(trie: &T, name: Option<String>) -> TrieStats {
let mut state = CurrTrackingState::default();
get_trie_stats_rec(trie, &mut state, 0);
let depth_stats = DepthStats {
lowest_depth: state.lowest_depth,
avg_leaf_depth: state.leaf_depth_sum as f32 / state.counts.leaf as f32,
avg_hash_depth: state.hash_depth_sum as f32 / state.counts.hash as f32,
};
TrieStats {
name,
counts: state.counts,
depth_stats,
}
}
fn get_trie_stats_rec<T: PartialTrie>(
node: &Node<T>,
state: &mut CurrTrackingState,
curr_depth: usize,
) {
match node {
Node::Empty => {
state.counts.empty += 1;
}
Node::Hash(_) => {
state.counts.hash += 1;
state.hash_depth_sum += curr_depth as u64;
state.update_lowest_depth_if_larger(curr_depth);
}
Node::Branch { children, value: _ } => {
state.counts.branch += 1;
for c in children {
get_trie_stats_rec(c, state, curr_depth + 1);
}
}
Node::Extension { nibbles: _, child } => {
state.counts.extension += 1;
get_trie_stats_rec(child, state, curr_depth + 1);
}
Node::Leaf {
nibbles: _,
value: _,
} => {
state.counts.leaf += 1;
state.leaf_depth_sum += curr_depth as u64;
state.update_lowest_depth_if_larger(curr_depth);
}
}
}
#[cfg(test)]
mod tests {
use super::get_trie_stats;
use crate::{
partial_trie::{HashedPartialTrie, PartialTrie},
testing_utils::{
generate_n_random_fixed_trie_hash_entries, generate_n_random_fixed_trie_value_entries,
handmade_trie_1,
},
trie_ops::TrieOpResult,
};
const MASSIVE_TRIE_SIZE: usize = 100_000;
#[test]
fn hand_made_trie_has_correct_node_stats() -> TrieOpResult<()> {
let (trie, _) = handmade_trie_1()?;
let stats = get_trie_stats(&trie);
assert_eq!(stats.counts.leaf, 4);
assert_eq!(stats.counts.hash, 0);
assert_eq!(stats.counts.branch, 4);
assert_eq!(stats.counts.extension, 2);
assert_eq!(stats.counts.empty, 57);
Ok(())
}
#[test]
#[ignore]
fn perfectly_balanced_trie_has_correct_node_stats() {
todo!()
}
#[test]
fn massive_leaf_trie_has_correct_leaf_node_stats() -> TrieOpResult<()> {
create_trie_and_stats_from_entries_and_assert(MASSIVE_TRIE_SIZE, 0, 9522)
}
#[test]
fn massive_hash_trie_has_correct_hash_node_stats() -> TrieOpResult<()> {
create_trie_and_stats_from_entries_and_assert(0, MASSIVE_TRIE_SIZE, 9855)
}
#[test]
fn massive_mixed_trie_has_correct_hash_node_stats() -> TrieOpResult<()> {
create_trie_and_stats_from_entries_and_assert(
MASSIVE_TRIE_SIZE / 2,
MASSIVE_TRIE_SIZE / 2,
1992,
)
}
fn create_trie_and_stats_from_entries_and_assert(
n_leaf_nodes: usize,
n_hash_nodes: usize,
seed: u64,
) -> TrieOpResult<()> {
let val_entries = generate_n_random_fixed_trie_value_entries(n_leaf_nodes, seed);
let hash_entries = generate_n_random_fixed_trie_hash_entries(n_hash_nodes, seed + 1);
let mut trie = HashedPartialTrie::default();
trie.extend(val_entries)?;
trie.extend(hash_entries)?;
let stats = get_trie_stats(&trie);
assert_eq!(stats.counts.leaf, n_leaf_nodes);
assert_eq!(stats.counts.hash, n_hash_nodes);
Ok(())
}
#[test]
#[ignore]
fn depth_stats_work() {
todo!()
}
}