#![allow(deprecated)]
use std::{
net::{Ipv6Addr, SocketAddr},
path::{Path, PathBuf},
};
use anyhow::{anyhow, bail, Context as _, Result};
use clap::Parser;
use iroh_net::{
defaults::{DEFAULT_HTTPS_PORT, DEFAULT_HTTP_PORT, DEFAULT_METRICS_PORT, DEFAULT_STUN_PORT},
relay::server as iroh_relay,
};
use serde::{Deserialize, Serialize};
use tokio_rustls_acme::{caches::DirCache, AcmeConfig};
use tracing::debug;
use tracing_subscriber::{prelude::*, EnvFilter};
const DEV_MODE_HTTP_PORT: u16 = 3340;
#[derive(Parser, Debug, Clone)]
#[clap(version, about, long_about = None)]
struct Cli {
#[clap(long, default_value_t = false)]
dev: bool,
#[clap(long, short)]
config_path: Option<PathBuf>,
}
#[derive(clap::ValueEnum, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum CertMode {
Manual,
LetsEncrypt,
}
fn load_certs(
filename: impl AsRef<Path>,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
let certfile = std::fs::File::open(filename).context("cannot open certificate file")?;
let mut reader = std::io::BufReader::new(certfile);
let certs: Result<Vec<_>, std::io::Error> = rustls_pemfile::certs(&mut reader).collect();
let certs = certs?;
Ok(certs)
}
fn load_secret_key(
filename: impl AsRef<Path>,
) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
let filename = filename.as_ref();
let keyfile = std::fs::File::open(filename)
.with_context(|| format!("cannot open secret key file {}", filename.display()))?;
let mut reader = std::io::BufReader::new(keyfile);
loop {
match rustls_pemfile::read_one(&mut reader).context("cannot parse secret key .pem file")? {
Some(rustls_pemfile::Item::Pkcs1Key(key)) => {
return Ok(rustls::pki_types::PrivateKeyDer::Pkcs1(key));
}
Some(rustls_pemfile::Item::Pkcs8Key(key)) => {
return Ok(rustls::pki_types::PrivateKeyDer::Pkcs8(key));
}
Some(rustls_pemfile::Item::Sec1Key(key)) => {
return Ok(rustls::pki_types::PrivateKeyDer::Sec1(key));
}
None => break,
_ => {}
}
}
bail!(
"no keys found in {} (encrypted keys not supported)",
filename.display()
);
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Config {
#[serde(default = "cfg_defaults::enable_relay")]
enable_relay: bool,
http_bind_addr: Option<SocketAddr>,
tls: Option<TlsConfig>,
#[serde(default = "cfg_defaults::enable_stun")]
enable_stun: bool,
stun_bind_addr: Option<SocketAddr>,
limits: Option<Limits>,
#[serde(default = "cfg_defaults::enable_metrics")]
enable_metrics: bool,
metrics_bind_addr: Option<SocketAddr>,
}
impl Config {
fn http_bind_addr(&self) -> SocketAddr {
self.http_bind_addr
.unwrap_or((Ipv6Addr::UNSPECIFIED, DEFAULT_HTTP_PORT).into())
}
fn stun_bind_addr(&self) -> SocketAddr {
self.stun_bind_addr
.unwrap_or_else(|| SocketAddr::new(self.http_bind_addr().ip(), DEFAULT_STUN_PORT))
}
fn metrics_bind_addr(&self) -> SocketAddr {
self.metrics_bind_addr
.unwrap_or_else(|| SocketAddr::new(self.http_bind_addr().ip(), DEFAULT_METRICS_PORT))
}
}
impl Default for Config {
fn default() -> Self {
Self {
enable_relay: true,
http_bind_addr: None,
tls: None,
enable_stun: true,
stun_bind_addr: None,
limits: None,
enable_metrics: true,
metrics_bind_addr: None,
}
}
}
mod cfg_defaults {
pub(crate) fn enable_relay() -> bool {
true
}
pub(crate) fn enable_stun() -> bool {
true
}
pub(crate) fn enable_metrics() -> bool {
true
}
pub(crate) mod tls_config {
pub(crate) fn prod_tls() -> bool {
true
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TlsConfig {
https_bind_addr: Option<SocketAddr>,
hostname: Option<String>,
cert_mode: CertMode,
cert_dir: Option<PathBuf>,
manual_cert_path: Option<PathBuf>,
manual_key_path: Option<PathBuf>,
#[serde(default = "cfg_defaults::tls_config::prod_tls")]
prod_tls: bool,
contact: Option<String>,
}
impl TlsConfig {
fn https_bind_addr(&self, cfg: &Config) -> SocketAddr {
self.https_bind_addr
.unwrap_or_else(|| SocketAddr::new(cfg.http_bind_addr().ip(), DEFAULT_HTTPS_PORT))
}
fn cert_dir(&self) -> PathBuf {
self.cert_dir.clone().unwrap_or_else(|| PathBuf::from("."))
}
fn cert_path(&self) -> PathBuf {
self.manual_cert_path
.clone()
.unwrap_or_else(|| self.cert_dir().join("default.crt"))
}
fn key_path(&self) -> PathBuf {
self.manual_key_path
.clone()
.unwrap_or_else(|| self.cert_dir().join("default.key"))
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct Limits {
accept_conn_limit: Option<f64>,
accept_conn_burst: Option<usize>,
}
impl Config {
async fn load(opts: &Cli) -> Result<Self> {
let config_path = if let Some(config_path) = &opts.config_path {
config_path
} else {
return Ok(Config::default());
};
if config_path.exists() {
Self::read_from_file(&config_path).await
} else {
let config = Config::default();
config.write_to_file(&config_path).await?;
Ok(config)
}
}
async fn read_from_file(path: impl AsRef<Path>) -> Result<Self> {
if !path.as_ref().is_file() {
bail!("config-path must be a file");
}
let config_ser = tokio::fs::read_to_string(&path)
.await
.context("unable to read config")?;
let config: Self = toml::from_str(&config_ser).context("config file must be valid toml")?;
Ok(config)
}
async fn write_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
let p = path
.as_ref()
.parent()
.ok_or_else(|| anyhow!("invalid config file path, no parent"))?;
tokio::fs::create_dir_all(p)
.await
.with_context(|| format!("unable to create config-path dir: {}", p.display()))?;
let config_ser = toml::to_string(self).context("unable to serialize configuration")?;
tokio::fs::write(path, config_ser)
.await
.context("unable to write config file")?;
Ok(())
}
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env())
.init();
let cli = Cli::parse();
let mut cfg = Config::load(&cli).await?;
if cli.dev {
cfg.tls = None;
if cfg.http_bind_addr.is_none() {
cfg.http_bind_addr = Some((Ipv6Addr::UNSPECIFIED, DEV_MODE_HTTP_PORT).into());
}
}
let relay_config = build_relay_config(cfg).await?;
debug!("{relay_config:#?}");
let mut relay = iroh_relay::Server::spawn(relay_config).await?;
tokio::select! {
biased;
_ = tokio::signal::ctrl_c() => (),
_ = relay.task_handle() => (),
}
relay.shutdown().await
}
async fn build_relay_config(cfg: Config) -> Result<iroh_relay::ServerConfig<std::io::Error>> {
let tls = match cfg.tls {
Some(ref tls) => {
let cert_config = match tls.cert_mode {
CertMode::Manual => {
let cert_path = tls.cert_path();
let key_path = tls.key_path();
let (private_key, certs) = tokio::task::spawn_blocking(move || {
let key = load_secret_key(key_path)?;
let certs = load_certs(cert_path)?;
anyhow::Ok((key, certs))
})
.await??;
iroh_relay::CertConfig::Manual { private_key, certs }
}
CertMode::LetsEncrypt => {
let hostname = tls
.hostname
.clone()
.context("LetsEncrypt needs a hostname")?;
let contact = tls
.contact
.clone()
.context("LetsEncrypt needs a contact email")?;
let config = AcmeConfig::new(vec![hostname.clone()])
.contact([format!("mailto:{}", contact)])
.cache_option(Some(DirCache::new(tls.cert_dir())))
.directory_lets_encrypt(tls.prod_tls);
iroh_relay::CertConfig::LetsEncrypt { config }
}
};
Some(iroh_relay::TlsConfig {
https_bind_addr: tls.https_bind_addr(&cfg),
cert: cert_config,
})
}
None => None,
};
let limits = iroh_relay::Limits {
accept_conn_limit: cfg
.limits
.as_ref()
.map(|l| l.accept_conn_limit)
.unwrap_or_default(),
accept_conn_burst: cfg
.limits
.as_ref()
.map(|l| l.accept_conn_burst)
.unwrap_or_default(),
};
let relay_config = iroh_relay::RelayConfig {
http_bind_addr: cfg.http_bind_addr(),
tls,
limits,
};
let stun_config = iroh_relay::StunConfig {
bind_addr: cfg.stun_bind_addr(),
};
Ok(iroh_relay::ServerConfig {
relay: Some(relay_config),
stun: Some(stun_config).filter(|_| cfg.enable_stun),
#[cfg(feature = "metrics")]
metrics_addr: Some(cfg.metrics_bind_addr()).filter(|_| cfg.enable_metrics),
})
}
mod metrics {
use iroh_metrics::{
core::{Counter, Metric},
struct_iterable::Iterable,
};
#[allow(missing_docs)]
#[derive(Debug, Clone, Iterable)]
pub struct StunMetrics {
pub requests: Counter,
pub ipv4_success: Counter,
pub ipv6_success: Counter,
pub bad_requests: Counter,
pub failures: Counter,
}
impl Default for StunMetrics {
fn default() -> Self {
Self {
requests: Counter::new("Number of STUN requests made to the server."),
ipv4_success: Counter::new("Number of successful ipv4 STUN requests served."),
ipv6_success: Counter::new("Number of successful ipv6 STUN requests served."),
bad_requests: Counter::new("Number of bad requests made to the STUN endpoint."),
failures: Counter::new("Number of STUN requests that end in failure."),
}
}
}
impl Metric for StunMetrics {
fn name() -> &'static str {
"stun"
}
}
}