1use crate::error::AgentError;
22
23pub const SSH_AGENTC_REQUEST_IDENTITIES: u8 = 11;
26pub const SSH_AGENT_IDENTITIES_ANSWER: u8 = 12;
28pub const SSH_AGENTC_SIGN_REQUEST: u8 = 13;
30pub const SSH_AGENT_SIGN_RESPONSE: u8 = 14;
32pub const SSH_AGENT_FAILURE: u8 = 5;
34
35pub const MAX_FRAME_LEN: usize = 256 * 1024;
39
40#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum Request {
44 RequestIdentities,
46 SignRequest {
49 key_blob: Vec<u8>,
51 data: Vec<u8>,
53 flags: u32,
55 },
56}
57
58pub struct Identity {
60 pub key_blob: Vec<u8>,
62 pub comment: String,
64}
65
66struct Reader<'a> {
69 buf: &'a [u8],
70 pos: usize,
71}
72
73impl<'a> Reader<'a> {
74 fn new(buf: &'a [u8]) -> Self {
75 Self { buf, pos: 0 }
76 }
77
78 fn remaining(&self) -> usize {
79 self.buf.len() - self.pos
80 }
81
82 fn u8(&mut self) -> Result<u8, AgentError> {
84 if self.remaining() < 1 {
85 return Err(protocol("truncated: expected a byte"));
86 }
87 let b = self.buf[self.pos];
88 self.pos += 1;
89 Ok(b)
90 }
91
92 fn u32(&mut self) -> Result<u32, AgentError> {
94 if self.remaining() < 4 {
95 return Err(protocol("truncated: expected a u32"));
96 }
97 let v = u32::from_be_bytes(self.buf[self.pos..self.pos + 4].try_into().unwrap());
98 self.pos += 4;
99 Ok(v)
100 }
101
102 fn string(&mut self) -> Result<Vec<u8>, AgentError> {
106 let len = self.u32()? as usize;
107 if len > MAX_FRAME_LEN {
108 return Err(protocol("string length exceeds the frame cap"));
109 }
110 if self.remaining() < len {
111 return Err(protocol("string length exceeds the remaining buffer"));
112 }
113 let out = self.buf[self.pos..self.pos + len].to_vec();
114 self.pos += len;
115 Ok(out)
116 }
117}
118
119fn protocol(msg: &str) -> AgentError {
120 AgentError::Protocol(msg.to_string())
121}
122
123pub fn parse_request(body: &[u8]) -> Result<Request, AgentError> {
128 let mut r = Reader::new(body);
129 let msg_type = r.u8()?;
130 match msg_type {
131 SSH_AGENTC_REQUEST_IDENTITIES => {
132 if r.remaining() != 0 {
134 return Err(protocol(
135 "REQUEST_IDENTITIES carries unexpected trailing bytes",
136 ));
137 }
138 Ok(Request::RequestIdentities)
139 }
140 SSH_AGENTC_SIGN_REQUEST => {
141 let key_blob = r.string()?;
142 let data = r.string()?;
143 let flags = r.u32()?;
144 if r.remaining() != 0 {
145 return Err(protocol("SIGN_REQUEST carries unexpected trailing bytes"));
146 }
147 Ok(Request::SignRequest {
148 key_blob,
149 data,
150 flags,
151 })
152 }
153 other => Err(AgentError::Protocol(format!(
154 "unsupported ssh-agent opcode {other}"
155 ))),
156 }
157}
158
159fn put_string(out: &mut Vec<u8>, bytes: &[u8]) {
162 kovra_core::write_string(out, bytes);
163}
164
165pub fn encode_identities_answer(identities: &[Identity]) -> Vec<u8> {
168 let mut out = Vec::new();
169 out.push(SSH_AGENT_IDENTITIES_ANSWER);
170 out.extend_from_slice(&(identities.len() as u32).to_be_bytes());
171 for id in identities {
172 put_string(&mut out, &id.key_blob);
173 put_string(&mut out, id.comment.as_bytes());
174 }
175 out
176}
177
178pub fn encode_sign_response(signature: &[u8]) -> Vec<u8> {
182 let mut out = Vec::new();
183 out.push(SSH_AGENT_SIGN_RESPONSE);
184 put_string(&mut out, signature);
185 out
186}
187
188pub fn encode_failure() -> Vec<u8> {
190 vec![SSH_AGENT_FAILURE]
191}
192
193pub fn frame(body: &[u8]) -> Vec<u8> {
196 let mut out = Vec::with_capacity(4 + body.len());
197 out.extend_from_slice(&(body.len() as u32).to_be_bytes());
198 out.extend_from_slice(body);
199 out
200}
201
202pub fn read_frame<R: std::io::Read>(stream: &mut R) -> Result<Option<Vec<u8>>, AgentError> {
209 let mut len_buf = [0u8; 4];
210 if !read_exact_or_eof(stream, &mut len_buf)? {
211 return Ok(None);
213 }
214 let len = u32::from_be_bytes(len_buf) as usize;
215 if len == 0 {
216 return Err(protocol("zero-length frame"));
217 }
218 if len > MAX_FRAME_LEN {
219 return Err(protocol("frame length exceeds the cap"));
220 }
221 let mut body = vec![0u8; len];
222 stream
223 .read_exact(&mut body)
224 .map_err(|e| AgentError::Io(e.to_string()))?;
225 Ok(Some(body))
226}
227
228fn read_exact_or_eof<R: std::io::Read>(stream: &mut R, buf: &mut [u8]) -> Result<bool, AgentError> {
231 let mut read = 0;
232 while read < buf.len() {
233 match stream.read(&mut buf[read..]) {
234 Ok(0) => {
235 if read == 0 {
236 return Ok(false);
237 }
238 return Err(protocol("unexpected EOF mid-frame"));
239 }
240 Ok(n) => read += n,
241 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
242 Err(e) => return Err(AgentError::Io(e.to_string())),
243 }
244 }
245 Ok(true)
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
255 fn request_identities_round_trip() {
256 let body = vec![SSH_AGENTC_REQUEST_IDENTITIES];
257 assert_eq!(parse_request(&body).unwrap(), Request::RequestIdentities);
258
259 let answer = encode_identities_answer(&[Identity {
260 key_blob: vec![1, 2, 3],
261 comment: "kovra:dev/ssh/deploy".into(),
262 }]);
263 assert_eq!(answer[0], SSH_AGENT_IDENTITIES_ANSWER);
264 assert_eq!(&answer[1..5], &1u32.to_be_bytes());
266 }
267
268 #[test]
270 fn sign_request_round_trip() {
271 let mut body = vec![SSH_AGENTC_SIGN_REQUEST];
272 put_string(&mut body, b"PUBKEYBLOB");
273 put_string(&mut body, b"challenge-data");
274 body.extend_from_slice(&2u32.to_be_bytes()); let framed = frame(&body);
277 let mut cursor = std::io::Cursor::new(framed);
278 let read_body = read_frame(&mut cursor).unwrap().unwrap();
279 assert_eq!(read_body, body);
280
281 match parse_request(&read_body).unwrap() {
282 Request::SignRequest {
283 key_blob,
284 data,
285 flags,
286 } => {
287 assert_eq!(key_blob, b"PUBKEYBLOB");
288 assert_eq!(data, b"challenge-data");
289 assert_eq!(flags, 2);
290 }
291 other => panic!("expected SignRequest, got {other:?}"),
292 }
293 }
294
295 #[test]
296 fn sign_response_encodes_signature_string() {
297 let resp = encode_sign_response(b"SIGBLOB");
298 assert_eq!(resp[0], SSH_AGENT_SIGN_RESPONSE);
299 assert_eq!(&resp[1..5], &(b"SIGBLOB".len() as u32).to_be_bytes());
300 assert_eq!(&resp[5..], b"SIGBLOB");
301 }
302
303 #[test]
305 fn oversized_string_length_is_rejected_not_allocated() {
306 let mut body = vec![SSH_AGENTC_SIGN_REQUEST];
307 body.extend_from_slice(&0xFFFF_FFFFu32.to_be_bytes());
309 let err = parse_request(&body).unwrap_err();
310 assert!(matches!(err, AgentError::Protocol(_)));
311 }
312
313 #[test]
314 fn unknown_opcode_is_rejected() {
315 let err = parse_request(&[200]).unwrap_err();
317 assert!(matches!(err, AgentError::Protocol(_)));
318 }
319
320 #[test]
321 fn empty_body_is_rejected() {
322 assert!(matches!(
323 parse_request(&[]).unwrap_err(),
324 AgentError::Protocol(_)
325 ));
326 }
327
328 #[test]
329 fn trailing_bytes_are_rejected() {
330 let body = vec![SSH_AGENTC_REQUEST_IDENTITIES, 0xAA];
331 assert!(matches!(
332 parse_request(&body).unwrap_err(),
333 AgentError::Protocol(_)
334 ));
335 }
336
337 #[test]
339 fn read_frame_rejects_oversized_length() {
340 let mut bytes = Vec::new();
341 bytes.extend_from_slice(&((MAX_FRAME_LEN + 1) as u32).to_be_bytes());
342 let mut cursor = std::io::Cursor::new(bytes);
343 assert!(matches!(
344 read_frame(&mut cursor).unwrap_err(),
345 AgentError::Protocol(_)
346 ));
347 }
348
349 #[test]
351 fn read_frame_eof_at_boundary_is_none() {
352 let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
353 assert!(read_frame(&mut cursor).unwrap().is_none());
354 }
355
356 #[test]
358 fn read_frame_partial_length_is_error() {
359 let mut cursor = std::io::Cursor::new(vec![0u8, 0u8]); assert!(read_frame(&mut cursor).is_err());
361 }
362
363 #[test]
366 fn arbitrary_inputs_never_panic() {
367 let samples: &[&[u8]] = &[
368 &[],
369 &[0],
370 &[5],
371 &[11],
372 &[11, 0],
373 &[13],
374 &[13, 0, 0, 0, 4],
375 &[13, 0, 0, 0, 4, 1, 2, 3, 4],
376 &[13, 255, 255, 255, 255],
377 &[13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
378 ];
379 for s in samples {
380 let _ = parse_request(s); }
382 }
383}