use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::time::Duration;
use tor_netdir::params::NetParameters;
use super::Action;
use tor_persist::JsonValue;
const TIME_HISTORY_LEN: usize = 1000;
const SUCCESS_HISTORY_DEFAULT_LEN: usize = 20;
const BUCKET_WIDTH_MSEC: u32 = 10;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
#[serde(transparent)]
struct MsecDuration(u32);
impl MsecDuration {
fn new_saturating(d: &Duration) -> Self {
let msec = std::cmp::min(d.as_millis(), u128::from(u32::MAX)) as u32;
MsecDuration(msec)
}
}
#[allow(clippy::checked_conversions)]
mod assertion {
const _: () = assert!(super::TIME_HISTORY_LEN <= u16::MAX as usize);
}
#[derive(Debug, Clone)]
struct History {
time_history: BoundedDeque<MsecDuration>,
time_histogram: BTreeMap<MsecDuration, u16>,
success_history: BoundedDeque<bool>,
}
impl History {
fn new_empty() -> Self {
History {
time_history: BoundedDeque::new(TIME_HISTORY_LEN),
time_histogram: BTreeMap::new(),
success_history: BoundedDeque::new(SUCCESS_HISTORY_DEFAULT_LEN),
}
}
fn clear(&mut self) {
self.time_history.clear();
self.time_histogram.clear();
self.success_history.clear();
}
fn set_success_history_len(&mut self, n: usize) {
self.success_history.set_max_len(n);
}
#[cfg(test)]
fn set_time_history_len(&mut self, n: usize) {
self.time_history.set_max_len(n);
}
fn from_sparse_histogram<I>(iter: I) -> Self
where
I: Iterator<Item = (MsecDuration, u16)>,
{
use rand::seq::{IteratorRandom, SliceRandom};
let mut rng = rand::rng();
let mut observations = iter
.take(TIME_HISTORY_LEN) .flat_map(|(dur, n)| std::iter::repeat_n(dur, n as usize))
.choose_multiple(&mut rng, TIME_HISTORY_LEN);
observations.shuffle(&mut rng);
let mut result = History::new_empty();
for obs in observations {
result.add_time(obs);
}
result
}
fn sparse_histogram(&self) -> impl Iterator<Item = (MsecDuration, u16)> + '_ {
self.time_histogram.iter().map(|(d, n)| (*d, *n))
}
fn bucket_center(time: MsecDuration) -> MsecDuration {
let idx = time.0 / BUCKET_WIDTH_MSEC;
let msec = (idx * BUCKET_WIDTH_MSEC) + (BUCKET_WIDTH_MSEC) / 2;
MsecDuration(msec)
}
fn inc_bucket(&mut self, time: MsecDuration) {
let center = History::bucket_center(time);
*self.time_histogram.entry(center).or_insert(0) += 1;
}
fn dec_bucket(&mut self, time: MsecDuration) {
use std::collections::btree_map::Entry;
let center = History::bucket_center(time);
match self.time_histogram.entry(center) {
Entry::Vacant(_) => {
}
Entry::Occupied(e) if e.get() <= &1 => {
e.remove();
}
Entry::Occupied(mut e) => {
*e.get_mut() -= 1;
}
}
}
fn add_time(&mut self, time: MsecDuration) {
match self.time_history.push_back(time) {
None => {}
Some(removed_time) => {
self.dec_bucket(removed_time);
}
}
self.inc_bucket(time);
}
fn n_times(&self) -> usize {
self.time_history.len()
}
fn add_success(&mut self, succeeded: bool) {
self.success_history.push_back(succeeded);
}
fn n_recent_timeouts(&self) -> usize {
self.success_history.iter().filter(|x| !**x).count()
}
fn n_most_frequent_bins(&self, n: usize) -> Vec<(MsecDuration, u16)> {
use itertools::Itertools;
use std::cmp::Reverse;
self.sparse_histogram()
.map(|(center, count)| (Reverse(count), center))
.k_smallest(n)
.map(|(Reverse(count), center)| (center, count))
.collect()
}
fn estimate_xm(&self, n_modes: usize) -> Option<u32> {
let bins = self.n_most_frequent_bins(n_modes);
let n_observations: u16 = bins.iter().map(|(_, n)| n).sum();
let total_observations: u64 = bins
.iter()
.map(|(d, n)| u64::from(d.0 * u32::from(*n)))
.sum();
if n_observations == 0 {
None
} else {
Some((total_observations / u64::from(n_observations)) as u32)
}
}
fn pareto_estimate(&self, n_modes: usize) -> Option<ParetoDist> {
let xm = self.estimate_xm(n_modes)?;
let n = self.time_history.len();
let sum_of_log_observations: f64 = self
.time_history
.iter()
.map(|m| f64::from(std::cmp::max(m.0, xm)).ln())
.sum();
let sum_of_log_xm = (n as f64) * f64::from(xm).ln();
let inv_alpha = (sum_of_log_observations - sum_of_log_xm) / (n as f64);
Some(ParetoDist {
x_m: f64::from(xm),
inv_alpha,
})
}
}
#[derive(Debug)]
struct ParetoDist {
x_m: f64,
inv_alpha: f64,
}
impl ParetoDist {
fn quantile(&self, q: f64) -> f64 {
let q = q.clamp(0.0, 1.0);
self.x_m / ((1.0 - q).powf(self.inv_alpha))
}
}
#[derive(Clone, Debug)]
pub(crate) struct Params {
use_estimates: bool,
min_observations: u16,
significant_hop: u8,
timeout_quantile: f64,
abandon_quantile: f64,
default_thresholds: (Duration, Duration),
n_modes_for_xm: usize,
success_history_len: usize,
reset_after_timeouts: usize,
min_timeout: Duration,
}
impl Default for Params {
fn default() -> Self {
Params {
use_estimates: true,
min_observations: 100,
significant_hop: 2,
timeout_quantile: 0.80,
abandon_quantile: 0.99,
default_thresholds: (Duration::from_secs(60), Duration::from_secs(60)),
n_modes_for_xm: 10,
success_history_len: SUCCESS_HISTORY_DEFAULT_LEN,
reset_after_timeouts: 18,
min_timeout: Duration::from_millis(10),
}
}
}
impl From<&NetParameters> for Params {
fn from(p: &NetParameters) -> Params {
let timeout = p
.cbt_initial_timeout
.try_into()
.unwrap_or_else(|_| Duration::from_secs(60));
let learning_disabled: bool = p.cbt_learning_disabled.into();
Params {
use_estimates: !learning_disabled,
min_observations: p.cbt_min_circs_for_estimate.get() as u16,
significant_hop: 2,
timeout_quantile: p.cbt_timeout_quantile.as_fraction(),
abandon_quantile: p.cbt_abandon_quantile.as_fraction(),
default_thresholds: (timeout, timeout),
n_modes_for_xm: p.cbt_num_xm_modes.get() as usize,
success_history_len: p.cbt_success_count.get() as usize,
reset_after_timeouts: p.cbt_max_timeouts.get() as usize,
min_timeout: p
.cbt_min_timeout
.try_into()
.unwrap_or_else(|_| Duration::from_millis(10)),
}
}
}
pub(crate) struct ParetoTimeoutEstimator {
history: History,
timeouts: Option<(Duration, Duration)>,
fallback_timeouts: (Duration, Duration),
p: Params,
}
impl Default for ParetoTimeoutEstimator {
fn default() -> Self {
Self::from_history(History::new_empty())
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(default)]
pub(crate) struct ParetoTimeoutState {
#[allow(dead_code)]
version: usize,
histogram: Vec<(MsecDuration, u16)>,
current_timeout: Option<MsecDuration>,
#[serde(flatten)]
unknown_fields: HashMap<String, JsonValue>,
}
impl ParetoTimeoutState {
pub(crate) fn latest_estimate(&self) -> Option<Duration> {
self.current_timeout
.map(|m| Duration::from_millis(m.0.into()))
}
}
impl ParetoTimeoutEstimator {
fn from_history(history: History) -> Self {
let p = Params::default();
ParetoTimeoutEstimator {
history,
timeouts: None,
fallback_timeouts: p.default_thresholds,
p,
}
}
pub(crate) fn from_state(state: ParetoTimeoutState) -> Self {
let history = History::from_sparse_histogram(state.histogram.into_iter());
Self::from_history(history)
}
fn base_timeouts(&mut self) -> (Duration, Duration) {
if let Some(x) = self.timeouts {
return x;
}
if self.history.n_times() < self.p.min_observations as usize {
return self.fallback_timeouts;
}
let dist = match self.history.pareto_estimate(self.p.n_modes_for_xm) {
Some(dist) => dist,
None => {
return self.fallback_timeouts;
}
};
let timeout_threshold = dist.quantile(self.p.timeout_quantile);
let abandon_threshold = dist
.quantile(self.p.abandon_quantile)
.max(timeout_threshold);
let timeouts = (
Duration::from_secs_f64(timeout_threshold / 1000.0).max(self.p.min_timeout),
Duration::from_secs_f64(abandon_threshold / 1000.0).max(self.p.min_timeout),
);
self.timeouts = Some(timeouts);
timeouts
}
}
impl super::TimeoutEstimator for ParetoTimeoutEstimator {
fn update_params(&mut self, p: &NetParameters) {
let parameters = p.into();
self.p = parameters;
let new_success_len = self.p.success_history_len;
self.history.set_success_history_len(new_success_len);
}
fn note_hop_completed(&mut self, hop: u8, delay: Duration, is_last: bool) {
if hop == self.p.significant_hop {
let time = MsecDuration::new_saturating(&delay);
self.history.add_time(time);
self.timeouts.take();
}
if is_last {
tracing::trace!(%hop, ?delay, "Circuit creation success");
self.history.add_success(true);
}
}
fn note_circ_timeout(&mut self, hop: u8, delay: Duration) {
let have_seen_recent_activity =
if let Some(last_traffic) = tor_proto::time_since_last_incoming_traffic() {
last_traffic < delay
} else {
true
};
tracing::trace!(%hop, ?delay, %have_seen_recent_activity, "Circuit timeout");
if hop > 0 && have_seen_recent_activity {
self.history.add_success(false);
if self.history.n_recent_timeouts() > self.p.reset_after_timeouts {
tracing::debug!("Multiple connections failed, resetting timeouts...");
let base_timeouts = self.base_timeouts();
self.history.clear();
self.timeouts.take();
if base_timeouts.0 >= self.fallback_timeouts.0 {
const MAX_FALLBACK_TIMEOUT: Duration = Duration::from_secs(7200);
self.fallback_timeouts.0 =
(self.fallback_timeouts.0 * 2).min(MAX_FALLBACK_TIMEOUT);
self.fallback_timeouts.1 =
(self.fallback_timeouts.1 * 2).min(MAX_FALLBACK_TIMEOUT);
}
}
}
}
fn timeouts(&mut self, action: &Action) -> (Duration, Duration) {
let (base_t, base_a) = if self.p.use_estimates {
self.base_timeouts()
} else {
return self.p.default_thresholds;
};
let reference_action = Action::BuildCircuit {
length: self.p.significant_hop as usize + 1,
};
debug_assert!(reference_action.timeout_scale() > 0);
let multiplier =
(action.timeout_scale() as f64) / (reference_action.timeout_scale() as f64);
use super::mul_duration_f64_saturating as mul;
(mul(base_t, multiplier), mul(base_a, multiplier))
}
fn learning_timeouts(&self) -> bool {
self.p.use_estimates && self.history.n_times() < usize::from(self.p.min_observations)
}
fn build_state(&mut self) -> Option<ParetoTimeoutState> {
let cur_timeout = MsecDuration::new_saturating(&self.base_timeouts().0);
Some(ParetoTimeoutState {
version: 1,
histogram: self.history.sparse_histogram().collect(),
current_timeout: Some(cur_timeout),
unknown_fields: Default::default(),
})
}
}
#[derive(Clone, Debug)]
struct BoundedDeque<T> {
inner: VecDeque<T>,
limit: usize,
}
impl<T> BoundedDeque<T> {
fn new(limit: usize) -> Self {
Self {
inner: VecDeque::with_capacity(limit),
limit,
}
}
fn clear(&mut self) {
self.inner.clear();
}
fn len(&self) -> usize {
self.inner.len()
}
fn push_back(&mut self, item: T) -> Option<T> {
if self.limit == 0 {
return None;
}
let removed = if self.len() == self.limit {
self.inner.pop_front()
} else {
None
};
self.inner.push_back(item);
removed
}
fn iter(&self) -> impl Iterator<Item = &T> {
self.inner.iter()
}
fn set_max_len(&mut self, new_limit: usize) {
if new_limit < self.limit {
let n_to_drain = self.inner.len().saturating_sub(new_limit);
self.inner.drain(0..n_to_drain);
self.inner.shrink_to_fit();
}
self.limit = new_limit;
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use crate::timeouts::TimeoutEstimator;
use tor_basic_utils::RngExt as _;
use tor_basic_utils::test_rng::testing_rng;
fn b3() -> Action {
Action::BuildCircuit { length: 3 }
}
impl From<u32> for MsecDuration {
fn from(v: u32) -> Self {
Self(v)
}
}
#[test]
fn ms_partial_cmp() {
#![allow(clippy::eq_op)]
let myriad: MsecDuration = 10_000.into();
let lakh: MsecDuration = 100_000.into();
let crore: MsecDuration = 10_000_000.into();
assert!(myriad < lakh);
assert!(myriad == myriad);
assert!(crore > lakh);
assert!(crore >= crore);
assert!(crore <= crore);
}
#[test]
fn history_lowlev() {
assert_eq!(History::bucket_center(1.into()), 5.into());
assert_eq!(History::bucket_center(903.into()), 905.into());
assert_eq!(History::bucket_center(0.into()), 5.into());
assert_eq!(History::bucket_center(u32::MAX.into()), 4294967295.into());
let mut h = History::new_empty();
h.inc_bucket(7.into());
h.inc_bucket(8.into());
h.inc_bucket(9.into());
h.inc_bucket(10.into());
h.inc_bucket(11.into());
h.inc_bucket(12.into());
h.inc_bucket(13.into());
h.inc_bucket(299.into());
assert_eq!(h.time_histogram.get(&5.into()), Some(&3));
assert_eq!(h.time_histogram.get(&15.into()), Some(&4));
assert_eq!(h.time_histogram.get(&25.into()), None);
assert_eq!(h.time_histogram.get(&295.into()), Some(&1));
h.dec_bucket(299.into());
h.dec_bucket(24.into());
h.dec_bucket(12.into());
assert_eq!(h.time_histogram.get(&15.into()), Some(&3));
assert_eq!(h.time_histogram.get(&25.into()), None);
assert_eq!(h.time_histogram.get(&295.into()), None);
h.add_success(true);
h.add_success(false);
assert_eq!(h.success_history.len(), 2);
h.clear();
assert_eq!(h.time_histogram.len(), 0);
assert_eq!(h.time_history.len(), 0);
assert_eq!(h.success_history.len(), 0);
}
#[test]
fn time_observation_management() {
let mut h = History::new_empty();
h.set_time_history_len(8);
h.add_time(300.into());
h.add_time(500.into());
h.add_time(542.into());
h.add_time(305.into());
h.add_time(543.into());
h.add_time(307.into());
assert_eq!(h.n_times(), 6);
let v = h.n_most_frequent_bins(10);
assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2), (505.into(), 1)]);
let v = h.n_most_frequent_bins(2);
assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
let v: Vec<_> = h.sparse_histogram().collect();
assert_eq!(&v[..], [(305.into(), 3), (505.into(), 1), (545.into(), 2)]);
h.add_time(212.into());
h.add_time(203.into());
h.add_time(617.into());
h.add_time(413.into());
assert_eq!(h.n_times(), 8);
let v: Vec<_> = h.sparse_histogram().collect();
assert_eq!(
&v[..],
[
(205.into(), 1),
(215.into(), 1),
(305.into(), 2),
(415.into(), 1),
(545.into(), 2),
(615.into(), 1)
]
);
let h2 = History::from_sparse_histogram(v.clone().into_iter());
let v2: Vec<_> = h2.sparse_histogram().collect();
assert_eq!(v, v2);
}
#[test]
fn success_observation_mechanism() {
let mut h = History::new_empty();
h.set_success_history_len(20);
assert_eq!(h.n_recent_timeouts(), 0);
h.add_success(true);
assert_eq!(h.n_recent_timeouts(), 0);
h.add_success(false);
assert_eq!(h.n_recent_timeouts(), 1);
for _ in 0..200 {
h.add_success(false);
}
assert_eq!(h.n_recent_timeouts(), 20);
h.add_success(true);
h.add_success(true);
h.add_success(true);
assert_eq!(h.n_recent_timeouts(), 20 - 3);
h.set_success_history_len(10);
assert_eq!(h.n_recent_timeouts(), 10 - 3);
}
#[test]
fn xm_calculation() {
let mut h = History::new_empty();
assert_eq!(h.estimate_xm(2), None);
for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
h.add_time(MsecDuration(*n));
}
let v = h.n_most_frequent_bins(2);
assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
let est = (305 * 3 + 545 * 2) / 5;
assert_eq!(h.estimate_xm(2), Some(est));
assert_eq!(est, 401);
}
#[test]
fn pareto_estimate() {
let mut h = History::new_empty();
assert!(h.pareto_estimate(2).is_none());
for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
h.add_time(MsecDuration(*n));
}
let expected_log_sum: f64 = [401, 500, 542, 401, 543, 401, 401, 401, 617, 413]
.iter()
.map(|x| f64::from(*x).ln())
.sum();
let expected_log_xm: f64 = (401_f64).ln() * 10.0;
let expected_alpha = 10.0 / (expected_log_sum - expected_log_xm);
let expected_inv_alpha = 1.0 / expected_alpha;
let p = h.pareto_estimate(2).unwrap();
assert!((401.0 - p.x_m).abs() < 1.0e-9);
assert!((expected_inv_alpha - p.inv_alpha).abs() < 1.0e-9);
let q60 = p.quantile(0.60);
let q99 = p.quantile(0.99);
assert!((q60 - 451.127) < 0.001);
assert!((q99 - 724.841) < 0.001);
}
#[test]
fn pareto_estimate_timeout() {
let mut est = ParetoTimeoutEstimator::default();
assert_eq!(
est.timeouts(&b3()),
(Duration::from_secs(60), Duration::from_secs(60))
);
est.p.min_observations = 0;
est.p.n_modes_for_xm = 2;
assert_eq!(
est.timeouts(&b3()),
(Duration::from_secs(60), Duration::from_secs(60))
);
for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
let d = Duration::from_millis(*msec);
est.note_hop_completed(2, d, true);
}
let t = est.timeouts(&b3());
assert_eq!(t.0.as_micros(), 493_169);
assert_eq!(t.1.as_micros(), 724_841);
let t2 = est.timeouts(&b3());
assert_eq!(t2, t);
let t2 = est.timeouts(&Action::BuildCircuit { length: 4 });
assert_eq!(t2.0, t.0.mul_f64(10.0 / 6.0));
assert_eq!(t2.1, t.1.mul_f64(10.0 / 6.0));
}
#[test]
fn pareto_estimate_clear() {
let mut est = ParetoTimeoutEstimator::default();
let params = NetParameters::from_map(&"cbtmincircs=1 cbtnummodes=2".parse().unwrap());
est.update_params(¶ms);
assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
assert!(est.learning_timeouts());
for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
let d = Duration::from_millis(*msec);
est.note_hop_completed(2, d, true);
}
assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
assert!(!est.learning_timeouts());
assert_eq!(est.history.n_recent_timeouts(), 0);
for _ in 0..18 {
est.note_circ_timeout(2, Duration::from_secs(2000));
}
assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
est.note_circ_timeout(2, Duration::from_secs(2000));
assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
for _ in 0..20 {
est.note_circ_timeout(2, Duration::from_secs(2000));
}
assert_eq!(est.timeouts(&b3()).0.as_micros(), 120_000_000);
}
#[test]
fn default_params() {
let p1 = Params::default();
let p2 = Params::from(&tor_netdir::params::NetParameters::default());
assert_eq!(format!("{:?}", p1), format!("{:?}", p2));
}
#[test]
fn state_conversion() {
let mut est = ParetoTimeoutEstimator::default();
let mut rng = testing_rng();
for _ in 0..1000 {
let d = Duration::from_millis(rng.gen_range_checked(10..3_000).unwrap());
est.note_hop_completed(2, d, true);
}
let state = est.build_state().unwrap();
assert_eq!(state.version, 1);
assert!(state.current_timeout.is_some());
let mut est2 = ParetoTimeoutEstimator::from_state(state);
let act = Action::BuildCircuit { length: 3 };
let ms1 = est.timeouts(&act).0.as_millis() as i32;
let ms2 = est2.timeouts(&act).0.as_millis() as i32;
assert!((ms1 - ms2).abs() < 50);
}
#[test]
fn validate_iterator_choose_multiple() {
use rand::seq::IteratorRandom as _;
let mut rng = testing_rng();
let mut ten_elements = (1..=10).choose_multiple(&mut rng, 100);
ten_elements.sort();
assert_eq!(ten_elements.len(), 10);
assert_eq!(ten_elements, (1..=10).collect::<Vec<_>>());
}
}