use std::collections::HashMap;
use std::io;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, bail, ensure, Context as _, Result};
use bifrostlink::declarative::RemoteEndpoints;
use bifrostlink::{Port, Remote, Rpc, Rtt, WeakRpc};
use bifrostlink_ports::unix_socket::from_socket;
use bytes::{Bytes, BytesMut};
use camino::{Utf8Path, Utf8PathBuf};
use remowt_link_shared::plugin::PluginEndpointsClient;
use remowt_link_shared::{
Address, BifConfig, ElevateEndpoints, ElevateError, Elevator, Fs, Pty, PtyClient, ShellId,
Systemd,
};
use russh::client::{connect, Config, Handle, Handler, Msg, Session};
use russh::keys::agent::client::AgentClient;
use russh::keys::agent::AgentIdentity;
use russh::keys::check_known_hosts;
use russh::keys::ssh_key::PublicKey;
use russh::{Channel, ChannelMsg, ChannelStream};
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _, DuplexStream, ReadHalf, WriteHalf};
use tokio::join;
use tokio::net::UnixListener;
use tokio::sync::mpsc;
use tokio::sync::oneshot::{self, channel};
use tracing::error;
use uuid::Uuid;
pub mod editor;
type Subs = Arc<Mutex<HashMap<Utf8PathBuf, oneshot::Sender<Channel<Msg>>>>>;
async fn read(srx: &mut ReadHalf<ChannelStream<Msg>>) -> io::Result<BytesMut> {
let len = srx.read_u32().await?;
let mut buf = BytesMut::zeroed(len as usize);
srx.read_exact(&mut buf).await?;
Ok(buf)
}
async fn write(stx: &mut WriteHalf<ChannelStream<Msg>>, value: Bytes) -> io::Result<()> {
stx.write_u32(value.len().try_into().expect("can't be larger"))
.await?;
stx.write_all(&value).await?;
Ok(())
}
fn sh_quote(s: impl AsRef<str>) -> String {
format!("'{}'", s.as_ref().replace('\'', "'\\''"))
}
const ESCALATORS: [(&str, &[&str]); 3] = [
("run0", &["--background=", "--pipe"]),
("sudo", &[]),
("doas", &[]),
];
pub struct AgentBundle {
dir: PathBuf,
hashes: HashMap<String, String>,
}
impl AgentBundle {
pub fn from_dir(dir: impl Into<PathBuf>) -> Result<Self> {
let dir = dir.into();
let hashes_path = dir.join("hashes");
let raw = std::fs::read_to_string(&hashes_path)
.with_context(|| format!("reading agent hashes at {}", hashes_path.display()))?;
let mut hashes = HashMap::new();
for line in raw.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let (arch, hash) = line
.split_once(char::is_whitespace)
.ok_or_else(|| anyhow!("malformed hashes line: {line:?}"))?;
hashes.insert(arch.to_owned(), hash.trim().to_owned());
}
ensure!(
!hashes.is_empty(),
"agent bundle {} has no hashes",
dir.display()
);
Ok(Self { dir, hashes })
}
fn binary(&self, arch: &str) -> PathBuf {
self.dir.join(format!("remowt-agent-{arch}"))
}
}
async fn run(sess: &Handle<SshHandler>, cmd: &str) -> Result<(Option<u32>, Vec<u8>)> {
let mut ch = sess.channel_open_session().await?;
ch.exec(true, cmd).await?;
let mut out = Vec::new();
let mut code = None;
while let Some(msg) = ch.wait().await {
match msg {
ChannelMsg::Data { data } => out.extend(data.as_ref()),
ChannelMsg::ExtendedData { data, .. } => {
error!(
"remote stderr: {}",
String::from_utf8_lossy(data.as_ref()).trim()
);
}
ChannelMsg::ExitStatus { exit_status } => code = Some(exit_status),
_ => {}
}
}
Ok((code, out))
}
async fn run_string_ok(sess: &Handle<SshHandler>, cmd: &str) -> Result<String> {
let (code, mut out) = run(sess, cmd).await?;
ensure!(
code == Some(0),
"remote command failed (exit {code:?}): {cmd}"
);
ensure!(out.ends_with(b"\n"));
out.pop();
String::from_utf8(out).context("expected utf8 output for command")
}
async fn deploy_agent(sess: &Handle<SshHandler>, bundle: &AgentBundle) -> Result<Utf8PathBuf> {
let arch = run_string_ok(sess, "uname -m").await?;
let hash = bundle
.hashes
.get(&arch)
.ok_or_else(|| anyhow!("no remowt-agent build for remote arch {arch:?}"))?;
let cache = run_string_ok(sess, "echo \"$XDG_CACHE_HOME\"")
.await?
.trim()
.to_owned();
let dir = if cache.is_empty() {
let home = run_string_ok(sess, "echo \"$HOME\"").await?;
ensure!(
!home.is_empty(),
"remote $HOME and $XDG_CACHE_HOME both empty"
);
Utf8PathBuf::from(home).join("cache/remowt")
} else {
Utf8PathBuf::from(cache).join("remowt")
};
let path = dir.join(hash);
let (present, _) = run(sess, &format!("test -x {}", sh_quote(&path))).await?;
if present != Some(0) {
let bin = bundle.binary(&arch);
let bytes = std::fs::read(&bin)
.with_context(|| format!("reading agent binary {}", bin.display()))?;
upload_agent(sess, &dir, &path, bytes).await?;
}
Ok(path)
}
async fn upload_agent(
sess: &Handle<SshHandler>,
dir: &Utf8Path,
path: &Utf8Path,
bytes: Vec<u8>,
) -> Result<()> {
run_string_ok(sess, &format!("mkdir -p {}", sh_quote(dir))).await?;
let tmp = path.join(format!("tmp.{}", Uuid::new_v4()));
let ch = sess.channel_open_session().await?;
ch.exec(true, format!("cat > {}", sh_quote(&tmp))).await?;
ch.data_bytes(bytes).await?;
ch.eof().await?;
let mut ch = ch;
let mut code = None;
while let Some(msg) = ch.wait().await {
match msg {
ChannelMsg::ExitStatus { exit_status } => code = Some(exit_status),
ChannelMsg::ExtendedData { data, .. } => {
error!(
"agent upload: {}",
String::from_utf8_lossy(data.as_ref()).trim()
);
}
_ => {}
}
}
ensure!(code == Some(0), "agent upload failed (exit {code:?})");
run_string_ok(sess, &format!("chmod 0755 {}", sh_quote(&tmp))).await?;
run_string_ok(
sess,
&format!("mv -f {} {}", sh_quote(&tmp), sh_quote(path)),
)
.await?;
Ok(())
}
async fn detect_escalation(
sess: &Handle<SshHandler>,
) -> Result<(&'static str, &'static [&'static str])> {
for (tool, flags) in ESCALATORS {
let (code, _) = run(sess, &format!("command -v {tool}")).await?;
if code == Some(0) {
return Ok((tool, flags));
}
}
bail!("no escalation tool (run0/sudo/doas) found on remote")
}
fn privileged_cmd(tool: &str, flags: &[&str], agent_path: &Utf8Path, path: Option<&str>) -> String {
let mut parts = vec![tool.to_owned()];
parts.extend(flags.iter().map(|f| f.to_string()));
parts.push(sh_quote(agent_path));
parts.push("real-agent".to_owned());
parts.push("--privileged".to_owned());
if let Some(p) = path {
parts.push("--path".to_owned());
parts.push(sh_quote(p));
}
parts.join(" ")
}
fn find_in_path(name: &str) -> Option<std::path::PathBuf> {
let path = std::env::var_os("PATH")?;
std::env::split_paths(&path)
.map(|dir| dir.join(name))
.find(|p| p.is_file())
}
fn port_from_channel(ch: Channel<Msg>) -> Port {
Port::new(move |mut rx, tx| async move {
let (mut srx, mut stx) = tokio::io::split(ch.into_stream());
let srx_task = async move {
loop {
match read(&mut srx).await {
Ok(buf) => {
if tx.send(buf.freeze()).is_err() {
break;
}
}
Err(e) => {
error!("channel read failed: {e}");
break;
}
}
}
};
let stx_task = async move {
while let Some(value) = rx.recv().await {
if let Err(e) = write(&mut stx, value).await {
error!("channel write failed: {e}");
break;
}
}
};
join!(srx_task, stx_task);
})
}
pub struct SshHandler {
host: String,
port: u16,
subs: Subs,
}
impl Handler for SshHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
server_public_key: &PublicKey,
) -> Result<bool, Self::Error> {
Ok(check_known_hosts(&self.host, self.port, server_public_key)?)
}
async fn server_channel_open_forwarded_streamlocal(
&mut self,
channel: Channel<Msg>,
socket_path: &str,
_session: &mut Session,
) -> Result<(), Self::Error> {
let Some(ch) = self
.subs
.lock()
.expect("lock")
.remove(&Utf8PathBuf::from(socket_path))
else {
return Err(russh::Error::WrongChannel);
};
let _ = ch.send(channel);
Ok(())
}
}
struct SshElevator {
sess: Arc<Handle<SshHandler>>,
rpc: WeakRpc<BifConfig>,
agent_path: Utf8PathBuf,
}
impl Elevator for SshElevator {
async fn elevate(&self) -> Result<(), ElevateError> {
let fail = |e: String| ElevateError::Failed(e);
let (tool, flags) = detect_escalation(&self.sess)
.await
.map_err(|e| fail(e.to_string()))?;
let ch = self
.sess
.channel_open_session()
.await
.map_err(|e| fail(e.to_string()))?;
ch.exec(true, privileged_cmd(tool, flags, &self.agent_path, None))
.await
.map_err(|e| fail(e.to_string()))?;
let rpc = self
.rpc
.clone()
.upgrade()
.ok_or_else(|| fail("rpc is gone".to_owned()))?;
rpc.add_direct(Address::AgentPrivileged, port_from_channel(ch), Rtt(0));
Ok(())
}
}
pub struct RemoteChild {
pub stdout: DuplexStream,
pub stderr: DuplexStream,
pub exit: oneshot::Receiver<Option<u32>>,
}
enum Transport {
Ssh {
sess: Arc<Handle<SshHandler>>,
subs: Subs,
remote_dir: Utf8PathBuf,
agent_path: Utf8PathBuf,
},
Local {
#[allow(dead_code)]
agent: Rpc<BifConfig>,
agent_path: String,
},
}
pub struct Remowt {
transport: Transport,
rpc: Rpc<BifConfig>,
elevated: tokio::sync::OnceCell<()>,
children: Mutex<Vec<tokio::process::Child>>,
}
pub type RemowtRemote = Remote<BifConfig>;
fn loopback() -> (Port, Port) {
let (a2b_tx, mut a2b_rx) = mpsc::unbounded_channel::<Bytes>();
let (b2a_tx, mut b2a_rx) = mpsc::unbounded_channel::<Bytes>();
let user = Port::new(move |mut rx, tx| async move {
loop {
tokio::select! {
msg = rx.recv() => match msg {
Some(msg) => if a2b_tx.send(msg).is_err() { break },
None => break,
},
msg = b2a_rx.recv() => match msg {
Some(msg) => if tx.send(msg).is_err() { break },
None => break,
},
}
}
});
let agent = Port::new(move |mut rx, tx| async move {
loop {
tokio::select! {
msg = rx.recv() => match msg {
Some(msg) => if b2a_tx.send(msg).is_err() { break },
None => break,
},
msg = a2b_rx.recv() => match msg {
Some(msg) => if tx.send(msg).is_err() { break },
None => break,
},
}
}
});
(user, agent)
}
impl Remowt {
pub async fn connect(host: &str, bundle: &AgentBundle) -> Result<Self> {
let conf = russh_config::parse_home(host)?;
let port = conf.host_config.port.unwrap_or(22);
let hostname = conf
.host_config
.hostname
.clone()
.unwrap_or_else(|| conf.host_name.clone());
let user = conf
.user
.clone()
.unwrap_or_else(|| std::env::var("USER").unwrap_or_else(|_| "root".to_owned()));
let subs: Subs = Arc::new(Mutex::new(HashMap::new()));
let mut sess = connect(
Arc::new(Config::default()),
(hostname.clone(), port),
SshHandler {
host: hostname,
port,
subs: subs.clone(),
},
)
.await?;
let mut agent = AgentClient::connect_env().await?;
let rsa_hash = sess.best_supported_rsa_hash().await?.flatten();
let mut authenticated = false;
for ident in agent.request_identities().await? {
let AgentIdentity::PublicKey { key, .. } = ident else {
continue;
};
if sess
.authenticate_publickey_with(user.clone(), key, rsa_hash, &mut agent)
.await?
.success()
{
authenticated = true;
break;
}
}
ensure!(authenticated, "ssh authentication failed");
let sess = Arc::new(sess);
let agent_path = deploy_agent(&sess, bundle).await?;
let remote_dir = remote_mktemp(&sess).await?;
let primary = remote_dir.join("primary.sock");
let (onetx, onerx) = channel();
subs.lock().expect("lock").insert(primary.clone(), onetx);
sess.streamlocal_forward(primary.clone()).await?;
let rpc = Rpc::<BifConfig>::new(Address::User);
let cmd_chan = sess.channel_open_session().await?;
cmd_chan
.exec(
true,
format!(
"{} real-agent --path={}",
sh_quote(&agent_path),
sh_quote(&primary)
),
)
.await?;
let port = port_from_channel(
onerx
.await
.map_err(|_| anyhow!("agent never opened its channel"))?,
);
rpc.add_direct(Address::Agent, port, Rtt(0));
Ok(Self {
transport: Transport::Ssh {
sess,
subs,
remote_dir,
agent_path,
},
rpc,
elevated: tokio::sync::OnceCell::new(),
children: Mutex::new(Vec::new()),
})
}
pub async fn connect_local(agent_path: &str) -> Result<Self> {
let (port_user, port_agent) = loopback();
let rpc = Rpc::<BifConfig>::new(Address::User);
let mut agent = Rpc::<BifConfig>::new(Address::Agent);
Fs::new().register_endpoints(&mut agent);
Systemd.register_endpoints(&mut agent);
Pty::new().register_endpoints(&mut agent);
agent.add_direct(Address::User, port_agent, Rtt(0));
rpc.add_direct(Address::Agent, port_user, Rtt(0));
Ok(Self {
transport: Transport::Local {
agent,
agent_path: agent_path.to_owned(),
},
rpc,
elevated: tokio::sync::OnceCell::new(),
children: Mutex::new(Vec::new()),
})
}
pub fn ssh(&self) -> Option<Arc<Handle<SshHandler>>> {
match &self.transport {
Transport::Ssh { sess, .. } => Some(sess.clone()),
Transport::Local { .. } => None,
}
}
pub fn rpc(&self) -> Rpc<BifConfig> {
self.rpc.clone()
}
pub fn endpoints<R: RemoteEndpoints<BifConfig>>(&self) -> R {
R::wrap(self.rpc.remote(Address::Agent))
}
pub async fn load_plugin(&self, id: u16, name: &str) -> Result<()> {
let client: PluginEndpointsClient<BifConfig> = self.endpoints();
client
.load_plugin(id, name.to_owned())
.await?
.map_err(|e| anyhow!("agent failed to load plugin: {e}"))
}
pub async fn run0_load_plugin_path(&self, id: u16, path: &str) -> Result<()> {
self.ensure_elevated().await?;
let client: PluginEndpointsClient<BifConfig> =
PluginEndpointsClient::wrap(self.rpc.remote(Address::AgentPrivileged));
client
.load_plugin_path(id, path.to_owned())
.await?
.map_err(|e| anyhow!("privileged agent failed to load plugin: {e}"))
}
pub fn plugin_endpoints<R: RemoteEndpoints<BifConfig>>(&self, id: u16) -> R {
R::wrap(self.rpc.remote(Address::Plugin(id)))
}
pub async fn run0_endpoints<R: RemoteEndpoints<BifConfig>>(&self) -> Result<R> {
self.ensure_elevated().await?;
Ok(R::wrap(self.rpc.remote(Address::AgentPrivileged)))
}
async fn ensure_elevated(&self) -> Result<()> {
self.elevated
.get_or_try_init(|| async {
let port = match &self.transport {
Transport::Ssh {
sess, agent_path, ..
} => {
let (tool, flags) = detect_escalation(sess).await?;
let ch = sess.channel_open_session().await?;
ch.exec(true, privileged_cmd(tool, flags, agent_path, None))
.await?;
port_from_channel(ch)
}
Transport::Local { agent_path, .. } => {
let sock = std::env::temp_dir()
.join(format!("remowt-priv-{}.sock", uuid::Uuid::new_v4()));
let _ = std::fs::remove_file(&sock);
let listener = UnixListener::bind(&sock)?;
let (tool, flags) = ESCALATORS
.iter()
.find(|(t, _)| find_in_path(t).is_some())
.ok_or_else(|| anyhow!("no escalation tool (run0/sudo/doas) found"))?;
let child = tokio::process::Command::new(tool)
.args(*flags)
.arg(agent_path)
.arg("real-agent")
.arg("--privileged")
.arg("--path")
.arg(sock.to_str().expect("temp path is utf-8"))
.kill_on_drop(true)
.spawn()?;
self.children.lock().expect("lock").push(child);
let (stream, _) = listener.accept().await?;
let _ = std::fs::remove_file(&sock);
from_socket(stream)
}
};
self.rpc.add_direct(Address::AgentPrivileged, port, Rtt(0));
anyhow::Ok(())
})
.await?;
Ok(())
}
pub async fn exec(&self, command: String) -> Result<RemoteChild> {
let Some(sess) = self.ssh() else {
bail!("exec should not be called on local")
};
let ch = sess.channel_open_session().await?;
ch.exec(true, command).await?;
let (mut out_w, stdout) = tokio::io::duplex(64 * 1024);
let (mut err_w, stderr) = tokio::io::duplex(64 * 1024);
let (exit_tx, exit) = oneshot::channel();
tokio::spawn(async move {
let mut ch = ch;
let mut code = None;
while let Some(msg) = ch.wait().await {
match msg {
ChannelMsg::Data { data } => {
if out_w.write_all(&data).await.is_err() {
break;
}
}
ChannelMsg::ExtendedData { data, .. } => {
if err_w.write_all(&data).await.is_err() {
break;
}
}
ChannelMsg::ExitStatus { exit_status } => code = Some(exit_status),
_ => {}
}
}
let _ = out_w.shutdown().await;
let _ = err_w.shutdown().await;
let _ = exit_tx.send(code);
});
Ok(RemoteChild {
stdout,
stderr,
exit,
})
}
pub fn serve_elevate(&self) -> Result<()> {
let Transport::Ssh {
sess, agent_path, ..
} = &self.transport
else {
bail!("elevate should not be called on local")
};
let mut rpc = self.rpc.clone();
ElevateEndpoints(SshElevator {
sess: sess.clone(),
rpc: self.rpc.clone().downgrade(),
agent_path: agent_path.to_owned(),
})
.register_endpoints(&mut rpc);
Ok(())
}
pub fn remote_dir(&self) -> Option<&Utf8Path> {
match &self.transport {
Transport::Ssh { remote_dir, .. } => Some(remote_dir),
Transport::Local { .. } => None,
}
}
pub async fn forward_socket(
&self,
remote_path: &Utf8Path,
) -> Result<oneshot::Receiver<Channel<Msg>>> {
let Transport::Ssh { sess, subs, .. } = &self.transport else {
bail!("forward_socket should not be called on local")
};
let (tx, rx) = oneshot::channel();
subs.lock()
.expect("lock")
.insert(remote_path.to_owned(), tx);
sess.streamlocal_forward(remote_path.to_owned()).await?;
Ok(rx)
}
pub async fn open_shell(&self, term: &str, cols: u16, rows: u16) -> Result<Shell> {
let Transport::Ssh { remote_dir, .. } = &self.transport else {
bail!("open_shell should not be called on local")
};
let sock = remote_dir.join(format!("shell-{}.sock", uuid::Uuid::new_v4()));
let rx = self.forward_socket(&sock).await?;
let client: PtyClient<BifConfig> = self.endpoints();
let id = client
.open_shell(sock, term.to_owned(), cols, rows)
.await?
.map_err(|e| anyhow!("agent failed to open shell: {e}"))?;
let ch = rx
.await
.map_err(|_| anyhow!("agent never connected the shell socket"))?;
Ok(Shell {
id,
stream: ch.into_stream(),
remote: self.rpc.remote(Address::Agent),
})
}
}
pub struct Shell {
pub id: ShellId,
pub stream: ChannelStream<Msg>,
remote: Remote<BifConfig>,
}
impl Shell {
pub fn resizer(&self) -> ShellResizer {
ShellResizer {
remote: self.remote.clone(),
id: self.id,
}
}
}
#[derive(Clone)]
pub struct ShellResizer {
remote: Remote<BifConfig>,
id: ShellId,
}
impl ShellResizer {
pub async fn resize(&self, cols: u16, rows: u16) -> Result<()> {
PtyClient::wrap(self.remote.clone())
.resize(self.id, cols, rows)
.await?
.map_err(|e| anyhow!("failed to resize remote shell: {e}"))
}
}
async fn remote_mktemp(sess: &Handle<SshHandler>) -> Result<Utf8PathBuf> {
let mut cmd_chan = sess.channel_open_session().await?;
cmd_chan
.exec(true, "mktemp -d remowt.XXXXXXXXXXXX --tmpdir")
.await?;
let mut stdout = vec![];
loop {
let Some(msg) = cmd_chan.wait().await else {
bail!("unexpected channel end");
};
match msg {
russh::ChannelMsg::Data { data } => stdout.extend(data.as_ref()),
russh::ChannelMsg::ExitStatus { exit_status } => {
if exit_status != 0 {
bail!("mktemp failed");
}
break;
}
_ => {}
}
}
ensure!(stdout.ends_with(b"\n"));
stdout.pop();
Ok(Utf8PathBuf::from(String::from_utf8(stdout)?))
}