Skip to main content

rch_common/
ssh.rs

1//! SSH client utilities for remote command execution.
2//!
3//! Provides connection management, command execution, and pooling support
4//! for the remote compilation pipeline.
5//!
6//! This module is only available on Unix platforms (requires openssh crate).
7
8use crate::types::{WorkerConfig, WorkerId};
9use anyhow::{Context, Result};
10use openssh::{ControlPersist, KnownHosts, Session, SessionBuilder, Stdio};
11use std::collections::HashMap;
12use std::num::NonZeroUsize;
13use std::path::Path;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
17use tokio::sync::{RwLock, mpsc};
18use tracing::{debug, error, info, warn};
19
20// Re-export platform-independent utilities for backwards compatibility
21pub use crate::ssh_utils::{
22    CommandResult, EnvPrefix, build_env_prefix, is_retryable_transport_error,
23    is_retryable_transport_error_text, is_valid_env_key, shell_escape_value,
24};
25
26/// Default SSH connection timeout.
27const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
28
29/// Default command execution timeout.
30const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(300);
31
32/// Maximum size for command output (stdout/stderr) to prevent OOM (10MB).
33const MAX_OUTPUT_SIZE: u64 = 10 * 1024 * 1024;
34
35const HEALTH_CHECK_COMMAND: &str = "echo ok";
36
37fn is_expected_health_check_output(stdout: &str) -> bool {
38    stdout
39        .trim()
40        .lines()
41        .last()
42        .is_some_and(is_health_check_sentinel)
43}
44
45fn is_health_check_sentinel(line: &str) -> bool {
46    matches!(line.trim(), "ok")
47}
48
49/// SSH connection options.
50#[derive(Debug, Clone)]
51pub struct SshOptions {
52    /// Connection timeout.
53    pub connect_timeout: Duration,
54    /// Command execution timeout.
55    pub command_timeout: Duration,
56    /// Server keepalive interval (`ssh -o ServerAliveInterval`).
57    ///
58    /// Defaults to `None` (OpenSSH default; keepalive disabled).
59    pub server_alive_interval: Option<Duration>,
60    /// How long the SSH ControlMaster should remain alive while idle.
61    ///
62    /// `None` preserves the OpenSSH crate default (`ControlPersist=yes`).
63    /// `Some(0s)` sets `ControlPersist=no` (close after initial connection).
64    pub control_persist_idle: Option<Duration>,
65    /// SSH control master mode for connection reuse.
66    pub control_master: bool,
67    /// Known hosts policy.
68    pub known_hosts: KnownHostsPolicy,
69}
70
71impl Default for SshOptions {
72    fn default() -> Self {
73        Self {
74            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
75            command_timeout: DEFAULT_COMMAND_TIMEOUT,
76            server_alive_interval: None,
77            control_persist_idle: None,
78            // Default to a plain SSH session. ControlMaster is an optimization
79            // and stale local control sockets can poison otherwise healthy
80            // connections. Callers that explicitly want mux reuse can opt in.
81            control_master: false,
82            known_hosts: KnownHostsPolicy::Add,
83        }
84    }
85}
86
87/// Known hosts policy for SSH connections.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum KnownHostsPolicy {
90    /// Strictly verify known hosts (recommended for production).
91    Strict,
92    /// Add unknown hosts automatically (for development).
93    Add,
94    /// Accept all hosts without verification (INSECURE - testing only).
95    AcceptAll,
96}
97
98#[cfg(test)]
99mod retry_tests {
100    use super::*;
101    use crate::test_guard;
102
103    #[test]
104    fn test_retryable_transport_error_text() {
105        let _guard = test_guard!();
106        assert!(is_retryable_transport_error_text(
107            "ssh: connect to host 1.2.3.4 port 22: Connection timed out"
108        ));
109        assert!(is_retryable_transport_error_text(
110            "kex_exchange_identification: Connection reset by peer"
111        ));
112        assert!(is_retryable_transport_error_text("Broken pipe"));
113        assert!(is_retryable_transport_error_text("Network is unreachable"));
114    }
115
116    #[test]
117    fn test_non_retryable_transport_error_text() {
118        let _guard = test_guard!();
119        assert!(!is_retryable_transport_error_text(
120            "Permission denied (publickey)."
121        ));
122        assert!(!is_retryable_transport_error_text(
123            "Host key verification failed."
124        ));
125        assert!(!is_retryable_transport_error_text(
126            "Could not resolve hostname worker.example.com: Name or service not known"
127        ));
128        assert!(!is_retryable_transport_error_text(
129            "Identity file /nope/id_rsa not accessible: No such file or directory"
130        ));
131    }
132}
133
134/// SSH client for a single worker connection.
135pub struct SshClient {
136    /// Worker configuration.
137    config: WorkerConfig,
138    /// SSH options.
139    options: SshOptions,
140    /// Active SSH session (if connected).
141    session: Option<Session>,
142}
143
144impl SshClient {
145    /// Create a new SSH client for a worker.
146    pub fn new(config: WorkerConfig, options: SshOptions) -> Self {
147        Self {
148            config,
149            options,
150            session: None,
151        }
152    }
153
154    /// Get the worker ID.
155    pub fn worker_id(&self) -> &WorkerId {
156        &self.config.id
157    }
158
159    /// Check if connected to the worker.
160    pub fn is_connected(&self) -> bool {
161        self.session.is_some()
162    }
163
164    fn is_configured_for(&self, config: &WorkerConfig) -> bool {
165        self.config.id == config.id
166            && self.config.host == config.host
167            && self.config.user == config.user
168            && self.config.identity_file == config.identity_file
169    }
170
171    /// Connect to the remote worker.
172    pub async fn connect(&mut self) -> Result<()> {
173        if self.session.is_some() {
174            debug!("Already connected to {}", self.config.id);
175            return Ok(());
176        }
177
178        let destination = format!("{}@{}", self.config.user, self.config.host);
179        debug!("Connecting to {} via SSH...", destination);
180
181        let session = match self
182            .connect_with_mode(&destination, self.options.control_master)
183            .await
184        {
185            Ok(session) => session,
186            Err(primary_error) if self.options.control_master => {
187                warn!(
188                    "SSH ControlMaster connection to {} failed ({}). Retrying without ControlMaster.",
189                    destination, primary_error
190                );
191                self.connect_with_mode(&destination, false)
192                    .await
193                    .with_context(|| {
194                        format!(
195                            "Failed to connect to {} after retrying without ControlMaster",
196                            destination
197                        )
198                    })?
199            }
200            Err(primary_error) => {
201                return Err(primary_error)
202                    .with_context(|| format!("Failed to connect to {}", destination));
203            }
204        };
205
206        info!("Connected to {} ({})", self.config.id, self.config.host);
207        self.session = Some(session);
208        Ok(())
209    }
210
211    async fn connect_with_mode(&self, destination: &str, control_master: bool) -> Result<Session> {
212        let mut builder = SessionBuilder::default();
213        self.configure_builder(&mut builder, control_master);
214
215        builder.connect(destination).await.with_context(|| {
216            if control_master {
217                format!(
218                    "Failed to connect to {} with ControlMaster enabled",
219                    destination
220                )
221            } else {
222                format!(
223                    "Failed to connect to {} with ControlMaster disabled",
224                    destination
225                )
226            }
227        })
228    }
229
230    fn configure_builder(&self, builder: &mut SessionBuilder, control_master: bool) {
231        let known_hosts = match self.options.known_hosts {
232            KnownHostsPolicy::Strict => KnownHosts::Strict,
233            KnownHostsPolicy::Add => KnownHosts::Add,
234            KnownHostsPolicy::AcceptAll => KnownHosts::Accept,
235        };
236
237        builder
238            .known_hosts_check(known_hosts)
239            .connect_timeout(self.options.connect_timeout);
240
241        if let Some(interval) = self.options.server_alive_interval {
242            builder.server_alive_interval(interval);
243        }
244
245        // Add identity file if specified
246        let identity_path = shellexpand::tilde(&self.config.identity_file);
247        if Path::new(identity_path.as_ref()).exists() {
248            builder.keyfile(identity_path.as_ref());
249        }
250
251        // Enable control master for connection reuse
252        if control_master {
253            if let Some(idle) = self.options.control_persist_idle {
254                if idle.is_zero() {
255                    builder.control_persist(ControlPersist::ClosedAfterInitialConnection);
256                } else {
257                    match usize::try_from(idle.as_secs()) {
258                        Ok(secs) => {
259                            if let Some(nonzero) = NonZeroUsize::new(secs) {
260                                builder.control_persist(ControlPersist::IdleFor(nonzero));
261                            } else {
262                                builder
263                                    .control_persist(ControlPersist::ClosedAfterInitialConnection);
264                            }
265                        }
266                        Err(_) => {
267                            warn!(
268                                "control_persist_idle too large ({}s); ignoring override",
269                                idle.as_secs()
270                            );
271                        }
272                    }
273                }
274            }
275
276            // Use a short control directory path to stay within the Unix domain
277            // socket path limit (104 bytes on macOS, 108 on Linux).  The openssh
278            // crate appends a `%C` hash (~32 chars) to form the socket filename,
279            // so the directory path itself must be short.
280            //
281            // On macOS `std::env::temp_dir()` returns a long path under
282            // /var/folders/…/T/ which, combined with the hash, exceeds 104 bytes.
283            // We therefore prefer `~/.ssh/rch` (short, stable, correct perms).
284            let control_dir = {
285                let home_ssh = dirs::home_dir().map(|h| h.join(".ssh").join("rch"));
286
287                if let Some(ref dir) = home_ssh {
288                    dir.clone()
289                } else if let Some(runtime_dir) = dirs::runtime_dir() {
290                    runtime_dir.join("rch-ssh")
291                } else {
292                    let username = whoami::username().unwrap_or_else(|_| "unknown".to_string());
293                    std::env::temp_dir().join(format!("rch-ssh-{}", username))
294                }
295            };
296
297            if let Err(e) = std::fs::create_dir_all(&control_dir) {
298                warn!(
299                    "Failed to create SSH control directory {:?}: {}",
300                    control_dir, e
301                );
302            } else {
303                // Set restrictive permissions (0700) to prevent symlink attacks
304                // and unauthorized access to SSH control sockets
305                #[cfg(unix)]
306                {
307                    use std::os::unix::fs::PermissionsExt;
308                    if let Err(e) = std::fs::set_permissions(
309                        &control_dir,
310                        std::fs::Permissions::from_mode(0o700),
311                    ) {
312                        warn!(
313                            "Failed to set permissions on SSH control directory {:?}: {}",
314                            control_dir, e
315                        );
316                    }
317                }
318            }
319            builder.control_directory(&control_dir);
320        }
321    }
322
323    /// Disconnect from the worker.
324    pub async fn disconnect(&mut self) -> Result<()> {
325        if let Some(session) = self.session.take() {
326            debug!("Disconnecting from {}", self.config.id);
327            session.close().await?;
328            info!("Disconnected from {}", self.config.id);
329        }
330        Ok(())
331    }
332
333    /// Execute a command on the remote worker.
334    pub async fn execute(&self, command: &str) -> Result<CommandResult> {
335        let session = self.session.as_ref().context("Not connected to worker")?;
336
337        let start = std::time::Instant::now();
338        debug!(
339            "Executing on {}: {}",
340            self.config.id,
341            crate::util::mask_sensitive_command(command)
342        );
343
344        let mut child = session
345            .command("sh")
346            .arg("-c")
347            .arg(command)
348            .stdout(Stdio::piped())
349            .stderr(Stdio::piped())
350            .spawn()
351            .await
352            .with_context(|| format!("Failed to spawn command on {}", self.config.id))?;
353
354        let execution_future = async {
355            // Read stdout and stderr concurrently to avoid deadlock if one pipe fills.
356            let stdout_handle = child.stdout().take();
357            let stderr_handle = child.stderr().take();
358
359            let stdout_fut = async {
360                if let Some(out) = stdout_handle {
361                    let reader = BufReader::new(out);
362                    let mut take = reader.take(MAX_OUTPUT_SIZE);
363                    let mut buf = String::new();
364                    take.read_to_string(&mut buf).await?;
365                    // Drain the rest to prevent SIGPIPE or blocking
366                    let mut reader = take.into_inner();
367                    let mut sink = tokio::io::sink();
368                    tokio::io::copy(&mut reader, &mut sink).await?;
369                    if buf.len() >= MAX_OUTPUT_SIZE as usize {
370                        buf.push_str("\n...[output truncated]...\n");
371                    }
372                    Ok::<String, anyhow::Error>(buf)
373                } else {
374                    Ok(String::new())
375                }
376            };
377
378            let stderr_fut = async {
379                if let Some(err) = stderr_handle {
380                    let reader = BufReader::new(err);
381                    let mut take = reader.take(MAX_OUTPUT_SIZE);
382                    let mut buf = String::new();
383                    take.read_to_string(&mut buf).await?;
384                    // Drain the rest to prevent SIGPIPE or blocking
385                    let mut reader = take.into_inner();
386                    let mut sink = tokio::io::sink();
387                    tokio::io::copy(&mut reader, &mut sink).await?;
388                    if buf.len() >= MAX_OUTPUT_SIZE as usize {
389                        buf.push_str("\n...[output truncated]...\n");
390                    }
391                    Ok::<String, anyhow::Error>(buf)
392                } else {
393                    Ok(String::new())
394                }
395            };
396
397            let (stdout, stderr) = tokio::try_join!(stdout_fut, stderr_fut)?;
398
399            let status = child
400                .wait()
401                .await
402                .with_context(|| "Failed to wait for command completion")?;
403
404            Ok::<_, anyhow::Error>((status, stdout, stderr))
405        };
406
407        match tokio::time::timeout(self.options.command_timeout, execution_future).await {
408            Ok(result) => {
409                let (status, stdout, stderr) = result?;
410                let duration = start.elapsed();
411                let exit_code = status.code().unwrap_or(-1);
412
413                debug!(
414                    "Command completed on {} (exit={}, duration={}ms)",
415                    self.config.id,
416                    exit_code,
417                    duration.as_millis()
418                );
419
420                Ok(CommandResult {
421                    exit_code,
422                    stdout,
423                    stderr,
424                    duration_ms: duration.as_millis() as u64,
425                })
426            }
427            Err(_) => {
428                // Timeout occurred. The inner future (which owns `child`) is dropped,
429                // but dropping an openssh RemoteChild does NOT kill the remote process
430                // if a ControlMaster is active. The remote process will only terminate
431                // when the caller disconnects the SshClient (closing the Session/
432                // ControlMaster). Callers should ensure disconnect() is called after
433                // a timeout to avoid leaked remote processes.
434                warn!(
435                    "Command timed out on {} after {:?}",
436                    self.config.id, self.options.command_timeout
437                );
438                anyhow::bail!("Command timed out after {:?}", self.options.command_timeout);
439            }
440        }
441    }
442
443    /// Execute a command and stream output in real-time.
444    pub async fn execute_streaming<F, G>(
445        &self,
446        command: &str,
447        mut on_stdout: F,
448        mut on_stderr: G,
449    ) -> Result<CommandResult>
450    where
451        F: FnMut(&str),
452        G: FnMut(&str),
453    {
454        let session = self.session.as_ref().context("Not connected to worker")?;
455
456        let start = std::time::Instant::now();
457        debug!(
458            "Executing (streaming) on {}: {}",
459            self.config.id,
460            crate::util::mask_sensitive_command(command)
461        );
462
463        let mut child = session
464            .command("sh")
465            .arg("-c")
466            .arg(command)
467            .stdout(Stdio::piped())
468            .stderr(Stdio::piped())
469            .spawn()
470            .await
471            .with_context(|| format!("Failed to spawn command on {}", self.config.id))?;
472
473        let stdout = child.stdout().take();
474        let stderr = child.stderr().take();
475
476        // Use a channel to aggregate stream events from reader tasks.
477        // This avoids cancellation safety issues with select! over AsyncBufReadExt::read_line.
478        let (tx, mut rx) = mpsc::channel(100);
479
480        // Spawn stdout reader
481        if let Some(out) = stdout {
482            let tx = tx.clone();
483            tokio::spawn(async move {
484                let mut reader = BufReader::new(out);
485                let mut line = String::new();
486                loop {
487                    line.clear();
488                    match reader.read_line(&mut line).await {
489                        Ok(0) => break, // EOF
490                        Ok(_) => {
491                            if tx.send(StreamEvent::Stdout(line.clone())).await.is_err() {
492                                break; // Receiver dropped
493                            }
494                        }
495                        Err(_) => break, // Read error
496                    }
497                }
498            });
499        }
500
501        // Spawn stderr reader
502        if let Some(err) = stderr {
503            let tx = tx.clone();
504            tokio::spawn(async move {
505                let mut reader = BufReader::new(err);
506                let mut line = String::new();
507                loop {
508                    line.clear();
509                    match reader.read_line(&mut line).await {
510                        Ok(0) => break, // EOF
511                        Ok(_) => {
512                            if tx.send(StreamEvent::Stderr(line.clone())).await.is_err() {
513                                break; // Receiver dropped
514                            }
515                        }
516                        Err(_) => break, // Read error
517                    }
518                }
519            });
520        }
521
522        // Drop original tx so rx closes when tasks finish
523        drop(tx);
524
525        let mut stdout_acc = String::new();
526        let mut stderr_acc = String::new();
527
528        enum StreamEvent {
529            Stdout(String),
530            Stderr(String),
531        }
532
533        let streaming_future = async {
534            // Process events until channel closes (EOF from both streams)
535            while let Some(event) = rx.recv().await {
536                match event {
537                    StreamEvent::Stdout(line) => {
538                        on_stdout(&line);
539                        if stdout_acc.len() < MAX_OUTPUT_SIZE as usize {
540                            stdout_acc.push_str(&line);
541                            if stdout_acc.len() >= MAX_OUTPUT_SIZE as usize {
542                                stdout_acc.push_str("\n...[output truncated]...\n");
543                            }
544                        }
545                    }
546                    StreamEvent::Stderr(line) => {
547                        on_stderr(&line);
548                        if stderr_acc.len() < MAX_OUTPUT_SIZE as usize {
549                            stderr_acc.push_str(&line);
550                            if stderr_acc.len() >= MAX_OUTPUT_SIZE as usize {
551                                stderr_acc.push_str("\n...[output truncated]...\n");
552                            }
553                        }
554                    }
555                }
556            }
557
558            let status = child.wait().await?;
559            Ok::<_, anyhow::Error>(status)
560        };
561
562        match tokio::time::timeout(self.options.command_timeout, streaming_future).await {
563            Ok(result) => {
564                let status = result?;
565                let duration = start.elapsed();
566                let exit_code = status.code().unwrap_or(-1);
567
568                Ok(CommandResult {
569                    exit_code,
570                    stdout: stdout_acc,
571                    stderr: stderr_acc,
572                    duration_ms: duration.as_millis() as u64,
573                })
574            }
575            Err(_) => {
576                // Timeout occurred - the spawned reader tasks will terminate when they
577                // try to send on rx (which is dropped when this scope exits).
578                // The child process is also dropped here, but openssh may not kill
579                // the remote process immediately. Log the situation for visibility.
580                //
581                // Note: The reader tasks are detached (tokio::spawn) so they continue
582                // briefly until they hit EOF or the send fails. This is acceptable
583                // because they're lightweight and will terminate quickly once the
584                // channel closes.
585                warn!(
586                    "Command (streaming) timed out on {} after {:?}, cleaning up",
587                    self.config.id, self.options.command_timeout
588                );
589                // rx is dropped here, which will cause senders to fail on next send
590                // child is dropped here, which signals termination to openssh
591                anyhow::bail!("Command timed out after {:?}", self.options.command_timeout);
592            }
593        }
594    }
595
596    /// Check if the worker is reachable via SSH.
597    pub async fn health_check(&self) -> Result<bool> {
598        match self.execute(HEALTH_CHECK_COMMAND).await {
599            Ok(result) => Ok(result.success() && is_expected_health_check_output(&result.stdout)),
600            Err(e) => {
601                warn!("Health check failed for {}: {}", self.config.id, e);
602                Ok(false)
603            }
604        }
605    }
606}
607
608/// Connection pool for managing multiple SSH connections.
609pub struct SshPool {
610    /// Pool of active connections.
611    connections: Arc<RwLock<HashMap<WorkerId, Arc<RwLock<SshClient>>>>>,
612    /// Default SSH options.
613    options: SshOptions,
614}
615
616impl SshPool {
617    /// Create a new connection pool.
618    pub fn new(options: SshOptions) -> Self {
619        Self {
620            connections: Arc::new(RwLock::new(HashMap::new())),
621            options,
622        }
623    }
624
625    /// Get or create a connection to a worker.
626    pub async fn get_or_connect(&self, config: &WorkerConfig) -> Result<Arc<RwLock<SshClient>>> {
627        let shared_client = self.get_or_create_client_entry(config).await;
628
629        let is_connected = {
630            let guard = shared_client.read().await;
631            guard.is_connected()
632        };
633        if is_connected {
634            debug!("Reusing existing connection to {}", config.id);
635            return Ok(shared_client);
636        }
637
638        // Perform the slow connection process while holding only the lock for this specific worker
639        let mut client_guard = shared_client.write().await;
640        // Double check it wasn't connected by another task that won the race
641        if !client_guard.is_connected() {
642            client_guard.connect().await?;
643        }
644        // Drop write lock before returning
645        drop(client_guard);
646
647        Ok(shared_client)
648    }
649
650    async fn get_or_create_client_entry(&self, config: &WorkerConfig) -> Arc<RwLock<SshClient>> {
651        let worker_id = config.id.clone();
652
653        loop {
654            let existing_client = {
655                let connections = self.connections.read().await;
656                connections.get(&worker_id).cloned()
657            };
658
659            if let Some(client) = existing_client {
660                let is_configured_for_worker = {
661                    let guard = client.read().await;
662                    guard.is_configured_for(config)
663                };
664                if is_configured_for_worker {
665                    return client;
666                }
667
668                let replacement = Arc::new(RwLock::new(SshClient::new(
669                    config.clone(),
670                    self.options.clone(),
671                )));
672                let replaced = {
673                    let mut connections = self.connections.write().await;
674                    if connections
675                        .get(&worker_id)
676                        .is_some_and(|current| Arc::ptr_eq(current, &client))
677                    {
678                        connections.insert(worker_id.clone(), replacement.clone());
679                        true
680                    } else {
681                        false
682                    }
683                };
684
685                if replaced {
686                    debug!(
687                        "Replaced SSH connection entry for {} after endpoint config changed",
688                        worker_id
689                    );
690                    return replacement;
691                }
692
693                continue;
694            }
695
696            let new_client = Arc::new(RwLock::new(SshClient::new(
697                config.clone(),
698                self.options.clone(),
699            )));
700            let inserted = {
701                let mut connections = self.connections.write().await;
702                if connections.contains_key(&worker_id) {
703                    false
704                } else {
705                    connections.insert(worker_id.clone(), new_client.clone());
706                    true
707                }
708            };
709
710            if inserted {
711                return new_client;
712            }
713        }
714    }
715
716    /// Close a specific connection.
717    pub async fn close(&self, worker_id: &WorkerId) -> Result<()> {
718        let client = {
719            let mut connections = self.connections.write().await;
720            connections.remove(worker_id)
721        };
722
723        if let Some(client) = client {
724            let mut client = client.write().await;
725            client.disconnect().await?;
726        }
727
728        Ok(())
729    }
730
731    /// Close all connections.
732    pub async fn close_all(&self) -> Result<()> {
733        let clients: Vec<_> = {
734            let mut connections = self.connections.write().await;
735            connections.drain().map(|(_, v)| v).collect()
736        };
737
738        for client in clients {
739            let mut client = client.write().await;
740            if let Err(e) = client.disconnect().await {
741                error!("Error closing connection: {}", e);
742            }
743        }
744
745        Ok(())
746    }
747
748    /// Get the number of active connections.
749    pub async fn active_connections(&self) -> usize {
750        self.connections.read().await.len()
751    }
752}
753
754impl Default for SshPool {
755    fn default() -> Self {
756        Self::new(SshOptions::default())
757    }
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use crate::test_guard;
764
765    #[test]
766    fn test_command_result_success() {
767        let _guard = test_guard!();
768        let result = CommandResult {
769            exit_code: 0,
770            stdout: "output".to_string(),
771            stderr: String::new(),
772            duration_ms: 100,
773        };
774        assert!(result.success());
775
776        let failed = CommandResult {
777            exit_code: 1,
778            stdout: String::new(),
779            stderr: "error".to_string(),
780            duration_ms: 50,
781        };
782        assert!(!failed.success());
783    }
784
785    #[test]
786    fn test_ssh_options_default() {
787        let _guard = test_guard!();
788        let options = SshOptions::default();
789        assert_eq!(options.connect_timeout, Duration::from_secs(10));
790        assert_eq!(options.command_timeout, Duration::from_secs(300));
791        assert!(options.server_alive_interval.is_none());
792        assert!(options.control_persist_idle.is_none());
793        assert!(!options.control_master);
794    }
795
796    #[test]
797    fn test_ssh_client_creation() {
798        let _guard = test_guard!();
799        let config = WorkerConfig {
800            id: WorkerId::new("test-worker"),
801            host: "192.168.1.100".to_string(),
802            user: "ubuntu".to_string(),
803            identity_file: "~/.ssh/id_rsa".to_string(),
804            total_slots: 8,
805            priority: 100,
806            tags: vec!["rust".to_string()],
807        };
808
809        let client = SshClient::new(config.clone(), SshOptions::default());
810        assert_eq!(client.worker_id().as_str(), "test-worker");
811        assert!(!client.is_connected());
812    }
813
814    #[test]
815    fn test_expected_health_check_output_accepts_sentinel_as_last_line() {
816        let _guard = test_guard!();
817
818        assert!(is_expected_health_check_output("ok\n"));
819        assert!(is_expected_health_check_output("login banner\nok\n"));
820        assert!(!is_expected_health_check_output(""));
821        assert!(!is_expected_health_check_output("not ok\n"));
822        assert!(!is_expected_health_check_output("ok\npost-command noise\n"));
823    }
824
825    fn worker_config(id: &str, host: &str, user: &str, identity_file: &str) -> WorkerConfig {
826        WorkerConfig {
827            id: WorkerId::new(id),
828            host: host.to_string(),
829            user: user.to_string(),
830            identity_file: identity_file.to_string(),
831            total_slots: 8,
832            priority: 100,
833            tags: vec!["rust".to_string()],
834        }
835    }
836
837    #[test]
838    fn test_ssh_client_configured_for_ignores_scheduling_fields() {
839        let _guard = test_guard!();
840        let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
841        let client = SshClient::new(config.clone(), SshOptions::default());
842
843        let mut scheduling_only_change = config;
844        scheduling_only_change.total_slots = 16;
845        scheduling_only_change.priority = 250;
846        scheduling_only_change.tags = vec!["rust".to_string(), "gpu".to_string()];
847
848        assert!(client.is_configured_for(&scheduling_only_change));
849    }
850
851    #[test]
852    fn test_ssh_client_configured_for_detects_endpoint_changes() {
853        let _guard = test_guard!();
854        let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
855        let client = SshClient::new(config, SshOptions::default());
856
857        assert!(!client.is_configured_for(&worker_config(
858            "worker-a",
859            "192.168.1.101",
860            "ubuntu",
861            "~/.ssh/id_rsa",
862        )));
863        assert!(!client.is_configured_for(&worker_config(
864            "worker-a",
865            "192.168.1.100",
866            "admin",
867            "~/.ssh/id_rsa",
868        )));
869        assert!(!client.is_configured_for(&worker_config(
870            "worker-a",
871            "192.168.1.100",
872            "ubuntu",
873            "~/.ssh/other_key",
874        )));
875    }
876
877    #[tokio::test]
878    async fn test_ssh_pool_reuses_matching_disconnected_entry() {
879        let _guard = test_guard!();
880        let pool = SshPool::default();
881        let config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
882
883        let first = pool.get_or_create_client_entry(&config).await;
884        let second = pool.get_or_create_client_entry(&config).await;
885
886        assert!(Arc::ptr_eq(&first, &second));
887        assert_eq!(pool.active_connections().await, 1);
888    }
889
890    #[tokio::test]
891    async fn test_ssh_pool_replaces_stale_entry_when_endpoint_changes() {
892        let _guard = test_guard!();
893        let pool = SshPool::default();
894        let old_config = worker_config("worker-a", "192.168.1.100", "ubuntu", "~/.ssh/id_rsa");
895        let new_config = worker_config("worker-a", "192.168.1.101", "admin", "~/.ssh/new_key");
896
897        let stale = pool.get_or_create_client_entry(&old_config).await;
898        let replacement = pool.get_or_create_client_entry(&new_config).await;
899
900        assert!(!Arc::ptr_eq(&stale, &replacement));
901        assert_eq!(pool.active_connections().await, 1);
902
903        let replacement_guard = replacement.read().await;
904        assert!(replacement_guard.is_configured_for(&new_config));
905    }
906
907    #[test]
908    fn test_build_env_prefix_quotes_and_rejects() {
909        let _guard = test_guard!();
910        let mut env = HashMap::new();
911        env.insert("RUSTFLAGS".to_string(), "-C target-cpu=native".to_string());
912        env.insert("QUOTED".to_string(), "a'b".to_string());
913        env.insert("BADVAL".to_string(), "line1\nline2".to_string());
914
915        let allowlist = vec![
916            "RUSTFLAGS".to_string(),
917            "QUOTED".to_string(),
918            "MISSING".to_string(),
919            "BADVAL".to_string(),
920            "BAD=KEY".to_string(),
921        ];
922
923        let prefix = build_env_prefix(&allowlist, |key| env.get(key).cloned());
924
925        assert!(prefix.prefix.contains("RUSTFLAGS='-C target-cpu=native'"));
926        // shell_escape uses '\'' style (end string, escaped quote, start string)
927        assert!(prefix.prefix.contains("QUOTED='a'\\''b'"));
928        assert!(!prefix.prefix.contains("MISSING="));
929        assert!(!prefix.prefix.contains("BADVAL="));
930        assert!(prefix.rejected.contains(&"BADVAL".to_string()));
931        assert!(prefix.rejected.contains(&"BAD=KEY".to_string()));
932        assert_eq!(
933            prefix.applied,
934            vec!["RUSTFLAGS".to_string(), "QUOTED".to_string()]
935        );
936    }
937
938    // ==========================================================================
939    // Proptest: SSH command escaping with special chars (bd-2elj)
940    // ==========================================================================
941
942    mod proptest_ssh_escaping {
943        use super::*;
944        use proptest::prelude::*;
945        use std::collections::HashMap;
946
947        proptest! {
948            #![proptest_config(ProptestConfig::with_cases(1000))]
949
950            // Test 1: is_valid_env_key never panics on arbitrary strings
951            #[test]
952            fn test_is_valid_env_key_no_panic(s in ".*") {
953        let _guard = test_guard!();
954                let _ = is_valid_env_key(&s);
955            }
956
957            // Test 2: Valid env keys start with letter/_ and contain only alphanum/_
958            #[test]
959            fn test_is_valid_env_key_accepts_valid(
960                first in "[a-zA-Z_]",
961                rest in "[a-zA-Z0-9_]{0,50}"
962            ) {
963        let _guard = test_guard!();
964                let key = format!("{}{}", first, rest);
965                prop_assert!(is_valid_env_key(&key), "Should accept valid key: {}", key);
966            }
967
968            // Test 3: Env keys starting with digit are invalid
969            #[test]
970            fn test_is_valid_env_key_rejects_digit_start(
971                digit in "[0-9]",
972                rest in "[a-zA-Z0-9_]{0,20}"
973            ) {
974        let _guard = test_guard!();
975                let key = format!("{}{}", digit, rest);
976                prop_assert!(!is_valid_env_key(&key), "Should reject digit-start key: {}", key);
977            }
978
979            // Test 4: shell_escape_value never panics on arbitrary strings
980            #[test]
981            fn test_shell_escape_value_no_panic(s in ".*") {
982        let _guard = test_guard!();
983                let _ = shell_escape_value(&s);
984            }
985
986            // Test 5: shell_escape_value rejects newlines/carriage returns/NUL
987            #[test]
988            fn test_shell_escape_value_rejects_unsafe(
989                prefix in "[a-zA-Z0-9 ]{0,10}",
990                bad_char in "[\n\r\0]",
991                suffix in "[a-zA-Z0-9 ]{0,10}"
992            ) {
993        let _guard = test_guard!();
994                let value = format!("{}{}{}", prefix, bad_char, suffix);
995                prop_assert!(shell_escape_value(&value).is_none(),
996                    "Should reject value with unsafe char: {:?}", value);
997            }
998
999            // Test 6: shell_escape_value handles safe values
1000            #[test]
1001            fn test_shell_escape_value_accepts_safe(s in "[a-zA-Z0-9 !@#$%^&*()_+=\\-\\[\\]{}|;:,./<>?]{0,100}") {
1002                // These don't contain \n, \r, or \0
1003                let result = shell_escape_value(&s);
1004                prop_assert!(result.is_some(), "Should accept safe value: {:?}", s);
1005
1006                // shell_escape only quotes values that need it (contain special chars)
1007                // Simple alphanumeric strings may be returned unquoted
1008                let escaped = match result {
1009                    Some(escaped) => escaped,
1010                    None => {
1011                        prop_assert!(false, "Should accept safe value: {:?}", s);
1012                        String::new()
1013                    }
1014                };
1015                if s.chars().any(|c| !c.is_ascii_alphanumeric() && c != '_') {
1016                    // Values with special chars should be quoted
1017                    prop_assert!(escaped.starts_with('\'') || escaped.contains('\''),
1018                        "Value with special chars should be quoted: {:?} -> {:?}", s, escaped);
1019                }
1020            }
1021
1022            // Test 7: shell_escape_value properly escapes single quotes
1023            #[test]
1024            fn test_shell_escape_value_escapes_quotes(
1025                prefix in "[a-zA-Z0-9]{0,10}",
1026                suffix in "[a-zA-Z0-9]{0,10}"
1027            ) {
1028        let _guard = test_guard!();
1029                let value = format!("{}'{}", prefix, suffix);
1030                let result = shell_escape_value(&value);
1031                prop_assert!(result.is_some());
1032
1033                let escaped = match result {
1034                    Some(escaped) => escaped,
1035                    None => {
1036                        prop_assert!(false, "Should escape single quote: {}", value);
1037                        String::new()
1038                    }
1039                };
1040                // shell_escape uses '\'' style (end string, escaped quote, start string)
1041                prop_assert!(escaped.contains("'\\''"),
1042                    "Should escape single quote: {} -> {}", value, escaped);
1043            }
1044
1045            // Test 8: build_env_prefix never panics
1046            #[test]
1047            fn test_build_env_prefix_no_panic(
1048                keys in prop::collection::vec("[a-zA-Z_][a-zA-Z0-9_]{0,10}", 0..10),
1049                values in prop::collection::vec(".*", 0..10)
1050            ) {
1051                let mut env = HashMap::new();
1052                for (i, key) in keys.iter().enumerate() {
1053                    if let Some(val) = values.get(i) {
1054                        env.insert(key.clone(), val.clone());
1055                    }
1056                }
1057
1058                let allowlist: Vec<String> = keys;
1059                let _ = build_env_prefix(&allowlist, |k| env.get(k).cloned());
1060            }
1061
1062            // Test 9: build_env_prefix rejects invalid keys (non-empty after trim)
1063            #[test]
1064            fn test_build_env_prefix_rejects_invalid_keys(
1065                // Generate keys that are invalid even after trimming
1066                invalid_key in "[0-9][a-zA-Z0-9_]{0,10}"  // Starts with digit
1067            ) {
1068        let _guard = test_guard!();
1069                let mut env = HashMap::new();
1070                env.insert(invalid_key.clone(), "value".to_string());
1071
1072                let allowlist = vec![invalid_key.clone()];
1073                let prefix = build_env_prefix(&allowlist, |k| env.get(k).cloned());
1074
1075                // Key should be rejected since it starts with a digit
1076                prop_assert!(!is_valid_env_key(&invalid_key),
1077                    "Key should be invalid: {}", invalid_key);
1078                prop_assert!(prefix.rejected.contains(&invalid_key),
1079                    "Should reject invalid key: {}", invalid_key);
1080                prop_assert!(prefix.prefix.is_empty());
1081            }
1082
1083            // Test 10: build_env_prefix handles missing values gracefully
1084            #[test]
1085            fn test_build_env_prefix_missing_values(
1086                keys in prop::collection::vec("[A-Z_][A-Z0-9_]{0,10}", 1..5)
1087            ) {
1088                // Empty env - all keys missing
1089                let env: HashMap<String, String> = HashMap::new();
1090                let prefix = build_env_prefix(&keys, |k| env.get(k).cloned());
1091
1092                // Should be empty prefix since no values found
1093                prop_assert!(prefix.prefix.is_empty(), "Should be empty when no values");
1094                prop_assert!(prefix.applied.is_empty());
1095                // Missing values don't count as rejected
1096                prop_assert!(prefix.rejected.is_empty());
1097            }
1098        }
1099
1100        // Targeted edge case tests
1101        #[test]
1102        fn test_shell_escape_edge_cases() {
1103            let _guard = test_guard!();
1104            // Empty string
1105            let result = shell_escape_value("");
1106            assert_eq!(result, Some("''".to_string()));
1107
1108            // Just single quote - shell_escape uses '\'' style (end string, escaped quote, start string)
1109            let result = shell_escape_value("'");
1110            assert_eq!(result, Some("''\\'''".to_string()));
1111
1112            // Multiple single quotes
1113            let result = shell_escape_value("'''");
1114            // shell_escape uses '\'' style for each single quote
1115            assert_eq!(
1116                result
1117                    .as_deref()
1118                    .map(|escaped| escaped.matches("'\\''").count()),
1119                Some(3)
1120            );
1121
1122            // Unicode
1123            let result = shell_escape_value("ζ—₯本θͺž");
1124            assert!(result.is_some());
1125
1126            // Emoji
1127            let result = shell_escape_value("πŸ”₯πŸš€");
1128            assert!(result.is_some());
1129
1130            // Mixed quotes and special chars
1131            let result = shell_escape_value("it's a \"test\" with $vars");
1132            assert!(result.is_some());
1133        }
1134
1135        #[test]
1136        fn test_is_valid_env_key_edge_cases() {
1137            let _guard = test_guard!();
1138            // Empty
1139            assert!(!is_valid_env_key(""));
1140
1141            // Single underscore
1142            assert!(is_valid_env_key("_"));
1143
1144            // Single letter
1145            assert!(is_valid_env_key("A"));
1146
1147            // Typical env vars
1148            assert!(is_valid_env_key("PATH"));
1149            assert!(is_valid_env_key("HOME"));
1150            assert!(is_valid_env_key("RUSTFLAGS"));
1151            assert!(is_valid_env_key("CC"));
1152            assert!(is_valid_env_key("_PRIVATE"));
1153            assert!(is_valid_env_key("MY_VAR_123"));
1154
1155            // Invalid: starts with number
1156            assert!(!is_valid_env_key("1VAR"));
1157            assert!(!is_valid_env_key("123"));
1158
1159            // Invalid: contains special chars
1160            assert!(!is_valid_env_key("MY-VAR"));
1161            assert!(!is_valid_env_key("MY.VAR"));
1162            assert!(!is_valid_env_key("MY VAR"));
1163            assert!(!is_valid_env_key("MY=VAR"));
1164
1165            // Invalid: Unicode
1166            assert!(!is_valid_env_key("ζ—₯本θͺž"));
1167            assert!(!is_valid_env_key("VARπŸ”₯"));
1168        }
1169
1170        #[test]
1171        fn test_build_env_prefix_integration() {
1172            let _guard = test_guard!();
1173            // Complex scenario with mixed valid/invalid
1174            let mut env = HashMap::new();
1175            env.insert("VALID".to_string(), "simple".to_string());
1176            env.insert("WITH_QUOTE".to_string(), "it's here".to_string());
1177            env.insert("NEWLINE".to_string(), "line1\nline2".to_string());
1178            env.insert("UNICODE".to_string(), "ζ—₯本θͺž".to_string());
1179            env.insert("EMPTY".to_string(), String::new());
1180            env.insert("123INVALID".to_string(), "value".to_string());
1181
1182            let allowlist = vec![
1183                "VALID".to_string(),
1184                "WITH_QUOTE".to_string(),
1185                "NEWLINE".to_string(),
1186                "UNICODE".to_string(),
1187                "EMPTY".to_string(),
1188                "123INVALID".to_string(),
1189                "MISSING".to_string(),
1190            ];
1191
1192            let prefix = build_env_prefix(&allowlist, |k| env.get(k).cloned());
1193
1194            // VALID should be applied
1195            assert!(prefix.applied.contains(&"VALID".to_string()));
1196            // shell_escape doesn't quote simple alphanumeric strings
1197            assert!(prefix.prefix.contains("VALID=simple"));
1198
1199            // WITH_QUOTE should be applied with escaped quote
1200            assert!(prefix.applied.contains(&"WITH_QUOTE".to_string()));
1201
1202            // NEWLINE should be rejected (unsafe value)
1203            assert!(prefix.rejected.contains(&"NEWLINE".to_string()));
1204
1205            // UNICODE should be applied (safe unicode)
1206            assert!(prefix.applied.contains(&"UNICODE".to_string()));
1207
1208            // EMPTY should be applied
1209            assert!(prefix.applied.contains(&"EMPTY".to_string()));
1210
1211            // 123INVALID should be rejected (invalid key)
1212            assert!(prefix.rejected.contains(&"123INVALID".to_string()));
1213
1214            // MISSING should not appear in either list (not found = silently ignored)
1215            assert!(!prefix.applied.contains(&"MISSING".to_string()));
1216            assert!(!prefix.rejected.contains(&"MISSING".to_string()));
1217        }
1218
1219        #[test]
1220        fn test_shell_escape_roundtrip_safety() {
1221            let _guard = test_guard!();
1222            // Values that when escaped and passed through shell should reconstruct original
1223            let test_values = [
1224                "simple",
1225                "with spaces",
1226                "with\ttab",
1227                "special!@#$%^&*()",
1228                "quoted\"value",
1229                "path/to/file",
1230                "-flag",
1231                "--long-flag=value",
1232                "",
1233            ];
1234
1235            for value in &test_values {
1236                let escaped = shell_escape_value(value);
1237                assert!(escaped.is_some(), "Should escape: {:?}", value);
1238            }
1239        }
1240    }
1241}