crankshaft_engine/service/runner/backend/generic/
driver.rs

1//! Command drivers in a generic backend.
2
3use std::io::Read as _;
4#[cfg(unix)]
5use std::os::unix::process::ExitStatusExt;
6#[cfg(windows)]
7use std::os::windows::process::ExitStatusExt;
8use std::process::ExitStatus;
9use std::process::Output;
10use std::sync::Arc;
11use std::time::Duration;
12
13use anyhow::Context as _;
14use anyhow::Result;
15use anyhow::bail;
16use crankshaft_config::backend::generic::driver::Config;
17use crankshaft_config::backend::generic::driver::Locale;
18use crankshaft_config::backend::generic::driver::Shell;
19use crankshaft_config::backend::generic::driver::ssh;
20use rand::Rng as _;
21use ssh2::Channel;
22use ssh2::Session;
23use thiserror::Error;
24use tokio::net::TcpStream;
25use tokio::process::Command;
26use tracing::debug;
27use tracing::error;
28use tracing::trace;
29
30/// An error related to a [`Driver`].
31#[derive(Error, Debug)]
32pub enum Error {
33    /// An i/o error.
34    #[error(transparent)]
35    Io(std::io::Error),
36
37    /// An error related to joining a [`tokio`] task.
38    #[error(transparent)]
39    Join(tokio::task::JoinError),
40
41    /// An [ssh error](ssh2::Error).
42    #[error(transparent)]
43    SSH2(ssh2::Error),
44}
45
46/// A command transport.
47///
48/// The command transport is what ships commands off to be run within an
49/// [`Driver`]. This might be executing commands locally or on a remote server
50/// via SSH.
51pub enum Transport {
52    /// Local command execution.
53    Local,
54
55    /// Command execution over an SSH session.
56    SSH(Arc<Session>),
57}
58
59impl std::fmt::Debug for Transport {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::Local => write!(f, "Local"),
63            Self::SSH(_) => f.debug_tuple("SSH").finish(),
64        }
65    }
66}
67
68/// A command driver.
69///
70/// A command driver is an abstraction through which shell commands can be
71/// dispatched within various locales (e.g., your local computer, remotely over
72/// SSH, etc).
73///
74/// In addition to containing the state around those connections, the driver
75/// also holds configuration necessary to know how to execute the commands
76/// (e.g., which shell to use when running).
77#[derive(Debug)]
78pub struct Driver {
79    /// The command transport.
80    transport: Transport,
81
82    /// The configuration.
83    config: Config,
84}
85
86impl Driver {
87    /// Initializes a new [`Driver`].
88    ///
89    /// This command requires an async runtime because, for some transports,
90    /// negotiation is done via subprocesses or network calls to initialize the
91    /// necessary state (e.g., establishing an SSH session with a remote host).
92    ///
93    /// **NOTE:** this method returns an [`anyhow::Result`] because any errors
94    /// are intended to be returned directly to the user in the calling binary
95    /// (i.e., the errors are typically unrecoverable).
96    pub async fn initialize(config: Config) -> Result<Self> {
97        // NOTE: this is cloned because `default()` is only implemented on the
98        // owned [`Locale`] type (not a reference).
99        let transport = match config.locale() {
100            // NOTE: no initialization is needed here, as we simply spawn a
101            // [`tokio::process::Command`] when [`command()`] is called.
102            Some(Locale::Local) | None => Ok(Transport::Local),
103            Some(Locale::SSH(config)) => create_ssh_transport(config).await,
104        }?;
105
106        Ok(Self { transport, config })
107    }
108
109    /// Runs a shell command within the configuration locale.
110    ///
111    /// **NOTE:** this method returns an [`anyhow::Result`] because any errors
112    /// are intended to be returned directly to the user in the calling binary
113    /// (i.e., the errors are typically unrecoverable).
114    pub async fn run(&self, command: impl Into<String>) -> Result<Output> {
115        let command = command.into();
116
117        match &self.transport {
118            Transport::Local => run_local_command(command, &self.config).await,
119            Transport::SSH(session) => {
120                run_ssh_command(session.clone(), &self.config, command).await
121            }
122        }
123    }
124
125    /// Gets the inner transport.
126    pub fn transport(&self) -> &Transport {
127        &self.transport
128    }
129
130    /// Gets the inner config.
131    pub fn config(&self) -> &Config {
132        &self.config
133    }
134}
135
136//=================//
137// Local Execution //
138//=================//
139
140/// Runs a command in a local context.
141async fn run_local_command(command: String, config: &Config) -> Result<Output> {
142    trace!("executing local command: `{command}`");
143
144    // NOTE: this is cloned because `default()` is only implemented on the owned
145    // [`Locale`] type (not a reference).
146    let command = match config.shell().unwrap_or_default() {
147        Shell::Bash => Command::new("/usr/bin/env")
148            .args(["bash", "-c", &command])
149            .stdout(std::process::Stdio::piped())
150            .stderr(std::process::Stdio::piped())
151            .spawn(),
152        Shell::Sh => Command::new("/usr/bin/env")
153            .args(["sh", "-c", &command])
154            .stdout(std::process::Stdio::piped())
155            .stderr(std::process::Stdio::piped())
156            .spawn(),
157    }
158    .context("spawning the local command")?;
159
160    command
161        .wait_with_output()
162        .await
163        .context("executing the local command")
164}
165
166//===============//
167// SSH Execution //
168//===============//
169
170/// Attempts to create an SSH transport.
171async fn create_ssh_transport(config: &ssh::Config) -> Result<Transport> {
172    let addr = format!("{host}:{port}", host = config.host(), port = config.port());
173
174    // Connect to the remote SSH host.
175    let message = format!("connecting to SSH host: {addr}");
176    debug!(message);
177    let tcp = TcpStream::connect(addr)
178        .await
179        .map_err(Error::Io)
180        .context(message)?;
181
182    // Create a new SSH session.
183    debug!("establishing a new SSH session and connecting");
184
185    trace!("creating a new SSH session");
186    let mut sess = Session::new()
187        .map_err(Error::SSH2)
188        .context("creating a new SSH session")?;
189    sess.set_tcp_stream(tcp);
190    trace!("performing the SSH handshake");
191    sess.handshake()
192        .map_err(Error::SSH2)
193        .context("performing the SSH handshake")?;
194
195    // Connect to the SSH agent and authenticate within the current
196    // session.
197    debug!("retrieving identities from the SSH agent");
198
199    trace!("initializing the SSH agent");
200    let mut agent = sess
201        .agent()
202        .map_err(Error::SSH2)
203        .context("initializing the SSH agent")?;
204
205    trace!("connecting to the SSH agent");
206    agent
207        .connect()
208        .map_err(Error::SSH2)
209        .context("connecting to the SSH agent")?;
210
211    trace!("listing identities within the SSH agent");
212    agent
213        .list_identities()
214        .map_err(Error::SSH2)
215        .context("listing identities within the SSH agent")?;
216
217    trace!("accessing the retrieved identities");
218    let identities = agent
219        .identities()
220        .map_err(Error::SSH2)
221        .context("accessing retrieved identities")?;
222
223    let key = match identities.len() {
224        0 => bail!("no identities found in the SSH agent! Try using `ssh-add` on your SSH key."),
225        // SAFETY: we just checked that there is exactly one SSH key
226        // in the agent, so this will always unwrap.
227        1 => identities.first().unwrap(),
228        _ => unimplemented!(
229            "`crankshaft` does not yet support multiple keys in an SSH agent. Please file an \
230             issue!"
231        ),
232    };
233
234    trace!(
235        "found a single identifier with the comment `{}`",
236        key.comment()
237    );
238
239    // Authenticate the SSH session.
240    debug!("authenticating SSH session");
241
242    if let Some(username) = config.username() {
243        agent
244            .userauth(username, key)
245            .map_err(Error::SSH2)
246            .with_context(|| {
247                format!(
248                    "authenticating with username `{}` and identity `{}`",
249                    username,
250                    key.comment()
251                )
252            })?;
253    } else {
254        let username = whoami::username();
255
256        agent
257            .userauth(&username, key)
258            .map_err(Error::SSH2)
259            .with_context(|| {
260                format!(
261                    "authenticating with username `{}` and identity `{}`",
262                    username,
263                    key.comment()
264                )
265            })?;
266    }
267
268    if sess.authenticated() {
269        debug!("authentication successful");
270        Ok(Transport::SSH(Arc::new(sess)))
271    } else {
272        error!("authentication failed!");
273        bail!("failed authentication")
274    }
275}
276
277/// The minimum amount of waiting time.
278const WAIT_FLOOR: u64 = 300;
279
280/// The amount of jitter to introduce.
281const WAIT_JITTER: u64 = 150;
282
283/// Attempts to create a new [`Channel`] with a backoff on failures.
284fn channel_session_with_backoff(
285    session: &Session,
286    max_attempts: u32,
287) -> std::result::Result<Channel, Error> {
288    let mut attempts = 0u32;
289    let mut wait_time = 0u64;
290
291    while attempts < max_attempts {
292        match session.channel_session() {
293            Ok(channel) => return Ok(channel),
294            Err(e) => {
295                attempts += 1;
296                trace!(
297                    "failed to connect: {}; attempt {}/{}",
298                    e, attempts, max_attempts,
299                );
300
301                if attempts >= max_attempts {
302                    return Err(Error::SSH2(e));
303                }
304
305                let jitter = rand::rng().random_range(0..=WAIT_JITTER);
306                wait_time += WAIT_FLOOR + jitter;
307
308                trace!("waiting for {} ms.", wait_time);
309                // NOTE: this will always be called from a blocking thread in
310                // the async runtime, so it's okay.
311                std::thread::sleep(Duration::from_millis(wait_time));
312            }
313        }
314    }
315
316    // SAFETY: the loop above should always return.
317    unreachable!()
318}
319
320/// Runs a remote command over SSH.
321async fn run_ssh_command(
322    session: Arc<ssh2::Session>,
323    config: &Config,
324    command: String,
325) -> Result<Output> {
326    let max_attempts = config.max_attempts();
327
328    let f = move || {
329        debug!("running command on remote host: `{}`", command);
330
331        // Create a new channel with which to communicate with the host.
332        trace!("creating a new session-based channel");
333        let mut channel =
334            channel_session_with_backoff(&session, max_attempts.unwrap_or_default().inner())
335                .context("creating a new session-based channel")?;
336
337        // Send a command across the channel.
338        trace!("sending the execution command");
339        channel
340            .exec(&command)
341            .map_err(Error::SSH2)
342            .context("executing a command over SSH")?;
343
344        // Read the entire output that was written to the channel.
345        trace!("reading the stdout of the command");
346        let mut stdout = Vec::new();
347        channel
348            .read_to_end(&mut stdout)
349            .map_err(Error::Io)
350            .context("reading the stdout of the command over SSH")?;
351
352        for line in String::from_utf8_lossy(&stdout).lines() {
353            trace!("stdout: {line}");
354        }
355
356        // Read the entire stderr that was written to the channel.
357        trace!("reading the stderr of the command");
358        let mut stderr = Vec::new();
359        channel
360            .stderr()
361            .read_to_end(&mut stderr)
362            .map_err(Error::Io)
363            .context("reading the stderr of the command over SSH")?;
364
365        for line in String::from_utf8_lossy(&stderr).lines() {
366            trace!("stderr: {line}");
367        }
368
369        // Getting the exit code.
370        let status = channel
371            .exit_status()
372            .map_err(Error::SSH2)
373            .context("getting the exit status of the command")?;
374
375        // Indicate to the remote host that we won't be sending any
376        // more data over this connection.
377        trace!("closing the client's end of the channel");
378        channel
379            .close()
380            .map_err(Error::SSH2)
381            .context("closing the SSH channel")?;
382
383        // Wait until the remote host also closes the connection.
384        trace!("waiting for the remote host to close their end of the channel");
385        channel
386            .wait_close()
387            .map_err(Error::SSH2)
388            .context("waiting for the SSH channel to be closed from the client's end")?;
389
390        #[cfg(unix)]
391        let output = Output {
392            // See WEXITSTATUS from wait(2) to explain the shift
393            status: ExitStatus::from_raw(status << 8),
394            stdout,
395            stderr,
396        };
397
398        #[cfg(windows)]
399        let output = Output {
400            status: ExitStatus::from_raw(status as u32),
401            stdout,
402            stderr,
403        };
404
405        Ok(output)
406    };
407
408    tokio::task::spawn_blocking(f)
409        .await
410        .map_err(Error::Join)
411        .context("running an SSH command")?
412}