use std::collections::HashMap;
use std::net::IpAddr;
use ipnetwork::IpNetwork;
use rand::Rng;
use stem_rs::descriptor::router_status::RouterStatusEntry;
use crate::error::{Error, Result};
pub fn is_valid_fingerprint(s: &str) -> bool {
s.len() == 40 && s.chars().all(|c| c.is_ascii_hexdigit())
}
pub fn is_valid_ip_or_network(s: &str) -> bool {
s.parse::<IpAddr>().is_ok() || s.parse::<IpNetwork>().is_ok()
}
pub fn is_valid_country_code(s: &str) -> bool {
s.len() == 2 && s.chars().all(|c| c.is_ascii_alphabetic())
}
pub trait NodeRestriction: Send + Sync {
fn r_is_ok(&self, router: &RouterStatusEntry) -> bool;
}
#[derive(Debug, Clone)]
pub struct FlagsRestriction {
pub mandatory: Vec<String>,
pub forbidden: Vec<String>,
}
impl FlagsRestriction {
pub fn new(mandatory: Vec<String>, forbidden: Vec<String>) -> Self {
Self {
mandatory,
forbidden,
}
}
}
impl NodeRestriction for FlagsRestriction {
fn r_is_ok(&self, router: &RouterStatusEntry) -> bool {
for m in &self.mandatory {
if !router.flags.contains(m) {
return false;
}
}
for f in &self.forbidden {
if router.flags.contains(f) {
return false;
}
}
true
}
}
pub struct NodeRestrictionList {
restrictions: Vec<Box<dyn NodeRestriction>>,
}
impl NodeRestrictionList {
pub fn new(restrictions: Vec<Box<dyn NodeRestriction>>) -> Self {
Self { restrictions }
}
pub fn r_is_ok(&self, router: &RouterStatusEntry) -> bool {
self.restrictions.iter().all(|r| r.r_is_ok(router))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Position {
Guard,
Middle,
Exit,
}
impl Position {
fn weight_key_suffix(&self) -> char {
match self {
Position::Guard => 'g',
Position::Middle => 'm',
Position::Exit => 'e',
}
}
}
pub struct BwWeightedGenerator {
rstr_routers: Vec<RouterStatusEntry>,
node_weights: Vec<f64>,
weight_total: f64,
exit_total: f64,
position: Position,
bw_weights: HashMap<String, i64>,
}
impl BwWeightedGenerator {
const WEIGHT_SCALE: f64 = 10000.0;
pub fn new(
sorted_routers: Vec<RouterStatusEntry>,
restrictions: NodeRestrictionList,
bw_weights: HashMap<String, i64>,
position: Position,
) -> Result<Self> {
let rstr_routers: Vec<RouterStatusEntry> = sorted_routers
.into_iter()
.filter(|r| restrictions.r_is_ok(r))
.collect();
if rstr_routers.is_empty() {
return Err(Error::NoNodesRemain);
}
let mut generator = Self {
rstr_routers,
node_weights: Vec::new(),
weight_total: 0.0,
exit_total: 0.0,
position,
bw_weights,
};
generator.rebuild_weights();
Ok(generator)
}
fn rebuild_weights(&mut self) {
self.node_weights.clear();
self.weight_total = 0.0;
for router in &self.rstr_routers {
let bw = router.measured.or(router.bandwidth).unwrap_or(0) as f64;
let weight = bw * self.flag_to_weight(router);
self.node_weights.push(weight);
self.weight_total += weight;
}
}
fn flag_to_weight(&self, router: &RouterStatusEntry) -> f64 {
let has_guard = router.flags.contains(&"Guard".to_string());
let has_exit = router.flags.contains(&"Exit".to_string());
let pos = self.position.weight_key_suffix();
let key = if has_guard && has_exit {
format!("W{}d", pos)
} else if has_exit {
format!("W{}e", pos)
} else if has_guard {
format!("W{}g", pos)
} else {
"Wmm".to_string()
};
self.bw_weights.get(&key).copied().unwrap_or(10000) as f64 / Self::WEIGHT_SCALE
}
pub fn repair_exits(&mut self) {
let old_position = self.position;
self.position = Position::Exit;
self.exit_total = 0.0;
for (i, router) in self.rstr_routers.iter().enumerate() {
if router.flags.contains(&"Exit".to_string()) {
let bw = router.measured.or(router.bandwidth).unwrap_or(0) as f64;
let weight = bw * self.flag_to_weight(router);
self.node_weights[i] = weight;
self.exit_total += weight;
}
}
self.position = old_position;
}
pub fn generate(&self) -> Result<&RouterStatusEntry> {
if self.rstr_routers.is_empty() || self.weight_total <= 0.0 {
return Err(Error::NoNodesRemain);
}
let mut rng = rand::thread_rng();
let choice_val = rng.gen_range(0.0..self.weight_total);
let mut cumulative = 0.0;
for (i, weight) in self.node_weights.iter().enumerate() {
cumulative += weight;
if cumulative > choice_val {
return Ok(&self.rstr_routers[i]);
}
}
Ok(self.rstr_routers.last().unwrap())
}
pub fn weight_total(&self) -> f64 {
self.weight_total
}
pub fn exit_total(&self) -> f64 {
self.exit_total
}
pub fn router_count(&self) -> usize {
self.rstr_routers.len()
}
pub fn routers(&self) -> &[RouterStatusEntry] {
&self.rstr_routers
}
pub fn node_weights(&self) -> &[f64] {
&self.node_weights
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_fingerprints() {
assert!(is_valid_fingerprint(
"AABBCCDD00112233445566778899AABBCCDDEEFF"
));
assert!(is_valid_fingerprint(
"aabbccdd00112233445566778899aabbccddeeff"
));
assert!(is_valid_fingerprint(
"0123456789abcdefABCDEF0123456789abcdefAB"
));
}
#[test]
fn test_invalid_fingerprints() {
assert!(!is_valid_fingerprint(""));
assert!(!is_valid_fingerprint("AABBCCDD"));
assert!(!is_valid_fingerprint(
"AABBCCDD00112233445566778899AABBCCDDEEFFGG"
));
assert!(!is_valid_fingerprint(
"GGHHIIJJ00112233445566778899AABBCCDDEEFF"
));
assert!(!is_valid_fingerprint(
"AABBCCDD00112233445566778899AABBCCDDEEF"
));
assert!(!is_valid_fingerprint(
"AABBCCDD00112233445566778899AABBCCDDEEFFF"
));
}
#[test]
fn test_valid_ip_addresses() {
assert!(is_valid_ip_or_network("192.168.1.1"));
assert!(is_valid_ip_or_network("10.0.0.0"));
assert!(is_valid_ip_or_network("255.255.255.255"));
assert!(is_valid_ip_or_network("0.0.0.0"));
assert!(is_valid_ip_or_network("::1"));
assert!(is_valid_ip_or_network("2001:db8::1"));
assert!(is_valid_ip_or_network("fe80::1"));
}
#[test]
fn test_valid_networks() {
assert!(is_valid_ip_or_network("192.168.1.0/24"));
assert!(is_valid_ip_or_network("10.0.0.0/8"));
assert!(is_valid_ip_or_network("0.0.0.0/0"));
assert!(is_valid_ip_or_network("2001:db8::/32"));
assert!(is_valid_ip_or_network("::/0"));
}
#[test]
fn test_invalid_ip_or_network() {
assert!(!is_valid_ip_or_network(""));
assert!(!is_valid_ip_or_network("not-an-ip"));
assert!(!is_valid_ip_or_network("192.168.1.256"));
assert!(!is_valid_ip_or_network("192.168.1.1/33"));
assert!(!is_valid_ip_or_network("192.168.1"));
assert!(!is_valid_ip_or_network("example.com"));
}
#[test]
fn test_valid_country_codes() {
assert!(is_valid_country_code("US"));
assert!(is_valid_country_code("us"));
assert!(is_valid_country_code("DE"));
assert!(is_valid_country_code("de"));
assert!(is_valid_country_code("GB"));
assert!(is_valid_country_code("JP"));
}
#[test]
fn test_invalid_country_codes() {
assert!(!is_valid_country_code(""));
assert!(!is_valid_country_code("U"));
assert!(!is_valid_country_code("USA"));
assert!(!is_valid_country_code("U1"));
assert!(!is_valid_country_code("12"));
assert!(!is_valid_country_code("U-"));
}
#[test]
fn test_flags_restriction() {
use chrono::Utc;
use stem_rs::descriptor::router_status::RouterStatusEntryType;
let mut router = RouterStatusEntry::new(
RouterStatusEntryType::V3,
"test".to_string(),
"A".repeat(40),
Utc::now(),
"192.0.2.1".parse().unwrap(),
9001,
);
router.flags = vec![
"Fast".to_string(),
"Stable".to_string(),
"Valid".to_string(),
];
let restriction = FlagsRestriction::new(
vec!["Fast".to_string(), "Stable".to_string()],
vec!["Authority".to_string()],
);
assert!(restriction.r_is_ok(&router));
router.flags.push("Authority".to_string());
assert!(!restriction.r_is_ok(&router));
router.flags = vec!["Fast".to_string()];
assert!(!restriction.r_is_ok(&router));
}
#[test]
fn test_node_restriction_list() {
use chrono::Utc;
use stem_rs::descriptor::router_status::RouterStatusEntryType;
let mut router = RouterStatusEntry::new(
RouterStatusEntryType::V3,
"test".to_string(),
"A".repeat(40),
Utc::now(),
"192.0.2.1".parse().unwrap(),
9001,
);
router.flags = vec![
"Fast".to_string(),
"Stable".to_string(),
"Valid".to_string(),
];
let restriction1 = FlagsRestriction::new(vec!["Fast".to_string()], vec![]);
let restriction2 = FlagsRestriction::new(vec!["Stable".to_string()], vec![]);
let list = NodeRestrictionList::new(vec![Box::new(restriction1), Box::new(restriction2)]);
assert!(list.r_is_ok(&router));
router.flags = vec!["Fast".to_string()];
assert!(!list.r_is_ok(&router));
}
}