use crate::config::{Protocol, ServicePort};
use crate::proxy::{proxy_handler, udp_handler};
use crate::server::{InnisfreeServer, Provider, ServerSpec};
use crate::ssh::SshKeypair;
use crate::state::{
remove_state_for_service, TunnelConfig, TunnelIdentity, TunnelStateDir, TunnelStatus,
};
use crate::wg::{LocalWg, WireguardManager};
use anyhow::{anyhow, bail, Context, Result};
use futures::future::{join_all, BoxFuture};
use std::net::{IpAddr, SocketAddr};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::signal::unix::{signal, SignalKind};
pub struct TunnelManager {
pub(crate) services: Vec<ServicePort>,
pub(crate) server: Box<dyn InnisfreeServer>,
pub(crate) name: String,
pub(crate) state: TunnelStateDir,
pub(crate) wg: WireguardManager,
pub(crate) ssh_client_keypair: SshKeypair,
pub(crate) ssh_server_keypair: SshKeypair,
local_wg: Option<LocalWg>,
}
impl TunnelManager {
pub fn server_ipv4(&self) -> Result<IpAddr> {
self.server.ipv4_address()
}
pub fn local_wg_address(&self) -> IpAddr {
self.wg.local_device.interface.address
}
pub fn services(&self) -> &[ServicePort] {
&self.services
}
}
impl TunnelManager {
pub async fn new(provider: Box<dyn Provider>, config: TunnelConfig) -> Result<TunnelManager> {
let state = TunnelStateDir::for_service(&config.name)?;
let wg = WireguardManager::new(&config.name)?;
let ssh_client_keypair = SshKeypair::new()?;
let ssh_server_keypair = SshKeypair::new()?;
state
.write_config(&config)
.context("persisting tunnel config")?;
state
.write_identity(&TunnelIdentity {
wireguard: wg.clone(),
})
.context("persisting tunnel identity")?;
let spec = ServerSpec {
name: config.name.clone(),
services: config.services.clone(),
wg_mgr: wg.clone(),
ssh_client_keypair: ssh_client_keypair.clone(),
ssh_server_keypair: ssh_server_keypair.clone(),
};
let server = provider.create(&spec).await?;
if let Some(ip) = config.floating_ip {
server.assign_floating_ip(ip).await?;
}
Ok(TunnelManager {
name: config.name,
services: config.services,
server,
ssh_client_keypair,
ssh_server_keypair,
state,
wg,
local_wg: None,
})
}
pub async fn up(&mut self) -> Result<()> {
self.wait_for_ssh().await?;
tracing::debug!("Configuring remote proxy...");
self.wait_for_cloudinit()
.await
.context("failed while waiting for cloudinit")?;
let ip = self.server.ipv4_address()?;
tracing::debug!("Configuring tunnel...");
self.wg.local_device.peer.endpoint = Some(ip);
self.wg
.local_device
.write_to(&self.state.wg_conf(), &self.services)
.context("failed to write wireguard configs")?;
tracing::debug!("Bringing up remote Wireguard interface");
self.bring_up_remote_wg()
.await
.context("failed to bring up remote wg interface")?;
tracing::debug!("Bringing up local Wireguard interface");
let local_wg = LocalWg::start(&self.wg.local_device)
.await
.context("failed to bring up local wg interface")?;
self.local_wg = Some(local_wg);
self.test_connection().await?;
self.write_ready_marker()
.context("writing ready marker after tunnel came up")?;
Ok(())
}
fn write_ready_marker(&self) -> Result<()> {
let status = TunnelStatus {
ip: self.server.ipv4_address()?,
droplet_id: self.server.server_id(),
ready_at: time::OffsetDateTime::now_utc(),
};
self.state
.write_status(&status)
.context("writing tunnel status")?;
Ok(())
}
async fn wait_for_cloudinit(&self) -> Result<()> {
let cmd: Vec<&str> = vec!["cloud-init", "status", "--long", "--wait"];
self.run_ssh_cmd(cmd).await
}
async fn wait_for_ssh(&self) -> Result<()> {
let dest_ip = SocketAddr::new(self.server.ipv4_address()?, 22);
loop {
match tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(dest_ip)).await {
Ok(Ok(_)) => {
tracing::debug!("SSH port is open, proceeding");
return Ok(());
}
Ok(Err(_)) | Err(_) => {
tracing::debug!("Waiting for ssh...");
tokio::time::sleep(Duration::from_secs(10)).await;
}
}
}
}
pub async fn block(&mut self) -> Result<()> {
let mut sigterm = signal(SignalKind::terminate()).context("registering SIGTERM handler")?;
let mut sigint = signal(SignalKind::interrupt()).context("registering SIGINT handler")?;
tokio::select! {
_ = sigterm.recv() => tracing::warn!("Received SIGTERM, exiting gracefully"),
_ = sigint.recv() => tracing::warn!("Received SIGINT, exiting gracefully"),
}
self.clean().await?;
tracing::info!("Clean up complete");
Ok(())
}
async fn test_connection(&self) -> Result<()> {
let ip = self.wg.remote_device.interface.address;
tokio::process::Command::new("ping")
.arg("-c1")
.arg("-w5")
.arg(ip.to_string())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.await
.context("Failed to ping remote Wireguard interface, tunnel broken")?;
tracing::debug!("Confirmed tunnel is established, able to ping across it");
Ok(())
}
async fn bring_up_remote_wg(&self) -> Result<()> {
let cmd = vec!["wg-quick", "up", "/tmp/innisfree.conf"];
tracing::trace!("Activating remote wg interface");
self.run_ssh_cmd(cmd).await
}
fn known_hosts(&self) -> Result<String> {
let ipv4_address = &self.server.ipv4_address()?;
let server_host_key = &self.ssh_server_keypair.public;
let host_line = format!("{} {}", ipv4_address, server_host_key);
let fpath = self.state.known_hosts();
std::fs::write(&fpath, host_line).context("Failed to create known_hosts")?;
Ok(fpath.display().to_string())
}
async fn run_ssh_cmd(&self, cmd: Vec<&str>) -> Result<()> {
let key_path = self.state.client_key();
self.ssh_client_keypair.write_to(&key_path)?;
let known_hosts = self.known_hosts()?;
let ip = self.server.ipv4_address()?.to_string();
let pretty_cmd = cmd.join(" ");
let mut cmd_args = ssh_base_args(&key_path.display().to_string(), &known_hosts, &ip);
cmd_args.extend(cmd.iter().map(|s| (*s).to_string()));
let output = tokio::process::Command::new("ssh")
.args(&cmd_args)
.output()
.await
.context("invoking ssh")?;
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
if !output.status.success() {
bail!(
"remote command `{pretty_cmd}` failed (ssh exit {}):\n\
--- stderr ---\n{}\n--- stdout ---\n{}",
output.status,
stderr.trim_end(),
stdout.trim_end(),
);
}
if !stdout.trim().is_empty() {
tracing::debug!("remote `{pretty_cmd}` stdout:\n{}", stdout.trim_end());
}
if !stderr.trim().is_empty() {
tracing::debug!("remote `{pretty_cmd}` stderr:\n{}", stderr.trim_end());
}
Ok(())
}
pub async fn clean(&mut self) -> Result<()> {
tracing::debug!("removing local Wireguard interface");
self.local_wg = None;
let destroy_result = self
.server
.destroy()
.await
.context("destroying remote server");
if let Err(e) = &destroy_result {
tracing::error!(
"failed to destroy remote server (manual cleanup may be required): {e:#}"
);
}
let config_result =
remove_state_for_service(&self.name).context("removing local state dir");
destroy_result?;
config_result
}
}
pub fn get_server_ip(service_name: &str) -> Result<IpAddr> {
tracing::trace!("Looking up server IP from ready marker");
let state = TunnelStateDir::open(service_name)?;
if let Some(status) = state.read_status().context("reading status.json")? {
return Ok(status.ip);
}
let fpath = state.ip_marker();
let content = std::fs::read_to_string(&fpath)
.with_context(|| format!("reading {} (tunnel may not be ready yet)", fpath.display()))?;
let ip: IpAddr = content
.trim()
.parse()
.with_context(|| format!("parsing IP from {}", fpath.display()))?;
Ok(ip)
}
pub async fn open_shell(service_name: &str) -> Result<()> {
let state = TunnelStateDir::open(service_name)?;
let key_path = state.client_key().display().to_string();
let known_hosts = state.known_hosts().display().to_string();
let ip = get_server_ip(service_name)?.to_string();
let cmd_args = ssh_base_args(&key_path, &known_hosts, &ip);
tokio::process::Command::new("ssh")
.args(&cmd_args)
.status()
.await
.context("SSH interactive session failed")?;
Ok(())
}
fn ssh_base_args(key_path: &str, known_hosts_path: &str, ip: &str) -> Vec<String> {
vec![
"-l".into(),
"innisfree".into(),
"-i".into(),
key_path.into(),
"-o".into(),
format!("UserKnownHostsFile={known_hosts_path}"),
"-o".into(),
"ConnectTimeout=5".into(),
ip.into(),
]
}
pub async fn run_proxy(
local_ip: IpAddr,
dest_ip: IpAddr,
services: Vec<ServicePort>,
) -> Result<()> {
let mut tasks: Vec<BoxFuture<'static, Result<()>>> = vec![];
for s in services {
let listen_addr: SocketAddr = format!("{}:{}", local_ip, &s.local_port).parse()?;
let dest_addr: SocketAddr = format!("{}:{}", dest_ip, &s.port).parse()?;
let h: BoxFuture<'static, Result<()>> = match s.protocol {
Protocol::Tcp => Box::pin(proxy_handler(listen_addr, dest_addr)),
Protocol::Udp => Box::pin(udp_handler(listen_addr, dest_addr)),
};
tasks.push(h);
}
let proxy_tasks = join_all(tasks).await;
tracing::warn!("Proxy stopped unexpectedly, no longer forwarding traffic");
for t in proxy_tasks {
match t {
Ok(t) => {
tracing::debug!("Service proxy returned ok: {:?}", t);
}
Err(e) => {
return Err(anyhow!("Service proxy failed: {}", e));
}
}
}
Ok(())
}