bc_components/encrypted_key/
ssh_agent_params.rs

1use std::{cell::RefCell, env, path::Path, rc::Rc};
2
3use bc_crypto::hkdf_hmac_sha256;
4use dcbor::prelude::*;
5use ssh_agent_client_rs::{Client, Identity};
6
7use super::{KeyDerivation, KeyDerivationMethod, SALT_LEN};
8use crate::{EncryptedMessage, Error, Nonce, Result, Salt, SymmetricKey};
9
10#[allow(dead_code)]
11pub trait SSHAgent {
12    fn list_identities(&mut self) -> Result<Vec<ssh_key::PublicKey>>;
13    fn add_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()>;
14    fn remove_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()>;
15    fn remove_all_identities(&mut self) -> Result<()>;
16    fn sign(
17        &mut self,
18        key: &ssh_key::PublicKey,
19        data: &[u8],
20    ) -> Result<ssh_key::Signature>;
21}
22
23impl SSHAgent for Client {
24    fn list_identities(&mut self) -> Result<Vec<ssh_key::PublicKey>> {
25        self.list_all_identities()
26            .map(|identities| {
27                identities
28                    .into_iter()
29                    .filter_map(|i| match i {
30                        Identity::PublicKey(pk) => Some(pk.into_owned()),
31                        _ => None,
32                    })
33                    .collect()
34            })
35            .map_err(|e| Error::ssh_agent(e.to_string()))
36    }
37
38    fn add_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
39        self.add_identity(key)
40            .map_err(|e| Error::ssh_agent(e.to_string()))
41    }
42
43    fn remove_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
44        self.remove_identity(key)
45            .map_err(|e| Error::ssh_agent(e.to_string()))
46    }
47
48    fn remove_all_identities(&mut self) -> Result<()> {
49        self.remove_all_identities()
50            .map_err(|e| Error::ssh_agent(e.to_string()))
51    }
52
53    fn sign(
54        &mut self,
55        key: &ssh_key::PublicKey,
56        data: &[u8],
57    ) -> Result<ssh_key::Signature> {
58        self.sign(key, data)
59            .map_err(|e| Error::ssh_agent(e.to_string()))
60    }
61}
62
63/// Struct representing SSH Agent parameters.
64///
65/// CDDL:
66/// ```cddl
67/// SSHAgentParams = [4, Salt, id: tstr]
68/// ```
69#[derive(Clone)]
70pub struct SSHAgentParams {
71    salt: Salt,
72    id: String,
73
74    agent: Option<Rc<RefCell<dyn SSHAgent + 'static>>>,
75}
76
77impl PartialEq for SSHAgentParams {
78    fn eq(&self, other: &Self) -> bool {
79        self.salt == other.salt && self.id == other.id
80    }
81}
82
83impl Eq for SSHAgentParams {}
84
85impl std::fmt::Debug for SSHAgentParams {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("SSHAgentParams")
88            .field("salt", &self.salt)
89            .field("id", &self.id)
90            .finish()
91    }
92}
93
94impl SSHAgentParams {
95    pub fn new() -> Self {
96        Self::new_opt(
97            Salt::new_with_len(SALT_LEN).unwrap(),
98            String::new(),
99            None,
100        )
101    }
102
103    pub fn new_opt(
104        salt: Salt,
105        id: impl AsRef<str>,
106        agent: Option<Rc<RefCell<dyn SSHAgent + 'static>>>,
107    ) -> Self {
108        Self { salt, id: id.as_ref().to_string(), agent }
109    }
110
111    pub fn salt(&self) -> &Salt { &self.salt }
112
113    pub fn id(&self) -> &String { &self.id }
114
115    pub fn agent(&self) -> Option<Rc<RefCell<dyn SSHAgent + 'static>>> {
116        self.agent.clone()
117    }
118
119    pub fn set_agent(
120        &mut self,
121        agent: Option<Rc<RefCell<dyn SSHAgent + 'static>>>,
122    ) {
123        self.agent = agent;
124    }
125}
126
127impl Default for SSHAgentParams {
128    fn default() -> Self { Self::new() }
129}
130
131/// Connect to whatever socket/pipe `$SSH_AUTH_SOCK` points at.
132pub fn connect_to_ssh_agent() -> Result<Rc<RefCell<dyn SSHAgent + 'static>>> {
133    let sock = env::var("SSH_AUTH_SOCK")
134        .map_err(|_| Error::ssh_agent("SSH_AUTH_SOCK env var not set"))?;
135    let client = Client::connect(Path::new(&sock))
136        .map_err(|_| Error::ssh_agent("no ssh-agent reachable"))?;
137    Ok(Rc::new(RefCell::new(client)))
138}
139
140impl KeyDerivation for SSHAgentParams {
141    const INDEX: usize = KeyDerivationMethod::SSHAgent as usize;
142
143    fn lock(
144        &mut self,
145        content_key: &SymmetricKey,
146        secret: impl AsRef<[u8]>,
147    ) -> Result<EncryptedMessage> {
148        // Convert `secret` to a string for the SSH ID.
149        let id = String::from_utf8(secret.as_ref().to_vec()).map_err(|_| {
150            Error::ssh_agent("SSH Agent secret must be a valid UTF-8 string")
151        })?;
152
153        // If None call connect_to_agent to get the agent.
154        let agent = self
155            .agent
156            .as_ref()
157            .map_or_else(|| connect_to_ssh_agent(), |a| Ok(a.clone()))?;
158
159        // List all identities in the SSH agent.
160        let ids = agent.borrow_mut().list_identities()?;
161
162        // Filter down to the identities that have Ed25519 keys.
163        let ids: Vec<_> = ids
164            .into_iter()
165            .filter(|k| k.key_data().ed25519().is_some())
166            .collect();
167
168        if ids.is_empty() {
169            return Err(Error::ssh_agent(
170                "No Ed25519 identities available in SSH agent",
171            ));
172        }
173
174        // If `id` is empty, use the first available identity, otherwise find
175        // the one matching `id`.
176        let identity = if id.is_empty() {
177            // If there is more than one identity, throw an error.
178            if ids.len() > 1 {
179                return Err(Error::ssh_agent(
180                    "Multiple identities available in SSH agent, but no ID provided",
181                ));
182            }
183            // Safe to unwrap because we checked that `ids` is not empty
184            ids.first().unwrap()
185        } else {
186            ids.iter()
187                .find(|k| k.comment() == id)
188                .ok_or_else(|| Error::ssh_agent("No matching identity found"))?
189        };
190
191        // Sign the salt with the identity.
192        let salt = self.salt().clone();
193        let sig = agent
194            .borrow_mut()
195            .sign(identity, salt.as_bytes())
196            .map_err(|_| Error::ssh_agent("SSH agent refused to sign"))?;
197
198        // Derive the symmetric key using HKDF with HMAC-SHA256.
199        let derived_key = SymmetricKey::from_data_ref(hkdf_hmac_sha256(
200            &sig,
201            &salt,
202            SymmetricKey::SYMMETRIC_KEY_SIZE,
203        ))
204        .unwrap(); // Safe to unwrap because SYMMETRIC_KEY_SIZE is valid.
205
206        // Set the ID in the parameters.
207        self.id = id;
208
209        // Encode the method as CBOR data.
210        let encoded_method = self.to_cbor_data();
211
212        // Encrypt the content key with the derived key, using the
213        // encoded method as additional authenticated data.
214        Ok(derived_key.encrypt(
215            content_key,
216            Some(encoded_method),
217            Option::<Nonce>::None,
218        ))
219    }
220
221    fn unlock(
222        &self,
223        encrypted_message: &EncryptedMessage,
224        secret: impl AsRef<[u8]>,
225    ) -> Result<SymmetricKey> {
226        // Convert `secret` to a string for the SSH ID.
227        let id = String::from_utf8(secret.as_ref().to_vec()).map_err(|_| {
228            Error::ssh_agent("SSH Agent secret must be a valid UTF-8 string")
229        })?;
230
231        // If None call connect_to_agent to get the agent.
232        let agent = self
233            .agent
234            .as_ref()
235            .map_or_else(|| connect_to_ssh_agent(), |a| Ok(a.clone()))?;
236
237        // List all identities in the SSH agent.
238        let ids = agent.borrow_mut().list_identities()?;
239
240        // Filter down to the identities that have Ed25519 keys.
241        let ids: Vec<_> = ids
242            .into_iter()
243            .filter(|k| k.key_data().ed25519().is_some())
244            .collect();
245
246        if ids.is_empty() {
247            return Err(Error::ssh_agent(
248                "No Ed25519 identities available in SSH agent",
249            ));
250        }
251
252        // id priority:
253        // 1. `id` passed in as secret if not empty,
254        // 2. `self.id` if not empty,
255        // 3. first available identity.
256        let identity = if !id.is_empty() {
257            ids.iter()
258                .find(|k| k.comment() == id)
259                .ok_or_else(|| Error::ssh_agent("No matching identity found"))?
260        } else if !self.id.is_empty() {
261            ids.iter()
262                .find(|k| k.comment() == self.id)
263                .ok_or_else(|| Error::ssh_agent("No matching identity found"))?
264        } else {
265            // Safe to unwrap because we checked that `ids` is not empty
266            ids.first().unwrap()
267        };
268
269        // Sign the salt with the identity.
270        let sig = agent
271            .borrow_mut()
272            .sign(identity, self.salt.as_bytes())
273            .map_err(|_| Error::ssh_agent("SSH agent refused to sign"))?;
274
275        // Derive the symmetric key using HKDF with HMAC-SHA256.
276        let derived_key = SymmetricKey::from_data_ref(hkdf_hmac_sha256(
277            &sig,
278            &self.salt,
279            SymmetricKey::SYMMETRIC_KEY_SIZE,
280        ))
281        .unwrap(); // Safe to unwrap because SYMMETRIC_KEY_SIZE is valid.
282
283        // Decrypt the encrypted key with the derived key.
284        let decrypted_key =
285            derived_key.decrypt(encrypted_message).map_err(|e| {
286                Error::crypto(format!(
287                    "Failed to decrypt the encrypted key: {}",
288                    e
289                ))
290            })?;
291
292        let content_key = decrypted_key.try_into().map_err(|e| {
293            Error::crypto(format!(
294                "Failed to convert decrypted key to SymmetricKey: {}",
295                e
296            ))
297        })?;
298
299        // If the decryption was successful, return the symmetric key.
300        Ok(content_key)
301    }
302}
303
304impl std::fmt::Display for SSHAgentParams {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        write!(f, r#"SSHAgent("{}")"#, self.id)
307    }
308}
309
310impl From<SSHAgentParams> for CBOR {
311    fn from(val: SSHAgentParams) -> Self {
312        vec![
313            CBOR::from(SSHAgentParams::INDEX),
314            val.salt.into(),
315            val.id.into(),
316        ]
317        .into()
318    }
319}
320
321impl TryFrom<CBOR> for SSHAgentParams {
322    type Error = dcbor::Error;
323
324    fn try_from(cbor: CBOR) -> dcbor::Result<Self> {
325        let a = cbor.try_into_array()?;
326        a.len()
327            .eq(&3)
328            .then_some(())
329            .ok_or_else(|| dcbor::Error::msg("Invalid SSHAgentParams"))?;
330        let mut iter = a.into_iter();
331        let _index: usize = iter
332            .next()
333            .ok_or_else(|| dcbor::Error::msg("Missing index"))?
334            .try_into()?;
335        let salt: Salt = iter
336            .next()
337            .ok_or_else(|| dcbor::Error::msg("Missing salt"))?
338            .try_into()?;
339        let id: String = iter
340            .next()
341            .ok_or_else(|| dcbor::Error::msg("Missing id"))?
342            .try_into()?;
343        Ok(SSHAgentParams { salt, id, agent: None })
344    }
345}
346
347#[cfg(test)]
348mod tests_common {
349    use std::{cell::RefCell, rc::Rc};
350
351    use dcbor::prelude::*;
352
353    use crate::{
354        EncryptedKey, KeyDerivation, KeyDerivationParams, SALT_LEN, SSHAgent,
355        SSHAgentParams, Salt,
356    };
357
358    pub fn test_id() -> String { "your_email@example.com".to_string() }
359
360    pub fn test_ssh_agent_params(agent: Rc<RefCell<dyn SSHAgent>>) {
361        // Create SSHAgentParams with the agent.
362        let params = SSHAgentParams::new_opt(
363            Salt::new_with_len(SALT_LEN).unwrap(),
364            "",
365            Some(agent.clone()),
366        );
367
368        // Create a content key to encrypt.
369        let content_key = crate::SymmetricKey::new();
370
371        // Empty: use the first identity in the agent.
372        let secret = b"";
373
374        // Lock the content key with the SSH agent parameters.
375        let encrypted_key = EncryptedKey::lock_opt(
376            KeyDerivationParams::SSHAgent(params),
377            secret,
378            &content_key,
379        )
380        .expect("Lock content key with SSH agent params");
381
382        // Serialize the encrypted key to CBOR.
383        let cbor_data = encrypted_key.to_cbor_data();
384
385        // Deserialize the CBOR data.
386        let cbor = CBOR::try_from_data(cbor_data)
387            .expect("Convert encrypted key to CBOR");
388
389        // Convert the CBOR back to an EncryptedKey.
390        let encrypted_key_2 = EncryptedKey::try_from_cbor(&cbor)
391            .expect("Convert CBOR to EncryptedKey");
392
393        // Extract the SSH agent parameters from the AAD CBOR.
394        let aad_cbor = encrypted_key_2
395            .aad_cbor()
396            .expect("Get AAD CBOR from EncryptedKey");
397        let mut params_2 = SSHAgentParams::try_from(aad_cbor)
398            .expect("Convert AAD CBOR to SSHAgentParams");
399
400        // Set the mock agent in the parameters.
401        params_2.set_agent(Some(agent.clone()));
402
403        // Unlock the content key using the SSH agent parameters.
404        let decrypted_content_key =
405            params_2.unlock(encrypted_key.encrypted_message(), secret);
406
407        // Assert that the decrypted key matches the original content key.
408        assert_eq!(
409            content_key,
410            decrypted_content_key
411                .expect("Unlock content key with SSH agent params")
412        );
413    }
414}
415
416#[cfg(test)]
417mod mock_agent_tests {
418    use std::{cell::RefCell, collections::HashMap, rc::Rc};
419
420    use super::tests_common::{test_id, test_ssh_agent_params};
421    use crate::{Error, Result, SSHAgent};
422
423    struct MockSSHAgent {
424        identities: HashMap<String, ssh_key::PrivateKey>,
425    }
426
427    impl MockSSHAgent {
428        fn new() -> Self { Self { identities: HashMap::new() } }
429
430        fn add_identity(&mut self, key: ssh_key::PrivateKey) {
431            self.identities.insert(key.comment().to_string(), key);
432        }
433    }
434
435    impl SSHAgent for MockSSHAgent {
436        fn list_identities(&mut self) -> Result<Vec<ssh_key::PublicKey>> {
437            Ok(self
438                .identities
439                .values()
440                .map(|k| k.public_key().clone())
441                .collect())
442        }
443
444        fn add_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
445            self.add_identity(key.clone());
446            Ok(())
447        }
448
449        fn remove_identity(&mut self, key: &ssh_key::PrivateKey) -> Result<()> {
450            self.identities.remove(key.comment());
451            Ok(())
452        }
453
454        fn remove_all_identities(&mut self) -> Result<()> {
455            self.identities.clear();
456            Ok(())
457        }
458
459        fn sign(
460            &mut self,
461            key: &ssh_key::PublicKey,
462            data: &[u8],
463        ) -> Result<ssh_key::Signature> {
464            // println!("Signing public key: {:?}", &key);
465            // println!("Data: {:?}", hex::encode(data));
466            let private_key = self
467                .identities
468                .get(key.comment())
469                .ok_or_else(|| Error::ssh_agent("Identity not found"))?;
470            // println!("Signing Private key: {:?}", private_key);
471            let sig: ssh_key::SshSig = private_key
472                .sign("test_namespace", ssh_key::HashAlg::Sha256, data)
473                .map_err(|e| {
474                    Error::ssh_agent(format!("Failed to sign data: {}", e))
475                })?;
476            // println!("Signature: {:?}", sig.signature());
477            Ok(sig.signature().clone())
478        }
479    }
480
481    fn mock_agent() -> Rc<RefCell<dyn SSHAgent>> {
482        let mut agent = MockSSHAgent::new();
483        let mut rng = bc_rand::SecureRandomNumberGenerator;
484        let keypair: ssh_key::private::Ed25519Keypair =
485            ssh_key::private::Ed25519Keypair::random(&mut rng);
486        let private_key =
487            ssh_key::PrivateKey::new(keypair.into(), test_id()).unwrap();
488        agent.add_identity(private_key);
489        Rc::new(RefCell::new(agent))
490    }
491
492    #[test]
493    fn test_mock_agent() {
494        let agent = mock_agent();
495        let identities = agent.borrow_mut().list_identities().unwrap();
496        assert!(!identities.is_empty(), "No identities found in SSH agent");
497
498        let first_identity = &identities[0];
499        assert_eq!(first_identity.comment(), test_id());
500        let data = b"test data";
501        let signature1 = agent.borrow_mut().sign(first_identity, data).unwrap();
502        let signature2 = agent.borrow_mut().sign(first_identity, data).unwrap();
503        assert_eq!(
504            signature1, signature2,
505            "Signatures should match for the same data"
506        );
507    }
508
509    #[test]
510    fn test_ssh_agent_params_with_mock_agent() {
511        // Create a mock SSH agent.
512        let agent = mock_agent();
513
514        // Test the SSHAgentParams with the mock agent.
515        test_ssh_agent_params(agent);
516    }
517}
518
519/// For these tests to run correctly, you need to have a real SSH agent running
520/// and have at least one Ed25519 identity added to it with
521/// `your_email@example.com` as the identity comment.
522///
523/// To run these tests, use the following command:
524/// ```bash
525/// cargo test real_agent_tests --features ssh_agent_tests
526/// ```
527///
528/// Your `SSH_AUTH_SOCK` environment variable must be set to the socket
529/// the SSH agent is listening on. This is usually set automatically when you
530/// start your SSH agent, but you can check it with:
531/// ```bash
532/// echo $SSH_AUTH_SOCK
533/// ```
534///
535/// To list the keys in your SSH agent, you can use:
536/// ```bash
537/// ssh-add -l
538/// ```
539///
540/// To generate a new Ed25519 key and add it to your SSH agent as a test
541/// identity, you can use:
542/// ```bash
543/// ssh-keygen -t ed25519 -C "your_email@example.com" -f <your_key_file>
544/// ssh-add <your_key_file>
545/// ```
546#[cfg(test)]
547#[cfg(feature = "ssh_agent_tests")]
548mod real_agent_tests {
549    use dcbor::prelude::*;
550
551    use super::tests_common::{test_id, test_ssh_agent_params};
552    use crate::{
553        EncryptedKey, KeyDerivationMethod, SymmetricKey, connect_to_ssh_agent,
554    };
555
556    pub fn test_content_key() -> SymmetricKey { SymmetricKey::new() }
557
558    #[test]
559    fn test_ssh_agent_params_with_real_agent() {
560        // Connect to the real SSH agent.
561        let agent = connect_to_ssh_agent().expect("Connect to SSH agent");
562
563        // Test the SSHAgentParams with the real agent.
564        test_ssh_agent_params(agent);
565    }
566
567    #[test]
568    fn test_encrypted_key_ssh_agent_roundtrip() {
569        let id = test_id();
570        let content_key = test_content_key();
571
572        let encrypted_key = EncryptedKey::lock(
573            KeyDerivationMethod::SSHAgent,
574            id.clone(),
575            &content_key,
576        )
577        .unwrap();
578        let expected = format!(r#"EncryptedKey(SSHAgent("{}"))"#, id);
579        assert_eq!(format!("{}", encrypted_key), expected);
580        let cbor = encrypted_key.clone().to_cbor();
581        let argon2id2 = EncryptedKey::try_from(cbor).unwrap();
582        let decrypted = EncryptedKey::unlock(&argon2id2, id).unwrap();
583
584        assert_eq!(content_key, decrypted);
585    }
586
587    #[test]
588    fn test_encrypted_key_ssh_agent_wrong_secret_fails() {
589        let secret = test_id();
590        let content_key = test_content_key();
591        let encrypted = EncryptedKey::lock(
592            KeyDerivationMethod::SSHAgent,
593            secret,
594            &content_key,
595        )
596        .unwrap();
597        let wrong_secret = b"wrong secret";
598        let result = EncryptedKey::unlock(&encrypted, wrong_secret);
599        assert!(result.is_err(), "Unlock should fail with wrong secret");
600    }
601
602    #[test]
603    fn test_ssh_agent_lock_fails_with_nonexistent_identity() {
604        let secret = b"nonexistent_identity";
605        let content_key = test_content_key();
606        let encrypted = EncryptedKey::lock(
607            KeyDerivationMethod::SSHAgent,
608            secret,
609            &content_key,
610        );
611        assert!(
612            encrypted.is_err(),
613            "Lock should fail with nonexistent identity"
614        );
615    }
616}