crankshaft_engine/service/runner/backend/generic/
driver.rs1use std::io::Read as _;
4#[cfg(unix)]
5use std::os::unix::process::ExitStatusExt;
6#[cfg(windows)]
7use std::os::windows::process::ExitStatusExt;
8use std::process::ExitStatus;
9use std::process::Output;
10use std::sync::Arc;
11use std::time::Duration;
12
13use anyhow::Context as _;
14use anyhow::Result;
15use anyhow::bail;
16use crankshaft_config::backend::generic::driver::Config;
17use crankshaft_config::backend::generic::driver::Locale;
18use crankshaft_config::backend::generic::driver::Shell;
19use crankshaft_config::backend::generic::driver::ssh;
20use rand::Rng as _;
21use ssh2::Channel;
22use ssh2::Session;
23use thiserror::Error;
24use tokio::net::TcpStream;
25use tokio::process::Command;
26use tracing::debug;
27use tracing::error;
28use tracing::trace;
29
30#[derive(Error, Debug)]
32pub enum Error {
33 #[error(transparent)]
35 Io(std::io::Error),
36
37 #[error(transparent)]
39 Join(tokio::task::JoinError),
40
41 #[error(transparent)]
43 SSH2(ssh2::Error),
44}
45
46pub enum Transport {
52 Local,
54
55 SSH(Arc<Session>),
57}
58
59impl std::fmt::Debug for Transport {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::Local => write!(f, "Local"),
63 Self::SSH(_) => f.debug_tuple("SSH").finish(),
64 }
65 }
66}
67
68#[derive(Debug)]
78pub struct Driver {
79 transport: Transport,
81
82 config: Config,
84}
85
86impl Driver {
87 pub async fn initialize(config: Config) -> Result<Self> {
97 let transport = match config.locale() {
100 Some(Locale::Local) | None => Ok(Transport::Local),
103 Some(Locale::SSH(config)) => create_ssh_transport(config).await,
104 }?;
105
106 Ok(Self { transport, config })
107 }
108
109 pub async fn run(&self, command: impl Into<String>) -> Result<Output> {
115 let command = command.into();
116
117 match &self.transport {
118 Transport::Local => run_local_command(command, &self.config).await,
119 Transport::SSH(session) => {
120 run_ssh_command(session.clone(), &self.config, command).await
121 }
122 }
123 }
124
125 pub fn transport(&self) -> &Transport {
127 &self.transport
128 }
129
130 pub fn config(&self) -> &Config {
132 &self.config
133 }
134}
135
136async fn run_local_command(command: String, config: &Config) -> Result<Output> {
142 trace!("executing local command: `{command}`");
143
144 let command = match config.shell().unwrap_or_default() {
147 Shell::Bash => Command::new("/usr/bin/env")
148 .args(["bash", "-c", &command])
149 .stdout(std::process::Stdio::piped())
150 .stderr(std::process::Stdio::piped())
151 .spawn(),
152 Shell::Sh => Command::new("/usr/bin/env")
153 .args(["sh", "-c", &command])
154 .stdout(std::process::Stdio::piped())
155 .stderr(std::process::Stdio::piped())
156 .spawn(),
157 }
158 .context("spawning the local command")?;
159
160 command
161 .wait_with_output()
162 .await
163 .context("executing the local command")
164}
165
166async fn create_ssh_transport(config: &ssh::Config) -> Result<Transport> {
172 let addr = format!("{host}:{port}", host = config.host(), port = config.port());
173
174 let message = format!("connecting to SSH host: {addr}");
176 debug!(message);
177 let tcp = TcpStream::connect(addr)
178 .await
179 .map_err(Error::Io)
180 .context(message)?;
181
182 debug!("establishing a new SSH session and connecting");
184
185 trace!("creating a new SSH session");
186 let mut sess = Session::new()
187 .map_err(Error::SSH2)
188 .context("creating a new SSH session")?;
189 sess.set_tcp_stream(tcp);
190 trace!("performing the SSH handshake");
191 sess.handshake()
192 .map_err(Error::SSH2)
193 .context("performing the SSH handshake")?;
194
195 debug!("retrieving identities from the SSH agent");
198
199 trace!("initializing the SSH agent");
200 let mut agent = sess
201 .agent()
202 .map_err(Error::SSH2)
203 .context("initializing the SSH agent")?;
204
205 trace!("connecting to the SSH agent");
206 agent
207 .connect()
208 .map_err(Error::SSH2)
209 .context("connecting to the SSH agent")?;
210
211 trace!("listing identities within the SSH agent");
212 agent
213 .list_identities()
214 .map_err(Error::SSH2)
215 .context("listing identities within the SSH agent")?;
216
217 trace!("accessing the retrieved identities");
218 let identities = agent
219 .identities()
220 .map_err(Error::SSH2)
221 .context("accessing retrieved identities")?;
222
223 let key = match identities.len() {
224 0 => bail!("no identities found in the SSH agent! Try using `ssh-add` on your SSH key."),
225 1 => identities.first().unwrap(),
228 _ => unimplemented!(
229 "`crankshaft` does not yet support multiple keys in an SSH agent. Please file an \
230 issue!"
231 ),
232 };
233
234 trace!(
235 "found a single identifier with the comment `{}`",
236 key.comment()
237 );
238
239 debug!("authenticating SSH session");
241
242 if let Some(username) = config.username() {
243 agent
244 .userauth(username, key)
245 .map_err(Error::SSH2)
246 .with_context(|| {
247 format!(
248 "authenticating with username `{}` and identity `{}`",
249 username,
250 key.comment()
251 )
252 })?;
253 } else {
254 let username = whoami::username();
255
256 agent
257 .userauth(&username, key)
258 .map_err(Error::SSH2)
259 .with_context(|| {
260 format!(
261 "authenticating with username `{}` and identity `{}`",
262 username,
263 key.comment()
264 )
265 })?;
266 }
267
268 if sess.authenticated() {
269 debug!("authentication successful");
270 Ok(Transport::SSH(Arc::new(sess)))
271 } else {
272 error!("authentication failed!");
273 bail!("failed authentication")
274 }
275}
276
277const WAIT_FLOOR: u64 = 300;
279
280const WAIT_JITTER: u64 = 150;
282
283fn channel_session_with_backoff(
285 session: &Session,
286 max_attempts: u32,
287) -> std::result::Result<Channel, Error> {
288 let mut attempts = 0u32;
289 let mut wait_time = 0u64;
290
291 while attempts < max_attempts {
292 match session.channel_session() {
293 Ok(channel) => return Ok(channel),
294 Err(e) => {
295 attempts += 1;
296 trace!(
297 "failed to connect: {}; attempt {}/{}",
298 e, attempts, max_attempts,
299 );
300
301 if attempts >= max_attempts {
302 return Err(Error::SSH2(e));
303 }
304
305 let jitter = rand::rng().random_range(0..=WAIT_JITTER);
306 wait_time += WAIT_FLOOR + jitter;
307
308 trace!("waiting for {} ms.", wait_time);
309 std::thread::sleep(Duration::from_millis(wait_time));
312 }
313 }
314 }
315
316 unreachable!()
318}
319
320async fn run_ssh_command(
322 session: Arc<ssh2::Session>,
323 config: &Config,
324 command: String,
325) -> Result<Output> {
326 let max_attempts = config.max_attempts();
327
328 let f = move || {
329 debug!("running command on remote host: `{}`", command);
330
331 trace!("creating a new session-based channel");
333 let mut channel =
334 channel_session_with_backoff(&session, max_attempts.unwrap_or_default().inner())
335 .context("creating a new session-based channel")?;
336
337 trace!("sending the execution command");
339 channel
340 .exec(&command)
341 .map_err(Error::SSH2)
342 .context("executing a command over SSH")?;
343
344 trace!("reading the stdout of the command");
346 let mut stdout = Vec::new();
347 channel
348 .read_to_end(&mut stdout)
349 .map_err(Error::Io)
350 .context("reading the stdout of the command over SSH")?;
351
352 for line in String::from_utf8_lossy(&stdout).lines() {
353 trace!("stdout: {line}");
354 }
355
356 trace!("reading the stderr of the command");
358 let mut stderr = Vec::new();
359 channel
360 .stderr()
361 .read_to_end(&mut stderr)
362 .map_err(Error::Io)
363 .context("reading the stderr of the command over SSH")?;
364
365 for line in String::from_utf8_lossy(&stderr).lines() {
366 trace!("stderr: {line}");
367 }
368
369 let status = channel
371 .exit_status()
372 .map_err(Error::SSH2)
373 .context("getting the exit status of the command")?;
374
375 trace!("closing the client's end of the channel");
378 channel
379 .close()
380 .map_err(Error::SSH2)
381 .context("closing the SSH channel")?;
382
383 trace!("waiting for the remote host to close their end of the channel");
385 channel
386 .wait_close()
387 .map_err(Error::SSH2)
388 .context("waiting for the SSH channel to be closed from the client's end")?;
389
390 #[cfg(unix)]
391 let output = Output {
392 status: ExitStatus::from_raw(status << 8),
394 stdout,
395 stderr,
396 };
397
398 #[cfg(windows)]
399 let output = Output {
400 status: ExitStatus::from_raw(status as u32),
401 stdout,
402 stderr,
403 };
404
405 Ok(output)
406 };
407
408 tokio::task::spawn_blocking(f)
409 .await
410 .map_err(Error::Join)
411 .context("running an SSH command")?
412}