use std::collections::HashMap;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use stem_rs::controller::{CircuitId, Controller};
use stem_rs::descriptor::router_status::RouterStatusEntry;
use stem_rs::events::ParsedEvent;
use stem_rs::version::Version;
use stem_rs::EventType;
use crate::bandguards::BandwidthStats;
use crate::cbtverify::TimeoutStats;
use crate::config::{Config, LogLevel};
use crate::error::{Error, Result};
use crate::logger::plog;
use crate::logguard::LogGuard;
use crate::node_selection::{BwWeightedGenerator, FlagsRestriction, NodeRestrictionList, Position};
use crate::pathverify::PathVerify;
use crate::vanguards::{ExcludeNodes, VanguardState};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
#[allow(dead_code)]
const MIN_TOR_VERSION_FOR_BW: &str = "0.3.4.10";
#[allow(dead_code)]
const MIN_TOR_VERSION_FOR_VANGUARDS: &str = "0.3.3.0";
static CLOSE_CIRCUITS: AtomicBool = AtomicBool::new(true);
pub fn set_close_circuits(value: bool) {
CLOSE_CIRCUITS.store(value, Ordering::SeqCst);
}
pub fn get_close_circuits() -> bool {
CLOSE_CIRCUITS.load(Ordering::SeqCst)
}
pub async fn authenticate_any(controller: &mut Controller, password: Option<&str>) -> Result<()> {
let result = controller.authenticate(password).await;
match result {
Ok(()) => {
let version = controller.get_version().await?;
plog(
LogLevel::Notice,
&format!(
"Vanguards {} connected to Tor {} using stem-rs",
VERSION, version
),
);
Ok(())
}
Err(stem_rs::Error::Authentication(stem_rs::AuthError::MissingPassword)) => {
let passwd = prompt_password()?;
controller.authenticate(Some(&passwd)).await?;
let version = controller.get_version().await?;
plog(
LogLevel::Notice,
&format!(
"Vanguards {} connected to Tor {} using stem-rs",
VERSION, version
),
);
Ok(())
}
Err(e) => Err(Error::Control(e)),
}
}
fn prompt_password() -> Result<String> {
eprint!("Controller password: ");
let mut password = String::new();
std::io::stdin()
.read_line(&mut password)
.map_err(Error::Io)?;
Ok(password.trim().to_string())
}
pub fn get_consensus_weights(consensus_filename: &Path) -> Result<HashMap<String, i64>> {
let file = std::fs::File::open(consensus_filename).map_err(|e| {
Error::Consensus(format!(
"cannot read {}: {}",
consensus_filename.display(),
e
))
})?;
let reader = BufReader::new(file);
let mut weights = HashMap::new();
for line in reader.lines() {
let line = line.map_err(|e| Error::Consensus(format!("read error: {}", e)))?;
if line.starts_with("bandwidth-weights ") {
for part in line.split_whitespace().skip(1) {
if let Some((key, value)) = part.split_once('=') {
if let Ok(v) = value.parse::<i64>() {
weights.insert(key.to_string(), v);
}
}
}
break;
}
}
if weights.is_empty() {
return Err(Error::Consensus(
"no bandwidth-weights found in consensus".to_string(),
));
}
Ok(weights)
}
pub async fn try_close_circuit(
controller: &mut Controller,
circ_id: &str,
logguard: Option<&mut LogGuard>,
) {
if let Some(lg) = logguard {
lg.dump_log_queue(circ_id, "Pre");
}
if get_close_circuits() {
let circuit_id = CircuitId::new(circ_id);
match controller.close_circuit(&circuit_id).await {
Ok(()) => {
plog(
LogLevel::Info,
&format!("We force-closed circuit {}", circ_id),
);
}
Err(e) => {
plog(
LogLevel::Info,
&format!("Failed to close circuit {}: {}", circ_id, e),
);
}
}
}
}
pub async fn configure_tor(
controller: &mut Controller,
state: &VanguardState,
config: &Config,
) -> Result<()> {
let vg_config = &config.vanguards;
if vg_config.num_layer1_guards > 0 {
controller
.set_conf("NumEntryGuards", &vg_config.num_layer1_guards.to_string())
.await?;
controller
.set_conf(
"NumDirectoryGuards",
&vg_config.num_layer1_guards.to_string(),
)
.await?;
}
if vg_config.layer1_lifetime_days > 0 {
controller
.set_conf(
"GuardLifetime",
&format!("{} days", vg_config.layer1_lifetime_days),
)
.await?;
}
let layer2_guardset = state.layer2_guardset();
controller
.set_conf("HSLayer2Nodes", &layer2_guardset)
.await
.inspect_err(|_e| {
plog(
LogLevel::Error,
"Vanguards requires Tor 0.3.3.x (and ideally 0.3.4.x or newer).",
);
})?;
if vg_config.num_layer3_guards > 0 {
let layer3_guardset = state.layer3_guardset();
controller
.set_conf("HSLayer3Nodes", &layer3_guardset)
.await?;
}
plog(
LogLevel::Info,
&format!("Layer2 guards: {}", layer2_guardset),
);
if vg_config.num_layer3_guards > 0 {
plog(
LogLevel::Info,
&format!("Layer3 guards: {}", state.layer3_guardset()),
);
}
Ok(())
}
pub async fn new_consensus_event(
controller: &mut Controller,
state: &mut VanguardState,
config: &Config,
) -> Result<()> {
let routers = get_network_statuses(controller).await?;
let exclude_nodes_conf = controller
.get_conf("ExcludeNodes")
.await
.ok()
.and_then(|v| v.first().cloned())
.unwrap_or_default();
let geoip_exclude = controller
.get_conf("GeoIPExcludeUnknown")
.await
.ok()
.and_then(|v| v.first().cloned());
let exclude = ExcludeNodes::parse(&exclude_nodes_conf, geoip_exclude.as_deref());
let data_dir = controller
.get_conf("DataDirectory")
.await?
.first()
.cloned()
.ok_or_else(|| {
Error::Config("You must set a DataDirectory location option in your torrc.".to_string())
})?;
let consensus_file = Path::new(&data_dir).join("cached-microdesc-consensus");
let weights = get_consensus_weights(&consensus_file)?;
consensus_update(state, &routers, &weights, &exclude, config)?;
if config.enable_vanguards {
configure_tor(controller, state, config).await?;
}
let state_path = Path::new(&state.state_file);
state.write_to_file(state_path).map_err(|e| {
plog(
LogLevel::Error,
&format!("Cannot write state to {}: {}", state.state_file, e),
);
e
})?;
Ok(())
}
fn consensus_update(
state: &mut VanguardState,
routers: &[RouterStatusEntry],
weights: &HashMap<String, i64>,
exclude: &ExcludeNodes,
config: &Config,
) -> Result<()> {
let mut sorted_routers: Vec<RouterStatusEntry> = routers.to_vec();
sorted_routers.sort_by(|a, b| {
let bw_a = a.measured.or(a.bandwidth).unwrap_or(0);
let bw_b = b.measured.or(b.bandwidth).unwrap_or(0);
bw_b.cmp(&bw_a)
});
let router_map: HashMap<String, &RouterStatusEntry> = sorted_routers
.iter()
.map(|r| (r.fingerprint.clone(), r))
.collect();
let consensus_fps: std::collections::HashSet<String> = sorted_routers
.iter()
.map(|r| r.fingerprint.clone())
.collect();
let restriction = FlagsRestriction::new(
vec![
"Fast".to_string(),
"Stable".to_string(),
"Valid".to_string(),
],
vec!["Authority".to_string()],
);
let restrictions = NodeRestrictionList::new(vec![Box::new(restriction)]);
let generator = BwWeightedGenerator::new(
sorted_routers.clone(),
restrictions,
weights.clone(),
Position::Middle,
)?;
if state.enable_vanguards {
VanguardState::remove_down_from_layer(&mut state.layer2, &consensus_fps);
VanguardState::remove_down_from_layer(&mut state.layer3, &consensus_fps);
VanguardState::remove_expired_from_layer(&mut state.layer2);
VanguardState::remove_expired_from_layer(&mut state.layer3);
VanguardState::remove_excluded_from_layer(&mut state.layer2, &router_map, exclude);
VanguardState::remove_excluded_from_layer(&mut state.layer3, &router_map, exclude);
state.replenish_layers(&generator, exclude, &config.vanguards)?;
}
let rend_restriction = FlagsRestriction::new(
vec!["Fast".to_string(), "Valid".to_string()],
vec!["Authority".to_string()],
);
let rend_restrictions = NodeRestrictionList::new(vec![Box::new(rend_restriction)]);
let mut rend_generator = BwWeightedGenerator::new(
sorted_routers,
rend_restrictions,
weights.clone(),
Position::Middle,
)?;
rend_generator.repair_exits();
state
.rendguard
.xfer_use_counts(&rend_generator, &config.rendguard);
Ok(())
}
async fn get_network_statuses(controller: &mut Controller) -> Result<Vec<RouterStatusEntry>> {
let response = controller
.get_info("ns/all")
.await
.map_err(|e| Error::DescriptorUnavailable(format!("Cannot get network statuses: {}", e)))?;
parse_network_statuses(&response)
}
fn parse_network_statuses(response: &str) -> Result<Vec<RouterStatusEntry>> {
use chrono::Utc;
use stem_rs::descriptor::router_status::RouterStatusEntryType;
let mut routers = Vec::new();
let mut current_router: Option<RouterStatusEntry> = None;
for line in response.lines() {
if line.starts_with("r ") {
if let Some(router) = current_router.take() {
routers.push(router);
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 8 {
let nickname = parts[1].to_string();
let fingerprint = decode_base64_fingerprint(parts[2]);
let address = parts[5]
.parse()
.unwrap_or_else(|_| "0.0.0.0".parse().unwrap());
let or_port = parts[6].parse().unwrap_or(9001);
current_router = Some(RouterStatusEntry::new(
RouterStatusEntryType::V3,
nickname,
fingerprint,
Utc::now(),
address,
or_port,
));
}
} else if let Some(stripped) = line.strip_prefix("s ") {
if let Some(ref mut router) = current_router {
router.flags = stripped.split_whitespace().map(|s| s.to_string()).collect();
}
} else if let Some(stripped) = line.strip_prefix("w ") {
if let Some(ref mut router) = current_router {
for part in stripped.split_whitespace() {
if let Some((key, value)) = part.split_once('=') {
if let Ok(v) = value.parse::<u64>() {
match key {
"Bandwidth" => router.bandwidth = Some(v),
"Measured" => router.measured = Some(v),
_ => {}
}
}
}
}
}
}
}
if let Some(router) = current_router {
routers.push(router);
}
Ok(routers)
}
fn decode_base64_fingerprint(b64: &str) -> String {
let padded = match b64.len() % 4 {
0 => b64.to_string(),
2 => format!("{}==", b64),
3 => format!("{}=", b64),
_ => b64.to_string(),
};
let decoded = base64_decode(&padded).unwrap_or_default();
decoded.iter().map(|b| format!("{:02X}", b)).collect()
}
fn base64_decode(input: &str) -> Option<Vec<u8>> {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let input = input.trim_end_matches('=');
let mut output = Vec::new();
let mut buffer = 0u32;
let mut bits = 0;
for c in input.bytes() {
let value = ALPHABET.iter().position(|&x| x == c)? as u32;
buffer = (buffer << 6) | value;
bits += 6;
if bits >= 8 {
bits -= 8;
output.push((buffer >> bits) as u8);
buffer &= (1 << bits) - 1;
}
}
Some(output)
}
pub async fn signal_event(
controller: &mut Controller,
state: &VanguardState,
config: &Config,
signal: &str,
) -> Result<()> {
if signal == "RELOAD" {
plog(LogLevel::Notice, "Tor got SIGHUP. Reapplying vanguards.");
configure_tor(controller, state, config).await?;
}
Ok(())
}
pub struct AppState {
pub vanguard_state: VanguardState,
pub bandwidth_stats: BandwidthStats,
pub timeout_stats: TimeoutStats,
pub logguard: Option<LogGuard>,
pub pathverify: Option<PathVerify>,
pub config: Config,
}
impl AppState {
pub fn new(vanguard_state: VanguardState, config: Config) -> Self {
Self {
vanguard_state,
bandwidth_stats: BandwidthStats::new(),
timeout_stats: TimeoutStats::new(),
logguard: None,
pathverify: None,
config,
}
}
}
async fn connect_to_tor(config: &Config) -> Result<Controller> {
if let Some(ref socket_path) = config.control_socket {
match Controller::from_socket_file(socket_path.as_path()).await {
Ok(controller) => {
plog(
LogLevel::Notice,
&format!("Connected to Tor via socket {}", socket_path.display()),
);
return Ok(controller);
}
Err(e) => {
return Err(Error::Control(e));
}
}
}
if let Some(port) = config.control_port {
let addr = format!("{}:{}", config.control_ip, port);
match Controller::from_port(
addr.parse()
.map_err(|e| Error::Config(format!("Invalid control address: {}", e)))?,
)
.await
{
Ok(controller) => {
plog(
LogLevel::Notice,
&format!("Connected to Tor via control port {}", addr),
);
return Ok(controller);
}
Err(e) => {
return Err(Error::Control(e));
}
}
}
if let Ok(controller) = Controller::from_socket_file(Path::new("/run/tor/control")).await {
plog(
LogLevel::Notice,
"Connected to Tor via /run/tor/control socket",
);
return Ok(controller);
}
let addr = format!("{}:9051", config.control_ip);
match Controller::from_port(
addr.parse()
.map_err(|e| Error::Config(format!("Invalid control address: {}", e)))?,
)
.await
{
Ok(controller) => {
plog(
LogLevel::Notice,
&format!("Connected to Tor via {} control port", addr),
);
Ok(controller)
}
Err(e) => Err(Error::Control(e)),
}
}
fn get_event_types(config: &Config, tor_version: &Version) -> Vec<EventType> {
let mut events = Vec::new();
if config.enable_vanguards || config.enable_rendguard {
events.push(EventType::NewConsensus);
events.push(EventType::Signal);
}
if config.enable_rendguard {
events.push(EventType::Circ);
}
if config.enable_bandguards {
events.push(EventType::Circ);
events.push(EventType::Bw);
events.push(EventType::OrConn);
events.push(EventType::NetworkLiveness);
let min_version = Version::new(0, 3, 4).with_patch(10);
if *tor_version >= min_version {
events.push(EventType::CircBw);
events.push(EventType::CircMinor);
} else {
plog(
LogLevel::Notice,
"In order for bandwidth-based protections to be enabled, you must use Tor 0.3.4.10 or newer.",
);
}
}
if config.enable_cbtverify {
events.push(EventType::Circ);
events.push(EventType::BuildTimeoutSet);
}
if config.enable_pathverify {
events.push(EventType::Circ);
events.push(EventType::CircMinor);
events.push(EventType::OrConn);
events.push(EventType::Guard);
events.push(EventType::ConfChanged);
}
if config.enable_logguard {
events.push(EventType::Circ);
events.push(EventType::Warn);
let log_events = LogGuard::get_log_event_types(config.logguard.dump_level);
for event_name in log_events {
match event_name {
"DEBUG" => events.push(EventType::Debug),
"INFO" => events.push(EventType::Info),
"NOTICE" => events.push(EventType::Notice),
"WARN" => events.push(EventType::Warn),
"ERR" => events.push(EventType::Err),
_ => {}
}
}
}
events.sort_by_key(|e| format!("{:?}", e));
events.dedup();
events
}
fn handle_circ_event(state: &mut AppState, event: &stem_rs::events::CircuitEvent, arrived_at: f64) {
let circ_id = &event.id.0;
let status = format!("{:?}", event.status);
let purpose = event.purpose.as_ref().map(|p| format!("{:?}", p));
let hs_state = event.hs_state.as_ref().map(|s| format!("{:?}", s));
let reason = event.reason.as_ref().map(|r| format!("{:?}", r));
let path: Vec<String> = event.path.iter().map(|(fp, _)| fp.clone()).collect();
if state.config.enable_rendguard {
if let (Some(ref p), Some(ref hs)) = (&purpose, &hs_state) {
if p == "HS_SERVICE_REND" && hs == "HSSR_CONNECTING" {
if let Some(rp_fp) = path.last() {
let valid = state
.vanguard_state
.rendguard
.valid_rend_use(rp_fp, &state.config.rendguard);
if !valid {
let usage_rate = state.vanguard_state.rendguard.usage_rate(rp_fp);
let expected = state.vanguard_state.rendguard.expected_weight(rp_fp);
plog(
LogLevel::Warn,
&format!(
"Possible rendezvous point overuse attack: {} used {:.2}% vs expected {:.2}%",
rp_fp, usage_rate, expected
),
);
}
}
}
}
}
if state.config.enable_bandguards {
state.bandwidth_stats.circ_event(
circ_id,
&status,
purpose.as_deref().unwrap_or("GENERAL"),
hs_state.as_deref(),
&path,
reason.as_deref(),
arrived_at,
);
}
if state.config.enable_cbtverify {
state.timeout_stats.circ_event(
circ_id,
&status,
purpose.as_deref().unwrap_or("GENERAL"),
hs_state.as_deref(),
reason.as_deref(),
);
}
if state.config.enable_logguard {
if let Some(ref mut lg) = state.logguard {
lg.circ_event(circ_id, &status, reason.as_deref());
}
}
if state.config.enable_pathverify {
if let Some(ref mut pv) = state.pathverify {
pv.circ_event(
circ_id,
&status,
purpose.as_deref().unwrap_or("GENERAL"),
hs_state.as_deref(),
&event.path,
);
}
}
}
fn handle_circbw_event(
state: &mut AppState,
event: &stem_rs::events::CircuitBandwidthEvent,
arrived_at: f64,
) {
if state.config.enable_bandguards {
state.bandwidth_stats.circbw_event(
&event.id.0,
event.read,
event.written,
event.delivered_read.unwrap_or(0),
event.delivered_written.unwrap_or(0),
event.overhead_read.unwrap_or(0),
event.overhead_written.unwrap_or(0),
arrived_at,
);
}
}
#[allow(dead_code)]
fn handle_circ_minor_event(state: &mut AppState, event: &stem_rs::events::CircuitEvent) {
let circ_id = &event.id.0;
let purpose = event.purpose.as_ref().map(|p| format!("{:?}", p));
let hs_state = event.hs_state.as_ref().map(|s| format!("{:?}", s));
let path: Vec<String> = event.path.iter().map(|(fp, _)| fp.clone()).collect();
if state.config.enable_bandguards {
state.bandwidth_stats.circ_minor_event(
circ_id,
"PURPOSE_CHANGED",
purpose.as_deref().unwrap_or("GENERAL"),
hs_state.as_deref(),
None, None, &path,
);
}
if state.config.enable_pathverify {
if let Some(ref mut pv) = state.pathverify {
pv.circ_minor_event(
circ_id,
purpose.as_deref().unwrap_or("GENERAL"),
None, &event.path,
);
}
}
}
fn handle_circ_minor_raw(state: &mut AppState, content: &str) {
let parts: Vec<&str> = content.split_whitespace().collect();
if parts.len() < 2 {
return;
}
let circ_id = parts[0];
let _event_type = parts[1];
let mut path: Vec<(String, Option<String>)> = Vec::new();
let mut purpose: Option<String> = None;
let mut hs_state: Option<String> = None;
let mut old_purpose: Option<String> = None;
let mut old_hs_state: Option<String> = None;
for part in parts.iter().skip(2) {
if let Some((key, value)) = part.split_once('=') {
match key {
"PURPOSE" => purpose = Some(value.to_string()),
"HS_STATE" => hs_state = Some(value.to_string()),
"OLD_PURPOSE" => old_purpose = Some(value.to_string()),
"OLD_HS_STATE" => old_hs_state = Some(value.to_string()),
_ => {}
}
} else if part.starts_with('$') || part.contains('~') || part.contains(',') {
for hop in part.split(',') {
let hop = hop.trim_start_matches('$');
if let Some((fp, nick)) = hop.split_once('~') {
path.push((fp.to_string(), Some(nick.to_string())));
} else if let Some((fp, nick)) = hop.split_once('=') {
path.push((fp.to_string(), Some(nick.to_string())));
} else if !hop.is_empty() {
path.push((hop.to_string(), None));
}
}
}
}
if state.config.enable_bandguards {
let path_fps: Vec<String> = path.iter().map(|(fp, _)| fp.clone()).collect();
state.bandwidth_stats.circ_minor_event(
circ_id,
_event_type,
purpose.as_deref().unwrap_or("GENERAL"),
hs_state.as_deref(),
old_purpose.as_deref(),
old_hs_state.as_deref(),
&path_fps,
);
}
if state.config.enable_pathverify {
if let Some(ref mut pv) = state.pathverify {
pv.circ_minor_event(
circ_id,
purpose.as_deref().unwrap_or("GENERAL"),
old_purpose.as_deref(),
&path,
);
}
}
}
fn handle_orconn_event(
state: &mut AppState,
event: &stem_rs::events::OrConnEvent,
arrived_at: f64,
) {
let status = format!("{:?}", event.status);
let reason = event.reason.as_ref().map(|r| format!("{:?}", r));
let conn_id = event.id.as_deref().unwrap_or("");
if state.config.enable_bandguards {
state.bandwidth_stats.orconn_event(
conn_id,
&event.target,
&status,
reason.as_deref(),
arrived_at,
);
}
if state.config.enable_pathverify {
if let Some(ref mut pv) = state.pathverify {
let guard_fp = if event.target.starts_with('$') {
event.target[1..].split(['~', '=']).next().unwrap_or("")
} else {
&event.target
};
pv.orconn_event(guard_fp, &status);
}
}
}
fn handle_bw_event(
state: &mut AppState,
_event: &stem_rs::events::BandwidthEvent,
arrived_at: f64,
) {
if state.config.enable_bandguards {
state
.bandwidth_stats
.check_connectivity(arrived_at, &state.config.bandguards);
}
}
fn handle_network_liveness_event(
state: &mut AppState,
event: &stem_rs::events::NetworkLivenessEvent,
arrived_at: f64,
) {
if state.config.enable_bandguards {
let status = format!("{:?}", event.status);
state
.bandwidth_stats
.network_liveness_event(&status, arrived_at);
}
}
fn handle_buildtimeout_set_event(
state: &mut AppState,
event: &stem_rs::events::BuildTimeoutSetEvent,
) {
if state.config.enable_cbtverify {
let set_type = format!("{:?}", event.set_type);
state.timeout_stats.cbt_event(&set_type, event.timeout_rate);
}
}
fn handle_guard_event(state: &mut AppState, event: &stem_rs::events::GuardEvent) {
if state.config.enable_pathverify {
if let Some(ref mut pv) = state.pathverify {
let status = format!("{:?}", event.status);
pv.guard_event(&event.endpoint_fingerprint, &status);
}
}
}
fn handle_conf_changed_event(state: &mut AppState, event: &stem_rs::events::ConfChangedEvent) {
if state.config.enable_pathverify {
if let Some(ref mut pv) = state.pathverify {
pv.conf_changed_event(&event.changed);
}
}
}
fn handle_log_event(state: &mut AppState, event: &stem_rs::events::LogEvent, arrived_at: f64) {
if state.config.enable_logguard {
if let Some(ref mut lg) = state.logguard {
let runlevel = format!("{:?}", event.runlevel);
lg.log_event_with_timestamp(&runlevel, &event.message, arrived_at);
if matches!(event.runlevel, stem_rs::Runlevel::Warn) {
lg.log_warn_event(&event.message);
}
}
}
}
async fn handle_signal_event(
controller: &mut Controller,
state: &mut AppState,
event: &stem_rs::events::SignalEvent,
) -> Result<()> {
let signal_name = format!("{:?}", event.signal);
signal_event(
controller,
&state.vanguard_state,
&state.config,
&signal_name,
)
.await
}
pub async fn control_loop(state: &mut AppState) -> String {
let mut controller = match connect_to_tor(&state.config).await {
Ok(c) => c,
Err(e) => return format!("failed: {}", e),
};
if let Err(e) = authenticate_any(&mut controller, state.config.control_pass.as_deref()).await {
return format!("failed: {}", e);
}
let tor_version = match controller.get_version().await {
Ok(v) => v,
Err(e) => return format!("failed: {}", e),
};
if state.config.enable_vanguards || state.config.enable_rendguard {
match new_consensus_event(&mut controller, &mut state.vanguard_state, &state.config).await {
Ok(()) => {}
Err(Error::DescriptorUnavailable(msg)) => {
plog(
LogLevel::Notice,
&format!("Tor needs descriptors: {}. Trying again...", msg),
);
return format!("failed: {}", msg);
}
Err(e) => return format!("failed: {}", e),
}
}
if state.config.one_shot_vanguards {
plog(
LogLevel::Notice,
"Updated vanguards. Exiting (one-shot mode).",
);
std::process::exit(0);
}
if state.config.enable_logguard {
state.logguard = Some(LogGuard::new(&state.config.logguard));
}
if state.config.enable_pathverify {
state.pathverify = Some(PathVerify::new(
state.config.enable_vanguards,
state.config.vanguards.num_layer1_guards,
state.config.vanguards.num_layer2_guards,
state.config.vanguards.num_layer3_guards,
));
if let Err(e) = controller.signal(stem_rs::Signal::Newnym).await {
plog(LogLevel::Warn, &format!("Failed to send NEWNYM: {}", e));
}
}
let event_types = get_event_types(&state.config, &tor_version);
if let Err(e) = controller.set_events(&event_types).await {
return format!("failed: {}", e);
}
loop {
match controller.recv_event().await {
Ok(event) => {
let arrived_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0);
match event {
ParsedEvent::Circuit(ref e) => {
handle_circ_event(state, e, arrived_at);
}
ParsedEvent::CircuitBandwidth(ref e) => {
handle_circbw_event(state, e, arrived_at);
}
ParsedEvent::OrConn(ref e) => {
handle_orconn_event(state, e, arrived_at);
}
ParsedEvent::Bandwidth(ref e) => {
handle_bw_event(state, e, arrived_at);
}
ParsedEvent::NetworkLiveness(ref e) => {
handle_network_liveness_event(state, e, arrived_at);
}
ParsedEvent::BuildTimeoutSet(ref e) => {
handle_buildtimeout_set_event(state, e);
}
ParsedEvent::Guard(ref e) => {
handle_guard_event(state, e);
}
ParsedEvent::ConfChanged(ref e) => {
handle_conf_changed_event(state, e);
}
ParsedEvent::Log(ref e) => {
handle_log_event(state, e, arrived_at);
}
ParsedEvent::Signal(ref e) => {
if let Err(err) = handle_signal_event(&mut controller, state, e).await {
plog(LogLevel::Warn, &format!("Signal event error: {}", err));
}
}
ParsedEvent::Unknown {
ref event_type,
ref content,
} => {
if event_type == "NEWCONSENSUS" {
if let Err(err) = new_consensus_event(
&mut controller,
&mut state.vanguard_state,
&state.config,
)
.await
{
plog(LogLevel::Warn, &format!("Consensus event error: {}", err));
}
} else if event_type == "CIRC_MINOR" {
handle_circ_minor_raw(state, content);
}
}
_ => {
}
}
if state.config.enable_bandguards {
let circs_to_check: Vec<String> =
state.bandwidth_stats.circs.keys().cloned().collect();
for circ_id in circs_to_check {
let limit_result = state
.bandwidth_stats
.check_circuit_limits(&circ_id, &state.config.bandguards);
match limit_result {
crate::bandguards::CircuitLimitResult::Ok => {}
crate::bandguards::CircuitLimitResult::TorBug {
bug_id,
dropped_cells,
} => {
plog(
LogLevel::Info,
&format!(
"Tor bug {} (dropped {} cells): {}",
bug_id, dropped_cells, circ_id
),
);
}
crate::bandguards::CircuitLimitResult::DroppedCells {
dropped_cells,
} => {
plog(
LogLevel::Warn,
&format!(
"Dropped cells attack ({} cells): {}",
dropped_cells, circ_id
),
);
try_close_circuit(
&mut controller,
&circ_id,
state.logguard.as_mut(),
)
.await;
}
crate::bandguards::CircuitLimitResult::MaxBytesExceeded {
bytes,
limit,
} => {
plog(
LogLevel::Warn,
&format!(
"Circuit {} exceeded max bytes ({} > {})",
circ_id, bytes, limit
),
);
try_close_circuit(
&mut controller,
&circ_id,
state.logguard.as_mut(),
)
.await;
}
crate::bandguards::CircuitLimitResult::HsdirBytesExceeded {
bytes,
limit,
} => {
plog(
LogLevel::Warn,
&format!(
"HSDIR circuit {} exceeded max bytes ({} > {})",
circ_id, bytes, limit
),
);
try_close_circuit(
&mut controller,
&circ_id,
state.logguard.as_mut(),
)
.await;
}
crate::bandguards::CircuitLimitResult::ServIntroBytesExceeded {
bytes,
limit,
} => {
plog(
LogLevel::Warn,
&format!(
"Service intro circuit {} exceeded max bytes ({} > {})",
circ_id, bytes, limit
),
);
try_close_circuit(
&mut controller,
&circ_id,
state.logguard.as_mut(),
)
.await;
}
}
}
}
}
Err(e) => {
plog(LogLevel::Debug, &format!("Event receive error: {}", e));
return "closed".to_string();
}
}
}
}
pub async fn run_main(config: Config) -> Result<()> {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
tokio::spawn(async move {
if let Ok(()) = tokio::signal::ctrl_c().await {
plog(LogLevel::Notice, "Got CTRL+C. Exiting.");
shutdown_clone.store(true, Ordering::SeqCst);
}
});
set_close_circuits(config.close_circuits);
let state_path = &config.state_file;
let vanguard_state = match VanguardState::read_from_file(state_path) {
Ok(mut state) => {
plog(
LogLevel::Info,
&format!("Current layer2 guards: {}", state.layer2_guardset()),
);
plog(
LogLevel::Info,
&format!("Current layer3 guards: {}", state.layer3_guardset()),
);
state.enable_vanguards = config.enable_vanguards;
state
}
Err(_) => {
plog(
LogLevel::Notice,
&format!(
"Creating new vanguard state file at: {}",
state_path.display()
),
);
let mut state = VanguardState::new(&state_path.to_string_lossy());
state.enable_vanguards = config.enable_vanguards;
state
}
};
let mut app_state = AppState::new(vanguard_state, config.clone());
let mut reconnects = 0u32;
let mut last_connected_at: Option<f64> = None;
let mut connected = false;
loop {
if shutdown.load(Ordering::SeqCst) {
break;
}
if let Some(limit) = config.retry_limit {
if reconnects >= limit {
break;
}
}
let result = control_loop(&mut app_state).await;
if last_connected_at.is_none() {
last_connected_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0),
);
}
if result == "closed" {
connected = true;
}
if result == "closed" || reconnects.is_multiple_of(10) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0);
let disconnected_secs = now - last_connected_at.unwrap_or(now);
let max_disconnected = config.bandguards.conn_max_disconnected_secs as f64;
if disconnected_secs > max_disconnected {
plog(
LogLevel::Warn,
&format!("Tor daemon connection {}. Trying again...", result),
);
} else {
plog(
LogLevel::Notice,
&format!("Tor daemon connection {}. Trying again...", result),
);
}
}
reconnects += 1;
tokio::time::sleep(Duration::from_secs(1)).await;
}
if !connected {
return Err(Error::Config("Failed to connect to Tor".to_string()));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_get_consensus_weights() {
let mut file = NamedTempFile::new().unwrap();
writeln!(
file,
"network-status-version 3 microdesc\n\
bandwidth-weights Wbd=0 Wbe=0 Wbg=4194 Wbm=10000 Wdb=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5806 Wgm=5806 Wmb=10000 Wmd=0 Wme=0 Wmg=4194 Wmm=10000"
)
.unwrap();
let weights = get_consensus_weights(file.path()).unwrap();
assert_eq!(weights.get("Wmm"), Some(&10000));
assert_eq!(weights.get("Wgg"), Some(&5806));
assert_eq!(weights.get("Wbd"), Some(&0));
}
#[test]
fn test_get_consensus_weights_missing() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "network-status-version 3 microdesc").unwrap();
let result = get_consensus_weights(file.path());
assert!(result.is_err());
}
#[test]
fn test_base64_decode() {
let decoded = base64_decode("SGVsbG8=").unwrap();
assert_eq!(decoded, b"Hello");
let decoded = base64_decode("SGVsbG8").unwrap();
assert_eq!(decoded, b"Hello");
}
#[test]
fn test_decode_base64_fingerprint() {
let b64 = "AAAAAAAAAAAAAAAAAAAAAAAAAAA";
let hex = decode_base64_fingerprint(b64);
assert_eq!(hex.len(), 40);
assert!(hex.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_parse_network_statuses() {
let response = "\
r relay1 AAAAAAAAAAAAAAAAAAAAAAAAAAAA BBBBBBBBBBBBBBBBBBBBBBBBBBBB 2024-01-01 00:00:00 192.168.1.1 9001 0
s Fast Guard Running Stable Valid
w Bandwidth=1000 Measured=900
r relay2 CCCCCCCCCCCCCCCCCCCCCCCCCCCC DDDDDDDDDDDDDDDDDDDDDDDDDDDD 2024-01-01 00:00:00 192.168.1.2 9002 0
s Fast Running Stable Valid Exit
w Bandwidth=2000";
let routers = parse_network_statuses(response).unwrap();
assert_eq!(routers.len(), 2);
assert_eq!(routers[0].nickname, "relay1");
assert!(routers[0].flags.contains(&"Guard".to_string()));
assert_eq!(routers[0].bandwidth, Some(1000));
assert_eq!(routers[0].measured, Some(900));
assert_eq!(routers[1].nickname, "relay2");
assert!(routers[1].flags.contains(&"Exit".to_string()));
assert_eq!(routers[1].bandwidth, Some(2000));
assert_eq!(routers[1].measured, None);
}
#[test]
fn test_close_circuits_flag() {
set_close_circuits(true);
assert!(get_close_circuits());
set_close_circuits(false);
assert!(!get_close_circuits());
set_close_circuits(true);
}
}