use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::time::Duration;
use thiserror::Error;
pub const FORWARDED_ENV_ALLOWLIST: &[&str] = &["WSLENV", "TC_WSL_DISTRO"];
#[must_use]
pub fn build_forward_env() -> BTreeMap<String, String> {
build_forward_env_with(&crate::paths::ProcessEnv)
}
#[must_use]
pub fn build_forward_env_with(env: &impl crate::paths::EnvSource) -> BTreeMap<String, String> {
FORWARDED_ENV_ALLOWLIST
.iter()
.filter_map(|k| env.get(k).map(|v| ((*k).to_owned(), v)))
.collect()
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Endpoint {
UnixSocket { path: PathBuf },
WindowsPipe { name: String },
}
#[derive(Debug, Clone, Serialize)]
pub struct Diagnostics {
pub endpoint: Endpoint,
pub log_path: Option<PathBuf>,
pub last_error: Option<String>,
pub startup_attempted: bool,
pub startup_elapsed_ms: u64,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum EnsureDaemonStatus {
AlreadyRunning {
endpoint: Endpoint,
pid: Option<u32>,
},
Started {
endpoint: Endpoint,
pid: Option<u32>,
log_path: PathBuf,
},
Unavailable {
reason: DaemonUnavailableReason,
diagnostics: Diagnostics,
},
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum DaemonUnavailableReason {
SpawnFailed,
StartupTimeout,
EndpointBindFailed,
BinaryNotFound,
}
#[derive(Debug, Error)]
pub enum EnsureError {
#[error("daemon binary not found at {0}")]
BinaryNotFound(PathBuf),
}
#[derive(Debug, Clone)]
pub struct EnsureDaemonOptions {
pub daemon_binary: PathBuf,
pub state_dir: PathBuf,
pub log_dir: PathBuf,
pub endpoint: Endpoint,
pub startup_timeout: Duration,
pub allow_spawn: bool,
}
pub async fn ensure_daemon(opts: EnsureDaemonOptions) -> EnsureDaemonStatus {
let start = std::time::Instant::now();
if probe_endpoint(&opts.endpoint).await {
return EnsureDaemonStatus::AlreadyRunning {
endpoint: opts.endpoint,
pid: None,
};
}
if !opts.allow_spawn {
return EnsureDaemonStatus::Unavailable {
reason: DaemonUnavailableReason::EndpointBindFailed,
diagnostics: Diagnostics {
endpoint: opts.endpoint,
log_path: None,
last_error: Some("endpoint unreachable; spawn disabled".into()),
startup_attempted: false,
startup_elapsed_ms: start.elapsed().as_millis() as u64,
},
};
}
let binary_has_separator =
opts.daemon_binary.components().nth(1).is_some() || opts.daemon_binary.is_absolute();
if binary_has_separator && !opts.daemon_binary.exists() {
return EnsureDaemonStatus::Unavailable {
reason: DaemonUnavailableReason::BinaryNotFound,
diagnostics: Diagnostics {
endpoint: opts.endpoint,
log_path: None,
last_error: Some(format!(
"daemon binary not found: {}",
opts.daemon_binary.display()
)),
startup_attempted: false,
startup_elapsed_ms: start.elapsed().as_millis() as u64,
},
};
}
let _ = std::fs::create_dir_all(&opts.log_dir);
let log_path = opts.log_dir.join("terminal-commanderd.log");
let log_file = match std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&log_path)
{
Ok(f) => f,
Err(e) => {
return EnsureDaemonStatus::Unavailable {
reason: DaemonUnavailableReason::SpawnFailed,
diagnostics: Diagnostics {
endpoint: opts.endpoint,
log_path: Some(log_path),
last_error: Some(format!("open log: {e}")),
startup_attempted: false,
startup_elapsed_ms: start.elapsed().as_millis() as u64,
},
};
}
};
let log_file_err = match log_file.try_clone() {
Ok(f) => f,
Err(e) => {
return EnsureDaemonStatus::Unavailable {
reason: DaemonUnavailableReason::SpawnFailed,
diagnostics: Diagnostics {
endpoint: opts.endpoint,
log_path: Some(log_path),
last_error: Some(format!("clone log fd: {e}")),
startup_attempted: false,
startup_elapsed_ms: start.elapsed().as_millis() as u64,
},
};
}
};
let tc_socket_val: std::ffi::OsString = match &opts.endpoint {
Endpoint::UnixSocket { path } => path.as_os_str().into(),
Endpoint::WindowsPipe { name } => name.into(),
};
let mut cmd = std::process::Command::new(&opts.daemon_binary);
cmd.arg("--data-dir")
.arg(&opts.state_dir)
.arg("start")
.arg("--mode")
.arg("ipc-server")
.env("TC_SOCKET", &tc_socket_val)
.envs(build_forward_env())
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::from(log_file))
.stderr(std::process::Stdio::from(log_file_err));
let child = match cmd.spawn() {
Ok(c) => c,
Err(e) => {
return EnsureDaemonStatus::Unavailable {
reason: DaemonUnavailableReason::SpawnFailed,
diagnostics: Diagnostics {
endpoint: opts.endpoint,
log_path: Some(log_path),
last_error: Some(format!("spawn: {e}")),
startup_attempted: true,
startup_elapsed_ms: start.elapsed().as_millis() as u64,
},
};
}
};
let pid = Some(child.id());
drop(child);
let deadline = std::time::Instant::now() + opts.startup_timeout;
while std::time::Instant::now() < deadline {
if probe_endpoint(&opts.endpoint).await {
return EnsureDaemonStatus::Started {
endpoint: opts.endpoint,
pid,
log_path,
};
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
EnsureDaemonStatus::Unavailable {
reason: DaemonUnavailableReason::StartupTimeout,
diagnostics: Diagnostics {
endpoint: opts.endpoint,
log_path: Some(log_path),
last_error: Some(format!(
"endpoint did not bind within {}ms",
opts.startup_timeout.as_millis()
)),
startup_attempted: true,
startup_elapsed_ms: start.elapsed().as_millis() as u64,
},
}
}
const PROBE_TIMEOUT: Duration = Duration::from_millis(500);
const PROBE_MAX_FRAME_BYTES: usize = 256 * 1024;
#[derive(Deserialize)]
struct ProbeResponseEnvelope {
result: ProbeResult,
}
#[derive(Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
enum ProbeResult {
Ok { response: ProbeResponse },
#[serde(other)]
Other,
}
#[derive(Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
enum ProbeResponse {
Health {
#[allow(dead_code)]
uptime_secs: u64,
#[serde(default)]
#[allow(dead_code)]
idle_secs: Option<u64>,
},
#[serde(other)]
Other,
}
async fn health_handshake<S>(mut stream: S) -> bool
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
const REQUEST_JSON: &[u8] = br#"{"correlation_id":0,"request":{"method":"health"}}"#;
let Ok(len) = u32::try_from(REQUEST_JSON.len()) else {
return false;
};
if stream.write_all(&len.to_be_bytes()).await.is_err() {
return false;
}
if stream.write_all(REQUEST_JSON).await.is_err() {
return false;
}
if stream.flush().await.is_err() {
return false;
}
let mut len_buf = [0_u8; 4];
if stream.read_exact(&mut len_buf).await.is_err() {
return false;
}
let resp_len = u32::from_be_bytes(len_buf) as usize;
if resp_len == 0 || resp_len > PROBE_MAX_FRAME_BYTES {
return false;
}
let mut payload = vec![0_u8; resp_len];
if stream.read_exact(&mut payload).await.is_err() {
return false;
}
matches!(
serde_json::from_slice::<ProbeResponseEnvelope>(&payload),
Ok(ProbeResponseEnvelope {
result: ProbeResult::Ok {
response: ProbeResponse::Health { .. },
},
})
)
}
pub async fn probe_endpoint(endpoint: &Endpoint) -> bool {
let handshake = async {
match endpoint {
#[cfg(unix)]
Endpoint::UnixSocket { path } => match tokio::net::UnixStream::connect(path).await {
Ok(stream) => health_handshake(stream).await,
Err(_) => false,
},
#[cfg(not(unix))]
Endpoint::UnixSocket { .. } => false,
#[cfg(windows)]
Endpoint::WindowsPipe { name } => {
use tokio::net::windows::named_pipe::ClientOptions;
match ClientOptions::new().open(name.as_str()) {
Ok(stream) => health_handshake(stream).await,
Err(_) => false,
}
}
#[cfg(not(windows))]
Endpoint::WindowsPipe { .. } => false,
}
};
(tokio::time::timeout(PROBE_TIMEOUT, handshake).await).unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn stub_returns_unavailable() {
let opts = EnsureDaemonOptions {
daemon_binary: PathBuf::from("nonexistent"),
state_dir: PathBuf::from("."),
log_dir: PathBuf::from("."),
endpoint: Endpoint::WindowsPipe {
name: r"\\.\pipe\unused".into(),
},
startup_timeout: Duration::from_millis(10),
allow_spawn: false,
};
let status = ensure_daemon(opts).await;
assert!(matches!(status, EnsureDaemonStatus::Unavailable { .. }));
}
#[tokio::test]
async fn bare_binary_name_does_not_fail_fast_on_missing_check() {
let dir = tempfile::TempDir::new().unwrap();
let opts = EnsureDaemonOptions {
daemon_binary: PathBuf::from("definitely-not-installed-xyz"),
state_dir: dir.path().into(),
log_dir: dir.path().into(),
endpoint: Endpoint::WindowsPipe {
name: r"\\.\pipe\unused".into(),
},
startup_timeout: Duration::from_millis(10),
allow_spawn: true,
};
let status = ensure_daemon(opts).await;
match status {
EnsureDaemonStatus::Unavailable {
reason,
diagnostics,
} => {
assert!(
matches!(reason, DaemonUnavailableReason::SpawnFailed),
"expected SpawnFailed, got {reason:?}"
);
assert!(
diagnostics.startup_attempted,
"startup must have been attempted (spawn was called)"
);
}
other => panic!("expected Unavailable, got {other:?}"),
}
}
#[test]
fn forward_env_allowlist_is_operational_non_secret() {
for k in FORWARDED_ENV_ALLOWLIST {
let lk = k.to_ascii_lowercase();
assert!(
!lk.contains("secret")
&& !lk.contains("password")
&& !lk.contains("credential")
&& !lk.contains("token")
&& !lk.contains("key"),
"allowlist must be operational-only; '{k}' looks secret"
);
}
assert!(FORWARDED_ENV_ALLOWLIST.contains(&"WSLENV"));
}
#[test]
fn build_forward_env_forwards_only_allowlisted_vars() {
struct FakeEnv(std::collections::HashMap<String, String>);
impl crate::paths::EnvSource for FakeEnv {
fn get(&self, key: &str) -> Option<String> {
self.0.get(key).cloned()
}
}
let secret = "TC_F6_TEST_SECRET_THING";
let mut map = std::collections::HashMap::new();
map.insert("WSLENV".to_owned(), "TC_F6/u".to_owned());
map.insert(secret.to_owned(), "nope".to_owned());
let env = build_forward_env_with(&FakeEnv(map));
assert_eq!(env.get("WSLENV").map(String::as_str), Some("TC_F6/u"));
assert!(
!env.contains_key(secret),
"non-allowlisted var must not be forwarded"
);
}
}