use std::collections::VecDeque;
use nakamoto_common::bitcoin::blockdata::constants::WITNESS_SCALE_FACTOR;
use nakamoto_common::bitcoin::{Block, OutPoint, Transaction, TxOut};
use nakamoto_common::collections::HashMap;
use nakamoto_common::nonempty::NonEmpty;
use super::Height;
pub const MAX_UTXO_SNAPSHOTS: usize = 12;
pub type FeeRate = u64;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FeeEstimate {
pub low: FeeRate,
pub median: FeeRate,
pub high: FeeRate,
}
impl FeeEstimate {
pub fn from(mut fees: Vec<FeeRate>) -> Option<Self> {
fees.sort_unstable();
NonEmpty::from_vec(fees).map(|fees| {
let count = fees.len();
let median = if count % 2 == 1 {
fees[count / 2]
} else {
let left = fees[count / 2 - 1] as f64;
let right = fees[count / 2] as f64;
((left + right) / 2.).round() as FeeRate
};
Self {
low: *fees.first(),
median,
high: *fees.last(),
}
})
}
}
type UtxoSet = HashMap<OutPoint, TxOut>;
#[derive(Debug, Default)]
pub struct FeeEstimator {
utxos: UtxoSet,
height: Height,
snapshots: VecDeque<(Height, UtxoSet)>,
}
impl FeeEstimator {
pub fn process(&mut self, block: Block, height: Height) -> Option<FeeEstimate> {
let mut fees = Vec::new();
let snapshot = self.utxos.clone();
if height <= self.height {
return None;
}
for tx in &block.txdata {
if let Some(rate) = self.apply(tx) {
fees.push(rate);
}
}
self.snapshots.push_back((self.height, snapshot));
if self.snapshots.len() > MAX_UTXO_SNAPSHOTS {
self.snapshots.pop_front();
}
self.height = height;
FeeEstimate::from(fees)
}
pub fn rollback(&mut self, height: Height) {
self.snapshots.retain(|(h, _)| h <= &height);
if let Some((h, snapshot)) = self.snapshots.pop_back() {
assert!(h <= height);
self.utxos = snapshot;
self.height = h;
}
}
fn apply(&mut self, tx: &Transaction) -> Option<FeeRate> {
let txid = tx.txid();
let mut received = 0;
let mut sent = 0;
for (vout, output) in tx.output.iter().enumerate() {
let outpoint = OutPoint {
txid,
vout: vout as u32,
};
self.utxos.insert(outpoint, output.clone());
sent += output.value;
}
if tx.is_coin_base() {
return None;
}
for input in tx.input.iter() {
if let Some(out) = self.utxos.remove(&input.previous_output) {
received += out.value;
} else {
return None;
}
}
assert!(received >= sent, "you can't spend what you don't have",);
let fee = received - sent;
let weight = tx.weight();
let rate = fee as f64 / (weight as f64 / WITNESS_SCALE_FACTOR as f64);
Some(rate.round() as FeeRate)
}
}
#[cfg(test)]
mod tests {
use super::*;
use nakamoto_test::assert_matches;
use nakamoto_test::block::gen;
#[test]
fn test_rollback() {
let mut fe = FeeEstimator::default();
let mut rng = fastrand::Rng::new();
let genesis = gen::genesis(&mut rng);
let blocks = gen::blockchain(genesis, 21, &mut rng);
let mut estimates = HashMap::with_hasher(rng.into());
for (height, block) in blocks.iter().cloned().enumerate().skip(1) {
let estimate = fe.process(block, height as Height);
estimates.insert(height, estimate);
}
assert_eq!(fe.snapshots.len(), MAX_UTXO_SNAPSHOTS as usize);
assert_eq!(fe.height, 21);
assert_matches!(fe.snapshots.back(), Some((20, _)));
fe.rollback(18);
assert_eq!(fe.snapshots.len(), 9);
assert_eq!(fe.height, 18);
assert_matches!(fe.snapshots.back(), Some((17, _)));
assert_eq!(
fe.process(blocks[19].clone(), 19).as_ref().unwrap(),
estimates[&19].as_ref().unwrap()
);
assert_eq!(fe.snapshots.len(), 10);
assert_eq!(fe.height, 19);
assert_matches!(fe.snapshots.back(), Some((18, _)));
}
#[test]
fn test_rollback_missing_height() {
let mut fe = FeeEstimator::default();
let mut rng = fastrand::Rng::new();
let genesis = gen::genesis(&mut rng);
let blocks = gen::blockchain(genesis, 14, &mut rng);
fe.process(blocks[8].clone(), 8);
fe.process(blocks[9].clone(), 9);
fe.process(blocks[13].clone(), 13);
fe.process(blocks[14].clone(), 14);
assert_eq!(fe.snapshots.len(), 4);
fe.rollback(10);
assert_eq!(fe.snapshots.len(), 2);
assert_eq!(fe.height, 9);
assert_matches!(fe.snapshots.back(), Some((8, _)));
fe.rollback(8);
assert_eq!(fe.snapshots.len(), 1);
assert_eq!(fe.height, 8);
assert_matches!(fe.snapshots.back(), Some((0, _)));
fe.rollback(4);
assert_eq!(fe.snapshots.len(), 0);
assert_eq!(fe.height, 0);
}
}