apple_codesign/remote_signing/
session_negotiation.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5//! Session establishment and crypto code for remote signing protocol.
6//!
7//! The intent of this module / file is to isolate the code with the highest
8//! sensitivity for security matters.
9
10use {
11    crate::remote_signing::RemoteSignError,
12    base64::Engine,
13    der::{Decode, Encode},
14    minicbor::{encode::Write, Decode as CborDecode, Decoder, Encode as CborEncode, Encoder},
15    oid_registry::OID_PKCS1_RSAENCRYPTION,
16    pkcs1::RsaPublicKey as RsaPublicKeyAsn1,
17    ring::{
18        aead::{
19            Aad, BoundKey, Nonce, NonceSequence, OpeningKey, SealingKey, UnboundKey, AES_128_GCM,
20            CHACHA20_POLY1305, NONCE_LEN,
21        },
22        agreement::{agree_ephemeral, EphemeralPrivateKey, UnparsedPublicKey, X25519},
23        hkdf::{Salt, HKDF_SHA256},
24        rand::{SecureRandom, SystemRandom},
25    },
26    rsa::{BigUint, Oaep, RsaPublicKey},
27    scroll::{Pwrite, LE},
28    spake2::{Ed25519Group, Identity, Password, Spake2},
29    spki::SubjectPublicKeyInfoRef,
30    std::fmt::{Display, Formatter},
31};
32
33type Result<T> = std::result::Result<T, RemoteSignError>;
34
35fn base64_engine() -> impl Engine {
36    base64::engine::general_purpose::URL_SAFE_NO_PAD
37}
38
39/// A generator of nonces that is a simple incrementing counter.
40///
41/// Assumed use with ChaCha20+Poly1305.
42#[derive(Default)]
43struct RemoteSigningNonceSequence {
44    id: u32,
45}
46
47impl NonceSequence for RemoteSigningNonceSequence {
48    fn advance(&mut self) -> ::std::result::Result<Nonce, ring::error::Unspecified> {
49        let mut data = [0u8; NONCE_LEN];
50        data.pwrite_with(self.id, 0, LE)
51            .map_err(|_| ring::error::Unspecified)?;
52
53        self.id += 1;
54
55        Ok(Nonce::assume_unique_for_key(data))
56    }
57}
58
59/// A nonce sequence that emits a constant value exactly once.
60#[derive(Default)]
61struct ConstantNonceSequence {
62    used: bool,
63}
64
65impl NonceSequence for ConstantNonceSequence {
66    fn advance(&mut self) -> ::std::result::Result<Nonce, ring::error::Unspecified> {
67        if self.used {
68            return Err(ring::error::Unspecified);
69        }
70
71        self.used = true;
72
73        Ok(Nonce::assume_unique_for_key([0x42; NONCE_LEN]))
74    }
75}
76
77/// The role being assumed by a peer.
78#[derive(Clone, Copy, Debug)]
79pub enum Role {
80    /// Peer who initiated the session.
81    A,
82    /// Peer who joined the session.
83    B,
84}
85
86impl Display for Role {
87    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88        f.write_str(match self {
89            Self::A => "A",
90            Self::B => "B",
91        })
92    }
93}
94
95/// Derives the identifier / info value used for HKDF expansion.
96fn derive_hkdf_info(role: Role, session_id: &str, extra_identifier: &[u8]) -> Vec<u8> {
97    role.to_string()
98        .as_bytes()
99        .iter()
100        .chain(std::iter::once(&b':'))
101        .chain(session_id.as_bytes().iter())
102        .chain(std::iter::once(&b':'))
103        .chain(extra_identifier.iter())
104        .copied()
105        .collect::<Vec<_>>()
106}
107
108pub struct PeerKeys {
109    sealing: SealingKey<RemoteSigningNonceSequence>,
110    opening: OpeningKey<RemoteSigningNonceSequence>,
111}
112
113impl PeerKeys {
114    /// Encrypt / seal a plaintext message using AEAD.
115    ///
116    /// Receives the plaintext message to encrypt.
117    ///
118    /// Returns the encrypted ciphertext.
119    pub fn seal(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
120        let mut output = plaintext.to_vec();
121        self.sealing
122            .seal_in_place_append_tag(Aad::empty(), &mut output)
123            .map_err(|_| RemoteSignError::Crypto("AEAD sealing error".into()))?;
124
125        Ok(output)
126    }
127
128    /// Decrypt / open a ciphertext using AEAD.
129    ///
130    /// Receives the ciphertext message to decrypt.
131    ///
132    /// Returns the decrypted and verified plaintext.
133    pub fn open(&mut self, mut ciphertext: Vec<u8>) -> Result<Vec<u8>> {
134        let plaintext = self
135            .opening
136            .open_in_place(Aad::empty(), &mut ciphertext)
137            .map_err(|_| RemoteSignError::Crypto("failed to decrypt message".into()))?;
138
139        Ok(plaintext.to_vec())
140    }
141}
142
143/// Derives a pair of AEAD keys from a shared encryption key.
144///
145/// Returns a pair of keys. One key is used for sealing / encrypting and the
146/// other for opening / decrypting.
147///
148/// `role` is the role that the current peer is playing. The session initiator
149/// generally uses `A` and the joiner / signer uses `B`.
150///
151/// `shared_key` is a private key that is mutually derived and identical on both
152/// peers. The mechanism for obtaining it varies.
153///
154/// `session_id` is the server-registered session identifier.
155///
156/// `extra_identifier` is an extra value to use when constructing identities for
157/// HKDF extraction.
158fn derive_aead_keys(
159    role: Role,
160    shared_key: Vec<u8>,
161    session_id: &str,
162    extra_identifier: &[u8],
163) -> Result<(
164    SealingKey<RemoteSigningNonceSequence>,
165    OpeningKey<RemoteSigningNonceSequence>,
166)> {
167    let salt = Salt::new(HKDF_SHA256, &[]);
168    let prk = salt.extract(&shared_key);
169
170    let a_identifier = derive_hkdf_info(Role::A, session_id, extra_identifier);
171    let b_identifier = derive_hkdf_info(Role::B, session_id, extra_identifier);
172
173    let a_info = [a_identifier.as_ref()];
174    let b_info = [b_identifier.as_ref()];
175
176    let a_key = prk
177        .expand(&a_info, &CHACHA20_POLY1305)
178        .map_err(|_| RemoteSignError::Crypto("error performing HKDF key derivation".into()))?;
179
180    let b_key = prk
181        .expand(&b_info, &CHACHA20_POLY1305)
182        .map_err(|_| RemoteSignError::Crypto("error performing HKDF key derivation".into()))?;
183
184    let (sealing_key, opening_key) = match role {
185        Role::A => (a_key, b_key),
186        Role::B => (b_key, a_key),
187    };
188
189    let sealing_key = SealingKey::new(sealing_key.into(), RemoteSigningNonceSequence::default());
190    let opening_key = OpeningKey::new(opening_key.into(), RemoteSigningNonceSequence::default());
191
192    Ok((sealing_key, opening_key))
193}
194
195fn encode_sjs(
196    scheme: &str,
197    payload: impl CborEncode<()>,
198) -> ::std::result::Result<Vec<u8>, minicbor::encode::Error<std::convert::Infallible>> {
199    let mut encoder = Encoder::new(Vec::<u8>::new());
200
201    {
202        let encoder = encoder.array(2)?;
203        encoder.str(scheme)?;
204        payload.encode(encoder, &mut ())?;
205        encoder.end()?;
206    }
207
208    Ok(encoder.into_writer())
209}
210
211/// Common behaviors for a session join string.
212///
213/// Implementations must also implement [Encode], which will emit the CBOR
214/// encoding of the instance to an encoder.
215pub trait SessionJoinString<'de>: CborDecode<'de, ()> + CborEncode<()> {
216    /// The scheme / name for this SJS implementation.
217    ///
218    /// This is advertised as the first component in the encoded SJS.
219    fn scheme() -> &'static str;
220
221    /// Obtain the raw bytes constituting the session join string.
222    fn to_bytes(&self) -> Result<Vec<u8>> {
223        encode_sjs(Self::scheme(), self)
224            .map_err(|e| RemoteSignError::SessionJoinString(format!("CBOR encoding error: {e}")))
225    }
226}
227
228struct PublicKeySessionJoinString {
229    aes_ciphertext: Vec<u8>,
230    public_key: Vec<u8>,
231    message_ciphertext: Vec<u8>,
232}
233
234impl<'de, C> CborDecode<'de, C> for PublicKeySessionJoinString {
235    fn decode(
236        d: &mut Decoder<'de>,
237        _ctx: &mut C,
238    ) -> std::result::Result<Self, minicbor::decode::Error> {
239        if !matches!(d.array()?, Some(3)) {
240            return Err(minicbor::decode::Error::message(
241                "not an array of 3 elements",
242            ));
243        }
244
245        let aes_ciphertext = d.bytes()?.to_vec();
246        let public_key = d.bytes()?.to_vec();
247        let message_ciphertext = d.bytes()?.to_vec();
248
249        Ok(Self {
250            aes_ciphertext,
251            public_key,
252            message_ciphertext,
253        })
254    }
255}
256
257impl<C> CborEncode<C> for PublicKeySessionJoinString {
258    fn encode<W: Write>(
259        &self,
260        e: &mut Encoder<W>,
261        _ctx: &mut C,
262    ) -> ::std::result::Result<(), minicbor::encode::Error<W::Error>> {
263        e.array(3)?;
264        e.bytes(&self.aes_ciphertext)?;
265        e.bytes(&self.public_key)?;
266        e.bytes(&self.message_ciphertext)?;
267        e.end()?;
268
269        Ok(())
270    }
271}
272
273impl SessionJoinString<'static> for PublicKeySessionJoinString {
274    fn scheme() -> &'static str {
275        "publickey0"
276    }
277}
278
279struct SharedSecretSessionJoinString {
280    session_id: String,
281    extra_identifier: Vec<u8>,
282    role_a_init_message: Vec<u8>,
283}
284
285impl<'de, C> CborDecode<'de, C> for SharedSecretSessionJoinString {
286    fn decode(
287        d: &mut Decoder<'de>,
288        _ctx: &mut C,
289    ) -> std::result::Result<Self, minicbor::decode::Error> {
290        if !matches!(d.array()?, Some(3)) {
291            return Err(minicbor::decode::Error::message(
292                "not an array of 3 elements",
293            ));
294        }
295
296        let session_id = d.str()?.to_string();
297        let extra_identifier = d.bytes()?.to_vec();
298        let role_a_init_message = d.bytes()?.to_vec();
299
300        Ok(Self {
301            session_id,
302            extra_identifier,
303            role_a_init_message,
304        })
305    }
306}
307
308impl<C> CborEncode<C> for SharedSecretSessionJoinString {
309    fn encode<W: Write>(
310        &self,
311        e: &mut Encoder<W>,
312        _ctx: &mut C,
313    ) -> ::std::result::Result<(), minicbor::encode::Error<W::Error>> {
314        e.array(3)?;
315        e.str(&self.session_id)?;
316        e.bytes(&self.extra_identifier)?;
317        e.bytes(&self.role_a_init_message)?;
318        e.end()?;
319
320        Ok(())
321    }
322}
323
324impl SessionJoinString<'static> for SharedSecretSessionJoinString {
325    fn scheme() -> &'static str {
326        "sharedsecret0"
327    }
328}
329
330/// A peer that initiates a remote signing session.
331pub trait SessionInitiatePeer {
332    /// Obtain the session ID to create / use.
333    fn session_id(&self) -> &str;
334
335    /// Obtain additional session context to store with the server.
336    ///
337    /// This context will be sent to the peer when it joins.
338    fn session_create_context(&self) -> Option<Vec<u8>>;
339
340    /// Obtain the raw bytes constituting the session join string.
341    fn session_join_string_bytes(&self) -> Result<Vec<u8>>;
342
343    /// Obtain the base 64 encoded session join string.
344    fn session_join_string_base64(&self) -> Result<String> {
345        Ok(base64_engine().encode(self.session_join_string_bytes()?))
346    }
347
348    /// Obtain the PEM encoded session join string.
349    fn session_join_string_pem(&self) -> Result<String> {
350        Ok(pem::encode(&pem::Pem::new(
351            "SESSION JOIN STRING",
352            self.session_join_string_bytes()?,
353        )))
354    }
355
356    /// Finalize a peer joined session using optional context provided by the peer.
357    ///
358    /// Yields encryption keys for this peer.
359    fn negotiate_session(self: Box<Self>, peer_context: Option<Vec<u8>>) -> Result<PeerKeys>;
360}
361
362pub enum SessionJoinState {
363    /// A generic shared secret value.
364    SharedSecret(Vec<u8>),
365
366    /// An entity capable of decrypting messages encrypted by the peer.
367    PublicKeyDecrypt(Box<dyn PublicKeyPeerDecrypt>),
368}
369
370/// A peer that joins sessions in a state before it has spoken to the server.
371pub trait SessionJoinPeerPreJoin {
372    /// Register additional state with the peer.
373    ///
374    /// This is used as a generic way to import implementation-specific state that
375    /// enables the peer join to complete.
376    fn register_state(&mut self, state: SessionJoinState) -> Result<()>;
377
378    /// Obtain information needed to join to a session.
379    ///
380    /// Consumes self because joining should be a one-time operation.
381    fn join_context(self: Box<Self>) -> Result<SessionJoinContext>;
382}
383
384pub trait SessionJoinPeerHandshake {
385    /// Finalize a peer joining session.
386    ///
387    /// Yields encryption keys for this peer.
388    fn negotiate_session(self: Box<Self>) -> Result<PeerKeys>;
389}
390
391/// Holds data needs to enable a joining peer to join a session.
392pub struct SessionJoinContext {
393    /// URL of server to join.
394    ///
395    /// If not set, the client default URL is used.
396    pub server_url: Option<String>,
397
398    /// The session ID to join.
399    pub session_id: String,
400
401    /// Additional data to relay to the peer to enable it to finalize the session.
402    pub peer_context: Option<Vec<u8>>,
403
404    /// Object that will finalize the peer handshake and derive encryption keys.
405    pub peer_handshake: Box<dyn SessionJoinPeerHandshake>,
406}
407
408#[derive(CborDecode, CborEncode)]
409#[cbor(array)]
410struct PublicKeySecretMessage {
411    #[n(0)]
412    server_url: Option<String>,
413
414    #[n(1)]
415    session_id: String,
416
417    #[n(2)]
418    challenge: Vec<u8>,
419
420    #[n(3)]
421    agreement_public: Vec<u8>,
422}
423
424pub struct PublicKeyInitiator {
425    session_id: String,
426    extra_identifier: Vec<u8>,
427    sjs: PublicKeySessionJoinString,
428    agreement_private: EphemeralPrivateKey,
429}
430
431impl SessionInitiatePeer for PublicKeyInitiator {
432    fn session_id(&self) -> &str {
433        &self.session_id
434    }
435
436    fn session_create_context(&self) -> Option<Vec<u8>> {
437        None
438    }
439
440    fn session_join_string_bytes(&self) -> Result<Vec<u8>> {
441        self.sjs.to_bytes()
442    }
443
444    fn negotiate_session(self: Box<Self>, peer_context: Option<Vec<u8>>) -> Result<PeerKeys> {
445        let public_key = peer_context.ok_or_else(|| {
446            RemoteSignError::Crypto(
447                "missing peer public key context in session join message".into(),
448            )
449        })?;
450
451        let public_key = UnparsedPublicKey::new(&X25519, public_key);
452
453        let (sealing, opening) =
454            agree_ephemeral(self.agreement_private, &public_key, |agreement_key| {
455                derive_aead_keys(
456                    Role::A,
457                    agreement_key.to_vec(),
458                    &self.session_id,
459                    &self.extra_identifier,
460                )
461            })
462            .map_err(|_| RemoteSignError::Crypto("error deriving agreement key".into()))?
463            .map_err(|_| {
464                RemoteSignError::Crypto("error deriving AEAD keys from agreement key".into())
465            })?;
466
467        Ok(PeerKeys { sealing, opening })
468    }
469}
470
471impl PublicKeyInitiator {
472    /// Create a new initiator using public key agreement.
473    pub fn new(peer_public_key: impl AsRef<[u8]>, server_url: Option<String>) -> Result<Self> {
474        let spki = SubjectPublicKeyInfoRef::from_der(peer_public_key.as_ref())
475            .map_err(|e| RemoteSignError::Crypto(format!("when parsing SPKI data: {e}")))?;
476
477        let session_id = uuid::Uuid::new_v4().to_string();
478
479        let rng = SystemRandom::new();
480
481        let mut challenge = [0u8; 32];
482        rng.fill(&mut challenge)
483            .map_err(|_| RemoteSignError::Crypto("failed to generate random data".into()))?;
484
485        let mut aes_key_data = [0u8; 16];
486        rng.fill(&mut aes_key_data)
487            .map_err(|_| RemoteSignError::Crypto("failed to generate random data".into()))?;
488
489        let agreement_private = EphemeralPrivateKey::generate(&X25519, &rng).map_err(|_| {
490            RemoteSignError::Crypto("failed to generate ephemeral agreement key".into())
491        })?;
492
493        let agreement_public = agreement_private.compute_public_key().map_err(|_| {
494            RemoteSignError::Crypto(
495                "failed to derive public key from ephemeral agreement key".into(),
496            )
497        })?;
498
499        let peer_message = PublicKeySecretMessage {
500            server_url,
501            session_id: session_id.clone(),
502            challenge: challenge.as_ref().to_vec(),
503            agreement_public: agreement_public.as_ref().to_vec(),
504        };
505
506        // The unique AES key is used to encrypt the main CBOR message.
507        let mut message_ciphertext = minicbor::to_vec(peer_message)
508            .map_err(|e| RemoteSignError::Crypto(format!("CBOR encode error: {e}")))?;
509        let aes_key = UnboundKey::new(&AES_128_GCM, &aes_key_data).map_err(|_| {
510            RemoteSignError::Crypto("failed to load AES encryption key into ring".into())
511        })?;
512        let mut sealing_key = SealingKey::new(aes_key, ConstantNonceSequence::default());
513        sealing_key
514            .seal_in_place_append_tag(Aad::empty(), &mut message_ciphertext)
515            .map_err(|_| RemoteSignError::Crypto("failed to AES encrypt message to peer".into()))?;
516
517        // The AES encrypting key is encrypted using asymmetric encryption.
518
519        let aes_ciphertext = match spki.algorithm.oid.as_ref() {
520            x if x == OID_PKCS1_RSAENCRYPTION.as_bytes() => {
521                let public_key = RsaPublicKeyAsn1::from_der(spki.subject_public_key.raw_bytes())
522                    .map_err(|e| {
523                        RemoteSignError::Crypto(format!("when parsing RSA public key: {e}"))
524                    })?;
525
526                let n = BigUint::from_bytes_be(public_key.modulus.as_bytes());
527                let e = BigUint::from_bytes_be(public_key.public_exponent.as_bytes());
528
529                let rsa_public = RsaPublicKey::new(n, e).map_err(|e| {
530                    RemoteSignError::Crypto(format!("when constructing RSA public key: {e}"))
531                })?;
532
533                let padding = Oaep::new::<sha2::Sha256>();
534
535                rsa_public
536                    .encrypt(&mut rand::thread_rng(), padding, &aes_key_data)
537                    .map_err(|e| {
538                        RemoteSignError::Crypto(format!("RSA public key encryption error: {e}"))
539                    })?
540            }
541            _ => {
542                return Err(RemoteSignError::Crypto(format!(
543                    "do not know how to encrypt for algorithm {}",
544                    spki.algorithm.oid
545                )));
546            }
547        };
548
549        let public_key = spki
550            .to_der()
551            .map_err(|e| RemoteSignError::Crypto(format!("when encoding SPKI to DER: {e}")))?;
552
553        let sjs = PublicKeySessionJoinString {
554            aes_ciphertext,
555            public_key,
556            message_ciphertext,
557        };
558
559        Ok(Self {
560            session_id,
561            extra_identifier: challenge.as_ref().to_vec(),
562            sjs,
563            agreement_private,
564        })
565    }
566}
567
568/// Describes a type that is capable of decrypting messages used during public key negotiation.
569pub trait PublicKeyPeerDecrypt {
570    /// Decrypt an encrypted message.
571    fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>>;
572}
573
574/// A joining peer using public key encryption.
575struct PublicKeyPeerPreJoined {
576    sjs: PublicKeySessionJoinString,
577
578    decrypter: Option<Box<dyn PublicKeyPeerDecrypt>>,
579}
580
581impl SessionJoinPeerPreJoin for PublicKeyPeerPreJoined {
582    fn register_state(&mut self, state: SessionJoinState) -> Result<()> {
583        match state {
584            SessionJoinState::PublicKeyDecrypt(decrypt) => {
585                self.decrypter = Some(decrypt);
586                Ok(())
587            }
588            SessionJoinState::SharedSecret(_) => Ok(()),
589        }
590    }
591
592    fn join_context(self: Box<Self>) -> Result<SessionJoinContext> {
593        let decrypter = self
594            .decrypter
595            .ok_or_else(|| RemoteSignError::Crypto("decryption key not registered".into()))?;
596
597        let aes_key = decrypter.decrypt(&self.sjs.aes_ciphertext)?;
598        let aes_key = UnboundKey::new(&AES_128_GCM, &aes_key).map_err(|_| {
599            RemoteSignError::Crypto("failed to construct AES key from key data".into())
600        })?;
601        let mut opening_key = OpeningKey::new(aes_key, ConstantNonceSequence::default());
602
603        let mut cbor_message = self.sjs.message_ciphertext.clone();
604        let cbor_plaintext = opening_key
605            .open_in_place(Aad::empty(), &mut cbor_message)
606            .map_err(|_| {
607                RemoteSignError::Crypto("failed to decrypt using shared AES key".into())
608            })?;
609
610        // The plaintext is a CBOR encoded message.
611        let message = minicbor::decode::<PublicKeySecretMessage>(cbor_plaintext)
612            .map_err(|e| RemoteSignError::Crypto(format!("CBOR decode error: {e}")))?;
613
614        let agreement_private = EphemeralPrivateKey::generate(&X25519, &SystemRandom::new())
615            .map_err(|_| {
616                RemoteSignError::Crypto("failed to generate ephemeral agreement key".into())
617            })?;
618        let agreement_public = agreement_private.compute_public_key().map_err(|_| {
619            RemoteSignError::Crypto(
620                "failed to derive public key from ephemeral agreement key".into(),
621            )
622        })?;
623
624        let peer_handshake = Box::new(PublicKeyHandshakePeer {
625            session_id: message.session_id.clone(),
626            extra_identifier: message.challenge,
627            agreement_private,
628            agreement_public: message.agreement_public,
629        });
630
631        Ok(SessionJoinContext {
632            server_url: message.server_url,
633            session_id: message.session_id,
634            peer_context: Some(agreement_public.as_ref().to_vec()),
635            peer_handshake,
636        })
637    }
638}
639
640impl PublicKeyPeerPreJoined {
641    fn new(sjs: PublicKeySessionJoinString) -> Result<Self> {
642        Ok(Self {
643            sjs,
644            decrypter: None,
645        })
646    }
647}
648
649pub struct PublicKeyHandshakePeer {
650    session_id: String,
651    extra_identifier: Vec<u8>,
652    agreement_private: EphemeralPrivateKey,
653    agreement_public: Vec<u8>,
654}
655
656impl SessionJoinPeerHandshake for PublicKeyHandshakePeer {
657    fn negotiate_session(self: Box<Self>) -> Result<PeerKeys> {
658        let peer_public_key = UnparsedPublicKey::new(&X25519, &self.agreement_public);
659
660        let (sealing, opening) =
661            agree_ephemeral(self.agreement_private, &peer_public_key, |agreement_key| {
662                derive_aead_keys(
663                    Role::B,
664                    agreement_key.to_vec(),
665                    &self.session_id,
666                    &self.extra_identifier,
667                )
668            })
669            .map_err(|_| RemoteSignError::Crypto("error deriving agreement key".into()))?
670            .map_err(|_| {
671                RemoteSignError::Crypto("error deriving AEAD keys from agreement key".into())
672            })?;
673
674        Ok(PeerKeys { sealing, opening })
675    }
676}
677
678fn spake_identity(role: Role, session_id: &str, extra_identifier: &[u8]) -> Identity {
679    Identity::new(&derive_hkdf_info(role, session_id, extra_identifier))
680}
681
682pub struct SharedSecretInitiator {
683    sjs: SharedSecretSessionJoinString,
684    spake: Spake2<Ed25519Group>,
685}
686
687impl SessionInitiatePeer for SharedSecretInitiator {
688    fn session_id(&self) -> &str {
689        &self.sjs.session_id
690    }
691
692    fn session_create_context(&self) -> Option<Vec<u8>> {
693        None
694    }
695
696    fn session_join_string_bytes(&self) -> Result<Vec<u8>> {
697        self.sjs.to_bytes()
698    }
699
700    fn negotiate_session(self: Box<Self>, peer_context: Option<Vec<u8>>) -> Result<PeerKeys> {
701        let spake_b = peer_context.ok_or_else(|| {
702            RemoteSignError::Crypto(
703                "missing SPAKE2 initialization context in session join message".into(),
704            )
705        })?;
706
707        let shared_key = self.spake.finish(&spake_b).map_err(|e| {
708            RemoteSignError::Crypto(format!("error finishing SPAKE2 key negotiation: {e}"))
709        })?;
710
711        let (sealing, opening) = derive_aead_keys(
712            Role::A,
713            shared_key,
714            &self.sjs.session_id,
715            &self.sjs.extra_identifier,
716        )?;
717
718        Ok(PeerKeys { sealing, opening })
719    }
720}
721
722impl SharedSecretInitiator {
723    pub fn new(shared_secret: Vec<u8>) -> Result<Self> {
724        let session_id = uuid::Uuid::new_v4().to_string();
725
726        let rng = SystemRandom::new();
727        let mut extra_identifier = [0u8; 16];
728        rng.fill(&mut extra_identifier)
729            .map_err(|_| RemoteSignError::Crypto("unable to generate random value".into()))?;
730
731        let (spake, role_a_init_message) = Spake2::<Ed25519Group>::start_a(
732            &Password::new(shared_secret),
733            &spake_identity(Role::A, &session_id, &extra_identifier),
734            &spake_identity(Role::B, &session_id, &extra_identifier),
735        );
736
737        Ok(Self {
738            sjs: SharedSecretSessionJoinString {
739                session_id,
740                extra_identifier: extra_identifier.as_ref().to_vec(),
741                role_a_init_message,
742            },
743            spake,
744        })
745    }
746}
747
748/// A joining peer using shared secrets.
749struct SharedSecretPeerPreJoined {
750    sjs: SharedSecretSessionJoinString,
751    shared_secret: Option<Vec<u8>>,
752}
753
754impl SessionJoinPeerPreJoin for SharedSecretPeerPreJoined {
755    fn register_state(&mut self, state: SessionJoinState) -> Result<()> {
756        match state {
757            SessionJoinState::SharedSecret(secret) => {
758                self.shared_secret = Some(secret);
759                Ok(())
760            }
761            SessionJoinState::PublicKeyDecrypt(_) => Ok(()),
762        }
763    }
764
765    fn join_context(self: Box<Self>) -> Result<SessionJoinContext> {
766        let shared_secret = self
767            .shared_secret
768            .as_ref()
769            .ok_or_else(|| RemoteSignError::Crypto("shared secret not defined".into()))?;
770
771        let (spake, init_message) = Spake2::<Ed25519Group>::start_b(
772            &Password::new(shared_secret),
773            &spake_identity(Role::A, &self.sjs.session_id, &self.sjs.extra_identifier),
774            &spake_identity(Role::B, &self.sjs.session_id, &self.sjs.extra_identifier),
775        );
776
777        let peer_handshake = Box::new(SharedSecretHandshakePeer {
778            session_id: self.sjs.session_id.clone(),
779            extra_identifier: self.sjs.extra_identifier,
780            role_a_init_message: self.sjs.role_a_init_message,
781            spake,
782        });
783
784        Ok(SessionJoinContext {
785            // TODO set this field if not the default.
786            server_url: None,
787            session_id: self.sjs.session_id,
788            peer_context: Some(init_message),
789            peer_handshake,
790        })
791    }
792}
793
794impl SharedSecretPeerPreJoined {
795    fn new(sjs: SharedSecretSessionJoinString) -> Result<Self> {
796        Ok(Self {
797            sjs,
798            shared_secret: None,
799        })
800    }
801}
802
803pub struct SharedSecretHandshakePeer {
804    session_id: String,
805    extra_identifier: Vec<u8>,
806    role_a_init_message: Vec<u8>,
807    spake: Spake2<Ed25519Group>,
808}
809
810impl SessionJoinPeerHandshake for SharedSecretHandshakePeer {
811    fn negotiate_session(self: Box<Self>) -> Result<PeerKeys> {
812        let shared_key = self.spake.finish(&self.role_a_init_message).map_err(|e| {
813            RemoteSignError::Crypto(format!("error finishing SPAKE2 key negotiation: {e}"))
814        })?;
815
816        let (sealing, opening) = derive_aead_keys(
817            Role::B,
818            shared_key,
819            &self.session_id,
820            &self.extra_identifier,
821        )?;
822
823        Ok(PeerKeys { sealing, opening })
824    }
825}
826
827pub fn create_session_joiner(
828    session_join_string: impl ToString,
829) -> Result<Box<dyn SessionJoinPeerPreJoin>> {
830    let input = session_join_string.to_string();
831
832    let trimmed = input.trim();
833
834    // Multiline is assumed to be PEM.
835    let sjs = if trimmed.contains('\n') {
836        let no_comments = trimmed
837            .lines()
838            .filter(|line| !line.starts_with('#'))
839            .collect::<Vec<_>>()
840            .join("\n");
841
842        let doc = pem::parse(no_comments.as_bytes())?;
843
844        if doc.tag() == "SESSION JOIN STRING" {
845            doc.contents().to_vec()
846        } else {
847            return Err(RemoteSignError::SessionJoinString(
848                "PEM does not define a SESSION JOIN STRING".into(),
849            ));
850        }
851    } else {
852        base64_engine().decode(trimmed.as_bytes())?
853    };
854
855    let mut decoder = Decoder::new(&sjs);
856    if !matches!(
857        decoder.array().map_err(|_| {
858            RemoteSignError::SessionJoinString("decode error: not a CBOR array".into())
859        })?,
860        Some(2)
861    ) {
862        return Err(RemoteSignError::SessionJoinString(
863            "decode error: not a CBOR array with 2 elements".into(),
864        ));
865    }
866
867    let scheme = decoder
868        .str()
869        .map_err(|_| RemoteSignError::SessionJoinString("failed to decode scheme name".into()))?;
870
871    match scheme {
872        _ if scheme == PublicKeySessionJoinString::scheme() => {
873            let sjs = PublicKeySessionJoinString::decode(&mut decoder, &mut ()).map_err(|e| {
874                RemoteSignError::SessionJoinString(format!("error decoding payload: {e}"))
875            })?;
876
877            Ok(Box::new(PublicKeyPeerPreJoined::new(sjs)?) as Box<dyn SessionJoinPeerPreJoin>)
878        }
879        _ if scheme == SharedSecretSessionJoinString::scheme() => {
880            let sjs =
881                SharedSecretSessionJoinString::decode(&mut decoder, &mut ()).map_err(|e| {
882                    RemoteSignError::SessionJoinString(format!("error decoding payload: {e}"))
883                })?;
884
885            Ok(Box::new(SharedSecretPeerPreJoined::new(sjs)?) as Box<dyn SessionJoinPeerPreJoin>)
886        }
887        _ => Err(RemoteSignError::SessionJoinString(format!(
888            "unknown scheme: {scheme}"
889        ))),
890    }
891}