remowt-client 0.1.2

russh-based client connection to a remowt agent
Documentation
use anyhow::{anyhow, bail, Result};
use camino::Utf8PathBuf;
use remowt_endpoints::subprocess::{ProcId, SpawnSpec, StderrSpec, StdioSpec, SubprocessClient};
use remowt_link_shared::BifConfig;
use tokio::io::{AsyncBufReadExt as _, AsyncWriteExt as _, BufReader};
use tracing::{info, warn};

use crate::forwarded::{RemowtListener, RemowtStream};
use crate::Remowt;

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum StdioMode {
	#[default]
	Null,
	Pipe,
	Inherit,
}

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum StderrMode {
	#[default]
	Null,
	Pipe,
	Inherit,
	MergeWithStdout,
}

#[derive(Default)]
pub struct SpawnOptions {
	pub program: String,
	pub args: Vec<String>,
	pub env: Vec<(String, String)>,
	pub env_clear: bool,
	pub cwd: Option<Utf8PathBuf>,
	pub escalated: bool,
	pub stdin: StdioMode,
	pub stdout: StdioMode,
	pub stderr: StderrMode,
}

pub struct RemowtChild {
	pub stdin: Option<RemowtStream>,
	pub stdout: Option<RemowtStream>,
	pub stderr: Option<RemowtStream>,
	id: ProcId,
	client: SubprocessClient<BifConfig>,
}

impl RemowtChild {
	pub async fn wait(self) -> Result<Option<i32>> {
		let RemowtChild {
			stdin,
			stdout,
			stderr,
			id,
			client,
		} = self;
		drop(stdin);
		drop(stdout);
		drop(stderr);
		client
			.wait(id)
			.await?
			.map_err(|e| anyhow!("agent wait failed: {e}"))
	}

	pub async fn kill(&self, signal: i32) -> Result<()> {
		self.client
			.kill(self.id, signal)
			.await?
			.map_err(|e| anyhow!("agent kill failed: {e}"))
	}
}

fn needs_socket(m: StdioMode) -> bool {
	matches!(m, StdioMode::Pipe | StdioMode::Inherit)
}

fn stderr_needs_socket(m: StderrMode) -> bool {
	matches!(m, StderrMode::Pipe | StderrMode::Inherit)
}

impl Remowt {
	/// Spawn a process on the remote machine (or locally, when this is a local
	/// connection) and return a handle to its stdio plus a way to wait/kill.
	pub async fn spawn(&self, opts: SpawnOptions) -> Result<RemowtChild> {
		let SpawnOptions {
			program,
			args,
			env,
			env_clear,
			cwd,
			escalated,
			stdin,
			stdout,
			stderr,
		} = opts;

		if matches!(stderr, StderrMode::MergeWithStdout) && !needs_socket(stdout) {
			bail!("stderr=MergeWithStdout requires stdout=Pipe or Inherit");
		}

		let stdin_bound = if needs_socket(stdin) {
			Some(self.bind_runtime_unix("proc-stdin").await?)
		} else {
			None
		};
		let stdout_bound = if needs_socket(stdout) {
			Some(self.bind_runtime_unix("proc-stdout").await?)
		} else {
			None
		};
		let stderr_bound = if stderr_needs_socket(stderr) {
			Some(self.bind_runtime_unix("proc-stderr").await?)
		} else {
			None
		};

		let stdin_spec = match &stdin_bound {
			Some((_, p)) => StdioSpec::Socket(p.clone()),
			None => StdioSpec::Null,
		};
		let stdout_spec = match &stdout_bound {
			Some((_, p)) => StdioSpec::Socket(p.clone()),
			None => StdioSpec::Null,
		};
		let stderr_spec = match (&stderr, &stderr_bound) {
			(StderrMode::Pipe | StderrMode::Inherit, Some((_, p))) => StderrSpec::Socket(p.clone()),
			(StderrMode::MergeWithStdout, _) => StderrSpec::MergeWithStdout,
			_ => StderrSpec::Null,
		};

		let client: SubprocessClient<BifConfig> = if escalated {
			// Boxed to break the async-fn type cycle
			Box::pin(self.run0_endpoints::<SubprocessClient<BifConfig>>()).await?
		} else {
			self.endpoints()
		};

		let spec = SpawnSpec {
			program: program.clone(),
			args,
			env,
			env_clear,
			cwd,
			stdin: stdin_spec,
			stdout: stdout_spec,
			stderr: stderr_spec,
		};
		let id = client
			.spawn(spec)
			.await?
			.map_err(|e| anyhow!("agent spawn failed: {e}"))?;

		let (stdin_res, stdout_res, stderr_res) = tokio::join!(
			accept(stdin_bound),
			accept(stdout_bound),
			accept(stderr_bound),
		);

		let stdin_stream = handle_stdin(stdin, stdin_res?, &program);
		let stdout_stream = handle_output(stdout, stdout_res?, &program, false);
		let stderr_stream = handle_output_err(stderr, stderr_res?, &program);

		Ok(RemowtChild {
			stdin: stdin_stream,
			stdout: stdout_stream,
			stderr: stderr_stream,
			id,
			client,
		})
	}
}

async fn accept(b: Option<(RemowtListener, Utf8PathBuf)>) -> Result<Option<RemowtStream>> {
	match b {
		Some((l, _)) => Ok(Some(l.accept().await?)),
		None => Ok(None),
	}
}

fn handle_stdin(mode: StdioMode, s: Option<RemowtStream>, program: &str) -> Option<RemowtStream> {
	match mode {
		StdioMode::Pipe => s,
		StdioMode::Inherit => {
			if let Some(s) = s {
				let program = program.to_owned();
				tokio::spawn(async move {
					let mut stdin = tokio::io::stdin();
					let mut s = s;
					if let Err(e) = tokio::io::copy(&mut stdin, &mut s).await {
						warn!(program, "stdin forward ended: {e}");
					}
					let _ = s.shutdown().await;
				});
			}
			None
		}
		StdioMode::Null => None,
	}
}

fn handle_output(
	mode: StdioMode,
	s: Option<RemowtStream>,
	program: &str,
	is_stderr: bool,
) -> Option<RemowtStream> {
	match mode {
		StdioMode::Pipe => s,
		StdioMode::Inherit => {
			if let Some(s) = s {
				let program = program.to_owned();
				tokio::spawn(pump_to_tracing(s, program, is_stderr));
			}
			None
		}
		StdioMode::Null => None,
	}
}

fn handle_output_err(
	mode: StderrMode,
	s: Option<RemowtStream>,
	program: &str,
) -> Option<RemowtStream> {
	match mode {
		StderrMode::Pipe => s,
		StderrMode::Inherit => {
			if let Some(s) = s {
				let program = program.to_owned();
				tokio::spawn(pump_to_tracing(s, program, true));
			}
			None
		}
		StderrMode::MergeWithStdout | StderrMode::Null => None,
	}
}

async fn pump_to_tracing(stream: RemowtStream, program: String, is_stderr: bool) {
	let mut reader = BufReader::new(stream).lines();
	loop {
		match reader.next_line().await {
			Ok(Some(line)) => {
				if is_stderr {
					warn!(program, "{line}");
				} else {
					info!(program, "{line}");
				}
			}
			Ok(None) => break,
			Err(e) => {
				warn!(program, "child log stream error: {e}");
				break;
			}
		}
	}
}