1use std::fmt;
2use std::io::{Read, Write};
3use std::path::{Path, PathBuf};
4
5#[cfg(unix)]
6pub use std::os::unix::net::UnixStream as Stream;
7
8#[cfg(windows)]
9pub use winpipe::WinStream as Stream;
10
11use thiserror::Error;
12use zeroize::Zeroize as _;
13
14use crate::agent::msg;
15use crate::agent::Constraint;
16use crate::encoding::{self, Encodable};
17use crate::encoding::{Buffer, Encoding, Reader};
18
19pub type Signature = [u8; 64];
21
22#[derive(Debug, Error)]
23pub enum Error {
24 #[error("SSH agent replied with unexpected data, violating the SSH agent protocol.")]
26 AgentProtocolError,
27 #[error(
28 "SSH agent replied with failure (protocol message number 5), which could not be handled."
29 )]
30 AgentFailure,
31 #[error("Unable to connect to SSH agent because '{path}' was not found: {source}")]
32 BadAuthSock {
33 path: String,
34 source: std::io::Error,
35 },
36 #[error("Encoding error while communicating with SSH agent: {0}")]
37 Encoding(#[from] encoding::Error),
38 #[error("Unable to read environment variable '{var}': {source}")]
39 EnvVar {
40 var: String,
41 source: std::env::VarError,
42 },
43 #[error("Unable to connect SSH agent using the path '{path}': {source}")]
44 Connect {
45 path: String,
46 #[source]
47 source: std::io::Error,
48 },
49 #[error("I/O error while communicating with SSH agent: {0}")]
50 Io(#[from] std::io::Error),
51}
52
53impl Error {
54 pub fn is_not_running(&self) -> bool {
55 matches!(self, Self::EnvVar { .. } | Self::BadAuthSock { .. })
56 }
57}
58
59pub struct AgentClient<S = Stream> {
61 path: Option<PathBuf>,
63
64 stream: S,
66}
67
68impl<S> AgentClient<S> {
69 pub fn path(&self) -> Option<&Path> {
70 self.path.as_deref()
71 }
72}
73
74impl AgentClient<Stream> {
75 pub fn connect<P>(path: P) -> Result<Self, Error>
77 where
78 P: AsRef<Path>,
79 {
80 let path = path.as_ref().to_owned();
81
82 let stream = match Stream::connect(&path) {
83 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
84 return Err(Error::BadAuthSock {
85 path: path.display().to_string(),
86 source: err,
87 })
88 }
89 Err(err) => {
90 return Err(Error::Connect {
91 path: path.display().to_string(),
92 source: err,
93 })
94 }
95 Ok(stream) => stream,
96 };
97
98 Ok(Self {
99 path: Some(path),
100 stream,
101 })
102 }
103
104 pub fn connect_env() -> Result<Self, Error> {
105 const SSH_AUTH_SOCK: &str = "SSH_AUTH_SOCK";
106
107 let path = match std::env::var(SSH_AUTH_SOCK) {
108 Ok(var) => var,
109 Err(err) => {
110 if cfg!(windows) {
111 "\\\\.\\pipe\\openssh-ssh-agent".to_string()
115 } else {
116 return Err(Error::EnvVar {
117 var: SSH_AUTH_SOCK.to_string(),
118 source: err,
119 });
120 }
121 }
122 };
123
124 Self::connect(path)
125 }
126}
127
128impl<Stream: ClientStream> AgentClient<Stream> {
129 pub fn new(path: Option<PathBuf>, stream: Stream) -> Self {
130 Self { path, stream }
131 }
132
133 pub fn add_identity<K>(&mut self, key: &K, constraints: &[Constraint]) -> Result<(), Error>
136 where
137 K: Encodable,
138 K::Error: std::error::Error + Send + Sync + 'static,
139 {
140 let mut buf = Buffer::default();
141
142 buf.resize(4, 0);
143
144 if constraints.is_empty() {
145 buf.push(msg::ADD_IDENTITY)
146 } else {
147 buf.push(msg::ADD_ID_CONSTRAINED)
148 }
149 key.write(&mut buf);
150
151 if !constraints.is_empty() {
152 for cons in constraints {
153 match *cons {
154 Constraint::KeyLifetime { seconds } => {
155 buf.push(msg::CONSTRAIN_LIFETIME);
156 buf.extend_u32(seconds);
157 }
158 Constraint::Confirm => buf.push(msg::CONSTRAIN_CONFIRM),
159 Constraint::Extensions {
160 ref name,
161 ref details,
162 } => {
163 buf.push(msg::CONSTRAIN_EXTENSION);
164 buf.extend_ssh_string(name);
165 buf.extend_ssh_string(details);
166 }
167 }
168 }
169 }
170 buf.write_len();
171 self.stream.request(&buf)?;
172
173 Ok(())
174 }
175
176 pub fn add_smartcard_key(
179 &mut self,
180 id: &str,
181 pin: &[u8],
182 constraints: &[Constraint],
183 ) -> Result<(), Error> {
184 let mut buf = Buffer::default();
185
186 buf.resize(4, 0);
187
188 if constraints.is_empty() {
189 buf.push(msg::ADD_SMARTCARD_KEY)
190 } else {
191 buf.push(msg::ADD_SMARTCARD_KEY_CONSTRAINED)
192 }
193 buf.extend_ssh_string(id.as_bytes());
194 buf.extend_ssh_string(pin);
195
196 if !constraints.is_empty() {
197 buf.extend_usize(constraints.len());
198 for cons in constraints {
199 match *cons {
200 Constraint::KeyLifetime { seconds } => {
201 buf.push(msg::CONSTRAIN_LIFETIME);
202 buf.extend_u32(seconds);
203 }
204 Constraint::Confirm => buf.push(msg::CONSTRAIN_CONFIRM),
205 Constraint::Extensions {
206 ref name,
207 ref details,
208 } => {
209 buf.push(msg::CONSTRAIN_EXTENSION);
210 buf.extend_ssh_string(name);
211 buf.extend_ssh_string(details);
212 }
213 }
214 }
215 }
216 buf.write_len();
217 self.stream.request(&buf)?;
218
219 Ok(())
220 }
221
222 pub fn lock(&mut self, passphrase: &[u8]) -> Result<(), Error> {
224 let mut buf = Buffer::default();
225
226 buf.resize(4, 0);
227 buf.push(msg::LOCK);
228 buf.extend_ssh_string(passphrase);
229 buf.write_len();
230
231 self.stream.request(&buf)?;
232
233 Ok(())
234 }
235
236 pub fn unlock(&mut self, passphrase: &[u8]) -> Result<(), Error> {
238 let mut buf = Buffer::default();
239 buf.resize(4, 0);
240 buf.push(msg::UNLOCK);
241 buf.extend_ssh_string(passphrase);
242 buf.write_len();
243
244 self.stream.request(&buf)?;
245
246 Ok(())
247 }
248
249 pub fn request_identities<K>(&mut self) -> Result<Vec<K>, Error>
252 where
253 K: Encodable,
254 K::Error: std::error::Error + Send + Sync + 'static,
255 {
256 let mut buf = Buffer::default();
257 buf.resize(4, 0);
258 buf.push(msg::REQUEST_IDENTITIES);
259 buf.write_len();
260
261 let mut keys = Vec::new();
262 let resp = self.stream.request(&buf)?;
263
264 if resp[0] == msg::IDENTITIES_ANSWER {
265 let mut r = resp.reader(1);
266 let n = r.read_u32()?;
267
268 for _ in 0..n {
269 let key = r.read_string()?;
270 let _ = r.read_string()?;
271 let mut r = key.reader(0);
272
273 if let Ok(pk) = K::read(&mut r) {
274 keys.push(pk);
275 }
276 }
277 }
278
279 Ok(keys)
280 }
281
282 pub fn sign<K>(&mut self, public: &K, data: &[u8]) -> Result<Signature, Error>
284 where
285 K: Encodable + fmt::Debug,
286 {
287 let req = self.prepare_sign_request(public, data);
288 let resp = self.stream.request(&req)?;
289
290 if !resp.is_empty() && resp[0] == msg::SIGN_RESPONSE {
291 self.read_signature(&resp)
292 } else if !resp.is_empty() && resp[0] == msg::FAILURE {
293 Err(Error::AgentFailure)
294 } else {
295 Err(Error::AgentProtocolError)
296 }
297 }
298
299 fn prepare_sign_request<K>(&self, public: &K, data: &[u8]) -> Buffer
300 where
301 K: Encodable + fmt::Debug,
302 {
303 let mut pk = Buffer::default();
309 public.write(&mut pk);
310
311 let total = 1 + pk.len() + 4 + data.len() + 4;
312
313 let mut buf = Buffer::default();
314 buf.extend_usize(total);
315 buf.push(msg::SIGN_REQUEST);
316 buf.extend_from_slice(&pk);
317 buf.extend_ssh_string(data);
318
319 buf.extend_u32(0);
321 buf
322 }
323
324 fn read_signature(&self, sig: &[u8]) -> Result<Signature, Error> {
325 let mut r = sig.reader(1);
326 let mut resp = r.read_string()?.reader(0);
327 let _t = resp.read_string()?;
328 let sig = resp.read_string()?;
329
330 let mut out = [0; 64];
331 out.copy_from_slice(sig);
332
333 Ok(out)
334 }
335
336 pub fn remove_identity<K>(&mut self, public: &K) -> Result<(), Error>
338 where
339 K: Encodable,
340 {
341 let mut pk: Buffer = Vec::new().into();
342 public.write(&mut pk);
343
344 let total = 1 + pk.len();
345
346 let mut buf = Buffer::default();
347 buf.extend_usize(total);
348 buf.push(msg::REMOVE_IDENTITY);
349 buf.extend_from_slice(&pk);
350
351 self.stream.request(&buf)?;
352
353 Ok(())
354 }
355
356 pub fn remove_smartcard_key(&mut self, id: &str, pin: &[u8]) -> Result<(), Error> {
358 let mut buf = Buffer::default();
359 buf.resize(4, 0);
360 buf.push(msg::REMOVE_SMARTCARD_KEY);
361 buf.extend_ssh_string(id.as_bytes());
362 buf.extend_ssh_string(pin);
363 buf.write_len();
364
365 self.stream.request(&buf)?;
366
367 Ok(())
368 }
369
370 pub fn remove_all_identities(&mut self) -> Result<(), Error> {
372 let mut buf = Buffer::default();
373 buf.resize(4, 0);
374 buf.push(msg::REMOVE_ALL_IDENTITIES);
375 buf.write_len();
376
377 self.stream.request(&buf)?;
378
379 Ok(())
380 }
381
382 pub fn extension(&mut self, typ: &[u8], ext: &[u8]) -> Result<(), Error> {
384 let mut buf = Buffer::default();
385
386 buf.resize(4, 0);
387 buf.push(msg::EXTENSION);
388 buf.extend_ssh_string(typ);
389 buf.extend_ssh_string(ext);
390 buf.write_len();
391
392 self.stream.request(&buf)?;
393
394 Ok(())
395 }
396
397 pub fn query_extension(&mut self, typ: &[u8], mut ext: Buffer) -> Result<bool, Error> {
399 let mut req = Buffer::default();
400
401 req.resize(4, 0);
402 req.push(msg::EXTENSION);
403 req.extend_ssh_string(typ);
404 req.write_len();
405
406 let resp = self.stream.request(&req)?;
407 let mut r = resp.reader(1);
408 ext.extend(r.read_string()?);
409
410 Ok(!resp.is_empty() && resp[0] == msg::SUCCESS)
411 }
412}
413
414pub trait ClientStream: Sized + Send + Sync {
415 fn request(&mut self, msg: &[u8]) -> Result<Buffer, Error>;
416}
417
418impl<S: Read + Write + Sized + Send + Sync> ClientStream for S {
419 fn request(&mut self, msg: &[u8]) -> Result<Buffer, Error> {
420 let mut resp = Buffer::default();
421
422 self.write_all(msg)?;
424 self.flush()?;
425
426 resp.resize(4, 0);
428 self.read_exact(&mut resp)?;
429
430 let len = u32::from_be_bytes(resp.as_slice().try_into().unwrap()) as usize;
432
433 resp.zeroize();
434 resp.resize(len, 0);
435 self.read_exact(&mut resp)?;
436
437 Ok(resp)
438 }
439}