use std::{
collections::HashSet,
fs::{self, File},
future::Future,
io::{Read, Write},
net::{IpAddr, Ipv6Addr, SocketAddr},
path::{Path, PathBuf},
sync::{Arc, LazyLock, atomic::AtomicBool},
time::Duration,
};
use anyhow::Context;
use directories::ProjectDirs;
use either::Either;
use serde::{Deserialize, Serialize};
use tokio::runtime::Runtime;
use crate::{
dev_tool::PeerId,
local_node::OperationMode,
tracing::tracer::get_log_dir,
transport::{CongestionControlAlgorithm, CongestionControlConfig, TransportKeypair},
};
mod secret;
pub use secret::*;
pub const DEFAULT_MAX_CONNECTIONS: usize = crate::ring::Ring::DEFAULT_MAX_CONNECTIONS;
pub const DEFAULT_MIN_CONNECTIONS: usize = crate::ring::Ring::DEFAULT_MIN_CONNECTIONS;
pub const DEFAULT_RANDOM_PEER_CONN_THRESHOLD: usize = 7;
pub const DEFAULT_MAX_HOPS_TO_LIVE: usize = 10;
pub(crate) const OPERATION_TTL: Duration = Duration::from_secs(60);
pub(crate) const PCK_VERSION: &str = env!("CARGO_PKG_VERSION");
pub(crate) const MIN_COMPATIBLE_VERSION: &str = env!("FREENET_MIN_COMPATIBLE_VERSION");
static ASYNC_RT: LazyLock<Option<Runtime>> = LazyLock::new(GlobalExecutor::initialize_async_rt);
const DEFAULT_TRANSIENT_BUDGET: usize = 2048;
const DEFAULT_TRANSIENT_TTL_SECS: u64 = 30;
const QUALIFIER: &str = "";
const ORGANIZATION: &str = "The Freenet Project Inc";
const APPLICATION: &str = "Freenet";
const FREENET_GATEWAYS_INDEX: &str = "https://freenet.org/keys/gateways.toml";
#[derive(clap::Parser, Debug, Clone)]
pub struct ConfigArgs {
#[arg(value_enum, env = "MODE")]
pub mode: Option<OperationMode>,
#[command(flatten)]
pub ws_api: WebsocketApiArgs,
#[command(flatten)]
pub network_api: NetworkArgs,
#[command(flatten)]
pub secrets: SecretArgs,
#[arg(long, env = "LOG_LEVEL")]
pub log_level: Option<tracing::log::LevelFilter>,
#[command(flatten)]
pub config_paths: ConfigPathsArgs,
#[arg(long, hide = true)]
pub id: Option<String>,
#[arg(long, short)]
pub version: bool,
#[arg(long, env = "MAX_BLOCKING_THREADS")]
pub max_blocking_threads: Option<usize>,
#[command(flatten)]
pub telemetry: TelemetryArgs,
}
impl Default for ConfigArgs {
fn default() -> Self {
Self {
mode: Some(OperationMode::Network),
network_api: NetworkArgs {
address: Some(default_listening_address()),
network_port: Some(default_network_api_port()),
public_address: None,
public_port: None,
is_gateway: false,
skip_load_from_network: true,
ignore_protocol_checking: false,
gateways: None,
gateway: None,
location: None,
bandwidth_limit: Some(3_000_000), total_bandwidth_limit: None,
min_bandwidth_per_connection: None,
blocked_addresses: None,
transient_budget: Some(DEFAULT_TRANSIENT_BUDGET),
transient_ttl_secs: Some(DEFAULT_TRANSIENT_TTL_SECS),
min_connections: None,
max_connections: None,
streaming_threshold: None, ledbat_min_ssthresh: None, congestion_control: None, bbr_startup_rate: None, },
ws_api: WebsocketApiArgs {
address: Some(default_listening_address()),
ws_api_port: Some(default_ws_api_port()),
token_ttl_seconds: None,
token_cleanup_interval_seconds: None,
allowed_host: None,
allowed_source_cidrs: None,
},
secrets: Default::default(),
log_level: Some(tracing::log::LevelFilter::Info),
config_paths: Default::default(),
id: None,
version: false,
max_blocking_threads: None,
telemetry: Default::default(),
}
}
}
impl ConfigArgs {
pub fn current_version(&self) -> &str {
PCK_VERSION
}
fn read_config(dir: &PathBuf) -> std::io::Result<Option<Config>> {
if !dir.exists() {
return Ok(None);
}
let mut read_dir = std::fs::read_dir(dir)?;
let config_args: Option<(String, String)> = read_dir.find_map(|e| {
if let Ok(e) = e {
if e.path().is_dir() {
return None;
}
let filename = e.file_name().to_string_lossy().into_owned();
let ext = filename.rsplit('.').next().map(|s| s.to_owned());
if let Some(ext) = ext {
if filename.starts_with("config") {
match ext.as_str() {
"toml" => {
tracing::debug!(filename = %filename, "Found configuration file");
return Some((filename, ext));
}
"json" => {
return Some((filename, ext));
}
_ => {}
}
}
}
}
None
});
match config_args {
Some((filename, ext)) => {
let path = dir.join(filename).with_extension(&ext);
tracing::debug!(path = ?path, "Reading configuration file");
match ext.as_str() {
"toml" => {
let mut file = File::open(&path)?;
let mut content = String::new();
file.read_to_string(&mut content)?;
let mut config = toml::from_str::<Config>(&content).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
})?;
let secrets = Self::read_secrets(
config.secrets.transport_keypair_path,
config.secrets.nonce_path,
config.secrets.cipher_path,
)?;
config.secrets = secrets;
Ok(Some(config))
}
"json" => {
let mut file = File::open(&path)?;
let mut config = serde_json::from_reader::<_, Config>(&mut file)?;
let secrets = Self::read_secrets(
config.secrets.transport_keypair_path,
config.secrets.nonce_path,
config.secrets.cipher_path,
)?;
config.secrets = secrets;
Ok(Some(config))
}
ext => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid configuration file extension: {ext}"),
)),
}
}
None => Ok(None),
}
}
pub async fn build(mut self) -> anyhow::Result<Config> {
self.network_api.validate()?;
let cfg = if let Some(path) = self.config_paths.config_dir.as_ref() {
if !path.exists() {
return Err(anyhow::Error::new(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Configuration directory not found",
)));
}
Self::read_config(path)?
} else {
let (config, data, is_temp_dir) = {
match ConfigPathsArgs::default_dirs(self.id.as_deref())? {
Either::Left(defaults) => (
defaults.config_local_dir().to_path_buf(),
defaults.data_local_dir().to_path_buf(),
false,
),
Either::Right(dir) => (dir.clone(), dir, true),
}
};
self.config_paths.config_dir = Some(config.clone());
if self.config_paths.data_dir.is_none() {
self.config_paths.data_dir = Some(data);
}
if is_temp_dir {
None
} else {
Self::read_config(&config)?.inspect(|_| {
tracing::debug!("Found configuration file in default directory");
})
}
};
let should_persist = cfg.is_none();
if let Some(cfg) = cfg {
self.secrets.merge(cfg.secrets);
self.mode.get_or_insert(cfg.mode);
self.ws_api.address.get_or_insert(cfg.ws_api.address);
self.ws_api.ws_api_port.get_or_insert(cfg.ws_api.port);
self.ws_api
.token_ttl_seconds
.get_or_insert(cfg.ws_api.token_ttl_seconds);
self.ws_api
.token_cleanup_interval_seconds
.get_or_insert(cfg.ws_api.token_cleanup_interval_seconds);
if !cfg.ws_api.allowed_hosts.is_empty() {
self.ws_api
.allowed_host
.get_or_insert(cfg.ws_api.allowed_hosts);
}
if !cfg.ws_api.allowed_source_cidrs.is_empty() {
self.ws_api.allowed_source_cidrs.get_or_insert(
cfg.ws_api
.allowed_source_cidrs
.iter()
.map(|net| net.to_string())
.collect(),
);
}
self.network_api
.address
.get_or_insert(cfg.network_api.address);
self.network_api
.network_port
.get_or_insert(cfg.network_api.port);
if let Some(addr) = cfg.network_api.public_address {
self.network_api.public_address.get_or_insert(addr);
}
if let Some(port) = cfg.network_api.public_port {
self.network_api.public_port.get_or_insert(port);
}
if let Some(limit) = cfg.network_api.bandwidth_limit {
self.network_api.bandwidth_limit.get_or_insert(limit);
}
if let Some(addrs) = cfg.network_api.blocked_addresses {
self.network_api
.blocked_addresses
.get_or_insert_with(|| addrs.into_iter().collect());
}
self.network_api
.transient_budget
.get_or_insert(cfg.network_api.transient_budget);
self.network_api
.transient_ttl_secs
.get_or_insert(cfg.network_api.transient_ttl_secs);
self.network_api
.min_connections
.get_or_insert(cfg.network_api.min_connections);
self.network_api
.max_connections
.get_or_insert(cfg.network_api.max_connections);
if cfg.network_api.streaming_threshold != default_streaming_threshold() {
self.network_api
.streaming_threshold
.get_or_insert(cfg.network_api.streaming_threshold);
}
if self.network_api.ledbat_min_ssthresh.is_none() {
self.network_api.ledbat_min_ssthresh = cfg.network_api.ledbat_min_ssthresh;
}
if self.network_api.congestion_control.is_none()
&& cfg.network_api.congestion_control != default_congestion_control()
{
self.network_api
.congestion_control
.get_or_insert(cfg.network_api.congestion_control);
}
if self.network_api.bbr_startup_rate.is_none() {
self.network_api.bbr_startup_rate = cfg.network_api.bbr_startup_rate;
}
self.log_level.get_or_insert(cfg.log_level);
self.config_paths.merge(cfg.config_paths.as_ref().clone());
if !cfg.telemetry.enabled {
self.telemetry.enabled = false;
}
if self.telemetry.endpoint.is_none() {
self.telemetry
.endpoint
.get_or_insert(cfg.telemetry.endpoint);
}
}
let mode = self.mode.unwrap_or(OperationMode::Network);
let config_paths = self.config_paths.build(self.id.as_deref())?;
let secrets = self.secrets.build(Some(&config_paths.secrets_dir(mode)))?;
let peer_id = self
.network_api
.public_address
.zip(self.network_api.public_port)
.map(|(addr, port)| {
PeerId::new(
secrets.transport_keypair.public().clone(),
(addr, port).into(),
)
});
let gateways_file = config_paths.config_dir.join("gateways.toml");
let remotely_loaded_gateways = if mode == OperationMode::Local {
Gateways::default()
} else if !self.network_api.skip_load_from_network {
load_gateways_from_index(FREENET_GATEWAYS_INDEX, &config_paths.secrets_dir)
.await
.inspect_err(|error| {
tracing::error!(
error = %error,
index = FREENET_GATEWAYS_INDEX,
"Failed to load gateways from index"
);
})
.unwrap_or_default()
} else if let Some(gateways) = self.network_api.gateways {
let gateways = gateways
.into_iter()
.map(|cfg| {
let cfg = serde_json::from_str::<InlineGwConfig>(&cfg)?;
Ok::<_, anyhow::Error>(GatewayConfig {
address: Address::HostAddress(cfg.address),
public_key_path: cfg.public_key_path,
location: cfg.location,
})
})
.collect::<Result<Vec<_>, _>>()?;
Gateways { gateways }
} else {
Gateways::default()
};
let has_cli_gateways = self
.network_api
.gateway
.as_ref()
.is_some_and(|v| !v.is_empty());
let gateways = if mode == OperationMode::Local {
Gateways { gateways: vec![] }
} else if !self.network_api.skip_load_from_network
&& !remotely_loaded_gateways.gateways.is_empty()
{
tracing::info!(
gateway_count = remotely_loaded_gateways.gateways.len(),
"Replacing local gateways with gateways from remote index"
);
if let Err(e) = remotely_loaded_gateways.save_to_file(&gateways_file) {
tracing::warn!(
error = %e,
file = ?gateways_file,
"Failed to save updated gateways to file"
);
}
remotely_loaded_gateways
} else if self.network_api.skip_load_from_network && self.network_api.is_gateway {
if remotely_loaded_gateways.gateways.is_empty() {
tracing::info!(
"Gateway running in isolated mode (skip_load_from_network), not connecting to other gateways"
);
Gateways { gateways: vec![] }
} else {
remotely_loaded_gateways
}
} else {
let remote_fetch_failed = !self.network_api.skip_load_from_network
&& remotely_loaded_gateways.gateways.is_empty();
if remote_fetch_failed {
tracing::warn!(
file = ?gateways_file,
"Remote gateway fetch failed, falling back to local cache"
);
}
let mut gateways = match File::open(&*gateways_file) {
Ok(mut file) => {
let mut content = String::new();
file.read_to_string(&mut content)?;
toml::from_str::<Gateways>(&content).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
})?
}
Err(err) => {
if peer_id.is_none()
&& mode == OperationMode::Network
&& remotely_loaded_gateways.gateways.is_empty()
&& !has_cli_gateways
{
let hint = if remote_fetch_failed {
"Cannot initialize node without gateways. \
The remote gateway index could not be reached and no \
local cache exists yet. Check your network connection \
and firewall settings, then try again."
} else {
"Cannot initialize node without gateways"
};
tracing::error!(
file = ?gateways_file,
error = %err,
remote_fetch_failed,
"{hint}"
);
return Err(anyhow::Error::new(std::io::Error::new(
std::io::ErrorKind::NotFound,
hint,
)));
}
if remotely_loaded_gateways.gateways.is_empty() {
tracing::warn!("No gateways file found, initializing disjoint gateway");
}
Gateways { gateways: vec![] }
}
};
if !remotely_loaded_gateways.gateways.is_empty() {
gateways.merge_and_deduplicate(remotely_loaded_gateways);
}
gateways
};
let mut gateways = gateways;
if let Some(cli_entries) = self.network_api.gateway {
let secrets_dir = config_paths.secrets_dir(mode);
if let Ok(entries) = fs::read_dir(&secrets_dir) {
for entry in entries.flatten() {
if entry
.file_name()
.to_str()
.is_some_and(|n| n.starts_with("cli_gw_") && n.ends_with(".pub"))
{
if let Err(e) = fs::remove_file(entry.path()) {
tracing::debug!(
error = %e,
file = ?entry.path(),
"Failed to remove stale CLI gateway key file"
);
}
}
}
}
let mut cli_gateways = Gateways { gateways: vec![] };
let mut seen_addrs = HashSet::new();
for entry in &cli_entries {
match parse_gateway(entry, &secrets_dir) {
Ok(gw) => {
if !seen_addrs.insert(gw.address.clone()) {
tracing::warn!(
address = ?gw.address,
"Skipping duplicate --gateway address"
);
continue;
}
tracing::info!(
address = ?gw.address,
"Adding user-specified gateway via --gateway"
);
cli_gateways.gateways.push(gw);
}
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to parse --gateway \"{entry}\": {e}"
));
}
}
}
cli_gateways.merge_and_deduplicate(gateways);
gateways = cli_gateways;
}
let this = Config {
mode,
peer_id,
network_api: NetworkApiConfig {
address: self.network_api.address.unwrap_or_else(|| match mode {
OperationMode::Local => default_local_address(),
OperationMode::Network => default_listening_address(),
}),
port: self
.network_api
.network_port
.unwrap_or_else(default_network_api_port),
public_address: self.network_api.public_address,
public_port: self.network_api.public_port,
ignore_protocol_version: self.network_api.ignore_protocol_checking,
bandwidth_limit: self.network_api.bandwidth_limit,
total_bandwidth_limit: self.network_api.total_bandwidth_limit,
min_bandwidth_per_connection: self.network_api.min_bandwidth_per_connection,
blocked_addresses: self
.network_api
.blocked_addresses
.map(|addrs| addrs.into_iter().collect()),
transient_budget: self
.network_api
.transient_budget
.unwrap_or(DEFAULT_TRANSIENT_BUDGET),
transient_ttl_secs: self
.network_api
.transient_ttl_secs
.unwrap_or(DEFAULT_TRANSIENT_TTL_SECS),
min_connections: self
.network_api
.min_connections
.unwrap_or(DEFAULT_MIN_CONNECTIONS),
max_connections: self
.network_api
.max_connections
.unwrap_or(DEFAULT_MAX_CONNECTIONS),
streaming_threshold: self
.network_api
.streaming_threshold
.unwrap_or_else(default_streaming_threshold),
ledbat_min_ssthresh: self
.network_api
.ledbat_min_ssthresh
.or_else(default_ledbat_min_ssthresh),
congestion_control: self
.network_api
.congestion_control
.clone()
.unwrap_or_else(default_congestion_control),
bbr_startup_rate: self.network_api.bbr_startup_rate,
skip_load_from_network: self.network_api.skip_load_from_network,
},
ws_api: WebsocketApiConfig {
address: {
self.ws_api.address.unwrap_or_else(|| match mode {
OperationMode::Local => default_local_address(),
OperationMode::Network => default_listening_address(),
})
},
port: self.ws_api.ws_api_port.unwrap_or(default_ws_api_port()),
token_ttl_seconds: self
.ws_api
.token_ttl_seconds
.unwrap_or(default_token_ttl_seconds()),
token_cleanup_interval_seconds: self
.ws_api
.token_cleanup_interval_seconds
.unwrap_or(default_token_cleanup_interval_seconds()),
allowed_hosts: self.ws_api.allowed_host.unwrap_or_default(),
allowed_source_cidrs: self
.ws_api
.allowed_source_cidrs
.as_ref()
.map(|cidrs| {
cidrs
.iter()
.map(|s| {
let net = s.parse::<ipnet::IpNet>().map_err(|e| {
anyhow::anyhow!(
"invalid CIDR `{s}` in allowed-source-cidrs: {e}"
)
})?;
crate::server::validate_source_cidr(&net).map_err(|msg| {
anyhow::anyhow!("allowed-source-cidrs: {msg}")
})?;
Ok::<_, anyhow::Error>(net)
})
.collect::<Result<Vec<_>, _>>()
})
.transpose()?
.unwrap_or_default(),
},
secrets,
log_level: self.log_level.unwrap_or(tracing::log::LevelFilter::Info),
config_paths: Arc::new(config_paths),
gateways: gateways.gateways.clone(),
is_gateway: self.network_api.is_gateway,
location: self.network_api.location,
max_blocking_threads: self
.max_blocking_threads
.unwrap_or_else(default_max_blocking_threads),
telemetry: TelemetryConfig {
enabled: self.telemetry.enabled,
endpoint: self
.telemetry
.endpoint
.unwrap_or_else(|| DEFAULT_TELEMETRY_ENDPOINT.to_string()),
transport_snapshot_interval_secs: self
.telemetry
.transport_snapshot_interval_secs
.unwrap_or_else(default_transport_snapshot_interval_secs),
is_test_environment: self.id.is_some(),
},
};
fs::create_dir_all(this.config_dir())?;
if !self.network_api.skip_load_from_network {
gateways.save_to_file(&gateways_file)?;
}
if should_persist {
let path = this.config_dir().join("config.toml");
tracing::info!(path = ?path, "Persisting configuration");
let mut file = File::create(path)?;
file.write_all(
toml::to_string(&this)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
.as_bytes(),
)?;
}
Ok(this)
}
}
mod serde_log_level_filter {
use serde::{Deserialize, Deserializer, Serializer};
use tracing::log::LevelFilter;
pub fn parse_log_level_str<'a, D>(level: &str) -> Result<LevelFilter, D::Error>
where
D: serde::Deserializer<'a>,
{
Ok(match level.trim() {
"off" | "Off" | "OFF" => LevelFilter::Off,
"error" | "Error" | "ERROR" => LevelFilter::Error,
"warn" | "Warn" | "WARN" => LevelFilter::Warn,
"info" | "Info" | "INFO" => LevelFilter::Info,
"debug" | "Debug" | "DEBUG" => LevelFilter::Debug,
"trace" | "Trace" | "TRACE" => LevelFilter::Trace,
s => return Err(serde::de::Error::custom(format!("unknown log level: {s}"))),
})
}
pub fn serialize<S>(level: &LevelFilter, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let level = match level {
LevelFilter::Off => "off",
LevelFilter::Error => "error",
LevelFilter::Warn => "warn",
LevelFilter::Info => "info",
LevelFilter::Debug => "debug",
LevelFilter::Trace => "trace",
};
serializer.serialize_str(level)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<LevelFilter, D::Error>
where
D: Deserializer<'de>,
{
let level = String::deserialize(deserializer)?;
parse_log_level_str::<D>(level.as_str())
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Config {
pub mode: OperationMode,
#[serde(flatten)]
pub network_api: NetworkApiConfig,
#[serde(flatten)]
pub ws_api: WebsocketApiConfig,
#[serde(flatten)]
pub secrets: Secrets,
#[serde(with = "serde_log_level_filter")]
pub log_level: tracing::log::LevelFilter,
#[serde(flatten)]
config_paths: Arc<ConfigPaths>,
#[serde(skip)]
pub(crate) peer_id: Option<PeerId>,
#[serde(skip)]
pub(crate) gateways: Vec<GatewayConfig>,
pub(crate) is_gateway: bool,
pub(crate) location: Option<f64>,
#[serde(default = "default_max_blocking_threads")]
pub max_blocking_threads: usize,
#[serde(flatten)]
pub telemetry: TelemetryConfig,
}
fn default_max_blocking_threads() -> usize {
std::thread::available_parallelism()
.map(|n| (n.get() * 2).clamp(4, 32))
.unwrap_or(8)
}
impl Config {
pub fn transport_keypair(&self) -> &TransportKeypair {
self.secrets.transport_keypair()
}
pub fn paths(&self) -> Arc<ConfigPaths> {
self.config_paths.clone()
}
}
#[derive(clap::Parser, Debug, Default, Clone, Serialize, Deserialize)]
pub struct NetworkArgs {
#[arg(
name = "network_address",
long = "network-address",
env = "NETWORK_ADDRESS"
)]
#[serde(rename = "network-address", skip_serializing_if = "Option::is_none")]
pub address: Option<IpAddr>,
#[arg(long, env = "NETWORK_PORT")]
#[serde(rename = "network-port", skip_serializing_if = "Option::is_none")]
pub network_port: Option<u16>,
#[arg(long = "public-network-address", env = "PUBLIC_NETWORK_ADDRESS")]
#[serde(
rename = "public-network-address",
skip_serializing_if = "Option::is_none"
)]
pub public_address: Option<IpAddr>,
#[arg(long = "public-network-port", env = "PUBLIC_NETWORK_PORT")]
#[serde(
rename = "public-network-port",
skip_serializing_if = "Option::is_none"
)]
pub public_port: Option<u16>,
#[arg(long)]
pub is_gateway: bool,
#[arg(long)]
pub skip_load_from_network: bool,
#[arg(long, hide = true)]
pub gateways: Option<Vec<String>>,
#[arg(long)]
#[serde(rename = "gateway", skip_serializing_if = "Option::is_none")]
pub gateway: Option<Vec<String>>,
#[arg(long, hide = true, env = "LOCATION")]
pub location: Option<f64>,
#[arg(long)]
pub ignore_protocol_checking: bool,
#[arg(long)]
pub bandwidth_limit: Option<usize>,
#[arg(long)]
#[serde(
rename = "total-bandwidth-limit",
skip_serializing_if = "Option::is_none"
)]
pub total_bandwidth_limit: Option<usize>,
#[arg(long)]
#[serde(
rename = "min-bandwidth-per-connection",
skip_serializing_if = "Option::is_none"
)]
pub min_bandwidth_per_connection: Option<usize>,
#[arg(long, num_args = 0..)]
pub blocked_addresses: Option<Vec<SocketAddr>>,
#[arg(long, env = "TRANSIENT_BUDGET")]
#[serde(rename = "transient-budget", skip_serializing_if = "Option::is_none")]
pub transient_budget: Option<usize>,
#[arg(long, env = "TRANSIENT_TTL_SECS")]
#[serde(rename = "transient-ttl-secs", skip_serializing_if = "Option::is_none")]
pub transient_ttl_secs: Option<u64>,
#[arg(long = "min-number-of-connections", env = "MIN_NUMBER_OF_CONNECTIONS")]
#[serde(
rename = "min-number-of-connections",
skip_serializing_if = "Option::is_none"
)]
pub min_connections: Option<usize>,
#[arg(long = "max-number-of-connections", env = "MAX_NUMBER_OF_CONNECTIONS")]
#[serde(
rename = "max-number-of-connections",
skip_serializing_if = "Option::is_none"
)]
pub max_connections: Option<usize>,
#[arg(long, env = "STREAMING_THRESHOLD")]
#[serde(
rename = "streaming-threshold",
skip_serializing_if = "Option::is_none"
)]
pub streaming_threshold: Option<usize>,
#[arg(long, env = "LEDBAT_MIN_SSTHRESH")]
#[serde(
rename = "ledbat-min-ssthresh",
skip_serializing_if = "Option::is_none"
)]
pub ledbat_min_ssthresh: Option<usize>,
#[arg(long, env = "FREENET_CONGESTION_CONTROL")]
#[serde(rename = "congestion-control", skip_serializing_if = "Option::is_none")]
pub congestion_control: Option<String>,
#[arg(long, env = "FREENET_BBR_STARTUP_RATE")]
#[serde(rename = "bbr-startup-rate", skip_serializing_if = "Option::is_none")]
pub bbr_startup_rate: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InlineGwConfig {
pub address: SocketAddr,
#[serde(rename = "public_key")]
pub public_key_path: PathBuf,
pub location: Option<f64>,
}
fn parse_gateway(input: &str, secrets_dir: &Path) -> anyhow::Result<GatewayConfig> {
let (addr_str, key_hex) = input.split_once(',').ok_or_else(|| {
anyhow::anyhow!(
"Invalid --gateway format: expected \"ip:port,hex-pubkey\", got \"{input}\""
)
})?;
let addr: SocketAddr = addr_str
.trim()
.parse()
.map_err(|e| anyhow::anyhow!("Invalid socket address \"{addr_str}\" in --gateway: {e}"))?;
let key_bytes = hex::decode(key_hex.trim())
.map_err(|e| anyhow::anyhow!("Invalid hex public key in --gateway: {e}"))?;
if key_bytes.len() != 32 {
anyhow::bail!(
"Invalid public key length {} in --gateway (expected 32 bytes / 64 hex chars)",
key_bytes.len()
);
}
fs::create_dir_all(secrets_dir)?;
let key_filename = format!("cli_gw_{}.pub", hex::encode(addr.to_string()));
let key_path = secrets_dir.join(&key_filename);
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::OpenOptionsExt;
let mut file = fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&key_path)?;
file.write_all(key_hex.trim().as_bytes())?;
}
#[cfg(not(unix))]
{
fs::write(&key_path, key_hex.trim())?;
}
Ok(GatewayConfig {
address: Address::HostAddress(addr),
public_key_path: key_path,
location: None,
})
}
impl NetworkArgs {
pub(crate) fn validate(&self) -> anyhow::Result<()> {
if self.is_gateway {
if self.public_address.is_none() {
return Err(anyhow::anyhow!(
"Gateway nodes must specify a public network address"
));
}
if self.public_port.is_none() && self.network_port.is_none() {
return Err(anyhow::anyhow!("Gateway nodes must specify a network port"));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkApiConfig {
#[serde(default = "default_listening_address", rename = "network-address")]
pub address: IpAddr,
#[serde(default = "default_network_api_port", rename = "network-port")]
pub port: u16,
#[serde(
rename = "public_network_address",
skip_serializing_if = "Option::is_none"
)]
pub public_address: Option<IpAddr>,
#[serde(rename = "public_port", skip_serializing_if = "Option::is_none")]
pub public_port: Option<u16>,
#[serde(skip)]
pub ignore_protocol_version: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub bandwidth_limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_bandwidth_limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_bandwidth_per_connection: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_addresses: Option<HashSet<SocketAddr>>,
#[serde(default = "default_transient_budget", rename = "transient-budget")]
pub transient_budget: usize,
#[serde(default = "default_transient_ttl_secs", rename = "transient-ttl-secs")]
pub transient_ttl_secs: u64,
#[serde(
default = "default_min_connections",
rename = "min-number-of-connections"
)]
pub min_connections: usize,
#[serde(
default = "default_max_connections",
rename = "max-number-of-connections"
)]
pub max_connections: usize,
#[serde(
default = "default_streaming_threshold",
rename = "streaming-threshold"
)]
pub streaming_threshold: usize,
#[serde(
default = "default_ledbat_min_ssthresh",
rename = "ledbat-min-ssthresh",
skip_serializing_if = "Option::is_none"
)]
pub ledbat_min_ssthresh: Option<usize>,
#[serde(default = "default_congestion_control", rename = "congestion-control")]
pub congestion_control: String,
#[serde(
default = "default_bbr_startup_rate",
rename = "bbr-startup-rate",
skip_serializing_if = "Option::is_none"
)]
pub bbr_startup_rate: Option<u64>,
#[serde(default)]
pub skip_load_from_network: bool,
}
impl NetworkApiConfig {
pub fn build_congestion_config(&self) -> CongestionControlConfig {
let algo = match self.congestion_control.to_lowercase().as_str() {
"bbr" => CongestionControlAlgorithm::Bbr,
"ledbat" => CongestionControlAlgorithm::Ledbat,
_ => CongestionControlAlgorithm::FixedRate, };
let mut config = CongestionControlConfig::new(algo);
if algo == CongestionControlAlgorithm::Bbr {
if let Some(rate) = self.bbr_startup_rate {
tracing::debug!("Using custom BBR startup pacing rate: {} bytes/sec", rate);
config = config.with_startup_min_pacing_rate(rate);
}
}
config
}
}
mod port_allocation;
use port_allocation::find_available_port;
pub fn default_network_api_port() -> u16 {
find_available_port().unwrap_or(31337) }
fn default_transient_budget() -> usize {
DEFAULT_TRANSIENT_BUDGET
}
fn default_transient_ttl_secs() -> u64 {
DEFAULT_TRANSIENT_TTL_SECS
}
fn default_min_connections() -> usize {
DEFAULT_MIN_CONNECTIONS
}
fn default_max_connections() -> usize {
DEFAULT_MAX_CONNECTIONS
}
fn default_streaming_threshold() -> usize {
64 * 1024
}
fn default_ledbat_min_ssthresh() -> Option<usize> {
Some(100 * 1024) }
fn default_congestion_control() -> String {
"fixedrate".to_string()
}
fn default_bbr_startup_rate() -> Option<u64> {
None
}
#[derive(clap::Parser, Debug, Default, Clone, Serialize, Deserialize)]
pub struct WebsocketApiArgs {
#[arg(
name = "ws_api_address",
long = "ws-api-address",
env = "WS_API_ADDRESS"
)]
#[serde(rename = "ws-api-address", skip_serializing_if = "Option::is_none")]
pub address: Option<IpAddr>,
#[arg(long, env = "WS_API_PORT")]
#[serde(rename = "ws-api-port", skip_serializing_if = "Option::is_none")]
pub ws_api_port: Option<u16>,
#[arg(long, env = "TOKEN_TTL_SECONDS")]
#[serde(rename = "token-ttl-seconds", skip_serializing_if = "Option::is_none")]
pub token_ttl_seconds: Option<u64>,
#[arg(long, env = "TOKEN_CLEANUP_INTERVAL_SECONDS")]
#[serde(
rename = "token-cleanup-interval-seconds",
skip_serializing_if = "Option::is_none"
)]
pub token_cleanup_interval_seconds: Option<u64>,
#[arg(long, env = "ALLOWED_HOST")]
#[serde(rename = "allowed-host", skip_serializing_if = "Option::is_none")]
pub allowed_host: Option<Vec<String>>,
#[arg(
long = "allowed-source-cidrs",
env = "ALLOWED_SOURCE_CIDRS",
value_delimiter = ','
)]
#[serde(
rename = "allowed-source-cidrs",
skip_serializing_if = "Option::is_none"
)]
pub allowed_source_cidrs: Option<Vec<String>>,
}
pub const DEFAULT_TELEMETRY_ENDPOINT: &str = "http://nova.locut.us:4318";
#[derive(clap::Parser, Debug, Clone, Serialize, Deserialize)]
pub struct TelemetryArgs {
#[arg(
long = "telemetry-enabled",
env = "FREENET_TELEMETRY_ENABLED",
default_value = "true"
)]
#[serde(rename = "telemetry-enabled", default = "default_telemetry_enabled")]
pub enabled: bool,
#[arg(long = "telemetry-endpoint", env = "FREENET_TELEMETRY_ENDPOINT")]
#[serde(rename = "telemetry-endpoint", skip_serializing_if = "Option::is_none")]
pub endpoint: Option<String>,
#[arg(
long = "transport-snapshot-interval-secs",
env = "FREENET_TRANSPORT_SNAPSHOT_INTERVAL_SECS"
)]
#[serde(
rename = "transport-snapshot-interval-secs",
skip_serializing_if = "Option::is_none"
)]
pub transport_snapshot_interval_secs: Option<u64>,
}
impl Default for TelemetryArgs {
fn default() -> Self {
Self {
enabled: true,
endpoint: None,
transport_snapshot_interval_secs: None,
}
}
}
fn default_telemetry_enabled() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TelemetryConfig {
#[serde(default = "default_telemetry_enabled", rename = "telemetry-enabled")]
pub enabled: bool,
#[serde(default = "default_telemetry_endpoint", rename = "telemetry-endpoint")]
pub endpoint: String,
#[serde(
default = "default_transport_snapshot_interval_secs",
rename = "transport-snapshot-interval-secs"
)]
pub transport_snapshot_interval_secs: u64,
#[serde(skip)]
pub is_test_environment: bool,
}
fn default_transport_snapshot_interval_secs() -> u64 {
30
}
fn default_telemetry_endpoint() -> String {
DEFAULT_TELEMETRY_ENDPOINT.to_string()
}
impl Default for TelemetryConfig {
fn default() -> Self {
Self {
enabled: true,
endpoint: DEFAULT_TELEMETRY_ENDPOINT.to_string(),
transport_snapshot_interval_secs: default_transport_snapshot_interval_secs(),
is_test_environment: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebsocketApiConfig {
#[serde(default = "default_listening_address", rename = "ws-api-address")]
pub address: IpAddr,
#[serde(default = "default_ws_api_port", rename = "ws-api-port")]
pub port: u16,
#[serde(default = "default_token_ttl_seconds", rename = "token-ttl-seconds")]
pub token_ttl_seconds: u64,
#[serde(
default = "default_token_cleanup_interval_seconds",
rename = "token-cleanup-interval-seconds"
)]
pub token_cleanup_interval_seconds: u64,
#[serde(default, rename = "allowed-host")]
pub allowed_hosts: Vec<String>,
#[serde(default, rename = "allowed-source-cidrs")]
pub allowed_source_cidrs: Vec<ipnet::IpNet>,
}
#[inline]
const fn default_token_ttl_seconds() -> u64 {
86400 }
#[inline]
const fn default_token_cleanup_interval_seconds() -> u64 {
300 }
impl From<SocketAddr> for WebsocketApiConfig {
fn from(addr: SocketAddr) -> Self {
Self {
address: addr.ip(),
port: addr.port(),
token_ttl_seconds: default_token_ttl_seconds(),
token_cleanup_interval_seconds: default_token_cleanup_interval_seconds(),
allowed_hosts: Vec::new(),
allowed_source_cidrs: Vec::new(),
}
}
}
impl Default for WebsocketApiConfig {
#[inline]
fn default() -> Self {
Self {
address: default_listening_address(),
port: default_ws_api_port(),
token_ttl_seconds: default_token_ttl_seconds(),
token_cleanup_interval_seconds: default_token_cleanup_interval_seconds(),
allowed_hosts: Vec::new(),
allowed_source_cidrs: Vec::new(),
}
}
}
#[inline]
const fn default_listening_address() -> IpAddr {
IpAddr::V6(Ipv6Addr::UNSPECIFIED)
}
#[inline]
const fn default_local_address() -> IpAddr {
IpAddr::V6(Ipv6Addr::LOCALHOST)
}
#[inline]
const fn default_ws_api_port() -> u16 {
7509
}
#[derive(clap::Parser, Default, Debug, Clone, Serialize, Deserialize)]
pub struct ConfigPathsArgs {
#[arg(long, default_value = None, env = "CONFIG_DIR")]
pub config_dir: Option<PathBuf>,
#[arg(long, default_value = None, env = "DATA_DIR")]
pub data_dir: Option<PathBuf>,
#[arg(long, default_value = None, env = "LOG_DIR")]
pub log_dir: Option<PathBuf>,
}
impl ConfigPathsArgs {
fn merge(&mut self, other: ConfigPaths) {
self.config_dir.get_or_insert(other.config_dir);
self.data_dir.get_or_insert(other.data_dir);
self.log_dir = self.log_dir.take().or(other.log_dir);
}
fn default_dirs(id: Option<&str>) -> std::io::Result<Either<ProjectDirs, PathBuf>> {
let default_dir: Either<_, _> = if cfg!(any(test, debug_assertions)) || id.is_some() {
let base_name = if let Some(id) = id {
format!("freenet-{id}")
} else {
"freenet".into()
};
let temp_path = std::env::temp_dir().join(&base_name);
if temp_path.exists() && fs::remove_dir_all(&temp_path).is_err() {
let unique_path =
std::env::temp_dir().join(format!("{}-{}", base_name, std::process::id()));
let _cleanup = fs::remove_dir_all(&unique_path);
return Ok(Either::Right(unique_path));
}
Either::Right(temp_path)
} else {
Either::Left(
ProjectDirs::from(QUALIFIER, ORGANIZATION, APPLICATION)
.ok_or(std::io::ErrorKind::NotFound)?,
)
};
Ok(default_dir)
}
pub fn build(self, id: Option<&str>) -> std::io::Result<ConfigPaths> {
#[allow(unused_variables)]
let has_custom_data_dir = self.data_dir.is_some();
let app_data_dir = self
.data_dir
.map(Ok::<_, std::io::Error>)
.unwrap_or_else(|| {
let default_dirs = Self::default_dirs(id)?;
let Either::Left(defaults) = default_dirs else {
unreachable!("default_dirs should return Left if data_dir is None and id is not set for temp dir")
};
Ok(defaults.data_local_dir().to_path_buf())
})?;
#[cfg(target_os = "windows")]
if !has_custom_data_dir && id.is_none() {
if let Ok(Either::Left(ref proj)) = Self::default_dirs(None) {
let old_roaming = proj.data_dir().to_path_buf();
if old_roaming != app_data_dir
&& old_roaming.join("contracts").exists()
&& !app_data_dir.join("contracts").exists()
{
tracing::info!(
old = ?old_roaming,
new = ?app_data_dir,
"Migrating data from Roaming to Local AppData"
);
if let Some(parent) = app_data_dir.parent() {
let _ = fs::create_dir_all(parent);
}
if let Err(e) = fs::rename(&old_roaming, &app_data_dir) {
tracing::warn!(
error = %e,
"Failed to migrate data directory; starting fresh"
);
}
}
}
}
let contracts_dir = app_data_dir.join("contracts");
let delegates_dir = app_data_dir.join("delegates");
let secrets_dir = app_data_dir.join("secrets");
let db_dir = app_data_dir.join("db");
if !contracts_dir.exists() {
fs::create_dir_all(&contracts_dir)?;
fs::create_dir_all(contracts_dir.join("local"))?;
}
if !delegates_dir.exists() {
fs::create_dir_all(&delegates_dir)?;
fs::create_dir_all(delegates_dir.join("local"))?;
}
if !secrets_dir.exists() {
fs::create_dir_all(&secrets_dir)?;
fs::create_dir_all(secrets_dir.join("local"))?;
}
if !db_dir.exists() {
fs::create_dir_all(&db_dir)?;
fs::create_dir_all(db_dir.join("local"))?;
}
let event_log = app_data_dir.join("_EVENT_LOG");
if !event_log.exists() {
fs::write(&event_log, [])?;
let mut local_file = event_log.clone();
local_file.set_file_name("_EVENT_LOG_LOCAL");
fs::write(local_file, [])?;
}
let config_dir = self
.config_dir
.map(Ok::<_, std::io::Error>)
.unwrap_or_else(|| {
let default_dirs = Self::default_dirs(id)?;
let Either::Left(defaults) = default_dirs else {
unreachable!("default_dirs should return Left if config_dir is None and id is not set for temp dir")
};
Ok(defaults.config_dir().to_path_buf())
})?;
let log_dir = self.log_dir.or_else(get_log_dir);
Ok(ConfigPaths {
config_dir,
data_dir: app_data_dir,
contracts_dir,
delegates_dir,
secrets_dir,
db_dir,
event_log,
log_dir,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigPaths {
contracts_dir: PathBuf,
delegates_dir: PathBuf,
secrets_dir: PathBuf,
db_dir: PathBuf,
event_log: PathBuf,
data_dir: PathBuf,
config_dir: PathBuf,
#[serde(default = "get_log_dir")]
log_dir: Option<PathBuf>,
}
impl ConfigPaths {
pub fn db_dir(&self, mode: OperationMode) -> PathBuf {
match mode {
OperationMode::Local => self.db_dir.join("local"),
OperationMode::Network => self.db_dir.to_owned(),
}
}
pub fn with_db_dir(mut self, db_dir: PathBuf) -> Self {
self.db_dir = db_dir;
self
}
pub fn contracts_dir(&self, mode: OperationMode) -> PathBuf {
match mode {
OperationMode::Local => self.contracts_dir.join("local"),
OperationMode::Network => self.contracts_dir.to_owned(),
}
}
pub fn with_contract_dir(mut self, contracts_dir: PathBuf) -> Self {
self.contracts_dir = contracts_dir;
self
}
pub fn delegates_dir(&self, mode: OperationMode) -> PathBuf {
match mode {
OperationMode::Local => self.delegates_dir.join("local"),
OperationMode::Network => self.delegates_dir.to_owned(),
}
}
pub fn with_delegates_dir(mut self, delegates_dir: PathBuf) -> Self {
self.delegates_dir = delegates_dir;
self
}
pub fn config_dir(&self) -> PathBuf {
self.config_dir.clone()
}
pub fn data_dir(&self) -> PathBuf {
self.data_dir.clone()
}
pub fn secrets_dir(&self, mode: OperationMode) -> PathBuf {
match mode {
OperationMode::Local => self.secrets_dir.join("local"),
OperationMode::Network => self.secrets_dir.to_owned(),
}
}
pub fn with_secrets_dir(mut self, secrets_dir: PathBuf) -> Self {
self.secrets_dir = secrets_dir;
self
}
pub fn event_log(&self, mode: OperationMode) -> PathBuf {
match mode {
OperationMode::Local => {
let mut local_file = self.event_log.clone();
local_file.set_file_name("_EVENT_LOG_LOCAL");
local_file
}
OperationMode::Network => self.event_log.to_owned(),
}
}
pub fn log_dir(&self) -> Option<&Path> {
self.log_dir.as_deref()
}
pub fn with_event_log(mut self, event_log: PathBuf) -> Self {
self.event_log = event_log;
self
}
pub fn iter(&self) -> ConfigPathsIter<'_> {
ConfigPathsIter {
curr: 0,
config_paths: self,
}
}
fn path_by_index(&self, index: usize) -> (bool, &PathBuf) {
match index {
0 => (true, &self.contracts_dir),
1 => (true, &self.delegates_dir),
2 => (true, &self.secrets_dir),
3 => (true, &self.db_dir),
4 => (true, &self.data_dir),
5 => (false, &self.event_log),
6 => (true, &self.config_dir),
_ => panic!("invalid path index"),
}
}
const MAX_PATH_INDEX: usize = 6;
}
pub struct ConfigPathsIter<'a> {
curr: usize,
config_paths: &'a ConfigPaths,
}
impl<'a> Iterator for ConfigPathsIter<'a> {
type Item = (bool, &'a PathBuf);
fn next(&mut self) -> Option<Self::Item> {
if self.curr > ConfigPaths::MAX_PATH_INDEX {
None
} else {
let path = self.config_paths.path_by_index(self.curr);
self.curr += 1;
Some(path)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(ConfigPaths::MAX_PATH_INDEX))
}
}
impl core::iter::FusedIterator for ConfigPathsIter<'_> {}
impl Config {
pub fn db_dir(&self) -> PathBuf {
self.config_paths.db_dir(self.mode)
}
pub fn contracts_dir(&self) -> PathBuf {
self.config_paths.contracts_dir(self.mode)
}
pub fn delegates_dir(&self) -> PathBuf {
self.config_paths.delegates_dir(self.mode)
}
pub fn secrets_dir(&self) -> PathBuf {
self.config_paths.secrets_dir(self.mode)
}
pub fn event_log(&self) -> PathBuf {
self.config_paths.event_log(self.mode)
}
pub fn config_dir(&self) -> PathBuf {
self.config_paths.config_dir()
}
pub fn data_dir(&self) -> PathBuf {
self.config_paths.data_dir()
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
struct Gateways {
pub gateways: Vec<GatewayConfig>,
}
impl Gateways {
pub fn merge_and_deduplicate(&mut self, other: Gateways) {
let mut seen: HashSet<Address> = HashSet::new();
let mut merged = Vec::with_capacity(self.gateways.len() + other.gateways.len());
for gw in self.gateways.drain(..).chain(other.gateways) {
if seen.insert(gw.address.clone()) {
merged.push(gw);
}
}
self.gateways = merged;
}
pub fn save_to_file(&self, path: &Path) -> anyhow::Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let content = toml::to_string(self)?;
fs::write(path, content)?;
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GatewayConfig {
pub address: Address,
#[serde(rename = "public_key")]
pub public_key_path: PathBuf,
#[serde(skip_serializing_if = "Option::is_none")]
pub location: Option<f64>,
}
impl PartialEq for GatewayConfig {
fn eq(&self, other: &Self) -> bool {
self.address == other.address
}
}
impl Eq for GatewayConfig {}
impl std::hash::Hash for GatewayConfig {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.address.hash(state);
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
pub enum Address {
#[serde(rename = "hostname")]
Hostname(String),
#[serde(rename = "host_address")]
HostAddress(SocketAddr),
}
pub struct GlobalExecutor;
impl GlobalExecutor {
pub(crate) fn initialize_async_rt() -> Option<Runtime> {
if tokio::runtime::Handle::try_current().is_ok() {
tracing::debug!(target: "freenet::diagnostics::thread_explosion", "GlobalExecutor: runtime exists");
None
} else {
tracing::warn!(target: "freenet::diagnostics::thread_explosion", "GlobalExecutor: Creating fallback runtime");
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.enable_all().thread_name("freenet-node");
if cfg!(debug_assertions) {
builder.worker_threads(2).max_blocking_threads(2);
}
Some(builder.build().expect("failed to build tokio runtime"))
}
}
#[inline]
pub fn spawn<R: Send + 'static>(
f: impl Future<Output = R> + Send + 'static,
) -> tokio::task::JoinHandle<R> {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(f)
} else if let Some(rt) = &*ASYNC_RT {
tracing::warn!(target: "freenet::diagnostics::thread_explosion", "GlobalExecutor::spawn using fallback");
rt.spawn(f)
} else {
unreachable!("ASYNC_RT should be initialized if Handle::try_current fails")
}
}
}
use rand::rngs::SmallRng;
use rand::{Rng, RngCore, SeedableRng};
static THREAD_INDEX_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
std::thread_local! {
static THREAD_RNG: std::cell::RefCell<Option<SmallRng>> = const { std::cell::RefCell::new(None) };
static THREAD_INDEX: std::cell::Cell<Option<u64>> = const { std::cell::Cell::new(None) };
static THREAD_SEED: std::cell::Cell<Option<u64>> = const { std::cell::Cell::new(None) };
}
pub struct GlobalRng;
pub struct SeedGuard {
_private: (),
}
impl Drop for SeedGuard {
fn drop(&mut self) {
GlobalRng::clear_seed();
}
}
impl GlobalRng {
pub fn set_seed(seed: u64) {
THREAD_SEED.with(|s| s.set(Some(seed)));
THREAD_RNG.with(|rng| {
*rng.borrow_mut() = None;
});
THREAD_INDEX.with(|idx| idx.set(Some(0)));
}
pub fn clear_seed() {
THREAD_SEED.with(|s| s.set(None));
THREAD_RNG.with(|rng| {
*rng.borrow_mut() = None;
});
THREAD_INDEX.with(|idx| idx.set(None));
}
pub fn thread_index() -> u64 {
THREAD_INDEX.with(|c| match c.get() {
Some(idx) => idx,
None => {
let idx = THREAD_INDEX_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
c.set(Some(idx));
idx
}
})
}
pub fn is_seeded() -> bool {
THREAD_SEED.with(|s| s.get()).is_some()
}
pub fn seed_guard(seed: u64) -> SeedGuard {
Self::set_seed(seed);
SeedGuard { _private: () }
}
pub fn scoped_seed<F, R>(seed: u64, f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = Self::seed_guard(seed);
f()
}
#[inline]
pub fn with_rng<F, R>(f: F) -> R
where
F: FnOnce(&mut dyn RngCore) -> R,
{
let seed = THREAD_SEED.with(|s| s.get());
if let Some(seed) = seed {
THREAD_RNG.with(|rng_cell| {
let mut rng_ref = rng_cell.borrow_mut();
if rng_ref.is_none() {
let thread_seed =
seed.wrapping_add(Self::thread_index().wrapping_mul(0x9E3779B97F4A7C15));
*rng_ref = Some(SmallRng::seed_from_u64(thread_seed));
}
f(rng_ref.as_mut().unwrap())
})
} else {
f(&mut rand::rng())
}
}
#[inline]
pub fn random_range<T, R>(range: R) -> T
where
T: rand::distr::uniform::SampleUniform,
R: rand::distr::uniform::SampleRange<T>,
{
Self::with_rng(|rng| rng.random_range(range))
}
#[inline]
pub fn random_bool(probability: f64) -> bool {
Self::with_rng(|rng| rng.random_bool(probability))
}
#[inline]
pub fn choose<T>(slice: &[T]) -> Option<&T> {
if slice.is_empty() {
None
} else {
let idx = Self::random_range(0..slice.len());
Some(&slice[idx])
}
}
#[inline]
pub fn shuffle<T>(slice: &mut [T]) {
Self::with_rng(|rng| {
use rand::seq::SliceRandom;
slice.shuffle(rng);
})
}
#[inline]
pub fn fill_bytes(dest: &mut [u8]) {
Self::with_rng(|rng| rng.fill_bytes(dest))
}
#[inline]
pub fn random_u64() -> u64 {
Self::with_rng(|rng| rng.random())
}
#[inline]
pub fn random_u32() -> u32 {
Self::with_rng(|rng| rng.random())
}
}
std::thread_local! {
static SIMULATION_TIME_MS: std::cell::Cell<Option<u64>> = const { std::cell::Cell::new(None) };
static SIMULATION_TIME_COUNTER: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
pub struct GlobalSimulationTime;
impl GlobalSimulationTime {
pub fn set_time_ms(time_ms: u64) {
SIMULATION_TIME_MS.with(|t| t.set(Some(time_ms)));
SIMULATION_TIME_COUNTER.with(|c| c.set(0));
}
pub fn clear_time() {
SIMULATION_TIME_MS.with(|t| t.set(None));
SIMULATION_TIME_COUNTER.with(|c| c.set(0));
}
pub fn current_time_ms() -> u64 {
SIMULATION_TIME_MS.with(|t| {
if let Some(base_time) = t.get() {
let counter = SIMULATION_TIME_COUNTER.with(|c| {
let val = c.get();
c.set(val + 1);
val
});
base_time.saturating_add(counter)
} else {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before unix epoch")
.as_millis() as u64
}
})
}
pub fn read_time_ms() -> u64 {
SIMULATION_TIME_MS.with(|t| {
if let Some(base_time) = t.get() {
let counter = SIMULATION_TIME_COUNTER.with(|c| c.get());
base_time.saturating_add(counter)
} else {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time before unix epoch")
.as_millis() as u64
}
})
}
pub fn is_simulation_time() -> bool {
SIMULATION_TIME_MS.with(|t| t.get().is_some())
}
pub fn new_ulid() -> ulid::Ulid {
use ulid::Ulid;
if GlobalRng::is_seeded() || Self::is_simulation_time() {
let timestamp_ms = Self::current_time_ms();
let mut random_bytes = [0u8; 10];
GlobalRng::fill_bytes(&mut random_bytes);
let ts = (timestamp_ms as u128) << 80;
let rand_high = (random_bytes[0] as u128) << 72;
let rand_mid = u64::from_be_bytes([
random_bytes[1],
random_bytes[2],
random_bytes[3],
random_bytes[4],
random_bytes[5],
random_bytes[6],
random_bytes[7],
random_bytes[8],
]) as u128;
let rand_low = (random_bytes[9] as u128) << 56;
let ulid_value = ts | rand_high | (rand_mid << 8) | rand_low;
Ulid(ulid_value)
} else {
Ulid::new()
}
}
}
std::thread_local! {
static SIMULATION_TRANSPORT_OPT: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
static SIMULATION_IDLE_TIMEOUT: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
}
pub struct SimulationTransportOpt;
impl SimulationTransportOpt {
pub fn enable() {
SIMULATION_TRANSPORT_OPT.with(|f| f.set(true));
}
pub fn disable() {
SIMULATION_TRANSPORT_OPT.with(|f| f.set(false));
}
pub fn is_enabled() -> bool {
SIMULATION_TRANSPORT_OPT.with(|f| f.get())
}
}
pub struct SimulationIdleTimeout;
impl SimulationIdleTimeout {
pub fn enable() {
SIMULATION_IDLE_TIMEOUT.with(|f| f.set(true));
}
pub fn disable() {
SIMULATION_IDLE_TIMEOUT.with(|f| f.set(false));
}
pub fn is_enabled() -> bool {
SIMULATION_IDLE_TIMEOUT.with(|f| f.get())
}
}
std::thread_local! {
static GLOBAL_RESYNC_REQUESTS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static GLOBAL_DELTA_SENDS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static GLOBAL_FULL_STATE_SENDS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static GLOBAL_PENDING_OP_INSERTS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static GLOBAL_PENDING_OP_REMOVES: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static GLOBAL_PENDING_OP_HWM: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static GLOBAL_NEIGHBOR_HOSTING_UPDATES: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
static GLOBAL_ANTI_STARVATION_TRIGGERS: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
pub struct GlobalTestMetrics;
impl GlobalTestMetrics {
pub fn reset() {
GLOBAL_RESYNC_REQUESTS.with(|c| c.set(0));
GLOBAL_DELTA_SENDS.with(|c| c.set(0));
GLOBAL_FULL_STATE_SENDS.with(|c| c.set(0));
GLOBAL_PENDING_OP_INSERTS.with(|c| c.set(0));
GLOBAL_PENDING_OP_REMOVES.with(|c| c.set(0));
GLOBAL_PENDING_OP_HWM.with(|c| c.set(0));
GLOBAL_NEIGHBOR_HOSTING_UPDATES.with(|c| c.set(0));
GLOBAL_ANTI_STARVATION_TRIGGERS.with(|c| c.set(0));
}
pub fn record_resync_request() {
GLOBAL_RESYNC_REQUESTS.with(|c| c.set(c.get() + 1));
}
pub fn resync_requests() -> u64 {
GLOBAL_RESYNC_REQUESTS.with(|c| c.get())
}
pub fn record_delta_send() {
GLOBAL_DELTA_SENDS.with(|c| c.set(c.get() + 1));
}
pub fn delta_sends() -> u64 {
GLOBAL_DELTA_SENDS.with(|c| c.get())
}
pub fn record_full_state_send() {
GLOBAL_FULL_STATE_SENDS.with(|c| c.set(c.get() + 1));
}
pub fn full_state_sends() -> u64 {
GLOBAL_FULL_STATE_SENDS.with(|c| c.get())
}
pub fn record_pending_op_insert() {
GLOBAL_PENDING_OP_INSERTS.with(|c| c.set(c.get() + 1));
}
pub fn pending_op_inserts() -> u64 {
GLOBAL_PENDING_OP_INSERTS.with(|c| c.get())
}
pub fn record_pending_op_remove() {
GLOBAL_PENDING_OP_REMOVES.with(|c| c.set(c.get() + 1));
}
pub fn pending_op_removes() -> u64 {
GLOBAL_PENDING_OP_REMOVES.with(|c| c.get())
}
pub fn record_pending_op_size(len: u64) {
GLOBAL_PENDING_OP_HWM.with(|c| c.set(c.get().max(len)));
}
pub fn pending_op_high_water_mark() -> u64 {
GLOBAL_PENDING_OP_HWM.with(|c| c.get())
}
pub fn record_neighbor_hosting_update() {
GLOBAL_NEIGHBOR_HOSTING_UPDATES.with(|c| c.set(c.get() + 1));
}
pub fn neighbor_hosting_updates() -> u64 {
GLOBAL_NEIGHBOR_HOSTING_UPDATES.with(|c| c.get())
}
pub fn record_anti_starvation_trigger() {
GLOBAL_ANTI_STARVATION_TRIGGERS.with(|c| c.set(c.get() + 1));
}
pub fn anti_starvation_triggers() -> u64 {
GLOBAL_ANTI_STARVATION_TRIGGERS.with(|c| c.get())
}
}
pub fn set_logger(
level: Option<tracing::level_filters::LevelFilter>,
endpoint: Option<String>,
log_dir: Option<&Path>,
) {
#[cfg(feature = "trace")]
{
static LOGGER_SET: AtomicBool = AtomicBool::new(false);
if LOGGER_SET
.compare_exchange(
false,
true,
std::sync::atomic::Ordering::Release,
std::sync::atomic::Ordering::SeqCst,
)
.is_err()
{
return;
}
crate::tracing::tracer::init_tracer(level, endpoint, log_dir)
.expect("failed tracing initialization")
}
}
async fn load_gateways_from_index(url: &str, pub_keys_dir: &Path) -> anyhow::Result<Gateways> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(30))
.build()?;
let response = client
.get(url)
.send()
.await?
.error_for_status()?
.text()
.await?;
let mut gateways: Gateways = toml::from_str(&response)?;
let mut base_url = reqwest::Url::parse(url)?;
base_url.set_path("");
let mut valid_gateways = Vec::new();
for gateway in &mut gateways.gateways {
gateway.location = None; let public_key_url = base_url.join(&gateway.public_key_path.to_string_lossy())?;
let public_key_response = client
.get(public_key_url)
.send()
.await?
.error_for_status()?;
let file_name = gateway
.public_key_path
.file_name()
.ok_or_else(|| anyhow::anyhow!("Invalid public key path"))?;
let local_path = pub_keys_dir.join(file_name);
let mut public_key_file = File::create(&local_path)?;
let content = public_key_response.bytes().await?;
std::io::copy(&mut content.as_ref(), &mut public_key_file)?;
let mut key_file = File::open(&local_path).with_context(|| {
format!(
"failed loading gateway pubkey from {:?}",
gateway.public_key_path
)
})?;
let mut buf = String::new();
key_file.read_to_string(&mut buf)?;
let buf = buf.trim();
if buf.starts_with("-----BEGIN") {
tracing::warn!(
public_key_path = ?gateway.public_key_path,
"Gateway uses legacy RSA PEM public key format. \
Gateway needs to be updated to X25519 format. Skipping."
);
continue;
}
if let Ok(key_bytes) = hex::decode(buf) {
if key_bytes.len() == 32 {
gateway.public_key_path = local_path;
valid_gateways.push(gateway.clone());
} else {
tracing::warn!(
public_key_path = ?gateway.public_key_path,
"Invalid public key length {} (expected 32), ignoring",
key_bytes.len()
);
}
} else {
tracing::warn!(
public_key_path = ?gateway.public_key_path,
"Invalid public key hex encoding in remote gateway file, ignoring"
);
}
}
gateways.gateways = valid_gateways;
Ok(gateways)
}
#[cfg(test)]
mod tests {
use httptest::{Expectation, Server, matchers::*, responders::*};
use crate::node::NodeConfig;
use crate::transport::TransportKeypair;
use super::*;
#[tokio::test]
async fn test_serde_config_args() {
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
..Default::default()
};
let cfg = args.build().await.unwrap();
let serialized = toml::to_string(&cfg).unwrap();
let _: Config = toml::from_str(&serialized).unwrap();
}
async fn build_with_cidrs(cidrs: Option<Vec<String>>) -> anyhow::Result<Config> {
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
ws_api: WebsocketApiArgs {
allowed_source_cidrs: cidrs,
..Default::default()
},
..Default::default()
};
args.build().await
}
#[tokio::test]
async fn allowed_source_cidrs_round_trip_through_build() {
let cfg = build_with_cidrs(Some(vec![
"100.64.0.0/10".to_string(),
"fd7a:115c:a1e0::/48".to_string(),
]))
.await
.unwrap();
assert_eq!(cfg.ws_api.allowed_source_cidrs.len(), 2);
assert_eq!(
cfg.ws_api.allowed_source_cidrs[0],
"100.64.0.0/10".parse::<ipnet::IpNet>().unwrap()
);
assert_eq!(
cfg.ws_api.allowed_source_cidrs[1],
"fd7a:115c:a1e0::/48".parse::<ipnet::IpNet>().unwrap()
);
}
#[tokio::test]
async fn allowed_source_cidrs_default_is_empty() {
let cfg = build_with_cidrs(None).await.unwrap();
assert!(cfg.ws_api.allowed_source_cidrs.is_empty());
}
#[tokio::test]
async fn allowed_source_cidrs_rejects_malformed() {
let err = build_with_cidrs(Some(vec!["not-a-cidr".to_string()]))
.await
.unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("allowed-source-cidrs") && msg.contains("not-a-cidr"),
"error should name the field and the offending value: {msg}"
);
}
#[tokio::test]
async fn allowed_source_cidrs_rejects_whole_internet_catchall() {
let err = build_with_cidrs(Some(vec!["0.0.0.0/0".to_string()]))
.await
.unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("0.0.0.0/0") && msg.contains("/8"),
"error should explain why and name the minimum: {msg}"
);
}
#[tokio::test]
async fn allowed_source_cidrs_rejects_ipv6_catchall() {
let err = build_with_cidrs(Some(vec!["::/0".to_string()]))
.await
.unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("::/0") && msg.contains("/16"));
}
async fn write_config_toml_with_ws_api(dir: &Path, ws_api_patch: &WebsocketApiConfig) {
let base_args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(dir.to_path_buf()),
data_dir: Some(dir.to_path_buf()),
log_dir: Some(dir.to_path_buf()),
},
..Default::default()
};
let mut base_cfg = base_args.build().await.unwrap();
base_cfg.ws_api = ws_api_patch.clone();
let toml_str = toml::to_string(&base_cfg).unwrap();
std::fs::write(dir.join("config.toml"), toml_str).unwrap();
}
#[tokio::test]
async fn file_config_cidrs_merged_into_build() {
let temp_dir = tempfile::tempdir().unwrap();
write_config_toml_with_ws_api(
temp_dir.path(),
&WebsocketApiConfig {
allowed_source_cidrs: vec![
"100.64.0.0/10".parse().unwrap(),
"fd7a:115c:a1e0::/48".parse().unwrap(),
],
allowed_hosts: vec!["my-tailscale-host".to_string()],
..Default::default()
},
)
.await;
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
..Default::default()
};
let cfg = args.build().await.unwrap();
assert_eq!(
cfg.ws_api.allowed_source_cidrs.len(),
2,
"CIDRs from config.toml must be present in built config"
);
assert_eq!(
cfg.ws_api.allowed_source_cidrs[0],
"100.64.0.0/10".parse::<ipnet::IpNet>().unwrap()
);
assert_eq!(
cfg.ws_api.allowed_source_cidrs[1],
"fd7a:115c:a1e0::/48".parse::<ipnet::IpNet>().unwrap()
);
assert_eq!(
cfg.ws_api.allowed_hosts,
vec!["my-tailscale-host".to_string()],
"allowed-host from config.toml must be present in built config"
);
}
#[tokio::test]
async fn cli_cidrs_override_file_config() {
let temp_dir = tempfile::tempdir().unwrap();
write_config_toml_with_ws_api(
temp_dir.path(),
&WebsocketApiConfig {
allowed_source_cidrs: vec!["10.0.0.0/8".parse().unwrap()],
allowed_hosts: vec!["file-host".to_string()],
..Default::default()
},
)
.await;
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
ws_api: WebsocketApiArgs {
allowed_source_cidrs: Some(vec!["172.16.0.0/12".to_string()]),
allowed_host: Some(vec!["cli-host".to_string()]),
..Default::default()
},
..Default::default()
};
let cfg = args.build().await.unwrap();
assert_eq!(cfg.ws_api.allowed_source_cidrs.len(), 1);
assert_eq!(
cfg.ws_api.allowed_source_cidrs[0],
"172.16.0.0/12".parse::<ipnet::IpNet>().unwrap(),
"CLI value must win over file config"
);
assert_eq!(
cfg.ws_api.allowed_hosts,
vec!["cli-host".to_string()],
"CLI value must win over file config"
);
}
#[tokio::test]
async fn test_load_gateways_from_index() {
let server = Server::run();
server.expect(
Expectation::matching(all_of!(request::method("GET"), request::path("/gateways")))
.respond_with(status_code(200).body(
r#"
[[gateways]]
address = { hostname = "example.com" }
public_key = "/path/to/public_key.pem"
"#,
)),
);
let url = server.url_str("/gateways");
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
server.expect(
Expectation::matching(request::path("/path/to/public_key.pem"))
.respond_with(status_code(200).body(key_hex)),
);
let pub_keys_dir = tempfile::tempdir().unwrap();
let gateways = load_gateways_from_index(&url, pub_keys_dir.path())
.await
.unwrap();
assert_eq!(gateways.gateways.len(), 1);
assert_eq!(
gateways.gateways[0].address,
Address::Hostname("example.com".to_string())
);
assert_eq!(
gateways.gateways[0].public_key_path,
pub_keys_dir.path().join("public_key.pem")
);
assert!(pub_keys_dir.path().join("public_key.pem").exists());
}
#[test]
fn test_gateways() {
let gateways = Gateways {
gateways: vec![
GatewayConfig {
address: Address::HostAddress(
([127, 0, 0, 1], default_network_api_port()).into(),
),
public_key_path: PathBuf::from("path/to/key"),
location: None,
},
GatewayConfig {
address: Address::Hostname("technic.locut.us".to_string()),
public_key_path: PathBuf::from("path/to/key"),
location: None,
},
],
};
let serialized = toml::to_string(&gateways).unwrap();
let _: Gateways = toml::from_str(&serialized).unwrap();
}
#[tokio::test]
#[ignore = "Requires gateway keys to be updated to X25519 format (issue #2531)"]
async fn test_remote_freenet_gateways() {
let tmp_dir = tempfile::tempdir().unwrap();
let gateways = load_gateways_from_index(FREENET_GATEWAYS_INDEX, tmp_dir.path())
.await
.unwrap();
assert!(!gateways.gateways.is_empty());
for gw in gateways.gateways {
assert!(gw.public_key_path.exists());
let key_contents = std::fs::read_to_string(&gw.public_key_path).unwrap();
let key_bytes =
hex::decode(key_contents.trim()).expect("Gateway public key should be valid hex");
assert_eq!(
key_bytes.len(),
32,
"Gateway public key should be 32 bytes (X25519)"
);
let socket = NodeConfig::parse_socket_addr(&gw.address).await.unwrap();
assert!(socket.port() > 1024); }
}
#[test]
fn test_streaming_config_defaults_via_serde() {
let minimal_config = r#"
network-address = "127.0.0.1"
network-port = 8080
"#;
let network_api: NetworkApiConfig = toml::from_str(minimal_config).unwrap();
assert_eq!(
network_api.streaming_threshold,
64 * 1024,
"Default streaming threshold should be 64KB"
);
}
#[test]
fn test_streaming_config_serde() {
let config_str = r#"
network-address = "127.0.0.1"
network-port = 8080
streaming-threshold = 131072
"#;
let config: NetworkApiConfig = toml::from_str(config_str).unwrap();
assert_eq!(config.streaming_threshold, 128 * 1024);
let serialized = toml::to_string(&config).unwrap();
assert!(serialized.contains("streaming-threshold = 131072"));
}
#[test]
fn test_network_args_streaming_defaults() {
let args = NetworkArgs::default();
assert!(
args.streaming_threshold.is_none(),
"NetworkArgs.streaming_threshold should be None by default"
);
}
#[test]
fn test_congestion_control_config_defaults() {
let config_str = r#"
network-address = "127.0.0.1"
network-port = 8080
"#;
let network_api: NetworkApiConfig = toml::from_str(config_str).unwrap();
assert_eq!(
network_api.congestion_control, "fixedrate",
"Default congestion control should be fixedrate"
);
assert!(
network_api.bbr_startup_rate.is_none(),
"Default BBR startup rate should be None"
);
let cc_config = network_api.build_congestion_config();
assert_eq!(cc_config.algorithm, CongestionControlAlgorithm::FixedRate);
}
#[test]
fn test_congestion_control_config_bbr() {
let config_str = r#"
network-address = "127.0.0.1"
network-port = 8080
congestion-control = "bbr"
bbr-startup-rate = 10000000
"#;
let config: NetworkApiConfig = toml::from_str(config_str).unwrap();
assert_eq!(config.congestion_control, "bbr");
assert_eq!(config.bbr_startup_rate, Some(10_000_000));
let cc_config = config.build_congestion_config();
assert_eq!(cc_config.algorithm, CongestionControlAlgorithm::Bbr);
}
#[test]
fn test_congestion_control_config_ledbat() {
let config_str = r#"
network-address = "127.0.0.1"
network-port = 8080
congestion-control = "ledbat"
"#;
let config: NetworkApiConfig = toml::from_str(config_str).unwrap();
assert_eq!(config.congestion_control, "ledbat");
let cc_config = config.build_congestion_config();
assert_eq!(cc_config.algorithm, CongestionControlAlgorithm::Ledbat);
}
#[test]
fn test_congestion_control_config_serde_roundtrip() {
let config_str = r#"
network-address = "127.0.0.1"
network-port = 8080
congestion-control = "bbr"
bbr-startup-rate = 5000000
"#;
let config: NetworkApiConfig = toml::from_str(config_str).unwrap();
let serialized = toml::to_string(&config).unwrap();
assert!(serialized.contains("congestion-control = \"bbr\""));
assert!(serialized.contains("bbr-startup-rate = 5000000"));
let config2: NetworkApiConfig = toml::from_str(&serialized).unwrap();
assert_eq!(config2.congestion_control, "bbr");
assert_eq!(config2.bbr_startup_rate, Some(5_000_000));
}
#[test]
fn test_set_seed_pins_thread_index_to_zero() {
GlobalRng::clear_seed();
GlobalRng::set_seed(0xDEAD_BEEF);
assert_eq!(GlobalRng::thread_index(), 0);
let val1 = GlobalRng::random_u64();
GlobalRng::set_seed(0xDEAD_BEEF);
let val2 = GlobalRng::random_u64();
assert_eq!(val1, val2);
GlobalRng::clear_seed();
}
#[tokio::test]
async fn test_config_build_with_gateway_flag() {
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
network_api: NetworkArgs {
gateway: Some(vec![format!("192.168.1.1:31337,{key_hex}")]),
..Default::default()
},
..Default::default()
};
let cfg = args.build().await.unwrap();
assert_eq!(cfg.gateways.len(), 1);
assert_eq!(
cfg.gateways[0].address,
Address::HostAddress("192.168.1.1:31337".parse().unwrap())
);
}
#[test]
fn test_parse_gateway_valid() {
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
let input = format!("192.168.1.1:31337,{key_hex}");
let tmp_dir = tempfile::tempdir().unwrap();
let gw = parse_gateway(&input, tmp_dir.path()).unwrap();
assert_eq!(
gw.address,
Address::HostAddress("192.168.1.1:31337".parse().unwrap())
);
assert!(gw.public_key_path.exists());
let saved_key = std::fs::read_to_string(&gw.public_key_path).unwrap();
assert_eq!(saved_key, key_hex);
assert_eq!(gw.location, None);
}
#[test]
fn test_parse_gateway_invalid_format() {
let tmp_dir = tempfile::tempdir().unwrap();
assert!(parse_gateway("192.168.1.1:31337", tmp_dir.path()).is_err());
assert!(parse_gateway("192.168.1.1:31337,not_hex_at_all!", tmp_dir.path()).is_err());
let short_hex = "ab".repeat(16);
assert!(parse_gateway(&format!("192.168.1.1:31337,{short_hex}"), tmp_dir.path()).is_err());
let key_hex = "ab".repeat(32);
assert!(parse_gateway(&format!("not_an_addr,{key_hex}"), tmp_dir.path()).is_err());
}
#[test]
fn test_gateway_deduplication() {
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
let tmp_dir = tempfile::tempdir().unwrap();
let addr: SocketAddr = "10.0.0.1:31337".parse().unwrap();
let file_loaded = Gateways {
gateways: vec![GatewayConfig {
address: Address::HostAddress(addr),
public_key_path: PathBuf::from("old/key/path"),
location: None,
}],
};
let gw = parse_gateway(&format!("{addr},{key_hex}"), tmp_dir.path()).unwrap();
let cli_key_path = gw.public_key_path.clone();
let mut cli = Gateways { gateways: vec![gw] };
cli.merge_and_deduplicate(file_loaded);
assert_eq!(cli.gateways.len(), 1);
assert_eq!(cli.gateways[0].public_key_path, cli_key_path);
}
#[tokio::test]
async fn test_config_build_network_mode_gateway_only() {
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Network),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
network_api: NetworkArgs {
gateway: Some(vec![format!("203.0.113.1:31337,{key_hex}")]),
skip_load_from_network: true,
..Default::default()
},
..Default::default()
};
let cfg = args.build().await.unwrap();
assert_eq!(cfg.gateways.len(), 1);
assert_eq!(
cfg.gateways[0].address,
Address::HostAddress("203.0.113.1:31337".parse().unwrap())
);
}
#[tokio::test]
async fn test_config_build_multiple_gateways() {
let kp1 = TransportKeypair::new();
let kp2 = TransportKeypair::new();
let kp3 = TransportKeypair::new();
let hex1 = hex::encode(kp1.public().as_bytes());
let hex2 = hex::encode(kp2.public().as_bytes());
let hex3 = hex::encode(kp3.public().as_bytes());
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
network_api: NetworkArgs {
gateway: Some(vec![
format!("10.0.0.1:31337,{hex1}"),
format!("10.0.0.2:31337,{hex2}"),
format!("10.0.0.3:31337,{hex3}"),
]),
..Default::default()
},
..Default::default()
};
let cfg = args.build().await.unwrap();
assert_eq!(cfg.gateways.len(), 3);
let addrs: Vec<_> = cfg.gateways.iter().map(|g| g.address.clone()).collect();
assert!(addrs.contains(&Address::HostAddress("10.0.0.1:31337".parse().unwrap())));
assert!(addrs.contains(&Address::HostAddress("10.0.0.2:31337".parse().unwrap())));
assert!(addrs.contains(&Address::HostAddress("10.0.0.3:31337".parse().unwrap())));
}
#[tokio::test]
async fn test_gateway_overrides_file_loaded() {
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
let tmp_dir = tempfile::tempdir().unwrap();
let addr: SocketAddr = "10.0.0.1:31337".parse().unwrap();
let mut file_gateways = Gateways {
gateways: vec![GatewayConfig {
address: Address::HostAddress(addr),
public_key_path: PathBuf::from("old/stale/key.pub"),
location: None,
}],
};
let gw = parse_gateway(&format!("{addr},{key_hex}"), tmp_dir.path()).unwrap();
let cli_key_path = gw.public_key_path.clone();
let mut cli_gateways = Gateways { gateways: vec![gw] };
cli_gateways.merge_and_deduplicate(file_gateways);
file_gateways = cli_gateways;
assert_eq!(file_gateways.gateways.len(), 1);
assert_eq!(file_gateways.gateways[0].public_key_path, cli_key_path);
}
#[tokio::test]
async fn test_config_build_network_mode_empty_gateway() {
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Network),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
network_api: NetworkArgs {
gateway: Some(vec![]),
skip_load_from_network: true,
..Default::default()
},
..Default::default()
};
let err = args.build().await.unwrap_err();
assert!(
err.to_string()
.contains("Cannot initialize node without gateways"),
"Expected 'Cannot initialize node without gateways', got: {err}"
);
}
#[tokio::test]
async fn test_config_build_invalid_gateway_error() {
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
network_api: NetworkArgs {
gateway: Some(vec!["not-valid".into()]),
..Default::default()
},
..Default::default()
};
let err = args.build().await.unwrap_err();
assert!(
err.to_string().contains("Failed to parse --gateway"),
"Expected 'Failed to parse --gateway', got: {err}"
);
}
#[tokio::test]
async fn test_config_build_duplicate_gateway_entries() {
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
let entry = format!("10.0.0.1:31337,{key_hex}");
let temp_dir = tempfile::tempdir().unwrap();
let args = ConfigArgs {
mode: Some(OperationMode::Local),
config_paths: ConfigPathsArgs {
config_dir: Some(temp_dir.path().to_path_buf()),
data_dir: Some(temp_dir.path().to_path_buf()),
log_dir: Some(temp_dir.path().to_path_buf()),
},
network_api: NetworkArgs {
gateway: Some(vec![entry.clone(), entry]),
..Default::default()
},
..Default::default()
};
let cfg = args.build().await.unwrap();
assert_eq!(cfg.gateways.len(), 1);
}
#[test]
fn test_parse_gateway_key_file_permissions() {
let keypair = TransportKeypair::new();
let key_hex = hex::encode(keypair.public().as_bytes());
let tmp_dir = tempfile::tempdir().unwrap();
let gw = parse_gateway(&format!("192.168.1.1:31337,{key_hex}"), tmp_dir.path()).unwrap();
assert!(gw.public_key_path.exists());
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode = std::fs::metadata(&gw.public_key_path)
.unwrap()
.permissions()
.mode()
& 0o777;
assert_eq!(mode, 0o600, "Key file should have 0600 permissions");
}
}
}