use std::ffi::OsString;
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::path::{Path, PathBuf};
use std::time::Duration;
use crate::error::{Error, Result};
use super::protocol::{
decode_identities_answer, decode_sign_response, encode_request_identities, encode_sign_request,
IdentityEntry, MAX_REPLY_LEN, SSH_AGENT_FAILURE, SSH_AGENT_IDENTITIES_ANSWER,
SSH_AGENT_SIGN_RESPONSE,
};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct AgentIdentity {
pub key_blob: Vec<u8>,
pub comment: String,
}
impl AgentIdentity {
pub fn algorithm(&self) -> String {
if self.key_blob.len() < 4 {
return String::new();
}
let len = u32::from_be_bytes([
self.key_blob[0],
self.key_blob[1],
self.key_blob[2],
self.key_blob[3],
]) as usize;
if self.key_blob.len() < 4 + len {
return String::new();
}
String::from_utf8_lossy(&self.key_blob[4..4 + len]).into_owned()
}
pub fn comment(&self) -> &str {
&self.comment
}
pub fn key_blob(&self) -> &[u8] {
&self.key_blob
}
}
pub struct Agent {
stream: UnixStream,
}
impl Agent {
pub fn connect(path: impl AsRef<Path>) -> Result<Self> {
let stream = UnixStream::connect(path.as_ref()).map_err(Error::from)?;
stream
.set_read_timeout(Some(DEFAULT_TIMEOUT))
.map_err(Error::from)?;
stream
.set_write_timeout(Some(DEFAULT_TIMEOUT))
.map_err(Error::from)?;
Ok(Self { stream })
}
pub fn connect_env() -> Result<Option<Self>> {
let raw: OsString = match std::env::var_os("SSH_AUTH_SOCK") {
Some(v) if !v.is_empty() => v,
_ => return Ok(None),
};
let path = PathBuf::from(raw);
if let Err(why) = validate_auth_sock(&path) {
eprintln!(
"warning: SSH_AUTH_SOCK={} rejected: {why}; ignoring agent",
path.display()
);
return Ok(None);
}
Self::connect(path).map(Some)
}
pub fn identities(&mut self) -> Result<Vec<AgentIdentity>> {
self.write_frame(&encode_request_identities())?;
let (msg_type, body) = self.read_frame()?;
if msg_type == SSH_AGENT_FAILURE {
return Err(Error::Protocol("agent: failure on identities request"));
}
if msg_type != SSH_AGENT_IDENTITIES_ANSWER {
return Err(Error::Protocol("agent: unexpected identities-reply type"));
}
let raw = decode_identities_answer(&body)?;
Ok(raw
.into_iter()
.map(|IdentityEntry { key_blob, comment }| AgentIdentity { key_blob, comment })
.collect())
}
pub fn sign(&mut self, key_blob: &[u8], data: &[u8], flags: u32) -> Result<Vec<u8>> {
self.write_frame(&encode_sign_request(key_blob, data, flags))?;
let (msg_type, body) = self.read_frame()?;
if msg_type == SSH_AGENT_FAILURE {
return Err(Error::Protocol("agent: failure on sign request"));
}
if msg_type != SSH_AGENT_SIGN_RESPONSE {
return Err(Error::Protocol("agent: unexpected sign-reply type"));
}
decode_sign_response(&body)
}
fn write_frame(&mut self, frame: &[u8]) -> Result<()> {
self.stream.write_all(frame).map_err(Error::from)?;
Ok(())
}
fn read_frame(&mut self) -> Result<(u8, Vec<u8>)> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).map_err(Error::from)?;
let len = u32::from_be_bytes(len_buf) as usize;
if len == 0 {
return Err(Error::Format("agent: zero-length frame"));
}
if len > MAX_REPLY_LEN {
return Err(Error::Format("agent: reply exceeds MAX_REPLY_LEN"));
}
let mut buf = vec![0u8; len];
self.stream.read_exact(&mut buf).map_err(Error::from)?;
let msg_type = buf[0];
let body = buf.split_off(1);
Ok((msg_type, body))
}
}
fn validate_auth_sock(path: &Path) -> core::result::Result<(), String> {
let md = std::fs::symlink_metadata(path)
.map_err(|e| format!("cannot stat: {e} (does the agent socket exist?)"))?;
if md.file_type().is_symlink() {
return Err(
"path is a symlink (a malicious symlink could redirect to another agent)".into(),
);
}
use std::os::unix::fs::{FileTypeExt, MetadataExt};
if !md.file_type().is_socket() {
return Err("path is not a Unix-domain socket".into());
}
let euid = nix::unistd::geteuid().as_raw();
if md.uid() != euid {
return Err(format!(
"socket is owned by uid {} but we are euid {} (refusing to trust another user's agent)",
md.uid(),
euid
));
}
let mode = md.mode();
if (mode & 0o077) != 0 {
return Err(format!(
"socket is group/world-accessible (mode {:o}); refusing to use it",
mode & 0o777
));
}
Ok(())
}
#[cfg(test)]
mod auth_sock_tests {
use super::*;
use std::os::unix::net::UnixListener;
fn make_socket() -> (PathBuf, UnixListener) {
let dir = std::env::temp_dir();
let unique = format!(
"puressh-auth-sock-test-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0),
);
let p = dir.join(unique);
let _ = std::fs::remove_file(&p);
let l = UnixListener::bind(&p).expect("bind unix listener");
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600));
(p, l)
}
#[test]
fn valid_owned_socket_is_accepted() {
let (p, _l) = make_socket();
let r = validate_auth_sock(&p);
std::fs::remove_file(&p).ok();
assert!(r.is_ok(), "expected Ok, got {r:?}");
}
#[test]
fn nonexistent_is_rejected() {
let p = std::env::temp_dir().join(format!(
"puressh-noent-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0),
));
assert!(validate_auth_sock(&p).is_err());
}
#[test]
fn regular_file_is_rejected() {
let p = std::env::temp_dir().join(format!(
"puressh-regular-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0),
));
std::fs::write(&p, b"hello").unwrap();
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o600)).unwrap();
let r = validate_auth_sock(&p);
std::fs::remove_file(&p).ok();
let msg = r.unwrap_err();
assert!(msg.contains("not a Unix-domain socket"), "got: {msg}");
}
#[test]
fn symlink_is_rejected() {
let (target, _l) = make_socket();
let link = std::env::temp_dir().join(format!(
"puressh-symlink-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0),
));
let _ = std::fs::remove_file(&link);
std::os::unix::fs::symlink(&target, &link).unwrap();
let r = validate_auth_sock(&link);
std::fs::remove_file(&link).ok();
std::fs::remove_file(&target).ok();
let msg = r.unwrap_err();
assert!(msg.contains("symlink"), "got: {msg}");
}
#[test]
fn group_or_world_accessible_is_rejected() {
let (p, _l) = make_socket();
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&p, std::fs::Permissions::from_mode(0o666)).unwrap();
let r = validate_auth_sock(&p);
std::fs::remove_file(&p).ok();
let msg = r.unwrap_err();
assert!(msg.contains("group/world-accessible"), "got: {msg}");
}
}