use crate::timeouts::{
pareto::{ParetoTimeoutEstimator, ParetoTimeoutState},
readonly::ReadonlyTimeoutEstimator,
Action, TimeoutEstimator,
};
use crate::TimeoutStateHandle;
use std::sync::Mutex;
use std::time::Duration;
use tor_netdir::params::NetParameters;
use tracing::{debug, warn};
pub(crate) struct Estimator {
inner: Mutex<Box<dyn TimeoutEstimator + Send + 'static>>,
}
impl Estimator {
#[cfg(test)]
pub(crate) fn new(est: impl TimeoutEstimator + Send + 'static) -> Self {
Self {
inner: Mutex::new(Box::new(est)),
}
}
pub(crate) fn from_storage(storage: &TimeoutStateHandle) -> Self {
let (_, est) = estimator_from_storage(storage);
Self {
inner: Mutex::new(est),
}
}
pub(crate) fn upgrade_to_owning_storage(&self, storage: &TimeoutStateHandle) {
let (readonly, est) = estimator_from_storage(storage);
if readonly {
warn!("Unable to upgrade to owned persistent storage.");
return;
}
*self.inner.lock().expect("Timeout estimator lock poisoned") = est;
}
pub(crate) fn reload_readonly_from_storage(&self, storage: &TimeoutStateHandle) {
if let Ok(Some(v)) = storage.load() {
let est = ReadonlyTimeoutEstimator::from_state(&v);
*self.inner.lock().expect("Timeout estimator lock poisoned") = Box::new(est);
} else {
debug!("Unable to reload timeout state.");
}
}
pub(crate) fn note_hop_completed(&self, hop: u8, delay: Duration, is_last: bool) {
let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
inner.note_hop_completed(hop, delay, is_last);
}
pub(crate) fn note_circ_timeout(&self, hop: u8, delay: Duration) {
let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
inner.note_circ_timeout(hop, delay);
}
pub(crate) fn timeouts(&self, action: &Action) -> (Duration, Duration) {
let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
inner.timeouts(action)
}
pub(crate) fn learning_timeouts(&self) -> bool {
let inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
inner.learning_timeouts()
}
pub(crate) fn update_params(&self, params: &NetParameters) {
let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
inner.update_params(params);
}
pub(crate) fn save_state(&self, storage: &TimeoutStateHandle) -> crate::Result<()> {
let state = {
let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
inner.build_state()
};
if let Some(state) = state {
storage.store(&state)?;
}
Ok(())
}
}
fn estimator_from_storage(
storage: &TimeoutStateHandle,
) -> (bool, Box<dyn TimeoutEstimator + Send + 'static>) {
let state = match storage.load() {
Ok(Some(v)) => v,
Ok(None) => ParetoTimeoutState::default(),
Err(e) => {
warn!("Unable to load timeout state: {}", e);
return (true, Box::new(ReadonlyTimeoutEstimator::new()));
}
};
if storage.can_store() {
(false, Box::new(ParetoTimeoutEstimator::from_state(state)))
} else {
(true, Box::new(ReadonlyTimeoutEstimator::from_state(&state)))
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use tor_persist::StateMgr;
#[test]
fn load_estimator() {
let params = NetParameters::default();
let storage = tor_persist::TestingStateMgr::new();
assert!(storage.try_lock().unwrap().held());
let handle = storage.clone().create_handle("paretorama");
let est = Estimator::from_storage(&handle);
assert!(est.learning_timeouts());
est.save_state(&handle).unwrap();
let storage2 = storage.new_manager();
assert!(!storage2.try_lock().unwrap().held());
let handle2 = storage2.clone().create_handle("paretorama");
let est2 = Estimator::from_storage(&handle2);
assert!(!est2.learning_timeouts());
est.update_params(¶ms);
est2.update_params(¶ms);
let act = Action::BuildCircuit { length: 3 };
assert_eq!(
est.timeouts(&act),
(Duration::from_secs(60), Duration::from_secs(60))
);
assert_eq!(
est2.timeouts(&act),
(Duration::from_secs(60), Duration::from_secs(60))
);
for _ in 0..500 {
est.note_hop_completed(2, Duration::from_secs(7), true);
est.note_hop_completed(2, Duration::from_secs(2), true);
est2.note_hop_completed(2, Duration::from_secs(4), true);
}
assert!(!est.learning_timeouts());
est.save_state(&handle).unwrap();
let to_1 = est.timeouts(&act);
assert_ne!(
est.timeouts(&act),
(Duration::from_secs(60), Duration::from_secs(60))
);
assert_eq!(
est2.timeouts(&act),
(Duration::from_secs(60), Duration::from_secs(60))
);
est2.reload_readonly_from_storage(&handle2);
let to_1_secs = to_1.0.as_secs_f64();
let timeouts = est2.timeouts(&act);
assert!((timeouts.0.as_secs_f64() - to_1_secs).abs() < 0.001);
assert!((timeouts.1.as_secs_f64() - to_1_secs).abs() < 0.001);
drop(est);
drop(handle);
drop(storage);
assert!(storage2.try_lock().unwrap().held());
est2.upgrade_to_owning_storage(&handle2);
let to_2 = est2.timeouts(&act);
assert!(to_2.0 > to_1.0 - Duration::from_secs(1));
assert!(to_2.0 < to_1.0 + Duration::from_secs(1));
for _ in 0..200 {
est2.note_hop_completed(2, Duration::from_secs(1), true);
}
let to_3 = est2.timeouts(&act);
assert!(to_3.0 < to_2.0);
}
}