multi-tunnel 0.3.3

Open and manage multiple SSH tunnels
// SPDX-FileCopyrightText: 2021 Kunal Mehta <legoktm@debian.org>
//
// SPDX-License-Identifier: GPL-3.0-or-later OR copyleft-next-0.3.1

use prometheus_exporter::prometheus::{register_int_gauge, IntGauge};
use serde::Deserialize;
use std::collections::{HashMap, HashSet};
use thiserror::Error as ThisError;
use tokio::{
    fs,
    process::{Child, Command},
    signal::unix::{signal, SignalKind},
    sync::watch::{channel, Receiver},
    time::{sleep, Duration, Instant},
};
use tracing::{debug, error, info};
use tracing_subscriber::EnvFilter;

const USAGE: &str = r#"
multi-tunnel
Conveniently set up SSH tunnels

USAGE:
    multi-tunnel [config.toml]
"#;

#[derive(ThisError, Debug)]
enum Error {
    #[error("I/O error: {0}")]
    IoError(#[from] std::io::Error),
    #[error("Invalid TOML in configuration file: {0}")]
    InvalidToml(#[from] toml::de::Error),
}

type Result<T> = std::result::Result<T, Error>;

/// Configuration file
/// ```toml
/// connection = "foo@bar.baz"
/// [httpd]
/// local = 80
/// remote = 8081
///
/// [another]
/// local = 8080
/// remote = 8082
/// ```
#[derive(Deserialize, Clone, Debug, Default)]
struct Config {
    /// E.g. foo@bar.baz
    connection: String,
    /// Mapping of ports
    #[serde(flatten)]
    tunnels: HashMap<String, Ports>,
}

#[derive(Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
struct Ports {
    local: u32,
    remote: u32,
}

/// Run a health check over SSH using curl. Return value is whether
/// the health check passed or not.
async fn run_health_check(connection: &str, remote_port: u32) -> bool {
    let result = Command::new("ssh")
        .args(&[
            connection,
            "curl",
            "--max-time",
            "10",
            &format!("http://localhost:{}", remote_port),
        ])
        .kill_on_drop(true)
        .output()
        .await;
    match result {
        // TODO: print curl output on failure
        Ok(status) => status.status.code().unwrap_or(1) == 0,
        Err(err) => {
            info!("Error running health check: {}", err);
            false
        }
    }
}

fn create_child(ports: Ports, connection: &str) -> Result<Child> {
    Ok(Command::new("ssh")
        .args(&[
            "-NTC",
            "-o",
            "ServerAliveInterval=60",
            "-o",
            "ExitOnForwardFailure=yes",
            "-R",
            &format!("{}:localhost:{}", ports.remote, ports.local),
            connection,
        ])
        .kill_on_drop(true)
        .spawn()?)
}

async fn run_child(
    name: &str,
    connection: &str,
    ports: Ports,
    active_counter: IntGauge,
) {
    loop {
        let mut child = match create_child(ports, connection) {
            Ok(child) => child,
            Err(err) => {
                info!(
                    "{}: unable to spawn ssh, restarting in 5: {}",
                    name, err
                );
                sleep(Duration::from_secs(5)).await;
                continue;
            }
        };
        info!("{}: started", name);
        active_counter.inc();
        let mut last_check = Instant::now();
        loop {
            // Only run the health check every 60s
            if last_check.elapsed() >= Duration::from_secs(60) {
                if !run_health_check(connection, ports.remote).await {
                    if let Err(err) = child.kill().await {
                        info!("{}: error killing ssh after health check failed: {}", name, err);
                    }
                    break;
                } else {
                    debug!("{}: passed health check", name);
                }
                last_check = Instant::now();
            }
            match child.try_wait() {
                Ok(Some(_)) => {
                    info!("{}: exited, restarting in 5", name);
                    break;
                }
                Ok(None) => {
                    // Still alive, just wait 5s before looping
                    sleep(Duration::from_secs(5)).await;
                }
                Err(err) => {
                    info!(
                        "{}: exited with error, restarting in 5: {}",
                        name, err
                    );
                    break;
                }
            }
        }
        match child.wait().await {
            Ok(_) => {
                info!("{}: exited, restarting in 5", name)
            }
            Err(err) => {
                info!("{}: exited with error, restarting in 5: {}", name, err)
            }
        }
        active_counter.dec();
        // Backoff for 5s
        sleep(Duration::from_secs(5)).await;
    }
}

async fn handle_tunnel(
    rx: &mut Receiver<Config>,
    name: &str,
    connection: &str,
    ports: &Ports,
    active_counter: IntGauge,
) {
    let mut current_ports = *ports;

    let proc_connection = connection.to_string();
    let proc_name = name.to_string();
    let active_counter2 = active_counter.clone();
    let mut handle = tokio::spawn(async move {
        run_child(&proc_name, &proc_connection, current_ports, active_counter2)
            .await
    });
    // Note: if this errors, then the sender is dead and we should probably exit
    while rx.changed().await.is_ok() {
        debug!("{}: received new configuration", name);
        let config = rx.borrow().clone();
        match config.tunnels.get(name) {
            Some(ports) => {
                if *ports != current_ports || config.connection != connection {
                    debug!("{}: configuration has changed", name);
                    // Abort the currently running process
                    handle.abort();
                    info!("{}: aborted", name);

                    let proc_connection = config.connection.to_string();
                    let proc_name = name.to_string();
                    current_ports = *ports;
                    let active_counter = active_counter.clone();
                    handle = tokio::spawn(async move {
                        run_child(
                            &proc_name,
                            &proc_connection,
                            current_ports,
                            active_counter,
                        )
                        .await
                    });
                }

                current_ports = *ports
            }
            None => {
                // This tunnel is no longer configured, end this handler
                handle.abort();
                info!("{}: no longer configured, aborted", name);
                active_counter.dec();
                break;
            }
        }
    }
}

/// The main controlling task that:
/// * spawns a task for each tunnel, keeping track of them
/// * listens for SIGHUP to reload config
/// * spawns new threads as needed
async fn handle_signal(config_path: &str) {
    let (tx, rx) = channel(Config::default());
    let mut stream = signal(SignalKind::hangup())
        .expect("unable to listen to SIGHUP signal");
    let active_counter =
        register_int_gauge!("multi_tunnel_active", "active tunnels open")
            .unwrap();
    let total_counter =
        register_int_gauge!("multi_tunnel_total", "total configured tunnel")
            .unwrap();
    let mut running = HashSet::new();
    loop {
        let config = match load_config(config_path).await {
            Ok(config) => {
                // Send the updated config to all threads
                if let Err(err) = tx.send(config.clone()) {
                    error!("Error reloading configurations: {}", err);
                }
                config
            }
            Err(err) => {
                error!("Error reloading configuration: {}", err);
                // Wait until we get a SIGHUP
                stream.recv().await;
                debug!("Reloading configuration");
                continue;
            }
        };
        total_counter.set(config.tunnels.len() as i64);
        for (name, ports) in config.tunnels.clone() {
            if running.contains(&name) {
                // Already running in a thread
                continue;
            }
            running.insert(name.to_string());
            let mut rx = rx.clone();
            let connection = config.connection.clone();
            let active_counter = active_counter.clone();
            tokio::spawn(async move {
                handle_tunnel(
                    &mut rx,
                    &name,
                    &connection,
                    &ports,
                    active_counter,
                )
                .await
            });
        }
        // Remove anything in the hashset that's no longer configured,
        // those threads should end themselves
        for name in running.clone() {
            if config.tunnels.get(&name).is_none() {
                running.remove(&name);
            }
        }
        // Wait until we get a SIGHUP
        stream.recv().await;
        debug!("Reloading configuration");
    }
}

/// Load configuration from the filesystem
async fn load_config(config_path: &str) -> Result<Config> {
    let contents = fs::read_to_string(config_path).await?;
    Ok(toml::from_str(&contents)?)
}

#[tokio::main]
async fn main() {
    // XXX: consider using clap if this gets more complex
    let args: Vec<_> = std::env::args().collect();
    if args.len() != 2 {
        eprintln!("Error: missing config file");
        eprintln!("{}", USAGE);
        std::process::exit(1);
    }
    tracing_subscriber::FmtSubscriber::builder()
        .with_env_filter(EnvFilter::new(
            // Default to RUST_LOG=info if not explicitly set
            std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()),
        ))
        .init();
    let config_path = args[1].to_string();
    info!("Starting multi-tunnel!");
    tokio::task::spawn_blocking(|| {
        // TODO: make port configurable
        prometheus_exporter::start("0.0.0.0:46581".parse().unwrap()).unwrap();
    });
    handle_signal(&config_path).await;
}