use crate::types::{WorkerConfig, WorkerId};
use anyhow::{Context, Result};
use openssh::{ControlPersist, KnownHosts, Session, SessionBuilder, Stdio};
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
use tokio::sync::{RwLock, mpsc};
use tracing::{debug, error, info, warn};
pub use crate::ssh_utils::{
CommandResult, EnvPrefix, build_env_prefix, is_retryable_transport_error,
is_retryable_transport_error_text, is_valid_env_key, shell_escape_value,
};
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(300);
const MAX_OUTPUT_SIZE: u64 = 10 * 1024 * 1024;
const HEALTH_CHECK_COMMAND: &str = "echo ok";
fn is_expected_health_check_output(stdout: &str) -> bool {
stdout
.trim()
.lines()
.last()
.is_some_and(is_health_check_sentinel)
}
fn is_health_check_sentinel(line: &str) -> bool {
matches!(line.trim(), "ok")
}
#[derive(Debug, Clone)]
pub struct SshOptions {
pub connect_timeout: Duration,
pub command_timeout: Duration,
pub server_alive_interval: Option<Duration>,
pub control_persist_idle: Option<Duration>,
pub control_master: bool,
pub known_hosts: KnownHostsPolicy,
}
impl Default for SshOptions {
fn default() -> Self {
Self {
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
command_timeout: DEFAULT_COMMAND_TIMEOUT,
server_alive_interval: None,
control_persist_idle: None,
control_master: false,
known_hosts: KnownHostsPolicy::Add,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KnownHostsPolicy {
Strict,
Add,
AcceptAll,
}
#[cfg(test)]
mod retry_tests {
use super::*;
use crate::test_guard;
#[test]
fn test_retryable_transport_error_text() {
let _guard = test_guard!();
assert!(is_retryable_transport_error_text(
"ssh: connect to host 1.2.3.4 port 22: Connection timed out"
));
assert!(is_retryable_transport_error_text(
"kex_exchange_identification: Connection reset by peer"
));
assert!(is_retryable_transport_error_text("Broken pipe"));
assert!(is_retryable_transport_error_text("Network is unreachable"));
}
#[test]
fn test_non_retryable_transport_error_text() {
let _guard = test_guard!();
assert!(!is_retryable_transport_error_text(
"Permission denied (publickey)."
));
assert!(!is_retryable_transport_error_text(
"Host key verification failed."
));
assert!(!is_retryable_transport_error_text(
"Could not resolve hostname worker.example.com: Name or service not known"
));
assert!(!is_retryable_transport_error_text(
"Identity file /nope/id_rsa not accessible: No such file or directory"
));
}
}
pub struct SshClient {
config: WorkerConfig,
options: SshOptions,
session: Option<Session>,
}
impl SshClient {
pub fn new(config: WorkerConfig, options: SshOptions) -> Self {
Self {
config,
options,
session: None,
}
}
pub fn worker_id(&self) -> &WorkerId {
&self.config.id
}
pub fn is_connected(&self) -> bool {
self.session.is_some()
}
fn is_configured_for(&self, config: &WorkerConfig) -> bool {
self.config.id == config.id
&& self.config.host == config.host
&& self.config.user == config.user
&& self.config.identity_file == config.identity_file
}
pub async fn connect(&mut self) -> Result<()> {
if self.session.is_some() {
debug!("Already connected to {}", self.config.id);
return Ok(());
}
let destination = format!("{}@{}", self.config.user, self.config.host);
debug!("Connecting to {} via SSH...", destination);
let session = match self
.connect_with_mode(&destination, self.options.control_master)
.await
{
Ok(session) => session,
Err(primary_error) if self.options.control_master => {
warn!(
"SSH ControlMaster connection to {} failed ({}). Retrying without ControlMaster.",
destination, primary_error
);
self.connect_with_mode(&destination, false)
.await
.with_context(|| {
format!(
"Failed to connect to {} after retrying without ControlMaster",
destination
)
})?
}
Err(primary_error) => {
return Err(primary_error)
.with_context(|| format!("Failed to connect to {}", destination));
}
};
info!("Connected to {} ({})", self.config.id, self.config.host);
self.session = Some(session);
Ok(())
}
async fn connect_with_mode(&self, destination: &str, control_master: bool) -> Result<Session> {
let mut builder = SessionBuilder::default();
self.configure_builder(&mut builder, control_master);
builder.connect(destination).await.with_context(|| {
if control_master {
format!(
"Failed to connect to {} with ControlMaster enabled",
destination
)
} else {
format!(
"Failed to connect to {} with ControlMaster disabled",
destination
)
}
})
}
fn configure_builder(&self, builder: &mut SessionBuilder, control_master: bool) {
let known_hosts = match self.options.known_hosts {
KnownHostsPolicy::Strict => KnownHosts::Strict,
KnownHostsPolicy::Add => KnownHosts::Add,
KnownHostsPolicy::AcceptAll => KnownHosts::Accept,
};
builder
.known_hosts_check(known_hosts)
.connect_timeout(self.options.connect_timeout);
if let Some(interval) = self.options.server_alive_interval {
builder.server_alive_interval(interval);
}
let identity_path = shellexpand::tilde(&self.config.identity_file);
if Path::new(identity_path.as_ref()).exists() {
builder.keyfile(identity_path.as_ref());
}
if control_master {
if let Some(idle) = self.options.control_persist_idle {
if idle.is_zero() {
builder.control_persist(ControlPersist::ClosedAfterInitialConnection);
} else {
match usize::try_from(idle.as_secs()) {
Ok(secs) => {
if let Some(nonzero) = NonZeroUsize::new(secs) {
builder.control_persist(ControlPersist::IdleFor(nonzero));
} else {
builder
.control_persist(ControlPersist::ClosedAfterInitialConnection);
}
}
Err(_) => {
warn!(
"control_persist_idle too large ({}s); ignoring override",
idle.as_secs()
);
}
}
}
}
let control_dir = {
let home_ssh = dirs::home_dir().map(|h| h.join(".ssh").join("rch"));
if let Some(ref dir) = home_ssh {
dir.clone()
} else if let Some(runtime_dir) = dirs::runtime_dir() {
runtime_dir.join("rch-ssh")
} else {
let username = whoami::username().unwrap_or_else(|_| "unknown".to_string());
std::env::temp_dir().join(format!("rch-ssh-{}", username))
}
};
if let Err(e) = std::fs::create_dir_all(&control_dir) {
warn!(
"Failed to create SSH control directory {:?}: {}",
control_dir, e
);
} else {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Err(e) = std::fs::set_permissions(
&control_dir,
std::fs::Permissions::from_mode(0o700),
) {
warn!(
"Failed to set permissions on SSH control directory {:?}: {}",
control_dir, e
);
}
}
}
builder.control_directory(&control_dir);
}
}
pub async fn disconnect(&mut self) -> Result<()> {
if let Some(session) = self.session.take() {
debug!("Disconnecting from {}", self.config.id);
session.close().await?;
info!("Disconnected from {}", self.config.id);
}
Ok(())
}
pub async fn execute(&self, command: &str) -> Result<CommandResult> {
let session = self.session.as_ref().context("Not connected to worker")?;
let start = std::time::Instant::now();
debug!(
"Executing on {}: {}",
self.config.id,
crate::util::mask_sensitive_command(command)
);
let mut child = session
.command("sh")
.arg("-c")
.arg(command)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.await
.with_context(|| format!("Failed to spawn command on {}", self.config.id))?;
let execution_future = async {
let stdout_handle = child.stdout().take();
let stderr_handle = child.stderr().take();
let stdout_fut = async {
if let Some(out) = stdout_handle {
let reader = BufReader::new(out);
let mut take = reader.take(MAX_OUTPUT_SIZE);
let mut buf = String::new();
take.read_to_string(&mut buf).await?;
let mut reader = take.into_inner();
let mut sink = tokio::io::sink();
tokio::io::copy(&mut reader, &mut sink).await?;
if buf.len() >= MAX_OUTPUT_SIZE as usize {
buf.push_str("\n...[output truncated]...\n");
}
Ok::<String, anyhow::Error>(buf)
} else {
Ok(String::new())
}
};
let stderr_fut = async {
if let Some(err) = stderr_handle {
let reader = BufReader::new(err);
let mut take = reader.take(MAX_OUTPUT_SIZE);
let mut buf = String::new();
take.read_to_string(&mut buf).await?;
let mut reader = take.into_inner();
let mut sink = tokio::io::sink();
tokio::io::copy(&mut reader, &mut sink).await?;
if buf.len() >= MAX_OUTPUT_SIZE as usize {
buf.push_str("\n...[output truncated]...\n");
}
Ok::<String, anyhow::Error>(buf)
} else {
Ok(String::new())
}
};
let (stdout, stderr) = tokio::try_join!(stdout_fut, stderr_fut)?;
let status = child
.wait()
.await
.with_context(|| "Failed to wait for command completion")?;
Ok::<_, anyhow::Error>((status, stdout, stderr))
};
match tokio::time::timeout(self.options.command_timeout, execution_future).await {
Ok(result) => {
let (status, stdout, stderr) = result?;
let duration = start.elapsed();
let exit_code = status.code().unwrap_or(-1);
debug!(
"Command completed on {} (exit={}, duration={}ms)",
self.config.id,
exit_code,
duration.as_millis()
);
Ok(CommandResult {
exit_code,
stdout,
stderr,
duration_ms: duration.as_millis() as u64,
})
}
Err(_) => {
warn!(
"Command timed out on {} after {:?}",
self.config.id, self.options.command_timeout
);
anyhow::bail!("Command timed out after {:?}", self.options.command_timeout);
}
}
}
pub async fn execute_streaming<F, G>(
&self,
command: &str,
mut on_stdout: F,
mut on_stderr: G,
) -> Result<CommandResult>
where
F: FnMut(&str),
G: FnMut(&str),
{
let session = self.session.as_ref().context("Not connected to worker")?;
let start = std::time::Instant::now();
debug!(
"Executing (streaming) on {}: {}",
self.config.id,
crate::util::mask_sensitive_command(command)
);
let mut child = session
.command("sh")
.arg("-c")
.arg(command)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.await
.with_context(|| format!("Failed to spawn command on {}", self.config.id))?;
let stdout = child.stdout().take();
let stderr = child.stderr().take();
let (tx, mut rx) = mpsc::channel(100);
if let Some(out) = stdout {
let tx = tx.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(out);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
if tx.send(StreamEvent::Stdout(line.clone())).await.is_err() {
break; }
}
Err(_) => break, }
}
});
}
if let Some(err) = stderr {
let tx = tx.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(err);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
if tx.send(StreamEvent::Stderr(line.clone())).await.is_err() {
break; }
}
Err(_) => break, }
}
});
}
drop(tx);
let mut stdout_acc = String::new();
let mut stderr_acc = String::new();
enum StreamEvent {
Stdout(String),
Stderr(String),
}
let streaming_future = async {
while let Some(event) = rx.recv().await {
match event {
StreamEvent::Stdout(line) => {
on_stdout(&line);
if stdout_acc.len() < MAX_OUTPUT_SIZE as usize {
stdout_acc.push_str(&line);
if stdout_acc.len() >= MAX_OUTPUT_SIZE as usize {
stdout_acc.push_str("\n...[output truncated]...\n");
}
}
}
StreamEvent::Stderr(line) => {
on_stderr(&line);
if stderr_acc.len() < MAX_OUTPUT_SIZE as usize {
stderr_acc.push_str(&line);
if stderr_acc.len() >= MAX_OUTPUT_SIZE as usize {
stderr_acc.push_str("\n...[output truncated]...\n");
}
}
}
}
}
let status = child.wait().await?;
Ok::<_, anyhow::Error>(status)
};
match tokio::time::timeout(self.options.command_timeout, streaming_future).await {
Ok(result) => {
let status = result?;
let duration = start.elapsed();
let exit_code = status.code().unwrap_or(-1);
Ok(CommandResult {
exit_code,
stdout: stdout_acc,
stderr: stderr_acc,
duration_ms: duration.as_millis() as u64,
})
}
Err(_) => {
warn!(
"Command (streaming) timed out on {} after {:?}, cleaning up",
self.config.id, self.options.command_timeout
);
anyhow::bail!("Command timed out after {:?}", self.options.command_timeout);
}
}
}
pub async fn health_check(&self) -> Result<bool> {
match self.execute(HEALTH_CHECK_COMMAND).await {
Ok(result) => Ok(result.success() && is_expected_health_check_output(&result.stdout)),
Err(e) => {
warn!("Health check failed for {}: {}", self.config.id, e);
Ok(false)
}
}
}
}
pub struct SshPool {
connections: Arc<RwLock<HashMap<WorkerId, Arc<RwLock<SshClient>>>>>,
options: SshOptions,
}
impl SshPool {
pub fn new(options: SshOptions) -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
options,
}
}
pub async fn get_or_connect(&self, config: &WorkerConfig) -> Result<Arc<RwLock<SshClient>>> {
let shared_client = self.get_or_create_client_entry(config).await;
let is_connected = {
let guard = shared_client.read().await;
guard.is_connected()
};
if is_connected {
debug!("Reusing existing connection to {}", config.id);
return Ok(shared_client);
}
let mut client_guard = shared_client.write().await;
if !client_guard.is_connected() {
client_guard.connect().await?;
}
drop(client_guard);
Ok(shared_client)
}
async fn get_or_create_client_entry(&self, config: &WorkerConfig) -> Arc<RwLock<SshClient>> {
let worker_id = config.id.clone();
loop {
let existing_client = {
let connections = self.connections.read().await;
connections.get(&worker_id).cloned()
};
if let Some(client) = existing_client {
let is_configured_for_worker = {
let guard = client.read().await;
guard.is_configured_for(config)
};
if is_configured_for_worker {
return client;
}
let replacement = Arc::new(RwLock::new(SshClient::new(
config.clone(),
self.options.clone(),
)));
let replaced = {
let mut connections = self.connections.write().await;
if connections
.get(&worker_id)
.is_some_and(|current| Arc::ptr_eq(current, &client))
{
connections.insert(worker_id.clone(), replacement.clone());
true
} else {
false
}
};
if replaced {
debug!(
"Replaced SSH connection entry for {} after endpoint config changed",
worker_id
);
return replacement;
}
continue;
}
let new_client = Arc::new(RwLock::new(SshClient::new(
config.clone(),
self.options.clone(),
)));
let inserted = {
let mut connections = self.connections.write().await;
if connections.contains_key(&worker_id) {
false
} else {
connections.insert(worker_id.clone(), new_client.clone());
true
}
};
if inserted {
return new_client;
}
}
}
pub async fn close(&self, worker_id: &WorkerId) -> Result<()> {
let client = {
let mut connections = self.connections.write().await;
connections.remove(worker_id)
};
if let Some(client) = client {
let mut client = client.write().await;
client.disconnect().await?;
}
Ok(())
}
pub async fn close_all(&self) -> Result<()> {
let clients: Vec<_> = {
let mut connections = self.connections.write().await;
connections.drain().map(|(_, v)| v).collect()
};
for client in clients {
let mut client = client.write().await;
if let Err(e) = client.disconnect().await {
error!("Error closing connection: {}", e);
}
}
Ok(())
}
pub async fn active_connections(&self) -> usize {
self.connections.read().await.len()
}
}
impl Default for SshPool {
fn default() -> Self {
Self::new(SshOptions::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_guard;
#[test]
fn test_command_result_success() {
let _guard = test_guard!();
let result = CommandResult {
exit_code: 0,
stdout: "output".to_string(),
stderr: String::new(),
duration_ms: 100,
};
assert!(result.success());
let failed = CommandResult {
exit_code: 1,
stdout: String::new(),
stderr: "error".to_string(),
duration_ms: 50,
};
assert!(!failed.success());
}
#[test]
fn test_ssh_options_default() {
let _guard = test_guard!();
let options = SshOptions::default();
assert_eq!(options.connect_timeout, Duration::from_secs(10));
assert_eq!(options.command_timeout, Duration::from_secs(300));
assert!(options.server_alive_interval.is_none());
assert!(options.control_persist_idle.is_none());
assert!(!options.control_master);
}
#[test]
fn test_ssh_client_creation() {
let _guard = test_guard!();
let config = WorkerConfig {
id: WorkerId::new("test-worker"),
host: "192.168.1.100".to_string(),
user: "ubuntu".to_string(),
identity_file: "~/.ssh/id_rsa".to_string(),
total_slots: 8,
priority: 100,
tags: vec!["rust".to_string()],
};
let client = SshClient::new(config.clone(), SshOptions::default());
assert_eq!(client.worker_id().as_str(), "test-worker");
assert!(!client.is_connected());
}
#[test]
fn test_expected_health_check_output_accepts_sentinel_as_last_line() {
let _guard = test_guard!();
assert!(is_expected_health_check_output("ok\n"));
assert!(is_expected_health_check_output("login banner\nok\n"));
assert!(!is_expected_health_check_output(""));
assert!(!is_expected_health_check_output("not ok\n"));
assert!(!is_expected_health_check_output("ok\npost-command noise\n"));
}
fn worker_config(id: &str, host: &str, user: &str, identity_file: &str) -> WorkerConfig {
WorkerConfig {
id: WorkerId::new(id),
host: host.to_string(),
user: user.to_string(),
identity_file: identity_file.to_string(),
total_slots: 8,
priority: 100,
tags: vec!["rust".to_string()],
}
}
#[test]
fn test_ssh_client_configured_for_ignores_scheduling_fields() {
let _guard = test_guard!();
let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
let client = SshClient::new(config.clone(), SshOptions::default());
let mut scheduling_only_change = config;
scheduling_only_change.total_slots = 16;
scheduling_only_change.priority = 250;
scheduling_only_change.tags = vec!["rust".to_string(), "gpu".to_string()];
assert!(client.is_configured_for(&scheduling_only_change));
}
#[test]
fn test_ssh_client_configured_for_detects_endpoint_changes() {
let _guard = test_guard!();
let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
let client = SshClient::new(config, SshOptions::default());
assert!(!client.is_configured_for(&worker_config(
"worker-a",
"192.168.1.101",
"ubuntu",
"~/.ssh/id_rsa",
)));
assert!(!client.is_configured_for(&worker_config(
"worker-a",
"192.168.1.100",
"admin",
"~/.ssh/id_rsa",
)));
assert!(!client.is_configured_for(&worker_config(
"worker-a",
"192.168.1.100",
"ubuntu",
"~/.ssh/other_key",
)));
}
#[tokio::test]
async fn test_ssh_pool_reuses_matching_disconnected_entry() {
let _guard = test_guard!();
let pool = SshPool::default();
let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
let first = pool.get_or_create_client_entry(&config).await;
let second = pool.get_or_create_client_entry(&config).await;
assert!(Arc::ptr_eq(&first, &second));
assert_eq!(pool.active_connections().await, 1);
}
#[tokio::test]
async fn test_ssh_pool_replaces_stale_entry_when_endpoint_changes() {
let _guard = test_guard!();
let pool = SshPool::default();
let old_config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
let new_config = worker_config("worker-a", "192.168.1.101", "admin", "~/.ssh/new_key");
let stale = pool.get_or_create_client_entry(&old_config).await;
let replacement = pool.get_or_create_client_entry(&new_config).await;
assert!(!Arc::ptr_eq(&stale, &replacement));
assert_eq!(pool.active_connections().await, 1);
let replacement_guard = replacement.read().await;
assert!(replacement_guard.is_configured_for(&new_config));
}
#[test]
fn test_build_env_prefix_quotes_and_rejects() {
let _guard = test_guard!();
let mut env = HashMap::new();
env.insert("RUSTFLAGS".to_string(), "-C target-cpu=native".to_string());
env.insert("QUOTED".to_string(), "a'b".to_string());
env.insert("BADVAL".to_string(), "line1\nline2".to_string());
let allowlist = vec![
"RUSTFLAGS".to_string(),
"QUOTED".to_string(),
"MISSING".to_string(),
"BADVAL".to_string(),
"BAD=KEY".to_string(),
];
let prefix = build_env_prefix(&allowlist, |key| env.get(key).cloned());
assert!(prefix.prefix.contains("RUSTFLAGS='-C target-cpu=native'"));
assert!(prefix.prefix.contains("QUOTED='a'\\''b'"));
assert!(!prefix.prefix.contains("MISSING="));
assert!(!prefix.prefix.contains("BADVAL="));
assert!(prefix.rejected.contains(&"BADVAL".to_string()));
assert!(prefix.rejected.contains(&"BAD=KEY".to_string()));
assert_eq!(
prefix.applied,
vec!["RUSTFLAGS".to_string(), "QUOTED".to_string()]
);
}
mod proptest_ssh_escaping {
use super::*;
use proptest::prelude::*;
use std::collections::HashMap;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn test_is_valid_env_key_no_panic(s in ".*") {
let _guard = test_guard!();
let _ = is_valid_env_key(&s);
}
#[test]
fn test_is_valid_env_key_accepts_valid(
first in "[a-zA-Z_]",
rest in "[a-zA-Z0-9_]{0,50}"
) {
let _guard = test_guard!();
let key = format!("{}{}", first, rest);
prop_assert!(is_valid_env_key(&key), "Should accept valid key: {}", key);
}
#[test]
fn test_is_valid_env_key_rejects_digit_start(
digit in "[0-9]",
rest in "[a-zA-Z0-9_]{0,20}"
) {
let _guard = test_guard!();
let key = format!("{}{}", digit, rest);
prop_assert!(!is_valid_env_key(&key), "Should reject digit-start key: {}", key);
}
#[test]
fn test_shell_escape_value_no_panic(s in ".*") {
let _guard = test_guard!();
let _ = shell_escape_value(&s);
}
#[test]
fn test_shell_escape_value_rejects_unsafe(
prefix in "[a-zA-Z0-9 ]{0,10}",
bad_char in "[\n\r\0]",
suffix in "[a-zA-Z0-9 ]{0,10}"
) {
let _guard = test_guard!();
let value = format!("{}{}{}", prefix, bad_char, suffix);
prop_assert!(shell_escape_value(&value).is_none(),
"Should reject value with unsafe char: {:?}", value);
}
#[test]
fn test_shell_escape_value_accepts_safe(s in "[a-zA-Z0-9 !@#$%^&*()_+=\\-\\[\\]{}|;:,./<>?]{0,100}") {
let result = shell_escape_value(&s);
prop_assert!(result.is_some(), "Should accept safe value: {:?}", s);
let escaped = match result {
Some(escaped) => escaped,
None => {
prop_assert!(false, "Should accept safe value: {:?}", s);
String::new()
}
};
if s.chars().any(|c| !c.is_ascii_alphanumeric() && c != '_') {
prop_assert!(escaped.starts_with('\'') || escaped.contains('\''),
"Value with special chars should be quoted: {:?} -> {:?}", s, escaped);
}
}
#[test]
fn test_shell_escape_value_escapes_quotes(
prefix in "[a-zA-Z0-9]{0,10}",
suffix in "[a-zA-Z0-9]{0,10}"
) {
let _guard = test_guard!();
let value = format!("{}'{}", prefix, suffix);
let result = shell_escape_value(&value);
prop_assert!(result.is_some());
let escaped = match result {
Some(escaped) => escaped,
None => {
prop_assert!(false, "Should escape single quote: {}", value);
String::new()
}
};
prop_assert!(escaped.contains("'\\''"),
"Should escape single quote: {} -> {}", value, escaped);
}
#[test]
fn test_build_env_prefix_no_panic(
keys in prop::collection::vec("[a-zA-Z_][a-zA-Z0-9_]{0,10}", 0..10),
values in prop::collection::vec(".*", 0..10)
) {
let mut env = HashMap::new();
for (i, key) in keys.iter().enumerate() {
if let Some(val) = values.get(i) {
env.insert(key.clone(), val.clone());
}
}
let allowlist: Vec<String> = keys;
let _ = build_env_prefix(&allowlist, |k| env.get(k).cloned());
}
#[test]
fn test_build_env_prefix_rejects_invalid_keys(
invalid_key in "[0-9][a-zA-Z0-9_]{0,10}" ) {
let _guard = test_guard!();
let mut env = HashMap::new();
env.insert(invalid_key.clone(), "value".to_string());
let allowlist = vec![invalid_key.clone()];
let prefix = build_env_prefix(&allowlist, |k| env.get(k).cloned());
prop_assert!(!is_valid_env_key(&invalid_key),
"Key should be invalid: {}", invalid_key);
prop_assert!(prefix.rejected.contains(&invalid_key),
"Should reject invalid key: {}", invalid_key);
prop_assert!(prefix.prefix.is_empty());
}
#[test]
fn test_build_env_prefix_missing_values(
keys in prop::collection::vec("[A-Z_][A-Z0-9_]{0,10}", 1..5)
) {
let env: HashMap<String, String> = HashMap::new();
let prefix = build_env_prefix(&keys, |k| env.get(k).cloned());
prop_assert!(prefix.prefix.is_empty(), "Should be empty when no values");
prop_assert!(prefix.applied.is_empty());
prop_assert!(prefix.rejected.is_empty());
}
}
#[test]
fn test_shell_escape_edge_cases() {
let _guard = test_guard!();
let result = shell_escape_value("");
assert_eq!(result, Some("''".to_string()));
let result = shell_escape_value("'");
assert_eq!(result, Some("''\\'''".to_string()));
let result = shell_escape_value("'''");
assert_eq!(
result
.as_deref()
.map(|escaped| escaped.matches("'\\''").count()),
Some(3)
);
let result = shell_escape_value("日本語");
assert!(result.is_some());
let result = shell_escape_value("🔥🚀");
assert!(result.is_some());
let result = shell_escape_value("it's a \"test\" with $vars");
assert!(result.is_some());
}
#[test]
fn test_is_valid_env_key_edge_cases() {
let _guard = test_guard!();
assert!(!is_valid_env_key(""));
assert!(is_valid_env_key("_"));
assert!(is_valid_env_key("A"));
assert!(is_valid_env_key("PATH"));
assert!(is_valid_env_key("HOME"));
assert!(is_valid_env_key("RUSTFLAGS"));
assert!(is_valid_env_key("CC"));
assert!(is_valid_env_key("_PRIVATE"));
assert!(is_valid_env_key("MY_VAR_123"));
assert!(!is_valid_env_key("1VAR"));
assert!(!is_valid_env_key("123"));
assert!(!is_valid_env_key("MY-VAR"));
assert!(!is_valid_env_key("MY.VAR"));
assert!(!is_valid_env_key("MY VAR"));
assert!(!is_valid_env_key("MY=VAR"));
assert!(!is_valid_env_key("日本語"));
assert!(!is_valid_env_key("VAR🔥"));
}
#[test]
fn test_build_env_prefix_integration() {
let _guard = test_guard!();
let mut env = HashMap::new();
env.insert("VALID".to_string(), "simple".to_string());
env.insert("WITH_QUOTE".to_string(), "it's here".to_string());
env.insert("NEWLINE".to_string(), "line1\nline2".to_string());
env.insert("UNICODE".to_string(), "日本語".to_string());
env.insert("EMPTY".to_string(), String::new());
env.insert("123INVALID".to_string(), "value".to_string());
let allowlist = vec![
"VALID".to_string(),
"WITH_QUOTE".to_string(),
"NEWLINE".to_string(),
"UNICODE".to_string(),
"EMPTY".to_string(),
"123INVALID".to_string(),
"MISSING".to_string(),
];
let prefix = build_env_prefix(&allowlist, |k| env.get(k).cloned());
assert!(prefix.applied.contains(&"VALID".to_string()));
assert!(prefix.prefix.contains("VALID=simple"));
assert!(prefix.applied.contains(&"WITH_QUOTE".to_string()));
assert!(prefix.rejected.contains(&"NEWLINE".to_string()));
assert!(prefix.applied.contains(&"UNICODE".to_string()));
assert!(prefix.applied.contains(&"EMPTY".to_string()));
assert!(prefix.rejected.contains(&"123INVALID".to_string()));
assert!(!prefix.applied.contains(&"MISSING".to_string()));
assert!(!prefix.rejected.contains(&"MISSING".to_string()));
}
#[test]
fn test_shell_escape_roundtrip_safety() {
let _guard = test_guard!();
let test_values = [
"simple",
"with spaces",
"with\ttab",
"special!@#$%^&*()",
"quoted\"value",
"path/to/file",
"-flag",
"--long-flag=value",
"",
];
for value in &test_values {
let escaped = shell_escape_value(value);
assert!(escaped.is_some(), "Should escape: {:?}", value);
}
}
}
}