use anyhow::{anyhow, bail, Context, Result};
use directories::BaseDirs;
use fs2::FileExt;
use serde::{Deserialize, Serialize};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::net::IpAddr;
use std::os::unix::fs::OpenOptionsExt;
use std::path::{Path, PathBuf};
use time::OffsetDateTime;
use crate::config::ServicePort;
use crate::wg::WireguardManager;
const LOCK_FILENAME: &str = ".lock";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TunnelConfig {
pub name: String,
pub services: Vec<ServicePort>,
pub dest_ip: IpAddr,
pub floating_ip: Option<IpAddr>,
pub provider: String,
pub region: String,
pub size: String,
pub image: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TunnelIdentity {
pub wireguard: WireguardManager,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TunnelStatus {
pub ip: IpAddr,
pub droplet_id: String,
#[serde(with = "time::serde::rfc3339")]
pub ready_at: OffsetDateTime,
}
#[derive(Debug)]
pub struct TunnelStateDir {
path: PathBuf,
name: String,
_lock: Option<File>,
}
impl TunnelStateDir {
pub fn for_service(service_name: &str) -> Result<Self> {
let path = base_dir()?.join(service_name);
migrate_legacy_dir(service_name, &path)?;
std::fs::create_dir_all(&path)
.with_context(|| format!("creating state dir {}", path.display()))?;
let lock = acquire_lock(service_name, &path)?;
let cfg_file = config_file_for(service_name)?;
if dir_has_leftover_state(&path)? || cfg_file.exists() {
bail!(
"found leftover state at {} (or config at {}) from a previous \
`innisfree up --name {service_name}` that did not clean up.\n\
Run `innisfree clean --name {service_name}` to wipe it, \
or pass `--force` to overwrite.",
path.display(),
cfg_file.display(),
);
}
Ok(Self {
path,
name: service_name.to_string(),
_lock: Some(lock),
})
}
pub fn open(service_name: &str) -> Result<Self> {
let path = base_dir()?.join(service_name);
migrate_legacy_dir(service_name, &path)?;
if !path.is_dir() {
return Err(anyhow!(
"no state for service '{service_name}' at {} — was the tunnel ever brought up?",
path.display()
));
}
Ok(Self {
path,
name: service_name.to_string(),
_lock: None,
})
}
pub fn config_file(&self) -> Result<PathBuf> {
config_file_for(&self.name)
}
pub fn identity_file(&self) -> PathBuf {
self.path.join("identity.json")
}
pub fn write_config(&self, config: &TunnelConfig) -> Result<()> {
let path = self.config_file()?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating config dir {}", parent.display()))?;
}
let body = toml::to_string_pretty(config).context("serializing tunnel config to TOML")?;
std::fs::write(&path, body).with_context(|| format!("writing {}", path.display()))?;
Ok(())
}
pub fn read_config(&self) -> Result<Option<TunnelConfig>> {
let path = self.config_file()?;
let body = match std::fs::read_to_string(&path) {
Ok(s) => s,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(e).with_context(|| format!("reading {}", path.display())),
};
let cfg: TunnelConfig =
toml::from_str(&body).with_context(|| format!("parsing {}", path.display()))?;
Ok(Some(cfg))
}
pub fn write_identity(&self, identity: &TunnelIdentity) -> Result<()> {
let path = self.identity_file();
let body =
serde_json::to_vec_pretty(identity).context("serializing tunnel identity to JSON")?;
let mut f = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)
.with_context(|| format!("opening {} for writing", path.display()))?;
f.write_all(&body)
.with_context(|| format!("writing {}", path.display()))?;
Ok(())
}
pub fn read_identity(&self) -> Result<Option<TunnelIdentity>> {
let path = self.identity_file();
let bytes = match std::fs::read(&path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(e).with_context(|| format!("reading {}", path.display())),
};
let id: TunnelIdentity = serde_json::from_slice(&bytes)
.with_context(|| format!("parsing {}", path.display()))?;
Ok(Some(id))
}
pub fn status_file(&self) -> PathBuf {
self.path.join("status.json")
}
pub fn ip_marker(&self) -> PathBuf {
self.path.join("ip")
}
pub fn write_status(&self, status: &TunnelStatus) -> Result<()> {
let json_path = self.status_file();
let json =
serde_json::to_vec_pretty(status).context("serializing tunnel status to JSON")?;
std::fs::write(&json_path, &json)
.with_context(|| format!("writing {}", json_path.display()))?;
let ip_path = self.ip_marker();
std::fs::write(&ip_path, format!("{}\n", status.ip))
.with_context(|| format!("writing {}", ip_path.display()))?;
Ok(())
}
pub fn read_status(&self) -> Result<Option<TunnelStatus>> {
let path = self.status_file();
let bytes = match std::fs::read(&path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => {
return Err(e).with_context(|| format!("reading {}", path.display()));
}
};
let status: TunnelStatus = serde_json::from_slice(&bytes)
.with_context(|| format!("parsing {}", path.display()))?;
Ok(Some(status))
}
pub fn known_hosts(&self) -> PathBuf {
self.path.join("known_hosts")
}
pub fn client_key(&self) -> PathBuf {
self.path.join("client_id_ed25519")
}
pub fn wg_conf(&self) -> PathBuf {
self.path.join(format!("{}.conf", self.name))
}
}
pub fn remove_state_for_service(service_name: &str) -> Result<()> {
let path = base_dir()?.join(service_name);
if path.is_dir() {
std::fs::remove_dir_all(&path).with_context(|| format!("removing {}", path.display()))?;
}
let cfg = config_file_for(service_name)?;
if cfg.exists() {
std::fs::remove_file(&cfg).with_context(|| format!("removing config {}", cfg.display()))?;
}
Ok(())
}
fn migrate_legacy_dir(service_name: &str, target: &Path) -> Result<()> {
if target.exists() {
return Ok(());
}
let Some(legacy) = legacy_base_dir().map(|p| p.join(service_name)) else {
return Ok(());
};
if !legacy.is_dir() {
return Ok(());
}
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating state parent {}", parent.display()))?;
}
std::fs::rename(&legacy, target).with_context(|| {
format!(
"migrating legacy state {} -> {}",
legacy.display(),
target.display()
)
})?;
tracing::info!(
"migrated legacy state dir {} -> {}",
legacy.display(),
target.display()
);
Ok(())
}
fn base_dir() -> Result<PathBuf> {
let bd = BaseDirs::new().ok_or_else(|| anyhow!("could not resolve home directory"))?;
let state_root = bd
.state_dir()
.map(Path::to_path_buf)
.unwrap_or_else(|| bd.home_dir().join(".local").join("state"));
Ok(state_root.join("innisfree"))
}
fn config_base_dir() -> Result<PathBuf> {
let bd = BaseDirs::new().ok_or_else(|| anyhow!("could not resolve home directory"))?;
Ok(bd.config_dir().join("innisfree"))
}
fn config_file_for(service_name: &str) -> Result<PathBuf> {
Ok(config_base_dir()?.join(format!("{service_name}.toml")))
}
fn legacy_base_dir() -> Option<PathBuf> {
BaseDirs::new().map(|bd| bd.home_dir().join(".config").join("innisfree"))
}
fn acquire_lock(service_name: &str, dir: &Path) -> Result<File> {
let lock_path = dir.join(LOCK_FILENAME);
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&lock_path)
.with_context(|| format!("opening lock file {}", lock_path.display()))?;
if let Err(e) = file.try_lock_exclusive() {
let holder_pid = std::fs::read_to_string(&lock_path)
.ok()
.and_then(|s| s.trim().parse::<u32>().ok());
let detail = match holder_pid {
Some(pid) => format!("(pid {pid} holds the lock)"),
None => format!("(lock at {} held)", lock_path.display()),
};
return Err(anyhow!(e)).with_context(|| {
format!("another `innisfree up --name {service_name}` appears to be running {detail}")
});
}
file.set_len(0)
.with_context(|| format!("truncating {}", lock_path.display()))?;
writeln!(&file, "{}", std::process::id())
.with_context(|| format!("writing pid to {}", lock_path.display()))?;
Ok(file)
}
fn dir_has_leftover_state(dir: &Path) -> Result<bool> {
for entry in
std::fs::read_dir(dir).with_context(|| format!("reading state dir {}", dir.display()))?
{
let entry = entry?;
if entry.file_name() != LOCK_FILENAME {
return Ok(true);
}
}
Ok(false)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn for_service_uses_xdg_state_home() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let dir = TunnelStateDir::for_service("svc-fresh").unwrap();
let expected = state_root.join("innisfree").join("svc-fresh");
assert_eq!(dir.path, expected);
assert!(expected.is_dir());
},
);
}
#[test]
fn migrates_from_legacy_config_dir() {
let td = TempDir::new().unwrap();
let home = td.path().join("home");
let state_root = td.path().join("state");
std::fs::create_dir_all(&home).unwrap();
let legacy = home.join(".config").join("innisfree").join("svc-mig");
std::fs::create_dir_all(&legacy).unwrap();
std::fs::write(legacy.join("ip"), "1.2.3.4\n").unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let dir = TunnelStateDir::open("svc-mig").unwrap();
let expected = state_root.join("innisfree").join("svc-mig");
assert_eq!(dir.path, expected);
assert!(expected.join("ip").is_file());
assert!(!legacy.exists());
},
);
}
#[test]
fn for_service_writes_pid_into_lock_file() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let dir = TunnelStateDir::for_service("svc-pid").unwrap();
let lock = dir.path.join(LOCK_FILENAME);
let recorded: u32 = std::fs::read_to_string(&lock)
.unwrap()
.trim()
.parse()
.unwrap();
assert_eq!(recorded, std::process::id());
},
);
}
#[test]
fn second_for_service_call_blocks_on_lock() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let first = TunnelStateDir::for_service("svc-busy").unwrap();
let err = TunnelStateDir::for_service("svc-busy").unwrap_err();
let chain: String = err
.chain()
.map(|c| format!("{c}"))
.collect::<Vec<_>>()
.join(" / ");
let pid = std::process::id();
assert!(
chain.contains(&format!("pid {pid}")),
"expected pid in error chain, got: {chain}"
);
drop(first);
let _second = TunnelStateDir::for_service("svc-busy").unwrap();
},
);
}
#[test]
fn for_service_rejects_leftover_state() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let stale = state_root.join("innisfree").join("svc-stale");
std::fs::create_dir_all(&stale).unwrap();
std::fs::write(stale.join("ip"), "9.9.9.9\n").unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let err = TunnelStateDir::for_service("svc-stale").unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("leftover state") && msg.contains("--force"),
"expected leftover-state guidance, got: {msg}"
);
},
);
}
#[test]
fn status_round_trips_through_status_json() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let dir = TunnelStateDir::for_service("svc-status").unwrap();
let written = TunnelStatus {
ip: "203.0.113.7".parse().unwrap(),
droplet_id: "12345".to_string(),
ready_at: OffsetDateTime::from_unix_timestamp(1_700_000_000).unwrap(),
};
dir.write_status(&written).unwrap();
let raw = std::fs::read_to_string(dir.path.join("status.json")).unwrap();
assert!(raw.contains("\"ip\""), "status.json missing ip: {raw}");
assert!(
raw.contains("\"droplet_id\": \"12345\""),
"status.json missing droplet_id: {raw}"
);
assert!(
raw.contains("2023-11-14T22:13:20Z"),
"status.json missing rfc3339 ready_at: {raw}"
);
let bare = std::fs::read_to_string(dir.path.join("ip")).unwrap();
assert_eq!(bare.trim(), "203.0.113.7");
let read = dir.read_status().unwrap().unwrap();
assert_eq!(read.ip, written.ip);
assert_eq!(read.droplet_id, written.droplet_id);
assert_eq!(read.ready_at, written.ready_at);
},
);
}
#[test]
fn read_status_returns_none_when_file_absent() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let dir = TunnelStateDir::for_service("svc-no-status").unwrap();
assert!(dir.read_status().unwrap().is_none());
},
);
}
#[test]
fn config_round_trips_through_xdg_config_home() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let cfg_root = td.path().join("config");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let state_str = state_root.to_str().unwrap();
let cfg_str = cfg_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("XDG_CONFIG_HOME", Some(cfg_str)),
("HOME", Some(home_str)),
],
|| {
let dir = TunnelStateDir::for_service("svc-cfg").unwrap();
let written = TunnelConfig {
name: "svc-cfg".to_string(),
services: ServicePort::from_str_multi("80/TCP,443/TCP").unwrap(),
dest_ip: "127.0.0.1".parse().unwrap(),
floating_ip: Some("198.51.100.7".parse().unwrap()),
provider: "digitalocean".to_string(),
region: "sfo3".to_string(),
size: "s-1vcpu-1gb".to_string(),
image: "debian-13-x64".to_string(),
};
dir.write_config(&written).unwrap();
let expected = cfg_root.join("innisfree").join("svc-cfg.toml");
assert_eq!(dir.config_file().unwrap(), expected);
assert!(
expected.is_file(),
"config file not written at {expected:?}"
);
let raw = std::fs::read_to_string(&expected).unwrap();
assert!(raw.contains("provider = \"digitalocean\""), "{raw}");
assert!(raw.contains("region = \"sfo3\""), "{raw}");
let read = dir.read_config().unwrap().unwrap();
assert_eq!(read.name, written.name);
assert_eq!(read.dest_ip, written.dest_ip);
assert_eq!(read.floating_ip, written.floating_ip);
assert_eq!(read.region, written.region);
assert_eq!(read.services.len(), written.services.len());
},
);
}
#[test]
fn read_config_returns_none_when_file_absent() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let cfg_root = td.path().join("config");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_root.to_str().unwrap())),
("XDG_CONFIG_HOME", Some(cfg_root.to_str().unwrap())),
("HOME", Some(home.to_str().unwrap())),
],
|| {
let dir = TunnelStateDir::for_service("svc-no-cfg").unwrap();
assert!(dir.read_config().unwrap().is_none());
},
);
}
#[test]
fn identity_round_trips_and_is_mode_0600() {
use std::os::unix::fs::PermissionsExt;
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let cfg_root = td.path().join("config");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_root.to_str().unwrap())),
("XDG_CONFIG_HOME", Some(cfg_root.to_str().unwrap())),
("HOME", Some(home.to_str().unwrap())),
],
|| {
let dir = TunnelStateDir::for_service("svc-id").unwrap();
let wg = WireguardManager::new("svc-id").unwrap();
let written = TunnelIdentity {
wireguard: wg.clone(),
};
dir.write_identity(&written).unwrap();
let id_path = dir.identity_file();
assert!(id_path.is_file());
let mode = std::fs::metadata(&id_path).unwrap().permissions().mode();
assert_eq!(
mode & 0o777,
0o600,
"identity.json must be 0600, got {:o}",
mode & 0o777,
);
let read = dir.read_identity().unwrap().unwrap();
assert_eq!(
read.wireguard
.local_device
.interface
.keypair
.public_bytes()
.unwrap(),
wg.local_device.interface.keypair.public_bytes().unwrap(),
);
assert_eq!(
read.wireguard
.local_device
.interface
.keypair
.private_bytes()
.unwrap(),
wg.local_device.interface.keypair.private_bytes().unwrap(),
);
assert_eq!(
read.wireguard.remote_device.interface.address,
wg.remote_device.interface.address,
);
},
);
}
#[test]
fn remove_state_drops_config_file_too() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let cfg_root = td.path().join("config");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_root.to_str().unwrap())),
("XDG_CONFIG_HOME", Some(cfg_root.to_str().unwrap())),
("HOME", Some(home.to_str().unwrap())),
],
|| {
let svc = "svc-clean";
let dir = TunnelStateDir::for_service(svc).unwrap();
let cfg = TunnelConfig {
name: svc.to_string(),
services: vec![],
dest_ip: "127.0.0.1".parse().unwrap(),
floating_ip: None,
provider: "digitalocean".to_string(),
region: "sfo3".to_string(),
size: "s-1vcpu-1gb".to_string(),
image: "debian-13-x64".to_string(),
};
dir.write_config(&cfg).unwrap();
let cfg_path = dir.config_file().unwrap();
let state_path = dir.path.clone();
drop(dir);
assert!(cfg_path.is_file());
assert!(state_path.is_dir());
remove_state_for_service(svc).unwrap();
assert!(!cfg_path.exists(), "config file survived clean");
assert!(!state_path.exists(), "state dir survived clean");
remove_state_for_service(svc).unwrap();
},
);
}
#[test]
fn for_service_rejects_leftover_config() {
let td = TempDir::new().unwrap();
let state_root = td.path().join("state");
let cfg_root = td.path().join("config");
let home = td.path().join("home");
std::fs::create_dir_all(&home).unwrap();
let cfg_path = cfg_root.join("innisfree").join("svc-cfg-stale.toml");
std::fs::create_dir_all(cfg_path.parent().unwrap()).unwrap();
std::fs::write(&cfg_path, "name = \"svc-cfg-stale\"\n").unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_root.to_str().unwrap())),
("XDG_CONFIG_HOME", Some(cfg_root.to_str().unwrap())),
("HOME", Some(home.to_str().unwrap())),
],
|| {
let err = TunnelStateDir::for_service("svc-cfg-stale").unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.contains("leftover state") && msg.contains("--force"),
"expected leftover guidance mentioning config, got: {msg}"
);
},
);
}
#[test]
fn migration_is_noop_when_no_legacy_dir() {
let td = TempDir::new().unwrap();
let home = td.path().join("home");
let state_root = td.path().join("state");
std::fs::create_dir_all(&home).unwrap();
let state_str = state_root.to_str().unwrap();
let home_str = home.to_str().unwrap();
temp_env::with_vars(
[
("XDG_STATE_HOME", Some(state_str)),
("HOME", Some(home_str)),
],
|| {
let dir = TunnelStateDir::for_service("svc-virgin").unwrap();
let expected = state_root.join("innisfree").join("svc-virgin");
assert_eq!(dir.path, expected);
assert!(expected.is_dir());
},
);
}
}