use crate::codec::ReadMessage::{Failure, Success};
use crate::Error::UnknownMessageType;
use crate::{Error, Identity, Result};
use bytes::{Buf, Bytes, BytesMut};
use ssh_encoding::{Decode, Encode};
use ssh_key::{Algorithm, Certificate, PrivateKey, PublicKey, Signature};
use std::io::{Read, Write};
type MessageTypeId = u8;
const SSH_AGENTC_REQUEST_IDENTITIES: MessageTypeId = 11;
const SSH_AGENTC_SIGN_REQUEST: MessageTypeId = 13;
const SSH_AGENTC_ADD_IDENTITY: MessageTypeId = 17;
const SSH_AGENTC_REMOVE_IDENTITY: MessageTypeId = 18;
const SSH_AGENTC_REMOVE_ALL_IDENTITIES: MessageTypeId = 19;
const SSH_AGENT_FAILURE: MessageTypeId = 5;
const SSH_AGENT_SUCCESS: MessageTypeId = 6;
const SSH_AGENT_SIGN_RESPONSE: MessageTypeId = 14;
const SSH_AGENT_IDENTITIES_ANSWER: MessageTypeId = 12;
const SSH_AGENT_RSA_SHA2_512: usize = 0x04;
const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
pub enum WriteMessage<'a> {
RequestIdentities,
Sign(&'a Identity<'a>, &'a [u8]),
AddIdentity(&'a PrivateKey),
RemoveIdentity(&'a PrivateKey),
RemoveAllIdentities,
}
#[derive(Debug)]
pub enum ReadMessage {
Failure,
Success,
Identities(Vec<Identity<'static>>),
Signature(Signature),
}
pub fn read_message(input: &mut dyn Read) -> Result<ReadMessage> {
let (t, buf) = read_packet(input)?;
match t {
SSH_AGENT_FAILURE => Ok(Failure),
SSH_AGENT_SUCCESS => Ok(Success),
SSH_AGENT_IDENTITIES_ANSWER => Ok(ReadMessage::Identities(make_identities(buf)?)),
SSH_AGENT_SIGN_RESPONSE => {
let mut buf = &buf[..];
if buf.get_length()? != buf.len() {
return invalid_data("different inner and outer size");
}
let sig = Signature::decode(&mut buf)?;
Ok(ReadMessage::Signature(sig))
}
_ => Err(UnknownMessageType(t)),
}
}
pub fn write_message(output: &mut dyn Write, message: WriteMessage) -> Result<()> {
let mut buf: Vec<u8> = Vec::new();
match message {
WriteMessage::RequestIdentities => buf.write_all(&[SSH_AGENTC_REQUEST_IDENTITIES])?,
WriteMessage::AddIdentity(key) => {
buf.write_all(&[SSH_AGENTC_ADD_IDENTITY])?;
key.key_data().encode(&mut buf)?;
let comment = key.comment();
write_u32(comment.len(), &mut buf)?;
buf.write_all(comment.as_ref())?
}
WriteMessage::RemoveIdentity(key) => {
buf.write_all(&[SSH_AGENTC_REMOVE_IDENTITY])?;
write_u32(key.public_key().key_data().encoded_len()?, &mut buf)?;
key.public_key().key_data().encode(&mut buf)?;
}
WriteMessage::RemoveAllIdentities => buf.write_all(&[SSH_AGENTC_REMOVE_ALL_IDENTITIES])?,
WriteMessage::Sign(key, data) => {
match key {
Identity::PublicKey(key) => {
buf.write_all(&[SSH_AGENTC_SIGN_REQUEST])?;
write_u32(key.key_data().encoded_len()?, &mut buf)?;
key.key_data().encode(&mut buf)?;
write_u32(data.len(), &mut buf)?;
buf.write_all(data)?;
match key.algorithm() {
Algorithm::Rsa { hash: _ } => write_u32(SSH_AGENT_RSA_SHA2_512, &mut buf)?,
_ => write_u32(0, &mut buf)?,
}
}
Identity::Certificate(cert) => {
buf.write_all(&[SSH_AGENTC_SIGN_REQUEST])?;
let encoded_len = cert.encoded_len()?;
write_u32(encoded_len, &mut buf)?;
cert.encode(&mut buf)?;
write_u32(data.len(), &mut buf)?;
buf.write_all(data)?;
write_u32(0, &mut buf)?;
}
}
}
}
write_u32(buf.len(), output)?;
output.write_all(&buf)?;
Ok(())
}
fn write_u32(i: usize, output: &mut dyn Write) -> Result<()> {
let i = u32::try_from(i)
.map_err(|_| Error::InvalidMessage(format!("Could not encode {i} into an u32 value")))?;
output.write_all(&i.to_be_bytes())?;
Ok(())
}
fn read_packet(mut input: impl Read) -> Result<(MessageTypeId, Bytes)> {
let mut buf = [0u8; 5];
input.read_exact(&mut buf)?;
let mut buf = buf.as_ref();
let len = buf.get_length()?;
let message_type = buf.get_u8();
if len > MAX_MESSAGE_SIZE {
return invalid_data(&format!(
"Refusing to read message with size larger than {MAX_MESSAGE_SIZE}"
));
}
let mut bytes: BytesMut = BytesMut::zeroed(len - 1);
input.read_exact(bytes.as_mut())?;
Ok((message_type, bytes.freeze()))
}
fn invalid_data<T>(message: &str) -> Result<T> {
Err(Error::InvalidMessage(String::from(message)))
}
fn make_identities<'a>(mut buf: Bytes) -> Result<Vec<Identity<'a>>> {
let len = buf.get_length()?;
let mut result: Vec<Identity> = Vec::with_capacity(len);
for _ in 0..len {
let key_len = buf.get_length()?;
let key_bytes = &buf.chunk()[..key_len];
if get_key_type(key_bytes)?.contains("-cert-") {
let cert = Certificate::from_bytes(key_bytes)?;
buf.advance(key_len);
let encoded_cert = format!("{} {}", cert.to_openssh()?, get_comment(&mut buf)?);
let cert_with_comment = Certificate::from_openssh(&encoded_cert)?;
result.push(cert_with_comment.into());
} else {
let mut public_key = PublicKey::from_bytes(&buf.chunk()[..key_len])?;
buf.advance(key_len);
public_key.set_comment(get_comment(&mut buf)?);
result.push(public_key.into());
}
}
Ok(result)
}
fn get_comment(buf: &mut Bytes) -> Result<String> {
let comment_len = buf.get_length()?;
let result = match std::str::from_utf8(&buf.chunk()[..comment_len]) {
Ok(comment) => Ok(comment.to_string()),
Err(_) => return invalid_data("Invalid utf-8 sequence in comment"),
};
buf.advance(comment_len);
result
}
fn get_key_type(bytes: &[u8]) -> Result<String> {
let mut buf = bytes;
let len = buf.get_length()?;
if buf.len() < len {
return invalid_data("buffer too short");
}
String::from_utf8(buf[..len].to_vec())
.map_err(|e| Error::InvalidMessage(format!("Invalid key type: {e}")))
}
trait GetLength {
fn get_length(&mut self) -> Result<usize>;
}
macro_rules! get_length {
($t:ty) => {
impl GetLength for $t {
fn get_length(&mut self) -> Result<usize> {
if self.len() < 4 {
return invalid_data("length field is too short");
}
Ok(self.get_u32() as usize)
}
}
};
}
get_length!(Bytes);
get_length!(&[u8]);
#[cfg(test)]
mod test {
use crate::codec::{
get_key_type, make_identities, read_message, write_message, write_u32, ReadMessage,
WriteMessage,
};
use crate::Error::InvalidMessage;
use crate::{Error, Identity};
use bytes::Bytes;
use ssh_key::{Certificate, PrivateKey, PublicKey};
use std::io::Cursor;
pub fn reader(data: &'static [u8]) -> Cursor<&'static [u8]> {
Cursor::new(data)
}
#[test]
fn test_read_message_identities_answer() {
let mut cursor = reader(b"\0\0\0\x05\x0c\0\0\0\0");
let result = read_message(&mut cursor).expect("failed to read_message()");
match result {
ReadMessage::Identities(identities) => {
assert_eq!(identities, vec![])
}
_ => panic!("result was not IdentitiesAnswer"),
}
}
#[test]
fn test_read_message_failure() {
let mut cursor = reader(b"\0\0\0\x01\x05");
let result = read_message(&mut cursor).expect("failed to read_message()");
match result {
ReadMessage::Failure => (),
_ => panic!("result was not FailureAnswer"),
}
}
#[test]
fn test_read_message_success() {
let mut cursor = reader(b"\0\0\0\x01\x06");
let result = read_message(&mut cursor).expect("failed to read_message()");
match result {
ReadMessage::Success => (),
_ => panic!("result was not SuccessAnswer"),
}
}
#[test]
fn test_read_message_unknown() {
let mut cursor = reader(b"\0\0\0\x01\xff");
let result = read_message(&mut cursor);
match result {
Err(Error::UnknownMessageType(_)) => (),
_ => panic!("did not receive expected error UnknownMessageType"),
}
}
#[test]
fn test_read_overly_long_message_length() {
let mut cursor = reader(b"\x01\0\0\x01\xff");
let result = read_message(&mut cursor);
match result {
Err(InvalidMessage(msg)) => assert_eq!(
msg,
"Refusing to read message with size larger than 1048576"
),
_ => panic!("did not receive expected error InvalidData"),
}
}
#[test]
fn test_make_identities() {
let data = Bytes::from_static(include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/identity_list_response.bin"
)));
let key = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/id_ed25519.pub"
));
let identity: Identity = PublicKey::from_openssh(key).unwrap().into();
assert_eq!(
make_identities(data).expect("Could not decode"),
vec![identity]
)
}
macro_rules! read_str {
($s:expr) => {
include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/data/", $s))
};
}
#[test]
fn test_make_identities_with_cert() -> Result<(), Error> {
let data = Bytes::from_static(include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/identities_with_cert.bin"
)));
let identities = make_identities(data)?;
assert_eq!(3, identities.len());
let mut identities = identities.iter();
if let Identity::PublicKey(pk) = identities.next().unwrap() {
compare_to_key(read_str!("id_ed25519_for_cert.pub"), pk);
} else {
panic!("did not receive expected public key");
}
if let Identity::Certificate(cert) = identities.next().unwrap() {
compare_to_cert(read_str!("id_ed25519_for_cert-cert.pub"), cert);
} else {
panic!("did not receive expected cert");
}
if let Identity::PublicKey(pk) = identities.next().unwrap() {
compare_to_key(read_str!("id_ecdsa.pub"), pk);
} else {
panic!("did not receive expected public key");
}
Ok(())
}
fn compare_to_key(expected: &str, actual: &PublicKey) {
assert_eq!(
PublicKey::from_openssh(expected).unwrap().key_data(),
actual.key_data()
)
}
fn compare_to_cert(expected: &str, actual: &Certificate) {
assert_eq!(
Certificate::from_openssh(expected).unwrap().public_key(),
actual.public_key()
)
}
#[test]
fn test_write_message() {
let mut output: Vec<u8> = Vec::new();
write_message(&mut output, WriteMessage::RequestIdentities).expect("failed writing");
assert_eq!(vec![0_u8, 0, 0, 1, 11], output)
}
macro_rules! add_identity {
($message_path:expr, $key_path:expr) => {
let key = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/",
$key_path
));
let expected = include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/",
$message_path
));
let mut output: Vec<u8> = Vec::new();
let key = PrivateKey::from_openssh(key).expect("failed to parse key");
write_message(&mut output, WriteMessage::AddIdentity(&key)).unwrap();
assert_eq!(expected, output.as_slice());
};
}
#[test]
fn test_write_add_identity() {
add_identity!("ssh-add_rsa.bin", "id_rsa");
add_identity!("ssh-add_dsa.bin", "id_dsa");
add_identity!("ssh-add_ed25519.bin", "id_ed25519");
add_identity!("ssh-add_ecdsa.bin", "id_ecdsa");
}
macro_rules! remove_identity {
($message_path:expr, $key_path:expr) => {
let key = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/",
$key_path
));
let expected = include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/",
$message_path
));
let mut output: Vec<u8> = Vec::new();
let key = PrivateKey::from_openssh(key).expect("failed to parse key");
write_message(&mut output, WriteMessage::RemoveIdentity(&key)).unwrap();
assert_eq!(expected, output.as_slice());
};
}
#[test]
fn test_write_remove_identity() {
remove_identity!("ssh-remove_rsa.bin", "id_rsa");
remove_identity!("ssh-remove_dsa.bin", "id_dsa");
remove_identity!("ssh-remove_ed25519.bin", "id_ed25519");
remove_identity!("ssh-remove_ecdsa.bin", "id_ecdsa");
}
#[test]
fn test_write_remove_all_identities() {
let mut output: Vec<u8> = Vec::new();
write_message(&mut output, WriteMessage::RemoveAllIdentities).expect("failed writing");
assert_eq!(vec![0_u8, 0, 0, 1, 19], output)
}
#[test]
#[cfg(target_pointer_width = "64")]
fn test_write_too_large() {
let mut output: Vec<u8> = Vec::new();
let result = write_u32(usize::MAX, &mut output);
match result {
Err(InvalidMessage(msg)) => {
assert_eq!(
format!("Could not encode {} into an u32 value", usize::MAX),
msg
)
}
_ => panic!("expected InvalidMessage"),
}
}
#[test]
fn test_write_sign_rsa() {
let key = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/id_rsa.pub",
));
let expected = include_bytes!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/sign_rsa.bin",
));
let key = PublicKey::from_openssh(key).expect("failed to parse key");
let mut output: Vec<u8> = Vec::new();
write_message(&mut output, WriteMessage::Sign(&key.into(), b"a")).unwrap();
assert_eq!(expected, output.as_slice());
}
#[test]
fn test_get_key_type() -> Result<(), Error> {
let buf = b"\0\0\0\x03foo";
assert_eq!(get_key_type(buf)?, "foo");
let buf = b"\0\0\0\x03foobar";
assert_eq!(get_key_type(buf)?, "foo");
let buf = b"\0\0\0";
match get_key_type(buf).unwrap_err() {
InvalidMessage(msg) => {
assert_eq!("length field is too short", msg)
}
_ => panic!("expected InvalidMessage"),
}
let buf = b"\0\0\0\x03f";
match get_key_type(buf).unwrap_err() {
InvalidMessage(msg) => {
assert_eq!("buffer too short", msg)
}
_ => panic!("expected InvalidMessage"),
}
let buf = b"\0\0\0\x03f\xc0\xaf";
match get_key_type(buf).unwrap_err() {
InvalidMessage(msg) => {
assert_eq!(
"Invalid key type: invalid utf-8 sequence of 1 bytes from index 1",
msg
)
}
_ => panic!("expected InvalidMessage"),
}
Ok(())
}
}