magic_wormhole/transit/
crypto.rs

1//! Cryptographic backbone of the Transit protocol
2//!
3//! This handles the encrypted handshakes during connection setup, then provides
4//! a simple "encrypt/decrypt" abstraction that will be used for all messages.
5
6#![allow(deprecated)]
7
8use super::{
9    TransitError, TransitKey, TransitRxKey, TransitTransport, TransitTransportRx,
10    TransitTransportTx, TransitTxKey,
11};
12use crate::Key;
13use async_trait::async_trait;
14use crypto_secretbox as secretbox;
15use crypto_secretbox::{aead::Aead, KeyInit};
16use futures::{future::BoxFuture, io::AsyncWriteExt};
17use std::sync::Arc;
18
19/// Private, because we try multiple handshakes and only
20/// one needs to succeed
21#[derive(Debug, thiserror::Error)]
22#[non_exhaustive]
23pub(super) enum TransitHandshakeError {
24    #[error("Handshake failed")]
25    HandshakeFailed,
26    #[error("Relay handshake failed")]
27    RelayHandshakeFailed,
28    #[error("Malformed peer address")]
29    BadAddress(
30        #[from]
31        #[source]
32        std::net::AddrParseError,
33    ),
34    #[error("Noise cryptography error")]
35    NoiseCrypto(
36        #[from]
37        #[source]
38        noise_protocol::Error,
39    ),
40    #[error("Decryption error")]
41    Decryption,
42    #[error("IO error")]
43    IO(
44        #[from]
45        #[source]
46        std::io::Error,
47    ),
48    #[cfg(target_family = "wasm")]
49    #[error("WASM error")]
50    WASM(
51        #[from]
52        #[source]
53        ws_stream_wasm::WsErr,
54    ),
55}
56
57impl From<()> for TransitHandshakeError {
58    fn from(_: ()) -> Self {
59        Self::Decryption
60    }
61}
62
63/// The Transit protocol has the property that the last message of the handshake is from the leader
64/// and confirms the usage of that specific connection. This trait represents that specific type state.
65pub(super) trait TransitCryptoInitFinalizer: Send {
66    fn handshake_finalize(
67        self: Box<Self>,
68        socket: &mut dyn TransitTransport,
69    ) -> BoxFuture<Result<DynTransitCrypto, TransitHandshakeError>>;
70}
71
72/// Due to poorly chosen abstractions elsewhere, the [`TransitCryptoInitFinalizer`] trait is also
73/// used by the follower side. Since it is a no-op there, simply implement the trait for the result.
74impl TransitCryptoInitFinalizer for DynTransitCrypto {
75    fn handshake_finalize(
76        self: Box<Self>,
77        _socket: &mut dyn TransitTransport,
78    ) -> BoxFuture<Result<DynTransitCrypto, TransitHandshakeError>> {
79        Box::pin(futures::future::ready(Ok(*self)))
80    }
81}
82
83/// Do a handshake. Multiple handshakes can be started from one instance on multiple streams.
84#[async_trait]
85pub(super) trait TransitCryptoInit: Send + Sync {
86    // Yes, this method returns a nested future. TODO explain
87    async fn handshake_leader(
88        &self,
89        socket: &mut dyn TransitTransport,
90    ) -> Result<Box<dyn TransitCryptoInitFinalizer>, TransitHandshakeError>;
91    async fn handshake_follower(
92        &self,
93        socket: &mut dyn TransitTransport,
94    ) -> Result<Box<dyn TransitCryptoInitFinalizer>, TransitHandshakeError>;
95}
96
97/// The classic Transit cryptography backend, powered by libsodium's "Secretbox" API.
98///
99/// The handshake looks like this (leader perspective):
100/// ```text
101/// -> transit sender ${transit_key.derive("transit_sender)")} ready\n\n
102/// <- transit receiver ${transit_key.derive("transit_receiver")} ready\n\n
103/// -> go\n
104/// ```
105pub struct SecretboxInit {
106    pub key: Arc<Key<TransitKey>>,
107}
108
109#[async_trait]
110impl TransitCryptoInit for SecretboxInit {
111    async fn handshake_leader(
112        &self,
113        socket: &mut dyn TransitTransport,
114    ) -> Result<Box<dyn TransitCryptoInitFinalizer>, TransitHandshakeError> {
115        // 9. create record keys
116        let rkey = self
117            .key
118            .derive_subkey_from_purpose("transit_record_receiver_key");
119        let skey = self
120            .key
121            .derive_subkey_from_purpose("transit_record_sender_key");
122
123        // for transmit mode, send send_handshake_msg and compare.
124        // the received message with send_handshake_msg
125        socket
126            .write_all(
127                format!(
128                    "transit sender {} ready\n\n",
129                    self.key
130                        .derive_subkey_from_purpose::<crate::GenericKey>("transit_sender")
131                        .to_hex()
132                )
133                .as_bytes(),
134            )
135            .await?;
136
137        let expected_rx_handshake = format!(
138            "transit receiver {} ready\n\n",
139            self.key
140                .derive_subkey_from_purpose::<crate::GenericKey>("transit_receiver")
141                .to_hex()
142        );
143        assert_eq!(expected_rx_handshake.len(), 89);
144        socket.read_expect(expected_rx_handshake.as_bytes()).await?;
145
146        struct Finalizer {
147            skey: Key<TransitTxKey>,
148            rkey: Key<TransitRxKey>,
149        }
150
151        impl TransitCryptoInitFinalizer for Finalizer {
152            fn handshake_finalize(
153                self: Box<Self>,
154                socket: &mut dyn TransitTransport,
155            ) -> BoxFuture<Result<DynTransitCrypto, TransitHandshakeError>> {
156                Box::pin(async move {
157                    socket.write_all(b"go\n").await?;
158
159                    Ok::<_, TransitHandshakeError>((
160                        Box::new(SecretboxCryptoEncrypt {
161                            skey: self.skey,
162                            snonce: Default::default(),
163                        }) as Box<dyn TransitCryptoEncrypt>,
164                        Box::new(SecretboxCryptoDecrypt {
165                            rkey: self.rkey,
166                            rnonce: Default::default(),
167                        }) as Box<dyn TransitCryptoDecrypt>,
168                    ))
169                })
170            }
171        }
172
173        Ok(Box::new(Finalizer { skey, rkey }))
174    }
175
176    async fn handshake_follower(
177        &self,
178        socket: &mut dyn TransitTransport,
179    ) -> Result<Box<dyn TransitCryptoInitFinalizer>, TransitHandshakeError> {
180        // 9. create record keys
181        /* The order here is correct. The "sender" and "receiver" side are a misnomer and should be called
182         * "leader" and "follower" instead. As a follower, we use the leader key for receiving and our
183         * key for sending.
184         */
185        let rkey = self
186            .key
187            .derive_subkey_from_purpose("transit_record_sender_key");
188        let skey = self
189            .key
190            .derive_subkey_from_purpose("transit_record_receiver_key");
191
192        // for receive mode, send receive_handshake_msg and compare.
193        // the received message with send_handshake_msg
194        socket
195            .write_all(
196                format!(
197                    "transit receiver {} ready\n\n",
198                    self.key
199                        .derive_subkey_from_purpose::<crate::GenericKey>("transit_receiver")
200                        .to_hex(),
201                )
202                .as_bytes(),
203            )
204            .await?;
205
206        let expected_tx_handshake = format!(
207            "transit sender {} ready\n\ngo\n",
208            self.key
209                .derive_subkey_from_purpose::<crate::GenericKey>("transit_sender")
210                .to_hex(),
211        );
212        assert_eq!(expected_tx_handshake.len(), 90);
213        socket.read_expect(expected_tx_handshake.as_bytes()).await?;
214
215        Ok(Box::new((
216            Box::new(SecretboxCryptoEncrypt {
217                skey,
218                snonce: Default::default(),
219            }) as Box<dyn TransitCryptoEncrypt>,
220            Box::new(SecretboxCryptoDecrypt {
221                rkey,
222                rnonce: Default::default(),
223            }) as Box<dyn TransitCryptoDecrypt>,
224        )) as Box<dyn TransitCryptoInitFinalizer>)
225    }
226}
227
228type NoiseHandshakeState = noise_protocol::HandshakeState<
229    noise_rust_crypto::X25519,
230    noise_rust_crypto::ChaCha20Poly1305,
231    noise_rust_crypto::Blake2s,
232>;
233type NoiseCipherState = noise_protocol::CipherState<noise_rust_crypto::ChaCha20Poly1305>;
234
235/// Cryptography based on the [noise protocol](noiseprotocol.org).
236/// → "Magic-Wormhole Dilation Handshake v1 Leader\n\n"
237/// ← "Magic-Wormhole Dilation Handshake v1 Follower\n\n"
238/// → psk, e // Handshake
239/// ← e, ee
240/// ← "" // First real message
241/// → "" // Not in this method, to confirm the connection
242///
243/// The noise protocol pattern used is "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
244pub struct NoiseInit {
245    pub key: Arc<Key<TransitKey>>,
246}
247
248#[async_trait]
249impl TransitCryptoInit for NoiseInit {
250    async fn handshake_leader(
251        &self,
252        socket: &mut dyn TransitTransport,
253    ) -> Result<Box<dyn TransitCryptoInitFinalizer>, TransitHandshakeError> {
254        socket
255            .write_all(b"Magic-Wormhole Dilation Handshake v1 Leader\n\n")
256            .await?;
257        socket
258            .read_expect(b"Magic-Wormhole Dilation Handshake v1 Follower\n\n")
259            .await?;
260
261        let mut handshake: NoiseHandshakeState = {
262            let mut builder = noise_protocol::HandshakeStateBuilder::new();
263            builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
264            builder.set_prologue(&[]);
265            builder.set_is_initiator(true);
266            builder.build_handshake_state()
267        };
268        handshake.push_psk(&self.key);
269
270        // → psk, e
271        socket
272            .write_transit_message(&handshake.write_message_vec(&[])?)
273            .await?;
274
275        // ← e, ee
276        handshake.read_message(&socket.read_transit_message().await?, &mut [])?;
277
278        assert!(handshake.completed());
279        let (tx, mut rx) = handshake.get_ciphers();
280
281        // ← ""
282        let peer_confirmation_message = rx.decrypt_vec(&socket.read_transit_message().await?)?;
283        ensure!(
284            peer_confirmation_message.is_empty(),
285            TransitHandshakeError::HandshakeFailed
286        );
287
288        struct Finalizer {
289            tx: NoiseCipherState,
290            rx: NoiseCipherState,
291        }
292
293        impl TransitCryptoInitFinalizer for Finalizer {
294            fn handshake_finalize(
295                mut self: Box<Self>,
296                socket: &mut dyn TransitTransport,
297            ) -> BoxFuture<Result<DynTransitCrypto, TransitHandshakeError>> {
298                Box::pin(async move {
299                    // → ""
300                    socket
301                        .write_transit_message(&self.tx.encrypt_vec(&[]))
302                        .await?;
303
304                    Ok::<_, TransitHandshakeError>((
305                        Box::new(NoiseCryptoEncrypt { tx: self.tx })
306                            as Box<dyn TransitCryptoEncrypt>,
307                        Box::new(NoiseCryptoDecrypt { rx: self.rx })
308                            as Box<dyn TransitCryptoDecrypt>,
309                    ))
310                })
311            }
312        }
313
314        Ok(Box::new(Finalizer { tx, rx }))
315    }
316
317    async fn handshake_follower(
318        &self,
319        socket: &mut dyn TransitTransport,
320    ) -> Result<Box<dyn TransitCryptoInitFinalizer>, TransitHandshakeError> {
321        socket
322            .write_all(b"Magic-Wormhole Dilation Handshake v1 Follower\n\n")
323            .await?;
324        socket
325            .read_expect(b"Magic-Wormhole Dilation Handshake v1 Leader\n\n")
326            .await?;
327
328        let mut handshake: NoiseHandshakeState = {
329            let mut builder = noise_protocol::HandshakeStateBuilder::new();
330            builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
331            builder.set_prologue(&[]);
332            builder.set_is_initiator(false);
333            builder.build_handshake_state()
334        };
335        handshake.push_psk(&self.key);
336
337        // ← psk, e
338        handshake.read_message(&socket.read_transit_message().await?, &mut [])?;
339
340        // → e, ee
341        socket
342            .write_transit_message(&handshake.write_message_vec(&[])?)
343            .await?;
344
345        assert!(handshake.completed());
346        // Warning: rx and tx are swapped here (read the `get_ciphers` doc carefully)
347        let (mut rx, mut tx) = handshake.get_ciphers();
348
349        // → ""
350        socket.write_transit_message(&tx.encrypt_vec(&[])).await?;
351
352        // ← ""
353        let peer_confirmation_message = rx.decrypt_vec(&socket.read_transit_message().await?)?;
354        ensure!(
355            peer_confirmation_message.is_empty(),
356            TransitHandshakeError::HandshakeFailed
357        );
358
359        Ok(Box::new((
360            Box::new(NoiseCryptoEncrypt { tx }) as Box<dyn TransitCryptoEncrypt>,
361            Box::new(NoiseCryptoDecrypt { rx }) as Box<dyn TransitCryptoDecrypt>,
362        )) as Box<dyn TransitCryptoInitFinalizer>)
363    }
364}
365
366type DynTransitCrypto = (Box<dyn TransitCryptoEncrypt>, Box<dyn TransitCryptoDecrypt>);
367
368#[async_trait]
369pub(super) trait TransitCryptoEncrypt: Send {
370    async fn encrypt(
371        &mut self,
372        socket: &mut dyn TransitTransportTx,
373        plaintext: &[u8],
374    ) -> Result<(), TransitError>;
375}
376
377#[async_trait]
378pub(super) trait TransitCryptoDecrypt: Send {
379    async fn decrypt(
380        &mut self,
381        socket: &mut dyn TransitTransportRx,
382    ) -> Result<Box<[u8]>, TransitError>;
383}
384
385struct SecretboxCryptoEncrypt {
386    /** Our key, used for sending */
387    pub skey: Key<TransitTxKey>,
388    /** Nonce for sending */
389    pub snonce: secretbox::Nonce,
390}
391
392struct SecretboxCryptoDecrypt {
393    /** Their key, used for receiving */
394    pub rkey: Key<TransitRxKey>,
395    /**
396     * Nonce for receiving
397     *
398     * We'll count as receiver and track if messages come in in order
399     */
400    pub rnonce: secretbox::Nonce,
401}
402
403#[async_trait]
404impl TransitCryptoEncrypt for SecretboxCryptoEncrypt {
405    async fn encrypt(
406        &mut self,
407        socket: &mut dyn TransitTransportTx,
408        plaintext: &[u8],
409    ) -> Result<(), TransitError> {
410        let nonce = &mut self.snonce;
411        let sodium_key = secretbox::Key::from_slice(&self.skey);
412
413        let ciphertext = {
414            let nonce_le = secretbox::Nonce::from_slice(nonce);
415
416            let cipher = secretbox::XSalsa20Poly1305::new(sodium_key);
417            cipher
418                .encrypt(nonce_le, plaintext)
419                /* TODO replace with (TransitError::Crypto) after the next xsalsa20poly1305 update */
420                .map_err(|_| TransitError::Crypto)?
421        };
422
423        // send the encrypted record
424        socket
425            .write_all(&((ciphertext.len() + nonce.len()) as u32).to_be_bytes())
426            .await?;
427        socket.write_all(nonce).await?;
428        socket.write_all(&ciphertext).await?;
429
430        crate::util::sodium_increment_be(nonce);
431
432        Ok(())
433    }
434}
435
436#[async_trait]
437impl TransitCryptoDecrypt for SecretboxCryptoDecrypt {
438    async fn decrypt(
439        &mut self,
440        socket: &mut dyn TransitTransportRx,
441    ) -> Result<Box<[u8]>, TransitError> {
442        let nonce = &mut self.rnonce;
443
444        let enc_packet = socket.read_transit_message().await?;
445
446        use std::io::{Error, ErrorKind};
447        ensure!(
448            enc_packet.len() >= secretbox::SecretBox::<secretbox::XSalsa20Poly1305>::NONCE_SIZE,
449            Error::new(
450                ErrorKind::InvalidData,
451                "Message must be long enough to contain at least the nonce"
452            )
453        );
454
455        // 3. decrypt the vector 'enc_packet' with the key.
456        let plaintext = {
457            let (received_nonce, ciphertext) = enc_packet
458                .split_at(secretbox::SecretBox::<secretbox::XSalsa20Poly1305>::NONCE_SIZE);
459            {
460                // Nonce check
461                ensure!(
462                    nonce.as_slice() == received_nonce,
463                    TransitError::Nonce(received_nonce.into(), nonce.as_slice().into()),
464                );
465
466                crate::util::sodium_increment_be(nonce);
467            }
468
469            let cipher = secretbox::XSalsa20Poly1305::new(secretbox::Key::from_slice(&self.rkey));
470            cipher
471                .decrypt(secretbox::Nonce::from_slice(received_nonce), ciphertext)
472                /* TODO replace with (TransitError::Crypto) after the next xsalsa20poly1305 update */
473                .map_err(|_| TransitError::Crypto)?
474        };
475
476        Ok(plaintext.into_boxed_slice())
477    }
478}
479
480struct NoiseCryptoEncrypt {
481    tx: NoiseCipherState,
482}
483
484struct NoiseCryptoDecrypt {
485    rx: NoiseCipherState,
486}
487
488#[async_trait]
489impl TransitCryptoEncrypt for NoiseCryptoEncrypt {
490    async fn encrypt(
491        &mut self,
492        socket: &mut dyn TransitTransportTx,
493        plaintext: &[u8],
494    ) -> Result<(), TransitError> {
495        socket
496            .write_transit_message(&self.tx.encrypt_vec(plaintext))
497            .await?;
498        Ok(())
499    }
500}
501
502#[async_trait]
503impl TransitCryptoDecrypt for NoiseCryptoDecrypt {
504    async fn decrypt(
505        &mut self,
506        socket: &mut dyn TransitTransportRx,
507    ) -> Result<Box<[u8]>, TransitError> {
508        let plaintext = self.rx.decrypt_vec(&socket.read_transit_message().await?)?;
509        Ok(plaintext.into_boxed_slice())
510    }
511}