use std::io::{ErrorKind, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::{Arc, Mutex};
use std::thread;
use puressh::auth::ClientCredential;
use puressh::client::{
ChannelStream, Client, ClientHandlers, Config, ForwardedTcpipCallback, ForwardedTcpipOrigin,
ServeContext,
};
#[path = "common.rs"]
mod common;
use common::{
build_host_key_policy, connect_agent_credentials, load_identity, read_password_from_stdin,
resolve_user, StrictMode,
};
const VERSION: &str = env!("CARGO_PKG_VERSION");
const USAGE: &str = "usage: ssh [-p port] [-i identity_file] [-l user] \
[-o StrictHostKeyChecking={yes,no,accept-new,ask}] \
[-o UserKnownHostsFile=PATH] [-o HashKnownHosts={yes,no}] \
[-o IdentitiesOnly={yes,no}] \
[-L LPORT:RHOST:RPORT] [-R RPORT:LHOST:LPORT] \
[-N] [-A] [-X] [-Y] \
[user@]host [command...]";
#[derive(Clone, Debug)]
struct LocalForward {
listen_port: u16,
remote_host: String,
remote_port: u16,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum X11Forward {
Untrusted,
Trusted,
}
#[derive(Clone, Debug)]
struct RemoteForward {
remote_port: u16,
local_host: String,
local_port: u16,
}
struct Cli {
port: u16,
identities: Vec<String>,
cli_user: Option<String>,
strict: StrictMode,
known_hosts_path: Option<PathBuf>,
hash_known_hosts: bool,
identities_only: bool,
locals: Vec<LocalForward>,
remotes: Vec<RemoteForward>,
no_command: bool,
agent_forward: bool,
x11_forward: Option<X11Forward>,
host: String,
user_in_host: Option<String>,
command: Option<String>,
}
fn parse_local_forward(s: &str) -> Result<LocalForward, String> {
let parts: Vec<&str> = s.splitn(3, ':').collect();
if parts.len() != 3 {
return Err(format!("-L expects LPORT:RHOST:RPORT, got {s:?}"));
}
let listen_port: u16 = parts[0]
.parse()
.map_err(|_| format!("-L: invalid LPORT {:?}", parts[0]))?;
let remote_host = parts[1].to_string();
if remote_host.is_empty() {
return Err("-L: RHOST cannot be empty".into());
}
let remote_port: u16 = parts[2]
.parse()
.map_err(|_| format!("-L: invalid RPORT {:?}", parts[2]))?;
Ok(LocalForward {
listen_port,
remote_host,
remote_port,
})
}
fn parse_remote_forward(s: &str) -> Result<RemoteForward, String> {
let parts: Vec<&str> = s.splitn(3, ':').collect();
if parts.len() != 3 {
return Err(format!("-R expects RPORT:LHOST:LPORT, got {s:?}"));
}
let remote_port: u16 = parts[0]
.parse()
.map_err(|_| format!("-R: invalid RPORT {:?}", parts[0]))?;
let local_host = parts[1].to_string();
if local_host.is_empty() {
return Err("-R: LHOST cannot be empty".into());
}
let local_port: u16 = parts[2]
.parse()
.map_err(|_| format!("-R: invalid LPORT {:?}", parts[2]))?;
Ok(RemoteForward {
remote_port,
local_host,
local_port,
})
}
fn parse_args(args: &[String]) -> Result<Cli, String> {
let mut port = 22u16;
let mut identities: Vec<String> = Vec::new();
let mut cli_user: Option<String> = None;
let mut strict = StrictMode::Ask;
let mut known_hosts_path: Option<PathBuf> = None;
let mut hash_known_hosts = false;
let mut identities_only = false;
let mut locals: Vec<LocalForward> = Vec::new();
let mut remotes: Vec<RemoteForward> = Vec::new();
let mut no_command = false;
let mut agent_forward = false;
let mut x11_forward: Option<X11Forward> = None;
let mut positional: Vec<String> = Vec::new();
let mut i = 0;
while i < args.len() {
let a = &args[i];
if a == "--" {
positional.extend_from_slice(&args[i + 1..]);
break;
}
match a.as_str() {
"-p" => {
i += 1;
let v = args.get(i).ok_or("-p requires a value")?;
port = v.parse::<u16>().map_err(|_| "invalid port".to_string())?;
}
"-i" => {
i += 1;
let v = args.get(i).ok_or("-i requires a value")?.clone();
identities.push(v);
}
"-l" => {
i += 1;
let v = args.get(i).ok_or("-l requires a value")?.clone();
cli_user = Some(v);
}
"-L" => {
i += 1;
let v = args.get(i).ok_or("-L requires a value")?;
locals.push(parse_local_forward(v)?);
}
"-R" => {
i += 1;
let v = args.get(i).ok_or("-R requires a value")?;
remotes.push(parse_remote_forward(v)?);
}
"-N" => {
no_command = true;
}
"-A" => {
agent_forward = true;
}
"-X" => {
x11_forward = Some(X11Forward::Untrusted);
}
"-Y" => {
x11_forward = Some(X11Forward::Trusted);
}
"-o" => {
i += 1;
let v = args.get(i).ok_or("-o requires a value")?;
let (k, val) = v
.split_once('=')
.ok_or_else(|| format!("-o expects KEY=VALUE, got {v:?}"))?;
match k.to_ascii_lowercase().as_str() {
"stricthostkeychecking" => {
strict = match val.to_ascii_lowercase().as_str() {
"yes" => StrictMode::Yes,
"no" | "off" => StrictMode::No,
"accept-new" => StrictMode::AcceptNew,
"ask" => StrictMode::Ask,
other => return Err(format!("unknown StrictHostKeyChecking={other}")),
};
}
"userknownhostsfile" => {
known_hosts_path = Some(PathBuf::from(val));
}
"hashknownhosts" => {
hash_known_hosts =
matches!(val.to_ascii_lowercase().as_str(), "yes" | "on");
}
"identitiesonly" => {
identities_only = matches!(val.to_ascii_lowercase().as_str(), "yes" | "on");
}
other => {
return Err(format!("unsupported -o option: {other}={val}"));
}
}
}
s if s.starts_with('-') => {
return Err(format!("unknown flag: {s}"));
}
_ => positional.push(a.clone()),
}
i += 1;
}
if positional.is_empty() {
return Err("missing host argument".into());
}
let target = positional.remove(0);
let (user_in_host, host) = match target.split_once('@') {
Some((u, h)) => (Some(u.to_string()), h.to_string()),
None => (None, target),
};
if host.is_empty() {
return Err("empty host".into());
}
let command = if positional.is_empty() {
None
} else {
Some(positional.join(" "))
};
Ok(Cli {
port,
identities,
cli_user,
strict,
known_hosts_path,
hash_known_hosts,
identities_only,
locals,
remotes,
no_command,
agent_forward,
x11_forward,
host,
user_in_host,
command,
})
}
fn run() -> Result<i32, String> {
let args: Vec<String> = std::env::args().skip(1).collect();
if args.iter().any(|a| a == "-h" || a == "--help") {
println!("{USAGE}");
println!();
println!("A pure-Rust SSH client built on puressh {VERSION}.");
return Ok(0);
}
if args.iter().any(|a| a == "-V" || a == "--version") {
println!("puressh ssh {VERSION}");
return Ok(0);
}
let cli = parse_args(&args).map_err(|e| format!("{e}\n{USAGE}"))?;
let user = resolve_user(cli.cli_user.as_deref(), cli.user_in_host.as_deref())?;
let policy = build_host_key_policy(
cli.strict,
cli.known_hosts_path.clone(),
cli.hash_known_hosts,
)?;
let cfg = Config {
host_key_policy: policy,
timeout: None,
};
let mut client = Client::connect_to_host(cli.host.as_str(), cli.port, cfg)
.map_err(|e| format!("connect: {e}"))?;
let mut credentials: Vec<ClientCredential> = Vec::new();
if !cli.identities_only {
match connect_agent_credentials() {
Ok(mut from_agent) => credentials.append(&mut from_agent),
Err(e) => eprintln!("warning: agent: {e}"),
}
}
for id_path in &cli.identities {
let pk = match load_identity(id_path) {
Ok(p) => p,
Err(e) => {
eprintln!("warning: {e}");
continue;
}
};
match pk.into_host_key() {
Ok(hk) => credentials.push(ClientCredential::PublicKey(hk)),
Err(e) => eprintln!("warning: identity {id_path}: {e}"),
}
}
let authed = if !credentials.is_empty() {
match client.authenticate(&user, credentials) {
Ok(()) => true,
Err(e) => {
eprintln!("publickey auth: {e}");
false
}
}
} else {
false
};
if !authed {
let password = read_password_from_stdin().map_err(|e| format!("read password: {e}"))?;
client
.authenticate_password(&user, &password)
.map_err(|e| format!("Auth failed: {e}"))?;
}
let want_forwarding = cli.no_command
|| !cli.remotes.is_empty()
|| !cli.locals.is_empty()
|| cli.agent_forward
|| cli.x11_forward.is_some();
if want_forwarding {
if cli.command.is_some() {
return Err(
"running a command alongside -A/-L/-R/-N/-X/-Y is not yet supported; \
invoke ssh twice or wire the forward without a command"
.into(),
);
}
if cli.no_command
&& cli.remotes.is_empty()
&& cli.locals.is_empty()
&& !cli.agent_forward
&& cli.x11_forward.is_none()
{
return Err("-N requires at least one of -A, -L, -R, -X, -Y".into());
}
return run_forwarding(client, &cli);
}
let command = cli
.command
.ok_or_else(|| "interactive shell not yet implemented".to_string())?;
let out = client.exec(&command).map_err(|e| format!("exec: {e}"))?;
let _ = std::io::stdout().write_all(&out.stdout);
let _ = std::io::stderr().write_all(&out.stderr);
Ok(out.exit_status.map(|s| s as i32).unwrap_or(255))
}
fn spawn_splice_to_tcp(stream: ChannelStream, tcp: TcpStream) {
use puressh::client::ChannelEgress;
let (chan_rx, chan_tx) = stream.into_raw();
let tcp_in = match tcp.try_clone() {
Ok(c) => c,
Err(_) => {
let _ = chan_tx.send(ChannelEgress::Eof);
let _ = chan_tx.send(ChannelEgress::Close);
return;
}
};
let tcp_out = tcp;
let chan_tx_a = chan_tx.clone();
let mut tcp_in_a = tcp_in;
let a = thread::spawn(move || {
let mut buf = [0u8; 32 * 1024];
loop {
match tcp_in_a.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
if chan_tx_a
.send(ChannelEgress::Data(buf[..n].to_vec()))
.is_err()
{
break;
}
}
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
let _ = chan_tx_a.send(ChannelEgress::Eof);
});
let mut tcp_out_b = tcp_out;
let b = thread::spawn(move || {
while let Ok(Some(chunk)) = chan_rx.recv() {
if tcp_out_b.write_all(&chunk).is_err() {
break;
}
}
let _ = tcp_out_b.shutdown(std::net::Shutdown::Read);
});
thread::spawn(move || {
let _ = a.join();
let _ = b.join();
let _ = chan_tx.send(ChannelEgress::Close);
});
}
fn run_forwarding(mut client: Client, cli: &Cli) -> Result<i32, String> {
let mut routes: std::collections::BTreeMap<(String, u16), (String, u16)> =
std::collections::BTreeMap::new();
for r in &cli.remotes {
let bound_port = client
.request_tcpip_forward("127.0.0.1", r.remote_port)
.map_err(|e| format!("tcpip-forward 127.0.0.1:{}: {e}", r.remote_port))?;
eprintln!(
"ssh: -R 127.0.0.1:{}:{}:{} active",
bound_port, r.local_host, r.local_port,
);
routes.insert(
("127.0.0.1".to_string(), bound_port),
(r.local_host.clone(), r.local_port),
);
}
let routes = Arc::new(Mutex::new(routes));
let routes_for_cb = Arc::clone(&routes);
let cb: Arc<ForwardedTcpipCallback> =
Arc::new(move |origin: ForwardedTcpipOrigin, stream: ChannelStream| {
let target = {
let map = match routes_for_cb.lock() {
Ok(g) => g,
Err(_) => return,
};
map.get(&(origin.bound_address.clone(), origin.bound_port))
.cloned()
};
let (local_host, local_port) = match target {
Some(t) => t,
None => {
eprintln!(
"ssh: forwarded-tcpip for unknown binding {}:{}; dropping",
origin.bound_address, origin.bound_port
);
return;
}
};
match TcpStream::connect((local_host.as_str(), local_port)) {
Ok(tcp) => spawn_splice_to_tcp(stream, tcp),
Err(e) => eprintln!(
"ssh: dial {}:{} for forwarded-tcpip from {}:{}: {e}",
local_host, local_port, origin.orig_address, origin.orig_port
),
}
});
let mut handlers = ClientHandlers::new().with_forwarded_tcpip(cb);
let agent_fwd_channel: Option<u32> = if cli.agent_forward {
#[cfg(unix)]
{
use puressh::forwarding::agent::splice_to_local_agent_callback;
let cb = splice_to_local_agent_callback().ok_or_else(|| {
"-A: $SSH_AUTH_SOCK is unset or names a socket that doesn't exist".to_string()
})?;
handlers = handlers.with_auth_agent(cb);
let id = client
.open_session_for_agent_forward()
.map_err(|e| format!("agent-forward session: {e}"))?;
eprintln!("ssh: -A agent forwarding requested");
Some(id)
}
#[cfg(not(unix))]
{
return Err("-A agent forwarding is not supported on this platform".to_string());
}
} else {
None
};
let x11_fwd_channel: Option<u32> = if let Some(mode) = cli.x11_forward {
#[cfg(not(unix))]
{
let _ = mode;
return Err("-X/-Y X11 forwarding is not supported on this platform".to_string());
}
#[cfg(unix)]
{
use puressh::forwarding::x11::splice_to_local_display_callback;
let cb = splice_to_local_display_callback().ok_or_else(|| {
"-X/-Y: $DISPLAY is unset or names a display we don't know how to dial".to_string()
})?;
handlers = handlers.with_x11(cb);
let cookie = mint_x11_cookie();
let id = client
.open_session_for_x11_forward(false, "MIT-MAGIC-COOKIE-1", &cookie, 0)
.map_err(|e| format!("x11-forward session: {e}"))?;
eprintln!(
"ssh: -{} X11 forwarding requested (cookie={} chars)",
if mode == X11Forward::Trusted {
"Y"
} else {
"X"
},
cookie.len(),
);
Some(id)
}
} else {
None
};
let ctx_opt: Option<ServeContext> = if cli.locals.is_empty() {
None
} else {
let (h, ctx) = handlers.with_serve_context();
handlers = h;
for l in &cli.locals {
let listener = TcpListener::bind(("127.0.0.1", l.listen_port))
.map_err(|e| format!("-L bind 127.0.0.1:{}: {e}", l.listen_port))?;
eprintln!(
"ssh: -L 127.0.0.1:{}:{}:{} active",
l.listen_port, l.remote_host, l.remote_port,
);
spawn_local_forward_listener(listener, l.clone(), ctx.clone());
}
Some(ctx)
};
let result = match client.serve(handlers) {
Ok(()) => Ok(0),
Err(e) => Err(format!("serve: {e}")),
};
if let Some(id) = agent_fwd_channel {
let _ = client.close_session(id);
}
if let Some(id) = x11_fwd_channel {
let _ = client.close_session(id);
}
drop(ctx_opt);
result
}
fn spawn_local_forward_listener(listener: TcpListener, spec: LocalForward, ctx: ServeContext) {
thread::spawn(move || {
for accept in listener.incoming() {
let tcp = match accept {
Ok(s) => s,
Err(e) => {
eprintln!("ssh: -L accept on 127.0.0.1:{}: {e}", spec.listen_port);
continue;
}
};
let orig = tcp
.peer_addr()
.map(|a| (a.ip().to_string(), a.port()))
.unwrap_or_else(|_| ("127.0.0.1".to_string(), 0));
let stream =
match ctx.open_direct_tcpip(&spec.remote_host, spec.remote_port, &orig.0, orig.1) {
Ok(s) => s,
Err(e) => {
eprintln!(
"ssh: -L direct-tcpip {}:{}: {e}",
spec.remote_host, spec.remote_port
);
continue;
}
};
spawn_splice_to_tcp(stream, tcp);
}
});
}
#[cfg(unix)]
fn mint_x11_cookie() -> String {
let mut bytes = [0u8; 16];
if let Ok(mut f) = std::fs::File::open("/dev/urandom") {
let _ = std::io::Read::read_exact(&mut f, &mut bytes);
}
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
bytes[0] ^= (pid & 0xff) as u8;
bytes[1] ^= ((pid >> 8) & 0xff) as u8;
bytes[2] ^= (nanos & 0xff) as u8;
bytes[3] ^= ((nanos >> 8) & 0xff) as u8;
let mut s = String::with_capacity(32);
for b in bytes {
s.push_str(&format!("{b:02x}"));
}
s
}
fn main() -> ExitCode {
match run() {
Ok(code) => {
let clamped = code.clamp(0, 255) as u8;
ExitCode::from(clamped)
}
Err(msg) => {
eprintln!("ssh: {msg}");
ExitCode::from(255)
}
}
}