use crate::error::AgentError;
pub const SSH_AGENTC_REQUEST_IDENTITIES: u8 = 11;
pub const SSH_AGENT_IDENTITIES_ANSWER: u8 = 12;
pub const SSH_AGENTC_SIGN_REQUEST: u8 = 13;
pub const SSH_AGENT_SIGN_RESPONSE: u8 = 14;
pub const SSH_AGENT_FAILURE: u8 = 5;
pub const MAX_FRAME_LEN: usize = 256 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Request {
RequestIdentities,
SignRequest {
key_blob: Vec<u8>,
data: Vec<u8>,
flags: u32,
},
}
pub struct Identity {
pub key_blob: Vec<u8>,
pub comment: String,
}
struct Reader<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> Reader<'a> {
fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
fn remaining(&self) -> usize {
self.buf.len() - self.pos
}
fn u8(&mut self) -> Result<u8, AgentError> {
if self.remaining() < 1 {
return Err(protocol("truncated: expected a byte"));
}
let b = self.buf[self.pos];
self.pos += 1;
Ok(b)
}
fn u32(&mut self) -> Result<u32, AgentError> {
if self.remaining() < 4 {
return Err(protocol("truncated: expected a u32"));
}
let v = u32::from_be_bytes(self.buf[self.pos..self.pos + 4].try_into().unwrap());
self.pos += 4;
Ok(v)
}
fn string(&mut self) -> Result<Vec<u8>, AgentError> {
let len = self.u32()? as usize;
if len > MAX_FRAME_LEN {
return Err(protocol("string length exceeds the frame cap"));
}
if self.remaining() < len {
return Err(protocol("string length exceeds the remaining buffer"));
}
let out = self.buf[self.pos..self.pos + len].to_vec();
self.pos += len;
Ok(out)
}
}
fn protocol(msg: &str) -> AgentError {
AgentError::Protocol(msg.to_string())
}
pub fn parse_request(body: &[u8]) -> Result<Request, AgentError> {
let mut r = Reader::new(body);
let msg_type = r.u8()?;
match msg_type {
SSH_AGENTC_REQUEST_IDENTITIES => {
if r.remaining() != 0 {
return Err(protocol(
"REQUEST_IDENTITIES carries unexpected trailing bytes",
));
}
Ok(Request::RequestIdentities)
}
SSH_AGENTC_SIGN_REQUEST => {
let key_blob = r.string()?;
let data = r.string()?;
let flags = r.u32()?;
if r.remaining() != 0 {
return Err(protocol("SIGN_REQUEST carries unexpected trailing bytes"));
}
Ok(Request::SignRequest {
key_blob,
data,
flags,
})
}
other => Err(AgentError::Protocol(format!(
"unsupported ssh-agent opcode {other}"
))),
}
}
fn put_string(out: &mut Vec<u8>, bytes: &[u8]) {
kovra_core::write_string(out, bytes);
}
pub fn encode_identities_answer(identities: &[Identity]) -> Vec<u8> {
let mut out = Vec::new();
out.push(SSH_AGENT_IDENTITIES_ANSWER);
out.extend_from_slice(&(identities.len() as u32).to_be_bytes());
for id in identities {
put_string(&mut out, &id.key_blob);
put_string(&mut out, id.comment.as_bytes());
}
out
}
pub fn encode_sign_response(signature: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
out.push(SSH_AGENT_SIGN_RESPONSE);
put_string(&mut out, signature);
out
}
pub fn encode_failure() -> Vec<u8> {
vec![SSH_AGENT_FAILURE]
}
pub fn frame(body: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + body.len());
out.extend_from_slice(&(body.len() as u32).to_be_bytes());
out.extend_from_slice(body);
out
}
pub fn read_frame<R: std::io::Read>(stream: &mut R) -> Result<Option<Vec<u8>>, AgentError> {
let mut len_buf = [0u8; 4];
if !read_exact_or_eof(stream, &mut len_buf)? {
return Ok(None);
}
let len = u32::from_be_bytes(len_buf) as usize;
if len == 0 {
return Err(protocol("zero-length frame"));
}
if len > MAX_FRAME_LEN {
return Err(protocol("frame length exceeds the cap"));
}
let mut body = vec![0u8; len];
stream
.read_exact(&mut body)
.map_err(|e| AgentError::Io(e.to_string()))?;
Ok(Some(body))
}
fn read_exact_or_eof<R: std::io::Read>(stream: &mut R, buf: &mut [u8]) -> Result<bool, AgentError> {
let mut read = 0;
while read < buf.len() {
match stream.read(&mut buf[read..]) {
Ok(0) => {
if read == 0 {
return Ok(false);
}
return Err(protocol("unexpected EOF mid-frame"));
}
Ok(n) => read += n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(AgentError::Io(e.to_string())),
}
}
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_identities_round_trip() {
let body = vec![SSH_AGENTC_REQUEST_IDENTITIES];
assert_eq!(parse_request(&body).unwrap(), Request::RequestIdentities);
let answer = encode_identities_answer(&[Identity {
key_blob: vec![1, 2, 3],
comment: "kovra:dev/ssh/deploy".into(),
}]);
assert_eq!(answer[0], SSH_AGENT_IDENTITIES_ANSWER);
assert_eq!(&answer[1..5], &1u32.to_be_bytes());
}
#[test]
fn sign_request_round_trip() {
let mut body = vec![SSH_AGENTC_SIGN_REQUEST];
put_string(&mut body, b"PUBKEYBLOB");
put_string(&mut body, b"challenge-data");
body.extend_from_slice(&2u32.to_be_bytes());
let framed = frame(&body);
let mut cursor = std::io::Cursor::new(framed);
let read_body = read_frame(&mut cursor).unwrap().unwrap();
assert_eq!(read_body, body);
match parse_request(&read_body).unwrap() {
Request::SignRequest {
key_blob,
data,
flags,
} => {
assert_eq!(key_blob, b"PUBKEYBLOB");
assert_eq!(data, b"challenge-data");
assert_eq!(flags, 2);
}
other => panic!("expected SignRequest, got {other:?}"),
}
}
#[test]
fn sign_response_encodes_signature_string() {
let resp = encode_sign_response(b"SIGBLOB");
assert_eq!(resp[0], SSH_AGENT_SIGN_RESPONSE);
assert_eq!(&resp[1..5], &(b"SIGBLOB".len() as u32).to_be_bytes());
assert_eq!(&resp[5..], b"SIGBLOB");
}
#[test]
fn oversized_string_length_is_rejected_not_allocated() {
let mut body = vec![SSH_AGENTC_SIGN_REQUEST];
body.extend_from_slice(&0xFFFF_FFFFu32.to_be_bytes());
let err = parse_request(&body).unwrap_err();
assert!(matches!(err, AgentError::Protocol(_)));
}
#[test]
fn unknown_opcode_is_rejected() {
let err = parse_request(&[200]).unwrap_err();
assert!(matches!(err, AgentError::Protocol(_)));
}
#[test]
fn empty_body_is_rejected() {
assert!(matches!(
parse_request(&[]).unwrap_err(),
AgentError::Protocol(_)
));
}
#[test]
fn trailing_bytes_are_rejected() {
let body = vec![SSH_AGENTC_REQUEST_IDENTITIES, 0xAA];
assert!(matches!(
parse_request(&body).unwrap_err(),
AgentError::Protocol(_)
));
}
#[test]
fn read_frame_rejects_oversized_length() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&((MAX_FRAME_LEN + 1) as u32).to_be_bytes());
let mut cursor = std::io::Cursor::new(bytes);
assert!(matches!(
read_frame(&mut cursor).unwrap_err(),
AgentError::Protocol(_)
));
}
#[test]
fn read_frame_eof_at_boundary_is_none() {
let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
assert!(read_frame(&mut cursor).unwrap().is_none());
}
#[test]
fn read_frame_partial_length_is_error() {
let mut cursor = std::io::Cursor::new(vec![0u8, 0u8]); assert!(read_frame(&mut cursor).is_err());
}
#[test]
fn arbitrary_inputs_never_panic() {
let samples: &[&[u8]] = &[
&[],
&[0],
&[5],
&[11],
&[11, 0],
&[13],
&[13, 0, 0, 0, 4],
&[13, 0, 0, 0, 4, 1, 2, 3, 4],
&[13, 255, 255, 255, 255],
&[13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
];
for s in samples {
let _ = parse_request(s); }
}
}