use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use arc_swap::ArcSwap;
use iroh::EndpointId;
use russh::CryptoVec;
use russh::keys::PrivateKey;
use russh::server::{Auth, Handle, Handler, Msg, Session};
use russh::{Channel, ChannelId, MethodKind, MethodSet};
use smol_str::SmolStr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use crate::peers::{DeviceUserMap, PeerTable};
pub const SSH_PORT: u16 = 22;
pub const SSH_LISTEN_PORT: u16 = 30022;
pub type SshAuthz = Arc<ArcSwap<HashMap<String, Vec<crate::config::SshRule>>>>;
pub fn new_authz() -> SshAuthz {
Arc::new(ArcSwap::from_pointee(HashMap::new()))
}
#[derive(Default, Debug, PartialEq)]
struct UserPolicy {
matched: bool,
any: bool,
nonroot: bool,
users: std::collections::HashSet<String>,
}
impl UserPolicy {
fn add(&mut self, users: &[String]) {
self.matched = true;
if users.iter().any(|u| u == "*") {
self.any = true;
} else if users.is_empty() {
self.nonroot = true;
} else {
self.users.extend(users.iter().cloned());
}
}
fn authorized(&self) -> bool {
self.matched
}
fn permits(&self, name: &str, uid: u32) -> bool {
self.any || self.users.contains(name) || (self.nonroot && uid != 0)
}
}
fn resolve_user_policy(authz: &SshAuthz, user: &EndpointId, networks: &[SmolStr]) -> UserPolicy {
let map = authz.load();
let id = user.to_string();
let mut policy = UserPolicy::default();
for net in networks {
if let Some(rules) = map.get(net.as_str()) {
for rule in rules {
if rule.peer == "*" || rule.peer == id {
policy.add(&rule.users);
}
}
}
}
policy
}
pub struct SshServer {
peers: PeerTable,
device_user_map: DeviceUserMap,
authz: SshAuthz,
}
impl SshServer {
pub fn new(peers: PeerTable, device_user_map: DeviceUserMap, authz: SshAuthz) -> Self {
Self {
peers,
device_user_map,
authz,
}
}
pub fn spawn(self, addrs: Vec<IpAddr>, token: CancellationToken) {
tokio::spawn(async move {
let key = match load_or_generate_host_key() {
Ok(k) => k,
Err(e) => {
warn!(error = %e, "mesh SSH: could not load host key; SSH disabled");
return;
}
};
let config = Arc::new(russh::server::Config {
keys: vec![key],
methods: MethodSet::from(&[MethodKind::None][..]),
inactivity_timeout: Some(Duration::from_secs(3600)),
auth_rejection_time: Duration::from_secs(1),
..Default::default()
});
for addr in addrs {
let listener = match bind_listener(addr, SSH_LISTEN_PORT) {
Ok(l) => l,
Err(e) => {
warn!(%addr, port = SSH_LISTEN_PORT, error = %e, "mesh SSH: cannot bind listener; skipping");
continue;
}
};
info!(%addr, port = SSH_LISTEN_PORT, "mesh SSH listening (reachable as :22)");
let peers = self.peers.clone();
let dum = self.device_user_map.clone();
let authz = self.authz.clone();
let config = config.clone();
let token = token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = token.cancelled() => break,
accepted = listener.accept() => {
let (stream, peer) = match accepted {
Ok(p) => p,
Err(e) => { debug!(error = %e, "mesh SSH accept failed"); continue; }
};
let config = config.clone();
let peers = peers.clone();
let dum = dum.clone();
let authz = authz.clone();
tokio::spawn(async move {
handle_conn(stream, peer, config, peers, dum, authz).await;
});
}
}
}
debug!(%addr, "mesh SSH listener stopped");
});
}
});
}
}
fn bind_listener(ip: IpAddr, port: u16) -> Result<tokio::net::TcpListener> {
use socket2::{Domain, Protocol, Socket, Type};
let domain = if ip.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let sock = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
sock.set_reuse_address(true)?;
#[cfg(unix)]
sock.set_reuse_port(true)?;
sock.set_nonblocking(true)?;
let addr: SocketAddr = (ip, port).into();
sock.bind(&addr.into())?;
sock.listen(128)?;
let std_listener: std::net::TcpListener = sock.into();
Ok(tokio::net::TcpListener::from_std(std_listener)?)
}
async fn handle_conn(
stream: tokio::net::TcpStream,
peer: SocketAddr,
config: Arc<russh::server::Config>,
peers: PeerTable,
device_user_map: DeviceUserMap,
authz: SshAuthz,
) {
let src = peer.ip();
let Some((peer_id, networks)) = peers.identity_and_networks(src) else {
debug!(%src, "mesh SSH: connection from unknown mesh IP, dropping");
return;
};
let user_identity = device_user_map.resolve(&peer_id);
let policy = resolve_user_policy(&authz, &user_identity, &networks);
debug!(%src, peer = %user_identity.fmt_short(), authorized = policy.authorized(), "mesh SSH connection");
let handler = SshHandler::new(policy, user_identity);
match russh::server::run_stream(config, stream, handler).await {
Ok(session) => {
let _ = session.await;
}
Err(e) => debug!(error = %e, "mesh SSH session ended with error"),
}
}
struct PtyReq {
term: String,
col: u16,
row: u16,
}
struct SshHandler {
policy: UserPolicy,
user: EndpointId,
login_user: String,
login: Option<LoginInfo>,
pty: Option<PtyReq>,
channel: Option<Channel<Msg>>,
resize_tx: Option<mpsc::UnboundedSender<pty_process::Size>>,
}
impl SshHandler {
fn new(policy: UserPolicy, user: EndpointId) -> Self {
Self {
policy,
user,
login_user: String::new(),
login: None,
pty: None,
channel: None,
resize_tx: None,
}
}
fn start(&mut self, command: Option<String>, session: &mut Session) {
let Some(channel) = self.channel.take() else {
return;
};
let Some(info) = self.login.take() else {
return;
};
let channel_id = channel.id();
let handle = session.handle();
let login_name = info.name.clone();
let pty = self.pty.take();
let peer = self.user;
let (resize_tx, resize_rx) = mpsc::unbounded_channel();
self.resize_tx = Some(resize_tx);
tokio::spawn(async move {
let result = match pty {
Some(pty_req) => run_pty_session(channel, info, command, pty_req, resize_rx).await,
None => run_pipe_session(channel, handle.clone(), channel_id, info, command).await,
};
let code = match result {
Ok(c) => c,
Err(e) => {
warn!(peer = %peer.fmt_short(), user = %login_name, error = %e, "mesh SSH session failed");
1
}
};
let _ = handle.exit_status_request(channel_id, code).await;
let _ = handle.eof(channel_id).await;
let _ = handle.close(channel_id).await;
});
}
}
impl Handler for SshHandler {
type Error = russh::Error;
async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
self.login_user = user.to_string();
if !self.policy.authorized() {
info!(peer = %self.user.fmt_short(), "mesh SSH: rejecting unauthorized peer");
return Ok(Auth::reject());
}
match resolve_login(user) {
Ok(info) if self.policy.permits(user, info.uid) => {
self.login = Some(info);
Ok(Auth::Accept)
}
Ok(info) => {
info!(peer = %self.user.fmt_short(), user, uid = info.uid,
"mesh SSH: peer not permitted to log in as this user");
Ok(Auth::reject())
}
Err(e) => {
debug!(peer = %self.user.fmt_short(), user, error = %e,
"mesh SSH: requested login user not found");
Ok(Auth::reject())
}
}
}
async fn channel_open_session(
&mut self,
channel: Channel<Msg>,
_session: &mut Session,
) -> Result<bool, Self::Error> {
self.channel = Some(channel);
Ok(true)
}
#[allow(clippy::too_many_arguments)]
async fn pty_request(
&mut self,
channel: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
_pix_width: u32,
_pix_height: u32,
_modes: &[(russh::Pty, u32)],
session: &mut Session,
) -> Result<(), Self::Error> {
self.pty = Some(PtyReq {
term: term.to_string(),
col: col_width as u16,
row: row_height as u16,
});
session.channel_success(channel)?;
Ok(())
}
async fn shell_request(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
self.start(None, session);
session.channel_success(channel)?;
Ok(())
}
async fn exec_request(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> Result<(), Self::Error> {
let cmd = String::from_utf8_lossy(data).to_string();
self.start(Some(cmd), session);
session.channel_success(channel)?;
Ok(())
}
async fn window_change_request(
&mut self,
channel: ChannelId,
col_width: u32,
row_height: u32,
_pix_width: u32,
_pix_height: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
if let Some(tx) = &self.resize_tx {
let _ = tx.send(pty_process::Size::new(row_height as u16, col_width as u16));
}
session.channel_success(channel)?;
Ok(())
}
}
struct LoginInfo {
uid: u32,
gid: u32,
home: PathBuf,
shell: PathBuf,
name: String,
}
fn resolve_login(login_user: &str) -> Result<LoginInfo> {
use uzers::os::unix::UserExt;
let pw = uzers::get_user_by_name(login_user)
.with_context(|| format!("no such local user: {login_user}"))?;
Ok(LoginInfo {
uid: pw.uid(),
gid: pw.primary_group_id(),
home: pw.home_dir().to_path_buf(),
shell: pw.shell().to_path_buf(),
name: pw.name().to_string_lossy().to_string(),
})
}
fn drop_privs(
uid: u32,
gid: u32,
name: &str,
) -> Result<impl FnMut() -> std::io::Result<()> + Send + Sync + 'static> {
let cname = std::ffi::CString::new(name).context("user name contains NUL")?;
Ok(move || {
unsafe {
#[cfg(target_os = "macos")]
let basegroup = gid as libc::c_int;
#[cfg(not(target_os = "macos"))]
let basegroup = gid as libc::gid_t;
if libc::initgroups(cname.as_ptr(), basegroup) != 0 {
return Err(std::io::Error::last_os_error());
}
if libc::setgid(gid as libc::gid_t) != 0 {
return Err(std::io::Error::last_os_error());
}
if libc::setuid(uid as libc::uid_t) != 0 {
return Err(std::io::Error::last_os_error());
}
}
Ok(())
})
}
fn login_env<'a>(home: &Path, shell: &Path, name: &str) -> [(&'a str, std::ffi::OsString); 5] {
[
("HOME", home.into()),
("USER", name.into()),
("LOGNAME", name.into()),
("SHELL", shell.into()),
(
"PATH",
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin".into(),
),
]
}
async fn run_pty_session(
channel: Channel<Msg>,
info: LoginInfo,
command: Option<String>,
pty_req: PtyReq,
mut resize_rx: mpsc::UnboundedReceiver<pty_process::Size>,
) -> Result<u32> {
let drop = drop_privs(info.uid, info.gid, &info.name)?;
let (pty, pts) = pty_process::open().context("opening pty")?;
let _ = pty.resize(pty_process::Size::new(pty_req.row, pty_req.col));
let mut cmd = pty_process::Command::new(&info.shell);
match &command {
Some(c) => cmd = cmd.arg("-c").arg(c),
None => cmd = cmd.arg("-l"),
}
cmd = cmd
.current_dir(&info.home)
.env_clear()
.envs(login_env(&info.home, &info.shell, &info.name))
.env("TERM", &pty_req.term);
cmd = unsafe { cmd.pre_exec(drop) };
let mut child = cmd.spawn(pts).context("spawning login shell")?;
let stream = channel.into_stream();
let (mut chan_read, mut chan_write) = tokio::io::split(stream);
let (mut pty_read, mut pty_write) = pty.into_split();
let c2p = tokio::spawn(async move {
let mut buf = [0u8; 8192];
loop {
tokio::select! {
r = chan_read.read(&mut buf) => match r {
Ok(0) | Err(_) => break,
Ok(n) => {
if pty_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
},
Some(size) = resize_rx.recv() => {
let _ = pty_write.resize(size);
}
}
}
});
let p2c = tokio::spawn(async move {
let _ = tokio::io::copy(&mut pty_read, &mut chan_write).await;
let _ = chan_write.shutdown().await;
});
let status = child.wait().await.context("waiting on child")?;
let _ = p2c.await;
c2p.abort();
Ok(status.code().unwrap_or(0) as u32)
}
async fn run_pipe_session(
channel: Channel<Msg>,
handle: Handle,
channel_id: ChannelId,
info: LoginInfo,
command: Option<String>,
) -> Result<u32> {
let drop = drop_privs(info.uid, info.gid, &info.name)?;
let mut cmd = tokio::process::Command::new(&info.shell);
match &command {
Some(c) => {
cmd.arg("-c").arg(c);
}
None => {
cmd.arg("-l");
}
}
cmd.current_dir(&info.home)
.env_clear()
.envs(login_env(&info.home, &info.shell, &info.name))
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
unsafe {
cmd.pre_exec(drop);
}
let mut child = cmd.spawn().context("spawning command")?;
let mut stdin = child.stdin.take().context("child stdin")?;
let mut stdout = child.stdout.take().context("child stdout")?;
let mut stderr = child.stderr.take().context("child stderr")?;
let stream = channel.into_stream();
let (mut chan_read, _chan_write) = tokio::io::split(stream);
let stdin_task = tokio::spawn(async move {
let _ = tokio::io::copy(&mut chan_read, &mut stdin).await;
});
let h_out = handle.clone();
let out_task = tokio::spawn(async move {
let mut buf = [0u8; 8192];
loop {
match stdout.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
if h_out
.data(channel_id, CryptoVec::from(&buf[..n]))
.await
.is_err()
{
break;
}
}
}
}
});
let h_err = handle.clone();
let err_task = tokio::spawn(async move {
let mut buf = [0u8; 8192];
loop {
match stderr.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
if h_err
.extended_data(channel_id, 1, CryptoVec::from(&buf[..n]))
.await
.is_err()
{
break;
}
}
}
}
});
let status = child.wait().await.context("waiting on child")?;
let _ = out_task.await;
let _ = err_task.await;
stdin_task.abort();
Ok(status.code().unwrap_or(0) as u32)
}
fn load_or_generate_host_key() -> Result<PrivateKey> {
use russh::keys::ssh_key::{LineEnding, rand_core::OsRng};
let path = crate::config::config_dir()?.join("ssh_host_key");
if path.exists() {
let pem = std::fs::read_to_string(&path).context("reading ssh host key")?;
return PrivateKey::from_openssh(&pem).context("parsing ssh host key");
}
let key = PrivateKey::random(&mut OsRng, russh::keys::Algorithm::Ed25519)
.context("generating ssh host key")?;
let pem = key
.to_openssh(LineEnding::LF)
.context("encoding ssh host key")?;
crate::config::write_file(&path, pem.as_bytes(), true)?;
Ok(key)
}
#[cfg(test)]
mod tests {
use super::*;
fn id(seed: u8) -> EndpointId {
let mut b = [0u8; 32];
b[0] = seed;
iroh::SecretKey::from(b).public()
}
fn rule(peer: &str, users: &[&str]) -> crate::config::SshRule {
crate::config::SshRule {
peer: peer.to_string(),
users: users.iter().map(|u| u.to_string()).collect(),
}
}
#[test]
fn authz_matches_identity_and_wildcard_per_network() {
let alice = id(1);
let bob = id(2);
let authz = new_authz();
let mut map = HashMap::new();
map.insert("net1".to_string(), vec![rule(&alice.to_string(), &[])]);
map.insert("net2".to_string(), vec![rule("*", &[])]);
authz.store(Arc::new(map));
let authorized = |u, nets: &[&str]| {
let nets: Vec<SmolStr> = nets.iter().map(SmolStr::new).collect();
resolve_user_policy(&authz, u, &nets).authorized()
};
assert!(authorized(&alice, &["net1"]));
assert!(!authorized(&bob, &["net1"]));
assert!(authorized(&bob, &["net2"]));
assert!(!authorized(&alice, &["net3"]));
assert!(authorized(&alice, &["net3", "net2"]));
}
#[test]
fn user_policy_default_is_nonroot() {
let alice = id(1);
let authz = new_authz();
authz.store(Arc::new(HashMap::from([(
"net".to_string(),
vec![rule(&alice.to_string(), &[])],
)])));
let p = resolve_user_policy(&authz, &alice, &[SmolStr::new("net")]);
assert!(p.permits("deploy", 1000), "non-root user allowed");
assert!(!p.permits("root", 0), "root (uid 0) blocked by default");
assert!(!p.permits("toor", 0), "any uid-0 account blocked, not just 'root'");
}
#[test]
fn user_policy_explicit_and_wildcard() {
let alice = id(1);
let authz = new_authz();
authz.store(Arc::new(HashMap::from([
("net1".to_string(), vec![rule(&alice.to_string(), &["deploy"])]),
("net2".to_string(), vec![rule(&alice.to_string(), &["*"])]),
])));
let p = resolve_user_policy(&authz, &alice, &[SmolStr::new("net1")]);
assert!(p.permits("deploy", 1000));
assert!(!p.permits("ci", 1001));
assert!(!p.permits("root", 0));
let p = resolve_user_policy(&authz, &alice, &[SmolStr::new("net2")]);
assert!(p.permits("root", 0));
let p = resolve_user_policy(&authz, &alice, &[SmolStr::new("net1"), SmolStr::new("net2")]);
assert!(p.permits("root", 0));
assert!(p.permits("anyone", 1234));
}
}