1use std::fmt;
2use std::io::{Read, Write};
3use std::ops::DerefMut;
4use std::os::unix::net::UnixStream;
5use std::path::Path;
6use std::str::FromStr;
7
8use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
9use log::*;
10use thiserror::Error;
11use zeroize::Zeroize as _;
12
13use crate::agent::msg;
14use crate::agent::Constraint;
15use crate::encoding::{self, Encodable};
16use crate::encoding::{Buffer, Encoding, Reader};
17
18pub type Signature = [u8; 64];
20
21#[derive(Debug, Error)]
22pub enum Error {
23 #[error("Agent protocol error")]
25 AgentProtocolError,
26 #[error("Agent failure")]
27 AgentFailure,
28 #[error("Unable to connect to ssh-agent. The environment variable `SSH_AUTH_SOCK` was set, but it points to a nonexistent file or directory.")]
29 BadAuthSock,
30 #[error(transparent)]
31 Encoding(#[from] encoding::Error),
32 #[error("Environment variable `{0}` not found")]
33 EnvVar(&'static str),
34 #[error(transparent)]
35 Io(#[from] std::io::Error),
36 #[error(transparent)]
37 Private(Box<dyn std::error::Error + Send + Sync + 'static>),
38 #[error(transparent)]
39 Public(Box<dyn std::error::Error + Send + Sync + 'static>),
40 #[error(transparent)]
41 Signature(Box<dyn std::error::Error + Send + Sync + 'static>),
42}
43
44impl Error {
45 pub fn is_not_running(&self) -> bool {
46 matches!(self, Self::EnvVar("SSH_AUTH_SOCK"))
47 }
48}
49
50pub struct AgentClient<S> {
52 stream: S,
53}
54
55impl<S> AgentClient<S> {
56 pub fn connect(stream: S) -> Self {
58 AgentClient { stream }
59 }
60
61 pub fn pid(&self) -> Option<u32> {
63 std::env::var("SSH_AGENT_PID")
64 .ok()
65 .and_then(|v| u32::from_str(&v).ok())
66 }
67}
68
69pub trait ClientStream: Sized + Send + Sync {
70 fn request(&mut self, req: &[u8]) -> Result<Buffer, Error>;
72
73 fn connect<P>(path: P) -> Result<AgentClient<Self>, Error>
75 where
76 P: AsRef<Path> + Send;
77
78 fn connect_env() -> Result<AgentClient<Self>, Error> {
79 let Ok(var) = std::env::var("SSH_AUTH_SOCK") else {
80 return Err(Error::EnvVar("SSH_AUTH_SOCK"));
81 };
82 match Self::connect(var) {
83 Err(Error::Io(io_err)) if io_err.kind() == std::io::ErrorKind::NotFound => {
84 Err(Error::BadAuthSock)
85 }
86 other => other,
87 }
88 }
89}
90
91impl<S: ClientStream> AgentClient<S> {
92 pub fn add_identity<K>(&mut self, key: &K, constraints: &[Constraint]) -> Result<(), Error>
95 where
96 K: Encodable,
97 K::Error: std::error::Error + Send + Sync + 'static,
98 {
99 let mut buf = Buffer::default();
100
101 buf.resize(4, 0);
102
103 if constraints.is_empty() {
104 buf.push(msg::ADD_IDENTITY)
105 } else {
106 buf.push(msg::ADD_ID_CONSTRAINED)
107 }
108 key.write(&mut buf);
109
110 if !constraints.is_empty() {
111 for cons in constraints {
112 match *cons {
113 Constraint::KeyLifetime { seconds } => {
114 buf.push(msg::CONSTRAIN_LIFETIME);
115 buf.deref_mut().write_u32::<BigEndian>(seconds)?
116 }
117 Constraint::Confirm => buf.push(msg::CONSTRAIN_CONFIRM),
118 Constraint::Extensions {
119 ref name,
120 ref details,
121 } => {
122 buf.push(msg::CONSTRAIN_EXTENSION);
123 buf.extend_ssh_string(name);
124 buf.extend_ssh_string(details);
125 }
126 }
127 }
128 }
129 buf.write_len();
130 self.stream.request(&buf)?;
131
132 Ok(())
133 }
134
135 pub fn add_smartcard_key(
138 &mut self,
139 id: &str,
140 pin: &[u8],
141 constraints: &[Constraint],
142 ) -> Result<(), Error> {
143 let mut buf = Buffer::default();
144
145 buf.resize(4, 0);
146
147 if constraints.is_empty() {
148 buf.push(msg::ADD_SMARTCARD_KEY)
149 } else {
150 buf.push(msg::ADD_SMARTCARD_KEY_CONSTRAINED)
151 }
152 buf.extend_ssh_string(id.as_bytes());
153 buf.extend_ssh_string(pin);
154
155 if !constraints.is_empty() {
156 buf.deref_mut()
157 .write_u32::<BigEndian>(constraints.len() as u32)?;
158 for cons in constraints {
159 match *cons {
160 Constraint::KeyLifetime { seconds } => {
161 buf.push(msg::CONSTRAIN_LIFETIME);
162 buf.deref_mut().write_u32::<BigEndian>(seconds)?;
163 }
164 Constraint::Confirm => buf.push(msg::CONSTRAIN_CONFIRM),
165 Constraint::Extensions {
166 ref name,
167 ref details,
168 } => {
169 buf.push(msg::CONSTRAIN_EXTENSION);
170 buf.extend_ssh_string(name);
171 buf.extend_ssh_string(details);
172 }
173 }
174 }
175 }
176 buf.write_len();
177 self.stream.request(&buf)?;
178
179 Ok(())
180 }
181
182 pub fn lock(&mut self, passphrase: &[u8]) -> Result<(), Error> {
184 let mut buf = Buffer::default();
185
186 buf.resize(4, 0);
187 buf.push(msg::LOCK);
188 buf.extend_ssh_string(passphrase);
189 buf.write_len();
190
191 self.stream.request(&buf)?;
192
193 Ok(())
194 }
195
196 pub fn unlock(&mut self, passphrase: &[u8]) -> Result<(), Error> {
198 let mut buf = Buffer::default();
199 buf.resize(4, 0);
200 buf.push(msg::UNLOCK);
201 buf.extend_ssh_string(passphrase);
202 buf.write_len();
203
204 self.stream.request(&buf)?;
205
206 Ok(())
207 }
208
209 pub fn request_identities<K>(&mut self) -> Result<Vec<K>, Error>
212 where
213 K: Encodable,
214 K::Error: std::error::Error + Send + Sync + 'static,
215 {
216 let mut buf = Buffer::default();
217 buf.resize(4, 0);
218 buf.push(msg::REQUEST_IDENTITIES);
219 buf.write_len();
220
221 let mut keys = Vec::new();
222 let resp = self.stream.request(&buf)?;
223
224 if resp[0] == msg::IDENTITIES_ANSWER {
225 let mut r = resp.reader(1);
226 let n = r.read_u32()?;
227
228 for _ in 0..n {
229 let key = r.read_string()?;
230 let _ = r.read_string()?;
231 let mut r = key.reader(0);
232
233 if let Ok(pk) = K::read(&mut r) {
234 keys.push(pk);
235 }
236 }
237 }
238
239 Ok(keys)
240 }
241
242 pub fn sign<K>(&mut self, public: &K, data: &[u8]) -> Result<Signature, Error>
244 where
245 K: Encodable + fmt::Debug,
246 {
247 let req = self.prepare_sign_request(public, data);
248 let resp = self.stream.request(&req)?;
249
250 if !resp.is_empty() && resp[0] == msg::SIGN_RESPONSE {
251 self.read_signature(&resp)
252 } else if !resp.is_empty() && resp[0] == msg::FAILURE {
253 Err(Error::AgentFailure)
254 } else {
255 Err(Error::AgentProtocolError)
256 }
257 }
258
259 fn prepare_sign_request<K>(&self, public: &K, data: &[u8]) -> Buffer
260 where
261 K: Encodable + fmt::Debug,
262 {
263 let mut pk = Buffer::default();
269 public.write(&mut pk);
270
271 let total = 1 + pk.len() + 4 + data.len() + 4;
272
273 let mut buf = Buffer::default();
274 buf.write_u32::<BigEndian>(total as u32)
275 .expect("Writing to a vector never fails");
276 buf.push(msg::SIGN_REQUEST);
277 buf.extend_from_slice(&pk);
278 buf.extend_ssh_string(data);
279
280 buf.write_u32::<BigEndian>(0).unwrap();
282 buf
283 }
284
285 fn read_signature(&self, sig: &[u8]) -> Result<Signature, Error> {
286 let mut r = sig.reader(1);
287 let mut resp = r.read_string()?.reader(0);
288 let _t = resp.read_string()?;
289 let sig = resp.read_string()?;
290
291 let mut out = [0; 64];
292 out.copy_from_slice(sig);
293
294 Ok(out)
295 }
296
297 pub fn remove_identity<K>(&mut self, public: &K) -> Result<(), Error>
299 where
300 K: Encodable,
301 {
302 let mut pk: Buffer = Vec::new().into();
303 public.write(&mut pk);
304
305 let total = 1 + pk.len();
306
307 let mut buf = Buffer::default();
308 buf.write_u32::<BigEndian>(total as u32)?;
309 buf.push(msg::REMOVE_IDENTITY);
310 buf.extend_from_slice(&pk);
311
312 self.stream.request(&buf)?;
313
314 Ok(())
315 }
316
317 pub fn remove_smartcard_key(&mut self, id: &str, pin: &[u8]) -> Result<(), Error> {
319 let mut buf = Buffer::default();
320 buf.resize(4, 0);
321 buf.push(msg::REMOVE_SMARTCARD_KEY);
322 buf.extend_ssh_string(id.as_bytes());
323 buf.extend_ssh_string(pin);
324 buf.write_len();
325
326 self.stream.request(&buf)?;
327
328 Ok(())
329 }
330
331 pub fn remove_all_identities(&mut self) -> Result<(), Error> {
333 let mut buf = Buffer::default();
334 buf.resize(4, 0);
335 buf.push(msg::REMOVE_ALL_IDENTITIES);
336 buf.write_len();
337
338 self.stream.request(&buf)?;
339
340 Ok(())
341 }
342
343 pub fn extension(&mut self, typ: &[u8], ext: &[u8]) -> Result<(), Error> {
345 let mut buf = Buffer::default();
346
347 buf.resize(4, 0);
348 buf.push(msg::EXTENSION);
349 buf.extend_ssh_string(typ);
350 buf.extend_ssh_string(ext);
351 buf.write_len();
352
353 self.stream.request(&buf)?;
354
355 Ok(())
356 }
357
358 pub fn query_extension(&mut self, typ: &[u8], mut ext: Buffer) -> Result<bool, Error> {
360 let mut req = Buffer::default();
361
362 req.resize(4, 0);
363 req.push(msg::EXTENSION);
364 req.extend_ssh_string(typ);
365 req.write_len();
366
367 let resp = self.stream.request(&req)?;
368 let mut r = resp.reader(1);
369 ext.extend(r.read_string()?);
370
371 Ok(!resp.is_empty() && resp[0] == msg::SUCCESS)
372 }
373}
374
375#[cfg(not(unix))]
376impl ClientStream for TcpStream {
377 fn connect_uds<P>(_: P) -> Result<AgentClient<Self>, Error>
378 where
379 P: AsRef<Path> + Send,
380 {
381 Err(Error::AgentFailure)
382 }
383
384 fn read_response(&mut self, _: &mut Buffer) -> Result<(), Error> {
385 Err(Error::AgentFailure)
386 }
387
388 fn connect_env() -> Result<AgentClient<Self>, Error> {
389 Err(Error::AgentFailure)
390 }
391}
392
393#[cfg(unix)]
394impl ClientStream for UnixStream {
395 fn connect<P>(path: P) -> Result<AgentClient<Self>, Error>
396 where
397 P: AsRef<Path> + Send,
398 {
399 let stream = UnixStream::connect(path)?;
400
401 Ok(AgentClient { stream })
402 }
403
404 fn request(&mut self, msg: &[u8]) -> Result<Buffer, Error> {
405 let mut resp = Buffer::default();
406
407 self.write_all(msg)?;
409 self.flush()?;
410
411 resp.resize(4, 0);
413 self.read_exact(&mut resp)?;
414
415 let len = BigEndian::read_u32(&resp) as usize;
417 resp.zeroize();
418 resp.resize(len, 0);
419 self.read_exact(&mut resp)?;
420
421 Ok(resp)
422 }
423}