use std::{path::PathBuf, sync::Arc};
use nix::unistd::{Gid, Uid, User};
use pty_process::{OwnedWritePty, Size};
use russh::{ChannelId, Sig, server::Handle};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::Mutex,
};
use crate::{
Device,
ssh::{ChannelEvent, ChannelHandler, SshAccept},
};
const DEFAULT_SHELL: &str = "/bin/sh";
const DEFAULT_PATH: &str = "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin";
#[derive(Debug, Clone)]
struct ResolvedUser {
name: String,
uid: Uid,
gid: Gid,
home: PathBuf,
shell: PathBuf,
}
fn resolve_user(local_user: &str) -> std::io::Result<ResolvedUser> {
match User::from_name(local_user) {
Ok(Some(user)) => {
let shell = if user.shell.as_os_str().is_empty() {
PathBuf::from(DEFAULT_SHELL)
} else {
user.shell
};
Ok(ResolvedUser {
name: user.name,
uid: user.uid,
gid: user.gid,
home: user.dir,
shell,
})
}
Ok(None) => Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("ssh: local user {local_user:?} not found in passwd database"),
)),
Err(e) => Err(std::io::Error::other(format!(
"ssh: resolving local user {local_user:?} failed: {e}"
))),
}
}
fn build_env(user: &ResolvedUser) -> Vec<(String, String)> {
vec![
("HOME".to_string(), user.home.to_string_lossy().into_owned()),
("USER".to_string(), user.name.clone()),
("LOGNAME".to_string(), user.name.clone()),
(
"SHELL".to_string(),
user.shell.to_string_lossy().into_owned(),
),
("PATH".to_string(), DEFAULT_PATH.to_string()),
("TERM".to_string(), "xterm-256color".to_string()),
]
}
const LOGIN_SHELL_ARG: &str = "-l";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PrivDropStep {
InitGroups(Gid),
SetGid(Gid),
SetUid(Uid),
}
fn priv_drop_plan(uid: Uid, gid: Gid, with_initgroups: bool) -> Vec<PrivDropStep> {
let mut plan = Vec::with_capacity(3);
if with_initgroups {
plan.push(PrivDropStep::InitGroups(gid));
}
plan.push(PrivDropStep::SetGid(gid));
plan.push(PrivDropStep::SetUid(uid));
plan
}
fn apply_priv_drop_step(
step: &PrivDropStep,
user_cname: Option<&std::ffi::CStr>,
) -> std::io::Result<()> {
match step {
PrivDropStep::InitGroups(gid) => {
#[cfg(not(target_vendor = "apple"))]
{
let cname = user_cname.ok_or_else(|| {
std::io::Error::other("ssh: initgroups step without user name")
})?;
nix::unistd::initgroups(cname, *gid)
.map_err(|e| std::io::Error::from_raw_os_error(e as i32))?;
}
#[cfg(target_vendor = "apple")]
{
let _ = (gid, user_cname);
}
}
PrivDropStep::SetGid(gid) => {
nix::unistd::setgid(*gid).map_err(|e| std::io::Error::from_raw_os_error(e as i32))?;
}
PrivDropStep::SetUid(uid) => {
nix::unistd::setuid(*uid).map_err(|e| std::io::Error::from_raw_os_error(e as i32))?;
}
}
Ok(())
}
pub struct ShellHandler {
channel_id: ChannelId,
pty_write: OwnedWritePty,
child: Arc<Mutex<tokio::process::Child>>,
}
impl ShellHandler {
async fn signal_child(&self, signum: i32) {
let pid = { self.child.lock().await.id() };
let Some(pid) = pid else {
return;
};
let Ok(signal) = nix::sys::signal::Signal::try_from(signum) else {
tracing::debug!(signum, "ssh: unmapped signal; not forwarding");
return;
};
if let Err(e) =
nix::sys::signal::kill(nix::unistd::Pid::from_raw(pid as nix::libc::pid_t), signal)
{
tracing::debug!(error = %e, signum, "ssh: failed forwarding signal to shell");
}
}
async fn kill_child(&self) {
let mut child = self.child.lock().await;
if let Err(e) = child.start_kill() {
tracing::debug!(error = %e, "ssh: failed to kill shell child");
}
}
}
fn sig_to_signum(sig: &Sig) -> Option<i32> {
Some(match sig {
Sig::HUP => nix::libc::SIGHUP,
Sig::INT => nix::libc::SIGINT,
Sig::QUIT => nix::libc::SIGQUIT,
Sig::KILL => nix::libc::SIGKILL,
Sig::TERM => nix::libc::SIGTERM,
_ => return None,
})
}
impl ChannelHandler for ShellHandler {
type Error = std::io::Error;
fn new(
rt: tokio::runtime::Handle,
channel_id: ChannelId,
session: Handle,
_dev: Arc<Device>,
accept: &SshAccept,
) -> Result<Self, Self::Error> {
let user = resolve_user(&accept.local_user)?;
let env = build_env(&user);
let (pty, pts) = pty_process::open().map_err(std::io::Error::other)?;
#[cfg(not(target_vendor = "apple"))]
let with_initgroups = true;
#[cfg(target_vendor = "apple")]
let with_initgroups = false;
let plan = priv_drop_plan(user.uid, user.gid, with_initgroups);
#[cfg(not(target_vendor = "apple"))]
let user_cname = std::ffi::CString::new(user.name.clone())
.map_err(|e| std::io::Error::other(format!("ssh: user name has NUL byte: {e}")))?;
let mut cmd = pty_process::Command::new(&user.shell);
cmd = cmd.arg(LOGIN_SHELL_ARG).current_dir(&user.home).env_clear();
for (k, v) in env {
cmd = cmd.env(k, v);
}
cmd = unsafe {
cmd.pre_exec(move || {
#[cfg(not(target_vendor = "apple"))]
let user_cname = Some(user_cname.as_c_str());
#[cfg(target_vendor = "apple")]
let user_cname: Option<&std::ffi::CStr> = None;
for step in &plan {
apply_priv_drop_step(step, user_cname)?;
}
Ok(())
})
};
let child = cmd.spawn(pts).map_err(std::io::Error::other)?;
let (mut pty_read, pty_write) = pty.into_split();
let child = Arc::new(Mutex::new(child));
let pump_child = child.clone();
rt.spawn(async move {
let mut buf = [0u8; 16 * 1024];
loop {
match pty_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
if session.data(channel_id, buf[..n].to_vec()).await.is_err() {
tracing::debug!(%channel_id, "ssh: client gone; stopping shell pump");
break;
}
}
Err(e) => {
tracing::debug!(error = %e, %channel_id, "ssh: pty read error");
break;
}
}
}
let status = { pump_child.lock().await.wait().await };
match status {
Ok(status) => {
use std::os::unix::process::ExitStatusExt as _;
let code = status
.code()
.unwrap_or_else(|| 128 + status.signal().unwrap_or(0))
as u32;
if session.exit_status_request(channel_id, code).await.is_err() {
tracing::debug!(%channel_id, "ssh: failed sending exit-status");
}
}
Err(e) => {
tracing::debug!(error = %e, %channel_id, "ssh: waiting on shell child");
}
}
if session.close(channel_id).await.is_err() {
tracing::trace!(%channel_id, "ssh: channel already closed");
}
});
Ok(Self {
channel_id,
pty_write,
child,
})
}
async fn handle_event(&mut self, event: &ChannelEvent) -> Result<(), Self::Error> {
match event {
ChannelEvent::Data(bytes) => {
self.pty_write.write_all(bytes).await?;
self.pty_write.flush().await?;
}
ChannelEvent::Resize { width, height } => {
if let Err(e) = self.pty_write.resize(Size::new(*height, *width)) {
tracing::debug!(error = %e, channel_id = %self.channel_id, "ssh: pty resize");
}
}
ChannelEvent::Signal(sig) => {
if let Some(signum) = sig_to_signum(sig) {
self.signal_child(signum).await;
} else {
tracing::debug!(?sig, "ssh: unhandled signal; not forwarding");
}
}
ChannelEvent::Close | ChannelEvent::Eof => {
tracing::debug!(channel_id = %self.channel_id, ?event, "ssh: closing shell");
self.kill_child().await;
}
}
Ok(())
}
}
#[cfg(all(test, feature = "ssh"))]
mod tests {
use super::*;
fn fake_user() -> ResolvedUser {
ResolvedUser {
name: "alice".to_string(),
uid: Uid::from_raw(1000),
gid: Gid::from_raw(1000),
home: PathBuf::from("/home/alice"),
shell: PathBuf::from("/bin/bash"),
}
}
#[test]
fn env_is_minimal_and_correct() {
let env = build_env(&fake_user());
let get = |k: &str| {
env.iter()
.find(|(key, _)| key == k)
.map(|(_, v)| v.as_str())
};
assert_eq!(get("HOME"), Some("/home/alice"));
assert_eq!(get("USER"), Some("alice"));
assert_eq!(get("LOGNAME"), Some("alice"));
assert_eq!(get("SHELL"), Some("/bin/bash"));
assert_eq!(get("TERM"), Some("xterm-256color"));
assert_eq!(get("PATH"), Some(DEFAULT_PATH));
assert_eq!(env.len(), 6);
}
#[test]
fn resolve_unknown_user_fails_closed() {
let err = resolve_user("definitely-not-a-real-user-xyz")
.expect_err("bogus user must fail closed");
assert!(matches!(
err.kind(),
std::io::ErrorKind::NotFound | std::io::ErrorKind::Other
));
}
#[test]
fn login_shell_uses_dash_l() {
assert_eq!(LOGIN_SHELL_ARG, "-l");
}
#[test]
fn priv_drop_plan_orders_uid_last() {
let uid = Uid::from_raw(1000);
let gid = Gid::from_raw(1000);
let plan = priv_drop_plan(uid, gid, true);
assert_eq!(
plan,
vec![
PrivDropStep::InitGroups(gid),
PrivDropStep::SetGid(gid),
PrivDropStep::SetUid(uid),
],
"drop sequence must be initgroups → setgid → setuid"
);
assert_eq!(plan.last(), Some(&PrivDropStep::SetUid(uid)));
}
#[test]
fn priv_drop_plan_apple_skips_initgroups() {
let uid = Uid::from_raw(1000);
let gid = Gid::from_raw(1000);
let plan = priv_drop_plan(uid, gid, false);
assert_eq!(
plan,
vec![PrivDropStep::SetGid(gid), PrivDropStep::SetUid(uid)],
);
assert!(!plan.contains(&PrivDropStep::InitGroups(gid)));
assert_eq!(plan.last(), Some(&PrivDropStep::SetUid(uid)));
}
#[test]
fn priv_drop_setgid_before_setuid() {
let uid = Uid::from_raw(1000);
let gid = Gid::from_raw(1000);
for with_initgroups in [true, false] {
let plan = priv_drop_plan(uid, gid, with_initgroups);
let setgid_idx = plan
.iter()
.position(|s| *s == PrivDropStep::SetGid(gid))
.expect("plan must set gid");
let setuid_idx = plan
.iter()
.position(|s| *s == PrivDropStep::SetUid(uid))
.expect("plan must set uid");
assert!(
setgid_idx < setuid_idx,
"setgid must precede setuid (with_initgroups={with_initgroups})"
);
}
}
#[test]
fn empty_shell_falls_back_to_default() {
let mut u = fake_user();
u.shell = PathBuf::from("");
let shell = if u.shell.as_os_str().is_empty() {
PathBuf::from(DEFAULT_SHELL)
} else {
u.shell.clone()
};
assert_eq!(shell, PathBuf::from(DEFAULT_SHELL));
}
}