bc_components/encrypted_key/
ssh_agent_params.rs

1use std::{cell::RefCell, env, path::Path, rc::Rc};
2
3use anyhow::{Context, Result, bail};
4use bc_crypto::hkdf_hmac_sha256;
5use dcbor::prelude::*;
6use ssh_agent_client_rs::Client;
7
8use super::{KeyDerivation, KeyDerivationMethod, SALT_LEN};
9use crate::{EncryptedMessage, Nonce, Salt, SymmetricKey};
10
11#[allow(dead_code)]
12pub trait SSHAgent {
13    fn list_identities(&mut self) -> Result<Vec<ssh_key::PublicKey>>;
14    fn add_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()>;
15    fn remove_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()>;
16    fn remove_all_identities(&mut self) -> Result<()>;
17    fn sign(
18        &mut self,
19        key: &ssh_key::PublicKey,
20        data: &[u8],
21    ) -> Result<ssh_key::Signature>;
22}
23
24impl SSHAgent for Client {
25    fn list_identities(&mut self) -> Result<Vec<ssh_key::PublicKey>> {
26        self.list_identities()
27            .map_err(|e| anyhow::anyhow!(e.to_string()))
28    }
29
30    fn add_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
31        self.add_identity(key)
32            .map_err(|e| anyhow::anyhow!(e.to_string()))
33    }
34
35    fn remove_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
36        self.remove_identity(key)
37            .map_err(|e| anyhow::anyhow!(e.to_string()))
38    }
39
40    fn remove_all_identities(&mut self) -> Result<()> {
41        self.remove_all_identities()
42            .map_err(|e| anyhow::anyhow!(e.to_string()))
43    }
44
45    fn sign(
46        &mut self,
47        key: &ssh_key::PublicKey,
48        data: &[u8],
49    ) -> Result<ssh_key::Signature> {
50        self.sign(key, data)
51            .map_err(|e| anyhow::anyhow!(e.to_string()))
52    }
53}
54
55/// Struct representing SSH Agent parameters.
56///
57/// CDDL:
58/// ```cddl
59/// SSHAgentParams = [4, Salt, id: tstr]
60/// ```
61#[derive(Clone)]
62pub struct SSHAgentParams {
63    salt: Salt,
64    id: String,
65
66    agent: Option<Rc<RefCell<dyn SSHAgent + 'static>>>,
67}
68
69impl PartialEq for SSHAgentParams {
70    fn eq(&self, other: &Self) -> bool {
71        self.salt == other.salt && self.id == other.id
72    }
73}
74
75impl Eq for SSHAgentParams {}
76
77impl std::fmt::Debug for SSHAgentParams {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("SSHAgentParams")
80            .field("salt", &self.salt)
81            .field("id", &self.id)
82            .finish()
83    }
84}
85
86impl SSHAgentParams {
87    pub fn new() -> Self {
88        Self::new_opt(
89            Salt::new_with_len(SALT_LEN).unwrap(),
90            String::new(),
91            None,
92        )
93    }
94
95    pub fn new_opt(
96        salt: Salt,
97        id: impl AsRef<str>,
98        agent: Option<Rc<RefCell<dyn SSHAgent + 'static>>>,
99    ) -> Self {
100        Self { salt, id: id.as_ref().to_string(), agent }
101    }
102
103    pub fn salt(&self) -> &Salt { &self.salt }
104
105    pub fn id(&self) -> &String { &self.id }
106
107    pub fn agent(&self) -> Option<Rc<RefCell<dyn SSHAgent + 'static>>> {
108        self.agent.clone()
109    }
110
111    pub fn set_agent(
112        &mut self,
113        agent: Option<Rc<RefCell<dyn SSHAgent + 'static>>>,
114    ) {
115        self.agent = agent;
116    }
117}
118
119/// Connect to whatever socket/pipe `$SSH_AUTH_SOCK` points at.
120pub fn connect_to_ssh_agent() -> Result<Rc<RefCell<dyn SSHAgent + 'static>>> {
121    let sock =
122        env::var("SSH_AUTH_SOCK").context("SSH_AUTH_SOCK env var not set")?;
123    let client =
124        Client::connect(Path::new(&sock)).context("no ssh-agent reachable")?;
125    Ok(Rc::new(RefCell::new(client)))
126}
127
128impl KeyDerivation for SSHAgentParams {
129    const INDEX: usize = KeyDerivationMethod::SSHAgent as usize;
130
131    fn lock(
132        &mut self,
133        content_key: &SymmetricKey,
134        secret: impl AsRef<[u8]>,
135    ) -> Result<EncryptedMessage> {
136        // Convert `secret` to a string for the SSH ID.
137        let id = String::from_utf8(secret.as_ref().to_vec())
138            .context("SSH Agent secret must be a valid UTF-8 string")?;
139
140        // If None call connect_to_agent to get the agent.
141        let agent = self
142            .agent
143            .as_ref()
144            .map_or_else(|| connect_to_ssh_agent(), |a| Ok(a.clone()))?;
145
146        // List all identities in the SSH agent.
147        let ids = agent.borrow_mut().list_identities()?;
148
149        // Filter down to the identities that have Ed25519 keys.
150        let ids: Vec<_> = ids
151            .into_iter()
152            .filter(|k| k.key_data().ed25519().is_some())
153            .collect();
154
155        if ids.is_empty() {
156            bail!("No Ed25519 identities available in SSH agent");
157        }
158
159        // If `id` is empty, use the first available identity, otherwise find
160        // the one matching `id`.
161        let identity = if id.is_empty() {
162            // If there is more than one identity, throw an error.
163            if ids.len() > 1 {
164                bail!("Multiple identities available in SSH agent, but no ID provided");
165            }
166            // Safe to unwrap because we checked that `ids` is not empty
167            ids.first().unwrap()
168        } else {
169            ids.iter()
170                .find(|k| k.comment() == id)
171                .context("No matching identity found")?
172        };
173
174        // Sign the salt with the identity.
175        let salt = self.salt().clone();
176        let sig = agent
177            .borrow_mut()
178            .sign(identity, salt.data())
179            .context("SSH agent refused to sign")?;
180
181        // Derive the symmetric key using HKDF with HMAC-SHA256.
182        let derived_key = SymmetricKey::from_data_ref(hkdf_hmac_sha256(
183            &sig,
184            &salt,
185            SymmetricKey::SYMMETRIC_KEY_SIZE,
186        ))
187        .unwrap(); // Safe to unwrap because SYMMETRIC_KEY_SIZE is valid.
188
189        // Set the ID in the parameters.
190        self.id = id;
191
192        // Encode the method as CBOR data.
193        let encoded_method = self.to_cbor_data();
194
195        // Encrypt the content key with the derived key, using the
196        // encoded method as additional authenticated data.
197        Ok(derived_key.encrypt(
198            content_key,
199            Some(encoded_method),
200            Option::<Nonce>::None,
201        ))
202    }
203
204    fn unlock(
205        &self,
206        encrypted_message: &EncryptedMessage,
207        secret: impl AsRef<[u8]>,
208    ) -> Result<SymmetricKey> {
209        // Convert `secret` to a string for the SSH ID.
210        let id = String::from_utf8(secret.as_ref().to_vec())
211            .context("SSH Agent secret must be a valid UTF-8 string")?;
212
213        // If None call connect_to_agent to get the agent.
214        let agent = self
215            .agent
216            .as_ref()
217            .map_or_else(|| connect_to_ssh_agent(), |a| Ok(a.clone()))?;
218
219        // List all identities in the SSH agent.
220        let ids = agent.borrow_mut().list_identities()?;
221
222        // Filter down to the identities that have Ed25519 keys.
223        let ids: Vec<_> = ids
224            .into_iter()
225            .filter(|k| k.key_data().ed25519().is_some())
226            .collect();
227
228        if ids.is_empty() {
229            bail!("No Ed25519 identities available in SSH agent");
230        }
231
232        // id priority:
233        // 1. `id` passed in as secret if not empty,
234        // 2. `self.id` if not empty,
235        // 3. first available identity.
236        let identity = if !id.is_empty() {
237            ids.iter()
238                .find(|k| k.comment() == id)
239                .context("No matching identity found")?
240        } else if !self.id.is_empty() {
241            ids.iter()
242                .find(|k| k.comment() == self.id)
243                .context("No matching identity found")?
244        } else {
245            // Safe to unwrap because we checked that `ids` is not empty
246            ids.first().unwrap()
247        };
248
249        // Sign the salt with the identity.
250        let sig = agent
251            .borrow_mut()
252            .sign(identity, self.salt.data())
253            .context("SSH agent refused to sign")?;
254
255        // Derive the symmetric key using HKDF with HMAC-SHA256.
256        let derived_key = SymmetricKey::from_data_ref(hkdf_hmac_sha256(
257            &sig,
258            &self.salt,
259            SymmetricKey::SYMMETRIC_KEY_SIZE,
260        ))
261        .unwrap(); // Safe to unwrap because SYMMETRIC_KEY_SIZE is valid.
262
263        // Decrypt the encrypted key with the derived key.
264        let decrypted_key = derived_key
265            .decrypt(encrypted_message)
266            .context("Failed to decrypt the encrypted key")?;
267
268        let content_key = decrypted_key
269            .try_into()
270            .context("Failed to convert decrypted key to SymmetricKey")?;
271
272        // If the decryption was successful, return the symmetric key.
273        Ok(content_key)
274    }
275}
276
277impl std::fmt::Display for SSHAgentParams {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        write!(f, r#"SSHAgent("{}")"#, self.id)
280    }
281}
282
283impl Into<CBOR> for SSHAgentParams {
284    fn into(self) -> CBOR {
285        vec![CBOR::from(Self::INDEX), self.salt.into(), self.id.into()].into()
286    }
287}
288
289impl TryFrom<CBOR> for SSHAgentParams {
290    type Error = dcbor::Error;
291
292    fn try_from(cbor: CBOR) -> dcbor::Result<Self> {
293        let a = cbor.try_into_array()?;
294        a.len()
295            .eq(&3)
296            .then_some(())
297            .ok_or_else(|| dcbor::Error::msg("Invalid SSHAgentParams"))?;
298        let mut iter = a.into_iter();
299        let _index: usize = iter
300            .next()
301            .ok_or_else(|| dcbor::Error::msg("Missing index"))?
302            .try_into()?;
303        let salt: Salt = iter
304            .next()
305            .ok_or_else(|| dcbor::Error::msg("Missing salt"))?
306            .try_into()?;
307        let id: String = iter
308            .next()
309            .ok_or_else(|| dcbor::Error::msg("Missing id"))?
310            .try_into()?;
311        Ok(SSHAgentParams { salt, id, agent: None })
312    }
313}
314
315#[cfg(test)]
316mod tests_common {
317    use std::{cell::RefCell, rc::Rc};
318
319    use dcbor::prelude::*;
320
321    use crate::{
322        EncryptedKey, KeyDerivation, KeyDerivationParams, SALT_LEN, SSHAgent,
323        SSHAgentParams, Salt,
324    };
325
326    pub fn test_id() -> String { "your_email@example.com".to_string() }
327
328    pub fn test_ssh_agent_params(agent: Rc<RefCell<dyn SSHAgent>>) {
329        // Create SSHAgentParams with the agent.
330        let params = SSHAgentParams::new_opt(
331            Salt::new_with_len(SALT_LEN).unwrap(),
332            "",
333            Some(agent.clone()),
334        );
335
336        // Create a content key to encrypt.
337        let content_key = crate::SymmetricKey::new();
338
339        // Empty: use the first identity in the agent.
340        let secret = b"";
341
342        // Lock the content key with the SSH agent parameters.
343        let encrypted_key = EncryptedKey::lock_opt(
344            KeyDerivationParams::SSHAgent(params),
345            secret,
346            &content_key,
347        )
348        .expect("Lock content key with SSH agent params");
349
350        // Serialize the encrypted key to CBOR.
351        let cbor_data = encrypted_key.to_cbor_data();
352
353        // Deserialize the CBOR data.
354        let cbor = CBOR::try_from_data(cbor_data)
355            .expect("Convert encrypted key to CBOR");
356
357        // Convert the CBOR back to an EncryptedKey.
358        let encrypted_key_2 = EncryptedKey::try_from_cbor(&cbor)
359            .expect("Convert CBOR to EncryptedKey");
360
361        // Extract the SSH agent parameters from the AAD CBOR.
362        let aad_cbor = encrypted_key_2
363            .aad_cbor()
364            .expect("Get AAD CBOR from EncryptedKey");
365        let mut params_2 = SSHAgentParams::try_from(aad_cbor)
366            .expect("Convert AAD CBOR to SSHAgentParams");
367
368        // Set the mock agent in the parameters.
369        params_2.set_agent(Some(agent.clone()));
370
371        // Unlock the content key using the SSH agent parameters.
372        let decrypted_content_key =
373            params_2.unlock(encrypted_key.encrypted_message(), secret);
374
375        // Assert that the decrypted key matches the original content key.
376        assert_eq!(
377            content_key,
378            decrypted_content_key
379                .expect("Unlock content key with SSH agent params")
380        );
381    }
382}
383
384#[cfg(test)]
385mod mock_agent_tests {
386    use std::{cell::RefCell, collections::HashMap, rc::Rc};
387
388    use anyhow::Result;
389
390    use super::tests_common::{test_id, test_ssh_agent_params};
391    use crate::SSHAgent;
392
393    struct MockSSHAgent {
394        identities: HashMap<String, ssh_key::PrivateKey>,
395    }
396
397    impl MockSSHAgent {
398        fn new() -> Self { Self { identities: HashMap::new() } }
399
400        fn add_identity(&mut self, key: ssh_key::PrivateKey) {
401            self.identities.insert(key.comment().to_string(), key);
402        }
403    }
404
405    impl SSHAgent for MockSSHAgent {
406        fn list_identities(&mut self) -> Result<Vec<ssh_key::PublicKey>> {
407            Ok(self
408                .identities
409                .values()
410                .map(|k| k.public_key().clone())
411                .collect())
412        }
413
414        fn add_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
415            self.add_identity(key.clone());
416            Ok(())
417        }
418
419        fn remove_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
420            self.identities.remove(key.comment());
421            Ok(())
422        }
423
424        fn remove_all_identities(&mut self) -> Result<()> {
425            self.identities.clear();
426            Ok(())
427        }
428
429        fn sign(
430            &mut self,
431            key: &ssh_key::PublicKey,
432            data: &[u8],
433        ) -> Result<ssh_key::Signature> {
434            // println!("Signing public key: {:?}", &key);
435            // println!("Data: {:?}", hex::encode(data));
436            let private_key = self
437                .identities
438                .get(key.comment())
439                .ok_or_else(|| anyhow::anyhow!("Identity not found"))?;
440            // println!("Signing Private key: {:?}", private_key);
441            let sig: ssh_key::SshSig = private_key
442                .sign("test_namespace", ssh_key::HashAlg::Sha256, data)
443                .map_err(|e| anyhow::anyhow!("Failed to sign data: {}", e))?;
444            // println!("Signature: {:?}", sig.signature());
445            Ok(sig.signature().clone())
446        }
447    }
448
449    fn mock_agent() -> Rc<RefCell<dyn SSHAgent>> {
450        let mut agent = MockSSHAgent::new();
451        let mut rng = bc_rand::SecureRandomNumberGenerator;
452        let keypair: ssh_key::private::Ed25519Keypair =
453            ssh_key::private::Ed25519Keypair::random(&mut rng);
454        let private_key =
455            ssh_key::PrivateKey::new(keypair.into(), test_id()).unwrap();
456        agent.add_identity(private_key);
457        Rc::new(RefCell::new(agent))
458    }
459
460    #[test]
461    fn test_mock_agent() {
462        let agent = mock_agent();
463        let identities = agent.borrow_mut().list_identities().unwrap();
464        assert!(!identities.is_empty(), "No identities found in SSH agent");
465
466        let first_identity = &identities[0];
467        assert_eq!(first_identity.comment(), test_id());
468        let data = b"test data";
469        let signature1 = agent.borrow_mut().sign(first_identity, data).unwrap();
470        let signature2 = agent.borrow_mut().sign(first_identity, data).unwrap();
471        assert_eq!(
472            signature1, signature2,
473            "Signatures should match for the same data"
474        );
475    }
476
477    #[test]
478    fn test_ssh_agent_params_with_mock_agent() {
479        // Create a mock SSH agent.
480        let agent = mock_agent();
481
482        // Test the SSHAgentParams with the mock agent.
483        test_ssh_agent_params(agent);
484    }
485}
486
487/// For these tests to run correctly, you need to have a real SSH agent running
488/// and have at least one Ed25519 identity added to it with
489/// `your_email@example.com` as the identity comment.
490///
491/// To run these tests, use the following command:
492/// ```bash
493/// cargo test real_agent_tests --features ssh_agent_tests
494/// ```
495///
496/// Your `SSH_AUTH_SOCK` environment variable must be set to the socket
497/// the SSH agent is listening on. This is usually set automatically when you
498/// start your SSH agent, but you can check it with:
499/// ```bash
500/// echo $SSH_AUTH_SOCK
501/// ```
502///
503/// To list the keys in your SSH agent, you can use:
504/// ```bash
505/// ssh-add -l
506/// ```
507///
508/// To generate a new Ed25519 key and add it to your SSH agent as a test
509/// identity, you can use:
510/// ```bash
511/// ssh-keygen -t ed25519 -C "your_email@example.com" -f <your_key_file>
512/// ssh-add <your_key_file>
513/// ```
514#[cfg(test)]
515#[cfg(feature = "ssh_agent_tests")]
516mod real_agent_tests {
517    use dcbor::prelude::*;
518
519    use super::tests_common::{test_id, test_ssh_agent_params};
520    use crate::{
521        EncryptedKey, KeyDerivationMethod, SymmetricKey, connect_to_ssh_agent,
522    };
523
524    pub fn test_content_key() -> SymmetricKey { SymmetricKey::new() }
525
526    #[test]
527    fn test_ssh_agent_params_with_real_agent() {
528        // Connect to the real SSH agent.
529        let agent = connect_to_ssh_agent().expect("Connect to SSH agent");
530
531        // Test the SSHAgentParams with the real agent.
532        test_ssh_agent_params(agent);
533    }
534
535    #[test]
536    fn test_encrypted_key_ssh_agent_roundtrip() {
537        let id = test_id();
538        let content_key = test_content_key();
539
540        let encrypted_key = EncryptedKey::lock(
541            KeyDerivationMethod::SSHAgent,
542            id.clone(),
543            &content_key,
544        )
545        .unwrap();
546        let expected = format!(r#"EncryptedKey(SSHAgent("{}"))"#, id);
547        assert_eq!(format!("{}", encrypted_key), expected);
548        let cbor = encrypted_key.clone().to_cbor();
549        let argon2id2 = EncryptedKey::try_from(cbor).unwrap();
550        let decrypted = EncryptedKey::unlock(&argon2id2, id).unwrap();
551
552        assert_eq!(content_key, decrypted);
553    }
554
555    #[test]
556    fn test_encrypted_key_ssh_agent_wrong_secret_fails() {
557        let secret = test_id();
558        let content_key = test_content_key();
559        let encrypted = EncryptedKey::lock(
560            KeyDerivationMethod::SSHAgent,
561            secret,
562            &content_key,
563        )
564        .unwrap();
565        let wrong_secret = b"wrong secret";
566        let result = EncryptedKey::unlock(&encrypted, wrong_secret);
567        assert!(result.is_err(), "Unlock should fail with wrong secret");
568    }
569
570    #[test]
571    fn test_ssh_agent_lock_fails_with_nonexistent_identity() {
572        let secret = b"nonexistent_identity";
573        let content_key = test_content_key();
574        let encrypted = EncryptedKey::lock(
575            KeyDerivationMethod::SSHAgent,
576            secret,
577            &content_key,
578        );
579        assert!(
580            encrypted.is_err(),
581            "Lock should fail with nonexistent identity"
582        );
583    }
584}