use std::{
net::{Ipv6Addr, SocketAddr},
path::{Path, PathBuf},
sync::Arc,
};
use anyhow::{anyhow, bail, Context as _, Result};
use clap::Parser;
use http::StatusCode;
use iroh_base::NodeId;
use iroh_relay::{
defaults::{
DEFAULT_HTTPS_PORT, DEFAULT_HTTP_PORT, DEFAULT_METRICS_PORT, DEFAULT_RELAY_QUIC_PORT,
DEFAULT_STUN_PORT,
},
server::{self as relay, ClientRateLimit, QuicConfig},
};
use n0_future::FutureExt;
use serde::{Deserialize, Serialize};
use tokio_rustls_acme::{caches::DirCache, AcmeConfig};
use tracing::{debug, warn};
use tracing_subscriber::{prelude::*, EnvFilter};
use url::Url;
const DEV_MODE_HTTP_PORT: u16 = 3340;
const X_IROH_NODE_ID: &str = "X-Iroh-NodeId";
const ENV_HTTP_BEARER_TOKEN: &str = "IROH_RELAY_HTTP_BEARER_TOKEN";
#[derive(Parser, Debug, Clone)]
#[clap(version, about, long_about = None)]
struct Cli {
#[clap(long, default_value_t = false)]
dev: bool,
#[clap(long, short)]
config_path: Option<PathBuf>,
}
#[derive(clap::ValueEnum, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum CertMode {
Manual,
LetsEncrypt,
#[cfg(feature = "server")]
Reloading,
}
fn load_certs(
filename: impl AsRef<Path>,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
let certfile = std::fs::File::open(filename).context("cannot open certificate file")?;
let mut reader = std::io::BufReader::new(certfile);
let certs: Result<Vec<_>, std::io::Error> = rustls_pemfile::certs(&mut reader).collect();
let certs = certs?;
Ok(certs)
}
fn load_secret_key(
filename: impl AsRef<Path>,
) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
let filename = filename.as_ref();
let keyfile = std::fs::File::open(filename)
.with_context(|| format!("cannot open secret key file {}", filename.display()))?;
let mut reader = std::io::BufReader::new(keyfile);
loop {
match rustls_pemfile::read_one(&mut reader).context("cannot parse secret key .pem file")? {
Some(rustls_pemfile::Item::Pkcs1Key(key)) => {
return Ok(rustls::pki_types::PrivateKeyDer::Pkcs1(key));
}
Some(rustls_pemfile::Item::Pkcs8Key(key)) => {
return Ok(rustls::pki_types::PrivateKeyDer::Pkcs8(key));
}
Some(rustls_pemfile::Item::Sec1Key(key)) => {
return Ok(rustls::pki_types::PrivateKeyDer::Sec1(key));
}
None => break,
_ => {}
}
}
bail!(
"no keys found in {} (encrypted keys not supported)",
filename.display()
);
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Config {
#[serde(default = "cfg_defaults::enable_relay")]
enable_relay: bool,
http_bind_addr: Option<SocketAddr>,
tls: Option<TlsConfig>,
#[serde(default = "cfg_defaults::enable_stun")]
enable_stun: bool,
stun_bind_addr: Option<SocketAddr>,
#[serde(default = "cfg_defaults::enable_quic_addr_discovery")]
enable_quic_addr_discovery: bool,
limits: Option<Limits>,
#[serde(default = "cfg_defaults::enable_metrics")]
enable_metrics: bool,
metrics_bind_addr: Option<SocketAddr>,
key_cache_capacity: Option<usize>,
#[serde(default)]
access: AccessConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
enum AccessConfig {
#[default]
Everyone,
Allowlist(Vec<NodeId>),
Denylist(Vec<NodeId>),
Http(HttpAccessConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct HttpAccessConfig {
url: Url,
bearer_token: Option<String>,
}
impl From<AccessConfig> for iroh_relay::server::AccessConfig {
fn from(cfg: AccessConfig) -> Self {
match cfg {
AccessConfig::Everyone => iroh_relay::server::AccessConfig::Everyone,
AccessConfig::Allowlist(allow_list) => {
let allow_list = Arc::new(allow_list);
iroh_relay::server::AccessConfig::Restricted(Box::new(move |node_id| {
let allow_list = allow_list.clone();
async move {
if allow_list.contains(&node_id) {
iroh_relay::server::Access::Allow
} else {
iroh_relay::server::Access::Deny
}
}
.boxed()
}))
}
AccessConfig::Denylist(deny_list) => {
let deny_list = Arc::new(deny_list);
iroh_relay::server::AccessConfig::Restricted(Box::new(move |node_id| {
let deny_list = deny_list.clone();
async move {
if deny_list.contains(&node_id) {
iroh_relay::server::Access::Deny
} else {
iroh_relay::server::Access::Allow
}
}
.boxed()
}))
}
AccessConfig::Http(mut config) => {
let client = reqwest::Client::default();
if let Ok(token) = std::env::var(ENV_HTTP_BEARER_TOKEN) {
config.bearer_token = Some(token);
}
let config = Arc::new(config);
iroh_relay::server::AccessConfig::Restricted(Box::new(move |node_id| {
let client = client.clone();
let config = config.clone();
async move { http_access_check(&client, &config, node_id).await }.boxed()
}))
}
}
}
}
#[tracing::instrument("http-access-check", skip_all, fields(node_id=%node_id.fmt_short()))]
async fn http_access_check(
client: &reqwest::Client,
config: &HttpAccessConfig,
node_id: NodeId,
) -> iroh_relay::server::Access {
use iroh_relay::server::Access;
debug!(url=%config.url, "Check relay access via HTTP POST");
match http_access_check_inner(client, config, node_id).await {
Ok(()) => {
debug!("HTTP access check OK: Allow access");
Access::Allow
}
Err(err) => {
debug!("HTTP access check failed: Deny access (reason: {err:#})");
Access::Deny
}
}
}
async fn http_access_check_inner(
client: &reqwest::Client,
config: &HttpAccessConfig,
node_id: NodeId,
) -> Result<()> {
let mut request = client
.post(config.url.clone())
.header(X_IROH_NODE_ID, node_id.to_string());
if let Some(token) = config.bearer_token.as_ref() {
request = request.header(http::header::AUTHORIZATION, format!("Bearer {token}"));
}
match request.send().await {
Err(err) => {
warn!("Failed to retrieve response for HTTP access check: {err:#}");
Err(err).context("Failed to fetch response")
}
Ok(res) if res.status() == StatusCode::OK => match res.text().await {
Ok(text) if text == "true" => Ok(()),
Ok(_) => Err(anyhow!("Invalid response text (must be 'true')")),
Err(err) => Err(err).context("Failed to read response"),
},
Ok(res) => Err(anyhow!("Received invalid status code ({})", res.status())),
}
}
impl Config {
fn http_bind_addr(&self) -> SocketAddr {
self.http_bind_addr
.unwrap_or((Ipv6Addr::UNSPECIFIED, DEFAULT_HTTP_PORT).into())
}
fn stun_bind_addr(&self) -> SocketAddr {
self.stun_bind_addr
.unwrap_or_else(|| SocketAddr::new(self.http_bind_addr().ip(), DEFAULT_STUN_PORT))
}
fn metrics_bind_addr(&self) -> SocketAddr {
self.metrics_bind_addr
.unwrap_or_else(|| SocketAddr::new(self.http_bind_addr().ip(), DEFAULT_METRICS_PORT))
}
}
impl Default for Config {
fn default() -> Self {
Self {
enable_relay: cfg_defaults::enable_relay(),
http_bind_addr: None,
tls: None,
enable_stun: cfg_defaults::enable_stun(),
stun_bind_addr: None,
enable_quic_addr_discovery: cfg_defaults::enable_quic_addr_discovery(),
limits: None,
enable_metrics: cfg_defaults::enable_metrics(),
metrics_bind_addr: None,
key_cache_capacity: Default::default(),
access: AccessConfig::Everyone,
}
}
}
mod cfg_defaults {
pub(crate) fn enable_relay() -> bool {
true
}
pub(crate) fn enable_stun() -> bool {
true
}
pub(crate) fn enable_quic_addr_discovery() -> bool {
false
}
pub(crate) fn enable_metrics() -> bool {
true
}
pub(crate) mod tls_config {
pub(crate) fn prod_tls() -> bool {
true
}
pub(crate) fn dangerous_http_only() -> bool {
false
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TlsConfig {
https_bind_addr: Option<SocketAddr>,
quic_bind_addr: Option<SocketAddr>,
hostname: Option<String>,
cert_mode: CertMode,
cert_dir: Option<PathBuf>,
manual_cert_path: Option<PathBuf>,
manual_key_path: Option<PathBuf>,
#[serde(default = "cfg_defaults::tls_config::prod_tls")]
prod_tls: bool,
contact: Option<String>,
#[serde(default = "cfg_defaults::tls_config::dangerous_http_only")]
dangerous_http_only: bool,
}
impl TlsConfig {
fn https_bind_addr(&self, cfg: &Config) -> SocketAddr {
self.https_bind_addr
.unwrap_or_else(|| SocketAddr::new(cfg.http_bind_addr().ip(), DEFAULT_HTTPS_PORT))
}
fn quic_bind_addr(&self, cfg: &Config) -> SocketAddr {
self.quic_bind_addr.unwrap_or_else(|| {
SocketAddr::new(self.https_bind_addr(cfg).ip(), DEFAULT_RELAY_QUIC_PORT)
})
}
fn cert_dir(&self) -> PathBuf {
self.cert_dir.clone().unwrap_or_else(|| PathBuf::from("."))
}
fn cert_path(&self) -> PathBuf {
self.manual_cert_path
.clone()
.unwrap_or_else(|| self.cert_dir().join("default.crt"))
}
fn key_path(&self) -> PathBuf {
self.manual_key_path
.clone()
.unwrap_or_else(|| self.cert_dir().join("default.key"))
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct Limits {
accept_conn_limit: Option<f64>,
accept_conn_burst: Option<usize>,
client: Option<PerClientRateLimitConfig>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct PerClientRateLimitConfig {
rx: Option<RateLimitConfig>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct RateLimitConfig {
bytes_per_second: Option<u32>,
max_burst_bytes: Option<u32>,
}
impl Config {
async fn load(opts: &Cli) -> Result<Self> {
let config_path = if let Some(config_path) = &opts.config_path {
config_path
} else {
return Ok(Config::default());
};
if config_path.exists() {
Self::read_from_file(&config_path).await
} else {
Ok(Config::default())
}
}
fn from_str(config: &str) -> Result<Self> {
toml::from_str(config).context("config must be valid toml")
}
async fn read_from_file(path: impl AsRef<Path>) -> Result<Self> {
if !path.as_ref().is_file() {
bail!("config-path must be a file");
}
let config_ser = tokio::fs::read_to_string(&path)
.await
.context("unable to read config")?;
Self::from_str(&config_ser)
}
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env())
.init();
let cli = Cli::parse();
let mut cfg = Config::load(&cli).await?;
if cfg.enable_quic_addr_discovery && cfg.tls.is_none() {
bail!("TLS must be configured in order to spawn a QUIC endpoint");
}
if cli.dev {
if let Some(ref mut tls) = cfg.tls {
tls.dangerous_http_only = true;
}
if cfg.http_bind_addr.is_none() {
cfg.http_bind_addr = Some((Ipv6Addr::UNSPECIFIED, DEV_MODE_HTTP_PORT).into());
}
}
if cfg.tls.is_none() && cfg.enable_quic_addr_discovery {
bail!("If QUIC address discovery is enabled, TLS must also be configured");
};
let relay_config = build_relay_config(cfg).await?;
debug!("{relay_config:#?}");
let mut relay = relay::Server::spawn(relay_config).await?;
tokio::select! {
biased;
_ = tokio::signal::ctrl_c() => (),
_ = relay.task_handle() => (),
}
relay.shutdown().await
}
async fn maybe_load_tls(
cfg: &Config,
) -> Result<Option<relay::TlsConfig<std::io::Error, std::io::Error>>> {
let Some(ref tls) = cfg.tls else {
return Ok(None);
};
let server_config = rustls::ServerConfig::builder_with_provider(std::sync::Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.expect("protocols supported by ring")
.with_no_client_auth();
let (cert_config, server_config) = match tls.cert_mode {
CertMode::Manual => {
let cert_path = tls.cert_path();
let key_path = tls.key_path();
let (private_key, certs) = tokio::task::spawn_blocking(move || {
let key = load_secret_key(key_path)?;
let certs = load_certs(cert_path)?;
anyhow::Ok((key, certs))
})
.await??;
let server_config = server_config.with_single_cert(certs.clone(), private_key)?;
(relay::CertConfig::Manual { certs }, server_config)
}
CertMode::LetsEncrypt => {
let hostname = tls
.hostname
.clone()
.context("LetsEncrypt needs a hostname")?;
let contact = tls
.contact
.clone()
.context("LetsEncrypt needs a contact email")?;
let config = AcmeConfig::new(vec![hostname.clone()])
.contact([format!("mailto:{}", contact)])
.cache_option(Some(DirCache::new(tls.cert_dir())))
.directory_lets_encrypt(tls.prod_tls);
let state = config.state();
let resolver = state.resolver().clone();
let server_config = server_config.with_cert_resolver(resolver);
(relay::CertConfig::LetsEncrypt { state }, server_config)
}
#[cfg(feature = "server")]
CertMode::Reloading => {
use rustls_cert_file_reader::FileReader;
use rustls_cert_reloadable_resolver::{key_provider::Dyn, CertifiedKeyLoader};
use webpki::types::{CertificateDer, PrivateKeyDer};
let cert_path = tls.cert_path();
let key_path = tls.key_path();
let interval = relay::DEFAULT_CERT_RELOAD_INTERVAL;
let key_reader = rustls_cert_file_reader::FileReader::new(
key_path,
rustls_cert_file_reader::Format::PEM,
);
let certs_reader = rustls_cert_file_reader::FileReader::new(
cert_path,
rustls_cert_file_reader::Format::PEM,
);
let loader: CertifiedKeyLoader<
Dyn,
FileReader<PrivateKeyDer<'_>>,
FileReader<Vec<CertificateDer<'_>>>,
> = CertifiedKeyLoader {
key_provider: Dyn(server_config.crypto_provider().key_provider),
key_reader,
certs_reader,
};
let resolver = Arc::new(relay::ReloadingResolver::init(loader, interval).await?);
let server_config = server_config.with_cert_resolver(resolver);
(relay::CertConfig::Reloading, server_config)
}
};
Ok(Some(relay::TlsConfig {
https_bind_addr: tls.https_bind_addr(cfg),
cert: cert_config,
server_config,
quic_bind_addr: tls.quic_bind_addr(cfg),
}))
}
async fn build_relay_config(cfg: Config) -> Result<relay::ServerConfig<std::io::Error>> {
let dangerous_http_only = cfg.tls.as_ref().is_some_and(|tls| tls.dangerous_http_only);
let relay_tls = maybe_load_tls(&cfg).await?;
let mut quic_config = None;
if cfg.enable_quic_addr_discovery {
if let Some(ref tls) = relay_tls {
quic_config = Some(QuicConfig {
server_config: tls.server_config.clone(),
bind_addr: tls.quic_bind_addr,
});
} else {
bail!("Must have a valid TLS configuration to enable a QUIC server for QUIC address discovery")
}
};
let limits = match cfg.limits {
Some(ref limits) => {
let client_rx = match &limits.client {
Some(PerClientRateLimitConfig { rx: Some(rx) }) => {
if rx.bytes_per_second.is_none() && rx.max_burst_bytes.is_some() {
bail!("bytes_per_seconds must be specified to enable the rate-limiter");
}
match rx.bytes_per_second {
Some(bps) => Some(ClientRateLimit {
bytes_per_second: bps
.try_into()
.context("bytes_per_second must be non-zero u32")?,
max_burst_bytes: rx
.max_burst_bytes
.map(|v| {
v.try_into().context("max_burst_bytes must be non-zero u32")
})
.transpose()?,
}),
None => None,
}
}
Some(PerClientRateLimitConfig { rx: None }) | None => None,
};
relay::Limits {
accept_conn_limit: limits.accept_conn_limit,
accept_conn_burst: limits.accept_conn_burst,
client_rx,
}
}
None => Default::default(),
};
let relay_config = relay::RelayConfig {
http_bind_addr: cfg.http_bind_addr(),
tls: relay_tls.and_then(|tls| if dangerous_http_only { None } else { Some(tls) }),
limits,
key_cache_capacity: cfg.key_cache_capacity,
access: cfg.access.clone().into(),
};
let stun_config = relay::StunConfig {
bind_addr: cfg.stun_bind_addr(),
};
Ok(relay::ServerConfig {
relay: Some(relay_config),
stun: Some(stun_config).filter(|_| cfg.enable_stun),
quic: quic_config,
#[cfg(feature = "metrics")]
metrics_addr: Some(cfg.metrics_bind_addr()).filter(|_| cfg.enable_metrics),
})
}
#[cfg(test)]
mod tests {
use std::num::NonZeroU32;
use iroh_base::SecretKey;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use testresult::TestResult;
use super::*;
#[tokio::test]
async fn test_rate_limit_config() -> TestResult {
let config = "
[limits.client.rx]
bytes_per_second = 400
max_burst_bytes = 800
";
let config = Config::from_str(config)?;
let relay_config = build_relay_config(config).await?;
let relay = relay_config.relay.expect("no relay config");
assert_eq!(
relay.limits.client_rx.expect("ratelimit").bytes_per_second,
NonZeroU32::try_from(400).unwrap()
);
assert_eq!(
relay.limits.client_rx.expect("ratelimit").max_burst_bytes,
Some(NonZeroU32::try_from(800).unwrap())
);
Ok(())
}
#[tokio::test]
async fn test_rate_limit_default() -> TestResult {
let config = Config::from_str("")?;
let relay_config = build_relay_config(config).await?;
let relay = relay_config.relay.expect("no relay config");
assert!(relay.limits.client_rx.is_none());
Ok(())
}
#[tokio::test]
async fn test_access_config() -> TestResult {
let config = "
access = \"everyone\"
";
let config = Config::from_str(config)?;
assert_eq!(config.access, AccessConfig::Everyone);
let mut rng = ChaCha8Rng::seed_from_u64(0);
let node_id = SecretKey::generate(&mut rng).public();
let config = format!(
"
access.allowlist = [
\"{node_id}\",
]
"
);
let config = Config::from_str(dbg!(&config))?;
assert_eq!(config.access, AccessConfig::Allowlist(vec![node_id]));
let config = r#"
access.http.url = "https://example.com/foo/bar?boo=baz"
"#
.to_string();
let config = Config::from_str(dbg!(&config))?;
assert_eq!(
config.access,
AccessConfig::Http(HttpAccessConfig {
url: "https://example.com/foo/bar?boo=baz".parse().unwrap(),
bearer_token: None
})
);
let config = r#"
access.http.url = "https://example.com/foo/bar?boo=baz"
access.http.bearer_token = "foo"
"#
.to_string();
let config = Config::from_str(dbg!(&config))?;
assert_eq!(
config.access,
AccessConfig::Http(HttpAccessConfig {
url: "https://example.com/foo/bar?boo=baz".parse().unwrap(),
bearer_token: Some("foo".to_string())
})
);
let config = r#"
access.http = { url = "https://example.com/foo" }
"#
.to_string();
let config = Config::from_str(dbg!(&config))?;
assert_eq!(
config.access,
AccessConfig::Http(HttpAccessConfig {
url: "https://example.com/foo".parse().unwrap(),
bearer_token: None
})
);
let config = r#"
access.http = { url = "https://example.com/foo", bearer_token = "foo" }
"#
.to_string();
let config = Config::from_str(dbg!(&config))?;
assert_eq!(
config.access,
AccessConfig::Http(HttpAccessConfig {
url: "https://example.com/foo".parse().unwrap(),
bearer_token: Some("foo".to_string())
})
);
Ok(())
}
}