1use crate::types::{WorkerConfig, WorkerId};
9use anyhow::{Context, Result};
10use openssh::{ControlPersist, KnownHosts, Session, SessionBuilder, Stdio};
11use std::collections::HashMap;
12use std::num::NonZeroUsize;
13use std::path::Path;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
17use tokio::sync::{RwLock, mpsc};
18use tracing::{debug, error, info, warn};
19
20pub use crate::ssh_utils::{
22 CommandResult, EnvPrefix, build_env_prefix, is_retryable_transport_error,
23 is_retryable_transport_error_text, is_valid_env_key, shell_escape_value,
24};
25
26const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
28
29const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(300);
31
32const MAX_OUTPUT_SIZE: u64 = 10 * 1024 * 1024;
34
35const HEALTH_CHECK_COMMAND: &str = "echo ok";
36
37fn is_expected_health_check_output(stdout: &str) -> bool {
38 stdout
39 .trim()
40 .lines()
41 .last()
42 .is_some_and(is_health_check_sentinel)
43}
44
45fn is_health_check_sentinel(line: &str) -> bool {
46 matches!(line.trim(), "ok")
47}
48
49#[derive(Debug, Clone)]
51pub struct SshOptions {
52 pub connect_timeout: Duration,
54 pub command_timeout: Duration,
56 pub server_alive_interval: Option<Duration>,
60 pub control_persist_idle: Option<Duration>,
65 pub control_master: bool,
67 pub known_hosts: KnownHostsPolicy,
69}
70
71impl Default for SshOptions {
72 fn default() -> Self {
73 Self {
74 connect_timeout: DEFAULT_CONNECT_TIMEOUT,
75 command_timeout: DEFAULT_COMMAND_TIMEOUT,
76 server_alive_interval: None,
77 control_persist_idle: None,
78 control_master: false,
82 known_hosts: KnownHostsPolicy::Add,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum KnownHostsPolicy {
90 Strict,
92 Add,
94 AcceptAll,
96}
97
98#[cfg(test)]
99mod retry_tests {
100 use super::*;
101 use crate::test_guard;
102
103 #[test]
104 fn test_retryable_transport_error_text() {
105 let _guard = test_guard!();
106 assert!(is_retryable_transport_error_text(
107 "ssh: connect to host 1.2.3.4 port 22: Connection timed out"
108 ));
109 assert!(is_retryable_transport_error_text(
110 "kex_exchange_identification: Connection reset by peer"
111 ));
112 assert!(is_retryable_transport_error_text("Broken pipe"));
113 assert!(is_retryable_transport_error_text("Network is unreachable"));
114 }
115
116 #[test]
117 fn test_non_retryable_transport_error_text() {
118 let _guard = test_guard!();
119 assert!(!is_retryable_transport_error_text(
120 "Permission denied (publickey)."
121 ));
122 assert!(!is_retryable_transport_error_text(
123 "Host key verification failed."
124 ));
125 assert!(!is_retryable_transport_error_text(
126 "Could not resolve hostname worker.example.com: Name or service not known"
127 ));
128 assert!(!is_retryable_transport_error_text(
129 "Identity file /nope/id_rsa not accessible: No such file or directory"
130 ));
131 }
132}
133
134pub struct SshClient {
136 config: WorkerConfig,
138 options: SshOptions,
140 session: Option<Session>,
142}
143
144impl SshClient {
145 pub fn new(config: WorkerConfig, options: SshOptions) -> Self {
147 Self {
148 config,
149 options,
150 session: None,
151 }
152 }
153
154 pub fn worker_id(&self) -> &WorkerId {
156 &self.config.id
157 }
158
159 pub fn is_connected(&self) -> bool {
161 self.session.is_some()
162 }
163
164 fn is_configured_for(&self, config: &WorkerConfig) -> bool {
165 self.config.id == config.id
166 && self.config.host == config.host
167 && self.config.user == config.user
168 && self.config.identity_file == config.identity_file
169 }
170
171 pub async fn connect(&mut self) -> Result<()> {
173 if self.session.is_some() {
174 debug!("Already connected to {}", self.config.id);
175 return Ok(());
176 }
177
178 let destination = format!("{}@{}", self.config.user, self.config.host);
179 debug!("Connecting to {} via SSH...", destination);
180
181 let session = match self
182 .connect_with_mode(&destination, self.options.control_master)
183 .await
184 {
185 Ok(session) => session,
186 Err(primary_error) if self.options.control_master => {
187 warn!(
188 "SSH ControlMaster connection to {} failed ({}). Retrying without ControlMaster.",
189 destination, primary_error
190 );
191 self.connect_with_mode(&destination, false)
192 .await
193 .with_context(|| {
194 format!(
195 "Failed to connect to {} after retrying without ControlMaster",
196 destination
197 )
198 })?
199 }
200 Err(primary_error) => {
201 return Err(primary_error)
202 .with_context(|| format!("Failed to connect to {}", destination));
203 }
204 };
205
206 info!("Connected to {} ({})", self.config.id, self.config.host);
207 self.session = Some(session);
208 Ok(())
209 }
210
211 async fn connect_with_mode(&self, destination: &str, control_master: bool) -> Result<Session> {
212 let mut builder = SessionBuilder::default();
213 self.configure_builder(&mut builder, control_master);
214
215 builder.connect(destination).await.with_context(|| {
216 if control_master {
217 format!(
218 "Failed to connect to {} with ControlMaster enabled",
219 destination
220 )
221 } else {
222 format!(
223 "Failed to connect to {} with ControlMaster disabled",
224 destination
225 )
226 }
227 })
228 }
229
230 fn configure_builder(&self, builder: &mut SessionBuilder, control_master: bool) {
231 let known_hosts = match self.options.known_hosts {
232 KnownHostsPolicy::Strict => KnownHosts::Strict,
233 KnownHostsPolicy::Add => KnownHosts::Add,
234 KnownHostsPolicy::AcceptAll => KnownHosts::Accept,
235 };
236
237 builder
238 .known_hosts_check(known_hosts)
239 .connect_timeout(self.options.connect_timeout);
240
241 if let Some(interval) = self.options.server_alive_interval {
242 builder.server_alive_interval(interval);
243 }
244
245 let identity_path = shellexpand::tilde(&self.config.identity_file);
247 if Path::new(identity_path.as_ref()).exists() {
248 builder.keyfile(identity_path.as_ref());
249 }
250
251 if control_master {
253 if let Some(idle) = self.options.control_persist_idle {
254 if idle.is_zero() {
255 builder.control_persist(ControlPersist::ClosedAfterInitialConnection);
256 } else {
257 match usize::try_from(idle.as_secs()) {
258 Ok(secs) => {
259 if let Some(nonzero) = NonZeroUsize::new(secs) {
260 builder.control_persist(ControlPersist::IdleFor(nonzero));
261 } else {
262 builder
263 .control_persist(ControlPersist::ClosedAfterInitialConnection);
264 }
265 }
266 Err(_) => {
267 warn!(
268 "control_persist_idle too large ({}s); ignoring override",
269 idle.as_secs()
270 );
271 }
272 }
273 }
274 }
275
276 let control_dir = {
285 let home_ssh = dirs::home_dir().map(|h| h.join(".ssh").join("rch"));
286
287 if let Some(ref dir) = home_ssh {
288 dir.clone()
289 } else if let Some(runtime_dir) = dirs::runtime_dir() {
290 runtime_dir.join("rch-ssh")
291 } else {
292 let username = whoami::username().unwrap_or_else(|_| "unknown".to_string());
293 std::env::temp_dir().join(format!("rch-ssh-{}", username))
294 }
295 };
296
297 if let Err(e) = std::fs::create_dir_all(&control_dir) {
298 warn!(
299 "Failed to create SSH control directory {:?}: {}",
300 control_dir, e
301 );
302 } else {
303 #[cfg(unix)]
306 {
307 use std::os::unix::fs::PermissionsExt;
308 if let Err(e) = std::fs::set_permissions(
309 &control_dir,
310 std::fs::Permissions::from_mode(0o700),
311 ) {
312 warn!(
313 "Failed to set permissions on SSH control directory {:?}: {}",
314 control_dir, e
315 );
316 }
317 }
318 }
319 builder.control_directory(&control_dir);
320 }
321 }
322
323 pub async fn disconnect(&mut self) -> Result<()> {
325 if let Some(session) = self.session.take() {
326 debug!("Disconnecting from {}", self.config.id);
327 session.close().await?;
328 info!("Disconnected from {}", self.config.id);
329 }
330 Ok(())
331 }
332
333 pub async fn execute(&self, command: &str) -> Result<CommandResult> {
335 let session = self.session.as_ref().context("Not connected to worker")?;
336
337 let start = std::time::Instant::now();
338 debug!(
339 "Executing on {}: {}",
340 self.config.id,
341 crate::util::mask_sensitive_command(command)
342 );
343
344 let mut child = session
345 .command("sh")
346 .arg("-c")
347 .arg(command)
348 .stdout(Stdio::piped())
349 .stderr(Stdio::piped())
350 .spawn()
351 .await
352 .with_context(|| format!("Failed to spawn command on {}", self.config.id))?;
353
354 let execution_future = async {
355 let stdout_handle = child.stdout().take();
357 let stderr_handle = child.stderr().take();
358
359 let stdout_fut = async {
360 if let Some(out) = stdout_handle {
361 let reader = BufReader::new(out);
362 let mut take = reader.take(MAX_OUTPUT_SIZE);
363 let mut buf = String::new();
364 take.read_to_string(&mut buf).await?;
365 let mut reader = take.into_inner();
367 let mut sink = tokio::io::sink();
368 tokio::io::copy(&mut reader, &mut sink).await?;
369 if buf.len() >= MAX_OUTPUT_SIZE as usize {
370 buf.push_str("\n...[output truncated]...\n");
371 }
372 Ok::<String, anyhow::Error>(buf)
373 } else {
374 Ok(String::new())
375 }
376 };
377
378 let stderr_fut = async {
379 if let Some(err) = stderr_handle {
380 let reader = BufReader::new(err);
381 let mut take = reader.take(MAX_OUTPUT_SIZE);
382 let mut buf = String::new();
383 take.read_to_string(&mut buf).await?;
384 let mut reader = take.into_inner();
386 let mut sink = tokio::io::sink();
387 tokio::io::copy(&mut reader, &mut sink).await?;
388 if buf.len() >= MAX_OUTPUT_SIZE as usize {
389 buf.push_str("\n...[output truncated]...\n");
390 }
391 Ok::<String, anyhow::Error>(buf)
392 } else {
393 Ok(String::new())
394 }
395 };
396
397 let (stdout, stderr) = tokio::try_join!(stdout_fut, stderr_fut)?;
398
399 let status = child
400 .wait()
401 .await
402 .with_context(|| "Failed to wait for command completion")?;
403
404 Ok::<_, anyhow::Error>((status, stdout, stderr))
405 };
406
407 match tokio::time::timeout(self.options.command_timeout, execution_future).await {
408 Ok(result) => {
409 let (status, stdout, stderr) = result?;
410 let duration = start.elapsed();
411 let exit_code = status.code().unwrap_or(-1);
412
413 debug!(
414 "Command completed on {} (exit={}, duration={}ms)",
415 self.config.id,
416 exit_code,
417 duration.as_millis()
418 );
419
420 Ok(CommandResult {
421 exit_code,
422 stdout,
423 stderr,
424 duration_ms: duration.as_millis() as u64,
425 })
426 }
427 Err(_) => {
428 warn!(
435 "Command timed out on {} after {:?}",
436 self.config.id, self.options.command_timeout
437 );
438 anyhow::bail!("Command timed out after {:?}", self.options.command_timeout);
439 }
440 }
441 }
442
443 pub async fn execute_streaming<F, G>(
445 &self,
446 command: &str,
447 mut on_stdout: F,
448 mut on_stderr: G,
449 ) -> Result<CommandResult>
450 where
451 F: FnMut(&str),
452 G: FnMut(&str),
453 {
454 let session = self.session.as_ref().context("Not connected to worker")?;
455
456 let start = std::time::Instant::now();
457 debug!(
458 "Executing (streaming) on {}: {}",
459 self.config.id,
460 crate::util::mask_sensitive_command(command)
461 );
462
463 let mut child = session
464 .command("sh")
465 .arg("-c")
466 .arg(command)
467 .stdout(Stdio::piped())
468 .stderr(Stdio::piped())
469 .spawn()
470 .await
471 .with_context(|| format!("Failed to spawn command on {}", self.config.id))?;
472
473 let stdout = child.stdout().take();
474 let stderr = child.stderr().take();
475
476 let (tx, mut rx) = mpsc::channel(100);
479
480 if let Some(out) = stdout {
482 let tx = tx.clone();
483 tokio::spawn(async move {
484 let mut reader = BufReader::new(out);
485 let mut line = String::new();
486 loop {
487 line.clear();
488 match reader.read_line(&mut line).await {
489 Ok(0) => break, Ok(_) => {
491 if tx.send(StreamEvent::Stdout(line.clone())).await.is_err() {
492 break; }
494 }
495 Err(_) => break, }
497 }
498 });
499 }
500
501 if let Some(err) = stderr {
503 let tx = tx.clone();
504 tokio::spawn(async move {
505 let mut reader = BufReader::new(err);
506 let mut line = String::new();
507 loop {
508 line.clear();
509 match reader.read_line(&mut line).await {
510 Ok(0) => break, Ok(_) => {
512 if tx.send(StreamEvent::Stderr(line.clone())).await.is_err() {
513 break; }
515 }
516 Err(_) => break, }
518 }
519 });
520 }
521
522 drop(tx);
524
525 let mut stdout_acc = String::new();
526 let mut stderr_acc = String::new();
527
528 enum StreamEvent {
529 Stdout(String),
530 Stderr(String),
531 }
532
533 let streaming_future = async {
534 while let Some(event) = rx.recv().await {
536 match event {
537 StreamEvent::Stdout(line) => {
538 on_stdout(&line);
539 if stdout_acc.len() < MAX_OUTPUT_SIZE as usize {
540 stdout_acc.push_str(&line);
541 if stdout_acc.len() >= MAX_OUTPUT_SIZE as usize {
542 stdout_acc.push_str("\n...[output truncated]...\n");
543 }
544 }
545 }
546 StreamEvent::Stderr(line) => {
547 on_stderr(&line);
548 if stderr_acc.len() < MAX_OUTPUT_SIZE as usize {
549 stderr_acc.push_str(&line);
550 if stderr_acc.len() >= MAX_OUTPUT_SIZE as usize {
551 stderr_acc.push_str("\n...[output truncated]...\n");
552 }
553 }
554 }
555 }
556 }
557
558 let status = child.wait().await?;
559 Ok::<_, anyhow::Error>(status)
560 };
561
562 match tokio::time::timeout(self.options.command_timeout, streaming_future).await {
563 Ok(result) => {
564 let status = result?;
565 let duration = start.elapsed();
566 let exit_code = status.code().unwrap_or(-1);
567
568 Ok(CommandResult {
569 exit_code,
570 stdout: stdout_acc,
571 stderr: stderr_acc,
572 duration_ms: duration.as_millis() as u64,
573 })
574 }
575 Err(_) => {
576 warn!(
586 "Command (streaming) timed out on {} after {:?}, cleaning up",
587 self.config.id, self.options.command_timeout
588 );
589 anyhow::bail!("Command timed out after {:?}", self.options.command_timeout);
592 }
593 }
594 }
595
596 pub async fn health_check(&self) -> Result<bool> {
598 match self.execute(HEALTH_CHECK_COMMAND).await {
599 Ok(result) => Ok(result.success() && is_expected_health_check_output(&result.stdout)),
600 Err(e) => {
601 warn!("Health check failed for {}: {}", self.config.id, e);
602 Ok(false)
603 }
604 }
605 }
606}
607
608pub struct SshPool {
610 connections: Arc<RwLock<HashMap<WorkerId, Arc<RwLock<SshClient>>>>>,
612 options: SshOptions,
614}
615
616impl SshPool {
617 pub fn new(options: SshOptions) -> Self {
619 Self {
620 connections: Arc::new(RwLock::new(HashMap::new())),
621 options,
622 }
623 }
624
625 pub async fn get_or_connect(&self, config: &WorkerConfig) -> Result<Arc<RwLock<SshClient>>> {
627 let shared_client = self.get_or_create_client_entry(config).await;
628
629 let is_connected = {
630 let guard = shared_client.read().await;
631 guard.is_connected()
632 };
633 if is_connected {
634 debug!("Reusing existing connection to {}", config.id);
635 return Ok(shared_client);
636 }
637
638 let mut client_guard = shared_client.write().await;
640 if !client_guard.is_connected() {
642 client_guard.connect().await?;
643 }
644 drop(client_guard);
646
647 Ok(shared_client)
648 }
649
650 async fn get_or_create_client_entry(&self, config: &WorkerConfig) -> Arc<RwLock<SshClient>> {
651 let worker_id = config.id.clone();
652
653 loop {
654 let existing_client = {
655 let connections = self.connections.read().await;
656 connections.get(&worker_id).cloned()
657 };
658
659 if let Some(client) = existing_client {
660 let is_configured_for_worker = {
661 let guard = client.read().await;
662 guard.is_configured_for(config)
663 };
664 if is_configured_for_worker {
665 return client;
666 }
667
668 let replacement = Arc::new(RwLock::new(SshClient::new(
669 config.clone(),
670 self.options.clone(),
671 )));
672 let replaced = {
673 let mut connections = self.connections.write().await;
674 if connections
675 .get(&worker_id)
676 .is_some_and(|current| Arc::ptr_eq(current, &client))
677 {
678 connections.insert(worker_id.clone(), replacement.clone());
679 true
680 } else {
681 false
682 }
683 };
684
685 if replaced {
686 debug!(
687 "Replaced SSH connection entry for {} after endpoint config changed",
688 worker_id
689 );
690 return replacement;
691 }
692
693 continue;
694 }
695
696 let new_client = Arc::new(RwLock::new(SshClient::new(
697 config.clone(),
698 self.options.clone(),
699 )));
700 let inserted = {
701 let mut connections = self.connections.write().await;
702 if connections.contains_key(&worker_id) {
703 false
704 } else {
705 connections.insert(worker_id.clone(), new_client.clone());
706 true
707 }
708 };
709
710 if inserted {
711 return new_client;
712 }
713 }
714 }
715
716 pub async fn close(&self, worker_id: &WorkerId) -> Result<()> {
718 let client = {
719 let mut connections = self.connections.write().await;
720 connections.remove(worker_id)
721 };
722
723 if let Some(client) = client {
724 let mut client = client.write().await;
725 client.disconnect().await?;
726 }
727
728 Ok(())
729 }
730
731 pub async fn close_all(&self) -> Result<()> {
733 let clients: Vec<_> = {
734 let mut connections = self.connections.write().await;
735 connections.drain().map(|(_, v)| v).collect()
736 };
737
738 for client in clients {
739 let mut client = client.write().await;
740 if let Err(e) = client.disconnect().await {
741 error!("Error closing connection: {}", e);
742 }
743 }
744
745 Ok(())
746 }
747
748 pub async fn active_connections(&self) -> usize {
750 self.connections.read().await.len()
751 }
752}
753
754impl Default for SshPool {
755 fn default() -> Self {
756 Self::new(SshOptions::default())
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use crate::test_guard;
764
765 #[test]
766 fn test_command_result_success() {
767 let _guard = test_guard!();
768 let result = CommandResult {
769 exit_code: 0,
770 stdout: "output".to_string(),
771 stderr: String::new(),
772 duration_ms: 100,
773 };
774 assert!(result.success());
775
776 let failed = CommandResult {
777 exit_code: 1,
778 stdout: String::new(),
779 stderr: "error".to_string(),
780 duration_ms: 50,
781 };
782 assert!(!failed.success());
783 }
784
785 #[test]
786 fn test_ssh_options_default() {
787 let _guard = test_guard!();
788 let options = SshOptions::default();
789 assert_eq!(options.connect_timeout, Duration::from_secs(10));
790 assert_eq!(options.command_timeout, Duration::from_secs(300));
791 assert!(options.server_alive_interval.is_none());
792 assert!(options.control_persist_idle.is_none());
793 assert!(!options.control_master);
794 }
795
796 #[test]
797 fn test_ssh_client_creation() {
798 let _guard = test_guard!();
799 let config = WorkerConfig {
800 id: WorkerId::new("test-worker"),
801 host: "192.168.1.100".to_string(),
802 user: "ubuntu".to_string(),
803 identity_file: "~/.ssh/id_rsa".to_string(),
804 total_slots: 8,
805 priority: 100,
806 tags: vec!["rust".to_string()],
807 };
808
809 let client = SshClient::new(config.clone(), SshOptions::default());
810 assert_eq!(client.worker_id().as_str(), "test-worker");
811 assert!(!client.is_connected());
812 }
813
814 #[test]
815 fn test_expected_health_check_output_accepts_sentinel_as_last_line() {
816 let _guard = test_guard!();
817
818 assert!(is_expected_health_check_output("ok\n"));
819 assert!(is_expected_health_check_output("login banner\nok\n"));
820 assert!(!is_expected_health_check_output(""));
821 assert!(!is_expected_health_check_output("not ok\n"));
822 assert!(!is_expected_health_check_output("ok\npost-command noise\n"));
823 }
824
825 fn worker_config(id: &str, host: &str, user: &str, identity_file: &str) -> WorkerConfig {
826 WorkerConfig {
827 id: WorkerId::new(id),
828 host: host.to_string(),
829 user: user.to_string(),
830 identity_file: identity_file.to_string(),
831 total_slots: 8,
832 priority: 100,
833 tags: vec!["rust".to_string()],
834 }
835 }
836
837 #[test]
838 fn test_ssh_client_configured_for_ignores_scheduling_fields() {
839 let _guard = test_guard!();
840 let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
841 let client = SshClient::new(config.clone(), SshOptions::default());
842
843 let mut scheduling_only_change = config;
844 scheduling_only_change.total_slots = 16;
845 scheduling_only_change.priority = 250;
846 scheduling_only_change.tags = vec!["rust".to_string(), "gpu".to_string()];
847
848 assert!(client.is_configured_for(&scheduling_only_change));
849 }
850
851 #[test]
852 fn test_ssh_client_configured_for_detects_endpoint_changes() {
853 let _guard = test_guard!();
854 let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
855 let client = SshClient::new(config, SshOptions::default());
856
857 assert!(!client.is_configured_for(&worker_config(
858 "worker-a",
859 "192.168.1.101",
860 "ubuntu",
861 "~/.ssh/id_rsa",
862 )));
863 assert!(!client.is_configured_for(&worker_config(
864 "worker-a",
865 "192.168.1.100",
866 "admin",
867 "~/.ssh/id_rsa",
868 )));
869 assert!(!client.is_configured_for(&worker_config(
870 "worker-a",
871 "192.168.1.100",
872 "ubuntu",
873 "~/.ssh/other_key",
874 )));
875 }
876
877 #[tokio::test]
878 async fn test_ssh_pool_reuses_matching_disconnected_entry() {
879 let _guard = test_guard!();
880 let pool = SshPool::default();
881 let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
882
883 let first = pool.get_or_create_client_entry(&config).await;
884 let second = pool.get_or_create_client_entry(&config).await;
885
886 assert!(Arc::ptr_eq(&first, &second));
887 assert_eq!(pool.active_connections().await, 1);
888 }
889
890 #[tokio::test]
891 async fn test_ssh_pool_replaces_stale_entry_when_endpoint_changes() {
892 let _guard = test_guard!();
893 let pool = SshPool::default();
894 let old_config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
895 let new_config = worker_config("worker-a", "192.168.1.101", "admin", "~/.ssh/new_key");
896
897 let stale = pool.get_or_create_client_entry(&old_config).await;
898 let replacement = pool.get_or_create_client_entry(&new_config).await;
899
900 assert!(!Arc::ptr_eq(&stale, &replacement));
901 assert_eq!(pool.active_connections().await, 1);
902
903 let replacement_guard = replacement.read().await;
904 assert!(replacement_guard.is_configured_for(&new_config));
905 }
906
907 #[test]
908 fn test_build_env_prefix_quotes_and_rejects() {
909 let _guard = test_guard!();
910 let mut env = HashMap::new();
911 env.insert("RUSTFLAGS".to_string(), "-C target-cpu=native".to_string());
912 env.insert("QUOTED".to_string(), "a'b".to_string());
913 env.insert("BADVAL".to_string(), "line1\nline2".to_string());
914
915 let allowlist = vec![
916 "RUSTFLAGS".to_string(),
917 "QUOTED".to_string(),
918 "MISSING".to_string(),
919 "BADVAL".to_string(),
920 "BAD=KEY".to_string(),
921 ];
922
923 let prefix = build_env_prefix(&allowlist, |key| env.get(key).cloned());
924
925 assert!(prefix.prefix.contains("RUSTFLAGS='-C target-cpu=native'"));
926 assert!(prefix.prefix.contains("QUOTED='a'\\''b'"));
928 assert!(!prefix.prefix.contains("MISSING="));
929 assert!(!prefix.prefix.contains("BADVAL="));
930 assert!(prefix.rejected.contains(&"BADVAL".to_string()));
931 assert!(prefix.rejected.contains(&"BAD=KEY".to_string()));
932 assert_eq!(
933 prefix.applied,
934 vec!["RUSTFLAGS".to_string(), "QUOTED".to_string()]
935 );
936 }
937
938 mod proptest_ssh_escaping {
943 use super::*;
944 use proptest::prelude::*;
945 use std::collections::HashMap;
946
947 proptest! {
948 #![proptest_config(ProptestConfig::with_cases(1000))]
949
950 #[test]
952 fn test_is_valid_env_key_no_panic(s in ".*") {
953 let _guard = test_guard!();
954 let _ = is_valid_env_key(&s);
955 }
956
957 #[test]
959 fn test_is_valid_env_key_accepts_valid(
960 first in "[a-zA-Z_]",
961 rest in "[a-zA-Z0-9_]{0,50}"
962 ) {
963 let _guard = test_guard!();
964 let key = format!("{}{}", first, rest);
965 prop_assert!(is_valid_env_key(&key), "Should accept valid key: {}", key);
966 }
967
968 #[test]
970 fn test_is_valid_env_key_rejects_digit_start(
971 digit in "[0-9]",
972 rest in "[a-zA-Z0-9_]{0,20}"
973 ) {
974 let _guard = test_guard!();
975 let key = format!("{}{}", digit, rest);
976 prop_assert!(!is_valid_env_key(&key), "Should reject digit-start key: {}", key);
977 }
978
979 #[test]
981 fn test_shell_escape_value_no_panic(s in ".*") {
982 let _guard = test_guard!();
983 let _ = shell_escape_value(&s);
984 }
985
986 #[test]
988 fn test_shell_escape_value_rejects_unsafe(
989 prefix in "[a-zA-Z0-9 ]{0,10}",
990 bad_char in "[\n\r\0]",
991 suffix in "[a-zA-Z0-9 ]{0,10}"
992 ) {
993 let _guard = test_guard!();
994 let value = format!("{}{}{}", prefix, bad_char, suffix);
995 prop_assert!(shell_escape_value(&value).is_none(),
996 "Should reject value with unsafe char: {:?}", value);
997 }
998
999 #[test]
1001 fn test_shell_escape_value_accepts_safe(s in "[a-zA-Z0-9 !@#$%^&*()_+=\\-\\[\\]{}|;:,./<>?]{0,100}") {
1002 let result = shell_escape_value(&s);
1004 prop_assert!(result.is_some(), "Should accept safe value: {:?}", s);
1005
1006 let escaped = match result {
1009 Some(escaped) => escaped,
1010 None => {
1011 prop_assert!(false, "Should accept safe value: {:?}", s);
1012 String::new()
1013 }
1014 };
1015 if s.chars().any(|c| !c.is_ascii_alphanumeric() && c != '_') {
1016 prop_assert!(escaped.starts_with('\'') || escaped.contains('\''),
1018 "Value with special chars should be quoted: {:?} -> {:?}", s, escaped);
1019 }
1020 }
1021
1022 #[test]
1024 fn test_shell_escape_value_escapes_quotes(
1025 prefix in "[a-zA-Z0-9]{0,10}",
1026 suffix in "[a-zA-Z0-9]{0,10}"
1027 ) {
1028 let _guard = test_guard!();
1029 let value = format!("{}'{}", prefix, suffix);
1030 let result = shell_escape_value(&value);
1031 prop_assert!(result.is_some());
1032
1033 let escaped = match result {
1034 Some(escaped) => escaped,
1035 None => {
1036 prop_assert!(false, "Should escape single quote: {}", value);
1037 String::new()
1038 }
1039 };
1040 prop_assert!(escaped.contains("'\\''"),
1042 "Should escape single quote: {} -> {}", value, escaped);
1043 }
1044
1045 #[test]
1047 fn test_build_env_prefix_no_panic(
1048 keys in prop::collection::vec("[a-zA-Z_][a-zA-Z0-9_]{0,10}", 0..10),
1049 values in prop::collection::vec(".*", 0..10)
1050 ) {
1051 let mut env = HashMap::new();
1052 for (i, key) in keys.iter().enumerate() {
1053 if let Some(val) = values.get(i) {
1054 env.insert(key.clone(), val.clone());
1055 }
1056 }
1057
1058 let allowlist: Vec<String> = keys;
1059 let _ = build_env_prefix(&allowlist, |k| env.get(k).cloned());
1060 }
1061
1062 #[test]
1064 fn test_build_env_prefix_rejects_invalid_keys(
1065 invalid_key in "[0-9][a-zA-Z0-9_]{0,10}" ) {
1068 let _guard = test_guard!();
1069 let mut env = HashMap::new();
1070 env.insert(invalid_key.clone(), "value".to_string());
1071
1072 let allowlist = vec![invalid_key.clone()];
1073 let prefix = build_env_prefix(&allowlist, |k| env.get(k).cloned());
1074
1075 prop_assert!(!is_valid_env_key(&invalid_key),
1077 "Key should be invalid: {}", invalid_key);
1078 prop_assert!(prefix.rejected.contains(&invalid_key),
1079 "Should reject invalid key: {}", invalid_key);
1080 prop_assert!(prefix.prefix.is_empty());
1081 }
1082
1083 #[test]
1085 fn test_build_env_prefix_missing_values(
1086 keys in prop::collection::vec("[A-Z_][A-Z0-9_]{0,10}", 1..5)
1087 ) {
1088 let env: HashMap<String, String> = HashMap::new();
1090 let prefix = build_env_prefix(&keys, |k| env.get(k).cloned());
1091
1092 prop_assert!(prefix.prefix.is_empty(), "Should be empty when no values");
1094 prop_assert!(prefix.applied.is_empty());
1095 prop_assert!(prefix.rejected.is_empty());
1097 }
1098 }
1099
1100 #[test]
1102 fn test_shell_escape_edge_cases() {
1103 let _guard = test_guard!();
1104 let result = shell_escape_value("");
1106 assert_eq!(result, Some("''".to_string()));
1107
1108 let result = shell_escape_value("'");
1110 assert_eq!(result, Some("''\\'''".to_string()));
1111
1112 let result = shell_escape_value("'''");
1114 assert_eq!(
1116 result
1117 .as_deref()
1118 .map(|escaped| escaped.matches("'\\''").count()),
1119 Some(3)
1120 );
1121
1122 let result = shell_escape_value("ζ₯ζ¬θͺ");
1124 assert!(result.is_some());
1125
1126 let result = shell_escape_value("π₯π");
1128 assert!(result.is_some());
1129
1130 let result = shell_escape_value("it's a \"test\" with $vars");
1132 assert!(result.is_some());
1133 }
1134
1135 #[test]
1136 fn test_is_valid_env_key_edge_cases() {
1137 let _guard = test_guard!();
1138 assert!(!is_valid_env_key(""));
1140
1141 assert!(is_valid_env_key("_"));
1143
1144 assert!(is_valid_env_key("A"));
1146
1147 assert!(is_valid_env_key("PATH"));
1149 assert!(is_valid_env_key("HOME"));
1150 assert!(is_valid_env_key("RUSTFLAGS"));
1151 assert!(is_valid_env_key("CC"));
1152 assert!(is_valid_env_key("_PRIVATE"));
1153 assert!(is_valid_env_key("MY_VAR_123"));
1154
1155 assert!(!is_valid_env_key("1VAR"));
1157 assert!(!is_valid_env_key("123"));
1158
1159 assert!(!is_valid_env_key("MY-VAR"));
1161 assert!(!is_valid_env_key("MY.VAR"));
1162 assert!(!is_valid_env_key("MY VAR"));
1163 assert!(!is_valid_env_key("MY=VAR"));
1164
1165 assert!(!is_valid_env_key("ζ₯ζ¬θͺ"));
1167 assert!(!is_valid_env_key("VARπ₯"));
1168 }
1169
1170 #[test]
1171 fn test_build_env_prefix_integration() {
1172 let _guard = test_guard!();
1173 let mut env = HashMap::new();
1175 env.insert("VALID".to_string(), "simple".to_string());
1176 env.insert("WITH_QUOTE".to_string(), "it's here".to_string());
1177 env.insert("NEWLINE".to_string(), "line1\nline2".to_string());
1178 env.insert("UNICODE".to_string(), "ζ₯ζ¬θͺ".to_string());
1179 env.insert("EMPTY".to_string(), String::new());
1180 env.insert("123INVALID".to_string(), "value".to_string());
1181
1182 let allowlist = vec![
1183 "VALID".to_string(),
1184 "WITH_QUOTE".to_string(),
1185 "NEWLINE".to_string(),
1186 "UNICODE".to_string(),
1187 "EMPTY".to_string(),
1188 "123INVALID".to_string(),
1189 "MISSING".to_string(),
1190 ];
1191
1192 let prefix = build_env_prefix(&allowlist, |k| env.get(k).cloned());
1193
1194 assert!(prefix.applied.contains(&"VALID".to_string()));
1196 assert!(prefix.prefix.contains("VALID=simple"));
1198
1199 assert!(prefix.applied.contains(&"WITH_QUOTE".to_string()));
1201
1202 assert!(prefix.rejected.contains(&"NEWLINE".to_string()));
1204
1205 assert!(prefix.applied.contains(&"UNICODE".to_string()));
1207
1208 assert!(prefix.applied.contains(&"EMPTY".to_string()));
1210
1211 assert!(prefix.rejected.contains(&"123INVALID".to_string()));
1213
1214 assert!(!prefix.applied.contains(&"MISSING".to_string()));
1216 assert!(!prefix.rejected.contains(&"MISSING".to_string()));
1217 }
1218
1219 #[test]
1220 fn test_shell_escape_roundtrip_safety() {
1221 let _guard = test_guard!();
1222 let test_values = [
1224 "simple",
1225 "with spaces",
1226 "with\ttab",
1227 "special!@#$%^&*()",
1228 "quoted\"value",
1229 "path/to/file",
1230 "-flag",
1231 "--long-flag=value",
1232 "",
1233 ];
1234
1235 for value in &test_values {
1236 let escaped = shell_escape_value(value);
1237 assert!(escaped.is_some(), "Should escape: {:?}", value);
1238 }
1239 }
1240 }
1241}