1use crate::error::AgentError;
7use log::{debug, error, info, warn};
8use ssh_agent_lib::proto::Identity;
9use ssh_key::PrivateKey as SshPrivateKey;
10use ssh_key::private::{Ed25519Keypair as SshEd25519Keypair, KeypairData};
11use ssh_key::public::{Ed25519PublicKey, KeyData};
12use std::io::{Read, Write};
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::time::Duration;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum AgentStatus {
20 Running {
22 key_count: usize,
24 },
25 ConnectionFailed,
27 NotRunning,
29}
30
31mod proto {
33 pub const SSH_AGENTC_REQUEST_IDENTITIES: u8 = 11;
35 pub const SSH_AGENTC_SIGN_REQUEST: u8 = 13;
36 pub const SSH_AGENTC_ADD_IDENTITY: u8 = 17;
37 pub const SSH_AGENTC_REMOVE_ALL_IDENTITIES: u8 = 19;
38
39 pub const SSH_AGENT_FAILURE: u8 = 5;
41 pub const SSH_AGENT_SUCCESS: u8 = 6;
42 pub const SSH_AGENT_IDENTITIES_ANSWER: u8 = 12;
43 pub const SSH_AGENT_SIGN_RESPONSE: u8 = 14;
44}
45
46pub fn check_agent_status<P: AsRef<Path>>(socket_path: P) -> AgentStatus {
59 let socket_path = socket_path.as_ref();
60
61 if !socket_path.exists() {
62 debug!("Agent socket does not exist: {:?}", socket_path);
63 return AgentStatus::NotRunning;
64 }
65
66 let mut stream = match UnixStream::connect(socket_path) {
68 Ok(s) => s,
69 Err(e) => {
70 warn!("Failed to connect to agent socket {:?}: {}", socket_path, e);
71 return AgentStatus::ConnectionFailed;
72 }
73 };
74
75 if let Err(e) = stream.set_read_timeout(Some(Duration::from_secs(5))) {
77 warn!("Failed to set read timeout: {}", e);
78 }
79 if let Err(e) = stream.set_write_timeout(Some(Duration::from_secs(5))) {
80 warn!("Failed to set write timeout: {}", e);
81 }
82
83 match request_identities_raw(&mut stream) {
85 Ok(identities) => {
86 info!("Agent is running with {} keys loaded", identities.len());
87 AgentStatus::Running {
88 key_count: identities.len(),
89 }
90 }
91 Err(e) => {
92 warn!("Failed to query agent identities: {}", e);
93 AgentStatus::ConnectionFailed
94 }
95 }
96}
97
98pub fn agent_sign<P: AsRef<Path>>(
112 socket_path: P,
113 pubkey: &[u8],
114 data: &[u8],
115) -> Result<Vec<u8>, AgentError> {
116 let socket_path = socket_path.as_ref();
117
118 if pubkey.len() != 32 {
119 return Err(AgentError::InvalidInput(format!(
120 "Public key must be 32 bytes, got {}",
121 pubkey.len()
122 )));
123 }
124
125 debug!(
126 "Signing via agent at {:?} with pubkey {:?}...",
127 socket_path,
128 hex::encode(&pubkey[..4])
129 );
130
131 let mut stream = UnixStream::connect(socket_path).map_err(|e| {
133 error!("Failed to connect to agent: {}", e);
134 AgentError::IO(e)
135 })?;
136
137 stream
138 .set_read_timeout(Some(Duration::from_secs(30)))
139 .map_err(AgentError::IO)?;
140 stream
141 .set_write_timeout(Some(Duration::from_secs(30)))
142 .map_err(AgentError::IO)?;
143
144 let pubkey_array: [u8; 32] = pubkey
146 .try_into()
147 .map_err(|_| AgentError::InvalidInput("Public key must be exactly 32 bytes".to_string()))?;
148
149 let key_data = KeyData::Ed25519(Ed25519PublicKey(pubkey_array));
150
151 let signature = sign_request_raw(&mut stream, &key_data, data)?;
153
154 debug!("Successfully signed via agent");
155 Ok(signature)
156}
157
158pub fn add_identity<P: AsRef<Path>>(
171 socket_path: P,
172 pkcs8_bytes: &[u8],
173) -> Result<Vec<u8>, AgentError> {
174 let socket_path = socket_path.as_ref();
175
176 debug!("Adding identity to agent at {:?}", socket_path);
177
178 let seed = extract_ed25519_seed(pkcs8_bytes)?;
180
181 let mut stream = UnixStream::connect(socket_path).map_err(|e| {
183 error!("Failed to connect to agent: {}", e);
184 AgentError::IO(e)
185 })?;
186
187 stream
188 .set_read_timeout(Some(Duration::from_secs(30)))
189 .map_err(AgentError::IO)?;
190 stream
191 .set_write_timeout(Some(Duration::from_secs(30)))
192 .map_err(AgentError::IO)?;
193
194 let ssh_keypair = SshEd25519Keypair::from_seed(&seed);
196 let pubkey_bytes = ssh_keypair.public.0.to_vec();
197 let keypair_data = KeypairData::Ed25519(ssh_keypair);
198 let private_key = SshPrivateKey::new(keypair_data, "auths-key")
199 .map_err(|e| AgentError::CryptoError(format!("Failed to create SSH key: {}", e)))?;
200
201 add_identity_raw(&mut stream, &private_key)?;
203
204 info!(
205 "Successfully added identity to agent: {:?}...",
206 hex::encode(&pubkey_bytes[..4])
207 );
208 Ok(pubkey_bytes)
209}
210
211pub fn list_identities<P: AsRef<Path>>(socket_path: P) -> Result<Vec<Vec<u8>>, AgentError> {
220 let socket_path = socket_path.as_ref();
221
222 let mut stream = UnixStream::connect(socket_path).map_err(|e| {
223 error!("Failed to connect to agent: {}", e);
224 AgentError::IO(e)
225 })?;
226
227 stream
228 .set_read_timeout(Some(Duration::from_secs(5)))
229 .map_err(AgentError::IO)?;
230 stream
231 .set_write_timeout(Some(Duration::from_secs(5)))
232 .map_err(AgentError::IO)?;
233
234 let identities = request_identities_raw(&mut stream)?;
235
236 let pubkeys: Vec<Vec<u8>> = identities
237 .into_iter()
238 .filter_map(|id| match id.pubkey {
239 KeyData::Ed25519(pk) => Some(pk.0.to_vec()),
240 _ => None,
241 })
242 .collect();
243
244 Ok(pubkeys)
245}
246
247pub fn remove_all_identities<P: AsRef<Path>>(socket_path: P) -> Result<(), AgentError> {
256 let socket_path = socket_path.as_ref();
257
258 debug!("Removing all identities from agent at {:?}", socket_path);
259
260 let mut stream = UnixStream::connect(socket_path).map_err(|e| {
261 error!("Failed to connect to agent: {}", e);
262 AgentError::IO(e)
263 })?;
264
265 stream
266 .set_read_timeout(Some(Duration::from_secs(5)))
267 .map_err(AgentError::IO)?;
268 stream
269 .set_write_timeout(Some(Duration::from_secs(5)))
270 .map_err(AgentError::IO)?;
271
272 let msg = [proto::SSH_AGENTC_REMOVE_ALL_IDENTITIES];
273 send_message(&mut stream, &msg)?;
274
275 let response = read_message(&mut stream)?;
276 if response.is_empty() {
277 return Err(AgentError::Proto(
278 "Empty remove-all response from agent".to_string(),
279 ));
280 }
281
282 match response[0] {
283 proto::SSH_AGENT_SUCCESS => {
284 info!("All identities removed from agent");
285 Ok(())
286 }
287 proto::SSH_AGENT_FAILURE => Err(AgentError::Proto(
288 "Agent refused to remove identities".to_string(),
289 )),
290 other => Err(AgentError::Proto(format!(
291 "Unexpected remove-all response: {}",
292 other
293 ))),
294 }
295}
296
297fn extract_ed25519_seed(pkcs8_bytes: &[u8]) -> Result<[u8; 32], AgentError> {
301 use pkcs8::PrivateKeyInfo;
302 use pkcs8::der::Decode;
303
304 let pk_info = PrivateKeyInfo::from_der(pkcs8_bytes).map_err(|e| {
306 AgentError::KeyDeserializationError(format!("Failed to parse PKCS#8: {}", e))
307 })?;
308
309 let seed = pk_info.private_key;
310 if seed.len() == 32 {
311 let mut arr = [0u8; 32];
312 arr.copy_from_slice(seed);
313 return Ok(arr);
314 }
315
316 if pkcs8_bytes.len() >= 48 {
319 let mut arr = [0u8; 32];
320 arr.copy_from_slice(&pkcs8_bytes[16..48]);
321 return Ok(arr);
322 }
323
324 Err(AgentError::KeyDeserializationError(format!(
325 "Could not extract Ed25519 seed (got {} bytes)",
326 seed.len()
327 )))
328}
329
330fn request_identities_raw(stream: &mut UnixStream) -> Result<Vec<Identity>, AgentError> {
332 let msg = [proto::SSH_AGENTC_REQUEST_IDENTITIES];
334 send_message(stream, &msg)?;
335
336 let response = read_message(stream)?;
338
339 if response.is_empty() {
340 return Err(AgentError::Proto("Empty response from agent".to_string()));
341 }
342
343 match response[0] {
344 proto::SSH_AGENT_IDENTITIES_ANSWER => parse_identities_answer(&response[1..]),
345 proto::SSH_AGENT_FAILURE => Err(AgentError::Proto("Agent returned failure".to_string())),
346 other => Err(AgentError::Proto(format!(
347 "Unexpected response type: {}",
348 other
349 ))),
350 }
351}
352
353fn parse_identities_answer(data: &[u8]) -> Result<Vec<Identity>, AgentError> {
355 if data.len() < 4 {
356 return Err(AgentError::Proto("Identities answer too short".to_string()));
357 }
358
359 let num_keys = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
360 let mut identities = Vec::with_capacity(num_keys);
361 let mut pos = 4;
362
363 for _ in 0..num_keys {
364 if pos + 4 > data.len() {
366 return Err(AgentError::Proto("Truncated key blob length".to_string()));
367 }
368 let blob_len =
369 u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
370 pos += 4;
371
372 if pos + blob_len > data.len() {
374 return Err(AgentError::Proto("Truncated key blob".to_string()));
375 }
376 let blob = &data[pos..pos + blob_len];
377 pos += blob_len;
378
379 if pos + 4 > data.len() {
381 return Err(AgentError::Proto("Truncated comment length".to_string()));
382 }
383 let comment_len =
384 u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
385 pos += 4;
386
387 if pos + comment_len > data.len() {
389 return Err(AgentError::Proto("Truncated comment".to_string()));
390 }
391 let comment = String::from_utf8_lossy(&data[pos..pos + comment_len]).to_string();
392 pos += comment_len;
393
394 if let Some(pubkey) = parse_ssh_pubkey_blob(blob) {
396 identities.push(Identity { pubkey, comment });
397 }
398 }
399
400 Ok(identities)
401}
402
403fn parse_ssh_pubkey_blob(blob: &[u8]) -> Option<KeyData> {
405 if blob.len() < 4 {
406 return None;
407 }
408
409 let type_len = u32::from_be_bytes([blob[0], blob[1], blob[2], blob[3]]) as usize;
411 if blob.len() < 4 + type_len {
412 return None;
413 }
414
415 let key_type = std::str::from_utf8(&blob[4..4 + type_len]).ok()?;
416 let rest = &blob[4 + type_len..];
417
418 match key_type {
419 "ssh-ed25519" => {
420 if rest.len() < 4 {
421 return None;
422 }
423 let key_len = u32::from_be_bytes([rest[0], rest[1], rest[2], rest[3]]) as usize;
424 if rest.len() < 4 + key_len || key_len != 32 {
425 return None;
426 }
427 let key_bytes: [u8; 32] = rest[4..4 + 32].try_into().ok()?;
428 Some(KeyData::Ed25519(Ed25519PublicKey(key_bytes)))
429 }
430 _ => None,
431 }
432}
433
434fn sign_request_raw(
436 stream: &mut UnixStream,
437 pubkey: &KeyData,
438 data: &[u8],
439) -> Result<Vec<u8>, AgentError> {
440 let pubkey_blob = encode_pubkey_blob(pubkey)?;
442
443 let mut msg = Vec::new();
445 msg.push(proto::SSH_AGENTC_SIGN_REQUEST);
446
447 msg.extend_from_slice(&(pubkey_blob.len() as u32).to_be_bytes());
449 msg.extend_from_slice(&pubkey_blob);
450
451 msg.extend_from_slice(&(data.len() as u32).to_be_bytes());
453 msg.extend_from_slice(data);
454
455 msg.extend_from_slice(&0u32.to_be_bytes());
457
458 send_message(stream, &msg)?;
459
460 let response = read_message(stream)?;
462
463 if response.is_empty() {
464 return Err(AgentError::Proto("Empty sign response".to_string()));
465 }
466
467 match response[0] {
468 proto::SSH_AGENT_SIGN_RESPONSE => parse_sign_response(&response[1..]),
469 proto::SSH_AGENT_FAILURE => Err(AgentError::SigningFailed(
470 "Agent refused to sign".to_string(),
471 )),
472 other => Err(AgentError::Proto(format!(
473 "Unexpected sign response type: {}",
474 other
475 ))),
476 }
477}
478
479fn parse_sign_response(data: &[u8]) -> Result<Vec<u8>, AgentError> {
481 if data.len() < 4 {
482 return Err(AgentError::Proto("Sign response too short".to_string()));
483 }
484
485 let sig_len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
487 if data.len() < 4 + sig_len {
488 return Err(AgentError::Proto("Truncated signature blob".to_string()));
489 }
490
491 let sig_blob = &data[4..4 + sig_len];
492
493 if sig_blob.len() < 4 {
495 return Err(AgentError::Proto("Signature blob too short".to_string()));
496 }
497
498 let type_len =
499 u32::from_be_bytes([sig_blob[0], sig_blob[1], sig_blob[2], sig_blob[3]]) as usize;
500 if sig_blob.len() < 4 + type_len + 4 {
501 return Err(AgentError::Proto("Truncated signature type".to_string()));
502 }
503
504 let rest = &sig_blob[4 + type_len..];
505 let sig_data_len = u32::from_be_bytes([rest[0], rest[1], rest[2], rest[3]]) as usize;
506 if rest.len() < 4 + sig_data_len {
507 return Err(AgentError::Proto("Truncated signature data".to_string()));
508 }
509
510 Ok(rest[4..4 + sig_data_len].to_vec())
511}
512
513fn encode_pubkey_blob(pubkey: &KeyData) -> Result<Vec<u8>, AgentError> {
515 match pubkey {
516 KeyData::Ed25519(pk) => {
517 let mut blob = Vec::new();
518
519 let key_type = b"ssh-ed25519";
521 blob.extend_from_slice(&(key_type.len() as u32).to_be_bytes());
522 blob.extend_from_slice(key_type);
523
524 blob.extend_from_slice(&32u32.to_be_bytes());
526 blob.extend_from_slice(&pk.0);
527
528 Ok(blob)
529 }
530 _ => Err(AgentError::InvalidInput(
531 "Only Ed25519 keys are supported".to_string(),
532 )),
533 }
534}
535
536fn add_identity_raw(
538 stream: &mut UnixStream,
539 private_key: &SshPrivateKey,
540) -> Result<(), AgentError> {
541 let mut msg = Vec::new();
543 msg.push(proto::SSH_AGENTC_ADD_IDENTITY);
544
545 match private_key.key_data() {
546 KeypairData::Ed25519(kp) => {
547 let key_type = b"ssh-ed25519";
549 msg.extend_from_slice(&(key_type.len() as u32).to_be_bytes());
550 msg.extend_from_slice(key_type);
551
552 msg.extend_from_slice(&32u32.to_be_bytes());
554 msg.extend_from_slice(&kp.public.0);
555
556 let mut priv_bytes = Vec::with_capacity(64);
559 priv_bytes.extend_from_slice(&kp.private.to_bytes());
560 priv_bytes.extend_from_slice(&kp.public.0);
561 msg.extend_from_slice(&(priv_bytes.len() as u32).to_be_bytes());
562 msg.extend_from_slice(&priv_bytes);
563
564 let comment = b"auths-key";
566 msg.extend_from_slice(&(comment.len() as u32).to_be_bytes());
567 msg.extend_from_slice(comment);
568 }
569 _ => {
570 return Err(AgentError::InvalidInput(
571 "Only Ed25519 keys are supported".to_string(),
572 ));
573 }
574 }
575
576 send_message(stream, &msg)?;
577
578 let response = read_message(stream)?;
580
581 if response.is_empty() {
582 return Err(AgentError::Proto("Empty add identity response".to_string()));
583 }
584
585 match response[0] {
586 proto::SSH_AGENT_SUCCESS => Ok(()),
587 proto::SSH_AGENT_FAILURE => Err(AgentError::Proto(
588 "Agent refused to add identity".to_string(),
589 )),
590 other => Err(AgentError::Proto(format!(
591 "Unexpected add identity response: {}",
592 other
593 ))),
594 }
595}
596
597fn send_message(stream: &mut UnixStream, msg: &[u8]) -> Result<(), AgentError> {
599 let len = (msg.len() as u32).to_be_bytes();
600 stream.write_all(&len).map_err(AgentError::IO)?;
601 stream.write_all(msg).map_err(AgentError::IO)?;
602 stream.flush().map_err(AgentError::IO)?;
603 Ok(())
604}
605
606fn read_message(stream: &mut UnixStream) -> Result<Vec<u8>, AgentError> {
608 let mut len_buf = [0u8; 4];
609 stream.read_exact(&mut len_buf).map_err(AgentError::IO)?;
610 let len = u32::from_be_bytes(len_buf) as usize;
611
612 if len > 256 * 1024 {
613 return Err(AgentError::Proto(format!(
614 "Message too large: {} bytes",
615 len
616 )));
617 }
618
619 let mut msg = vec![0u8; len];
620 stream.read_exact(&mut msg).map_err(AgentError::IO)?;
621 Ok(msg)
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_check_agent_status_not_running() {
630 let status = check_agent_status("/nonexistent/path/to/socket.sock");
631 assert_eq!(status, AgentStatus::NotRunning);
632 }
633
634 #[test]
635 fn test_encode_pubkey_blob() {
636 let pubkey = Ed25519PublicKey([0x42; 32]);
637 let key_data = KeyData::Ed25519(pubkey);
638 let blob = encode_pubkey_blob(&key_data).unwrap();
639
640 assert_eq!(&blob[0..4], &11u32.to_be_bytes()); assert_eq!(&blob[4..15], b"ssh-ed25519");
643 assert_eq!(&blob[15..19], &32u32.to_be_bytes()); assert_eq!(&blob[19..51], &[0x42; 32]); }
646
647 #[test]
648 fn test_parse_ssh_pubkey_blob() {
649 let mut blob = Vec::new();
651 blob.extend_from_slice(&11u32.to_be_bytes()); blob.extend_from_slice(b"ssh-ed25519");
653 blob.extend_from_slice(&32u32.to_be_bytes()); blob.extend_from_slice(&[0x42; 32]); let result = parse_ssh_pubkey_blob(&blob);
657 assert!(result.is_some());
658
659 if let Some(KeyData::Ed25519(pk)) = result {
660 assert_eq!(pk.0, [0x42; 32]);
661 } else {
662 panic!("Expected Ed25519 key");
663 }
664 }
665
666 #[test]
667 fn test_extract_ed25519_seed_pkcs8() {
668 let result = extract_ed25519_seed(&[0u8; 10]);
671 assert!(result.is_err());
672 }
673}