use prometheus_exporter::prometheus::{register_int_gauge, IntGauge};
use serde::Deserialize;
use std::collections::{HashMap, HashSet};
use thiserror::Error as ThisError;
use tokio::{
fs,
process::{Child, Command},
signal::unix::{signal, SignalKind},
sync::watch::{channel, Receiver},
time::{sleep, Duration, Instant},
};
use tracing::{debug, error, info};
use tracing_subscriber::EnvFilter;
const USAGE: &str = r#"
multi-tunnel
Conveniently set up SSH tunnels
USAGE:
multi-tunnel [config.toml]
"#;
#[derive(ThisError, Debug)]
enum Error {
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("Invalid TOML in configuration file: {0}")]
InvalidToml(#[from] toml::de::Error),
}
type Result<T> = std::result::Result<T, Error>;
#[derive(Deserialize, Clone, Debug, Default)]
struct Config {
connection: String,
#[serde(flatten)]
tunnels: HashMap<String, Ports>,
}
#[derive(Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
struct Ports {
local: u32,
remote: u32,
}
async fn run_health_check(connection: &str, remote_port: u32) -> bool {
let result = Command::new("ssh")
.args(&[
connection,
"curl",
"--max-time",
"10",
&format!("http://localhost:{}", remote_port),
])
.kill_on_drop(true)
.output()
.await;
match result {
Ok(status) => status.status.code().unwrap_or(1) == 0,
Err(err) => {
info!("Error running health check: {}", err);
false
}
}
}
fn create_child(ports: Ports, connection: &str) -> Result<Child> {
Ok(Command::new("ssh")
.args(&[
"-NTC",
"-o",
"ServerAliveInterval=60",
"-o",
"ExitOnForwardFailure=yes",
"-R",
&format!("{}:localhost:{}", ports.remote, ports.local),
connection,
])
.kill_on_drop(true)
.spawn()?)
}
async fn run_child(
name: &str,
connection: &str,
ports: Ports,
active_counter: IntGauge,
) {
loop {
let mut child = match create_child(ports, connection) {
Ok(child) => child,
Err(err) => {
info!(
"{}: unable to spawn ssh, restarting in 5: {}",
name, err
);
sleep(Duration::from_secs(5)).await;
continue;
}
};
info!("{}: started", name);
active_counter.inc();
let mut last_check = Instant::now();
loop {
if last_check.elapsed() >= Duration::from_secs(60) {
if !run_health_check(connection, ports.remote).await {
if let Err(err) = child.kill().await {
info!("{}: error killing ssh after health check failed: {}", name, err);
}
break;
} else {
debug!("{}: passed health check", name);
}
last_check = Instant::now();
}
match child.try_wait() {
Ok(Some(_)) => {
info!("{}: exited, restarting in 5", name);
break;
}
Ok(None) => {
sleep(Duration::from_secs(5)).await;
}
Err(err) => {
info!(
"{}: exited with error, restarting in 5: {}",
name, err
);
break;
}
}
}
match child.wait().await {
Ok(_) => {
info!("{}: exited, restarting in 5", name)
}
Err(err) => {
info!("{}: exited with error, restarting in 5: {}", name, err)
}
}
active_counter.dec();
sleep(Duration::from_secs(5)).await;
}
}
async fn handle_tunnel(
rx: &mut Receiver<Config>,
name: &str,
connection: &str,
ports: &Ports,
active_counter: IntGauge,
) {
let mut current_ports = *ports;
let proc_connection = connection.to_string();
let proc_name = name.to_string();
let active_counter2 = active_counter.clone();
let mut handle = tokio::spawn(async move {
run_child(&proc_name, &proc_connection, current_ports, active_counter2)
.await
});
while rx.changed().await.is_ok() {
debug!("{}: received new configuration", name);
let config = rx.borrow().clone();
match config.tunnels.get(name) {
Some(ports) => {
if *ports != current_ports || config.connection != connection {
debug!("{}: configuration has changed", name);
handle.abort();
info!("{}: aborted", name);
let proc_connection = config.connection.to_string();
let proc_name = name.to_string();
current_ports = *ports;
let active_counter = active_counter.clone();
handle = tokio::spawn(async move {
run_child(
&proc_name,
&proc_connection,
current_ports,
active_counter,
)
.await
});
}
current_ports = *ports
}
None => {
handle.abort();
info!("{}: no longer configured, aborted", name);
active_counter.dec();
break;
}
}
}
}
async fn handle_signal(config_path: &str) {
let (tx, rx) = channel(Config::default());
let mut stream = signal(SignalKind::hangup())
.expect("unable to listen to SIGHUP signal");
let active_counter =
register_int_gauge!("multi_tunnel_active", "active tunnels open")
.unwrap();
let total_counter =
register_int_gauge!("multi_tunnel_total", "total configured tunnel")
.unwrap();
let mut running = HashSet::new();
loop {
let config = match load_config(config_path).await {
Ok(config) => {
if let Err(err) = tx.send(config.clone()) {
error!("Error reloading configurations: {}", err);
}
config
}
Err(err) => {
error!("Error reloading configuration: {}", err);
stream.recv().await;
debug!("Reloading configuration");
continue;
}
};
total_counter.set(config.tunnels.len() as i64);
for (name, ports) in config.tunnels.clone() {
if running.contains(&name) {
continue;
}
running.insert(name.to_string());
let mut rx = rx.clone();
let connection = config.connection.clone();
let active_counter = active_counter.clone();
tokio::spawn(async move {
handle_tunnel(
&mut rx,
&name,
&connection,
&ports,
active_counter,
)
.await
});
}
for name in running.clone() {
if config.tunnels.get(&name).is_none() {
running.remove(&name);
}
}
stream.recv().await;
debug!("Reloading configuration");
}
}
async fn load_config(config_path: &str) -> Result<Config> {
let contents = fs::read_to_string(config_path).await?;
Ok(toml::from_str(&contents)?)
}
#[tokio::main]
async fn main() {
let args: Vec<_> = std::env::args().collect();
if args.len() != 2 {
eprintln!("Error: missing config file");
eprintln!("{}", USAGE);
std::process::exit(1);
}
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()),
))
.init();
let config_path = args[1].to_string();
info!("Starting multi-tunnel!");
tokio::task::spawn_blocking(|| {
prometheus_exporter::start("0.0.0.0:46581".parse().unwrap()).unwrap();
});
handle_signal(&config_path).await;
}