iroh_relay/protos/
handshake.rs

1//! Implements the handshake protocol that authenticates and authorizes clients connecting to the relays.
2//!
3//! The purpose of the handshake is to
4//! 1. Inform the relay of the client's EndpointId
5//! 2. Check that the connecting client owns the secret key for its EndpointId ("is authentic"/"authentication")
6//! 3. Possibly check that the client has access to this relay, if the relay requires authorization.
7//!
8//! Additional complexity comes from the fact that there's two ways that clients can authenticate with
9//! relays.
10//!
11//! One way is via an explicitly sent challenge:
12//!
13//! 1. Once a websocket connection is opened, a client receives a challenge (the `ServerChallenge` frame)
14//! 2. The client sends back what is essentially a signature of that challenge with their secret key
15//!    that matches the EndpointId they have, as well as the EndpointId (the `ClientAuth` frame)
16//!
17//! The second way is very similar to the [Concealed HTTP Auth RFC], and involves send a header that
18//! contains a signature of some shared keying material extracted from TLS ([RFC 5705]).
19//!
20//! The second way can save a full round trip, because the challenge doesn't have to be sent to the client
21//! first, however, it won't always work, as it relies on the keying material extraction feature of TLS,
22//! which is not available in browsers (but might be in the future?) and might break when there's an
23//! HTTPS proxy that doesn't properly deal with this TLS feature.
24//!
25//! [Concealed HTTP Auth RFC]: https://datatracker.ietf.org/doc/rfc9729/
26//! [RFC 5705]: https://datatracker.ietf.org/doc/html/rfc5705
27use bytes::{BufMut, Bytes, BytesMut};
28use data_encoding::BASE32HEX_NOPAD as HEX;
29#[cfg(not(wasm_browser))]
30use http::HeaderValue;
31#[cfg(feature = "server")]
32use iroh_base::Signature;
33use iroh_base::{PublicKey, SecretKey};
34use n0_error::{e, ensure, stack_error};
35use n0_future::{SinkExt, TryStreamExt};
36#[cfg(feature = "server")]
37use rand::CryptoRng;
38use tracing::trace;
39
40use super::{
41    common::{FrameType, FrameTypeError},
42    streams::BytesStreamSink,
43};
44use crate::ExportKeyingMaterial;
45
46/// Domain separation string for the [`ServerChallenge`] signature
47const DOMAIN_SEP_CHALLENGE: &str = "iroh-relay handshake v1 challenge signature";
48
49/// Domain separation label for [`KeyMaterialClientAuth`]'s use of [`ExportKeyingMaterial`]
50#[cfg(not(wasm_browser))]
51const DOMAIN_SEP_TLS_EXPORT_LABEL: &[u8] = b"iroh-relay handshake v1";
52
53/// Authentication message from the client.
54#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
55#[cfg_attr(wasm_browser, allow(unused))]
56pub(crate) struct KeyMaterialClientAuth {
57    /// The client's public key
58    pub(crate) public_key: PublicKey,
59    /// A signature of (a hash of) extracted key material.
60    #[serde(with = "serde_bytes")]
61    #[debug("{}", HEX.encode(signature))]
62    pub(crate) signature: [u8; 64],
63    /// Part of the extracted key material.
64    ///
65    /// Allows making sure we have the same underlying key material.
66    #[debug("{}", HEX.encode(key_material_suffix))]
67    pub(crate) key_material_suffix: [u8; 16],
68}
69
70/// A challenge for the client to sign with their secret key for EndpointId authentication.
71#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
72pub(crate) struct ServerChallenge {
73    /// The challenge to sign.
74    /// Must be randomly generated with an RNG that is safe to use for crypto.
75    #[debug("{}", HEX.encode(challenge))]
76    pub(crate) challenge: [u8; 16],
77}
78
79/// Authentication message from the client.
80///
81/// Used when authentication via [`KeyMaterialClientAuth`] didn't work.
82#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
83pub(crate) struct ClientAuth {
84    /// The client's public key, a.k.a. the `EndpointId`
85    pub(crate) public_key: PublicKey,
86    /// A signature of (a hash of) the [`ServerChallenge`].
87    ///
88    /// This is what provides the authentication.
89    #[serde(with = "serde_bytes")]
90    #[debug("{}", HEX.encode(signature))]
91    pub(crate) signature: [u8; 64],
92}
93
94/// Confirmation of successful connection.
95#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
96pub(crate) struct ServerConfirmsAuth;
97
98/// Denial of connection. The client couldn't be verified as authentic.
99#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
100pub(crate) struct ServerDeniesAuth {
101    reason: String,
102}
103
104/// Trait for getting the frame type tag for a frame.
105///
106/// Used only in the handshake, as the frame we expect next
107/// is fairly stateful.
108/// Not used in the send/recv protocol, as any frame is
109/// allowed to happen at any time there.
110trait Frame {
111    /// The frame type this frame is identified by and prefixed with
112    const TAG: FrameType;
113}
114
115impl<T: Frame> Frame for &T {
116    const TAG: FrameType = T::TAG;
117}
118
119impl Frame for ServerChallenge {
120    const TAG: FrameType = FrameType::ServerChallenge;
121}
122
123impl Frame for ClientAuth {
124    const TAG: FrameType = FrameType::ClientAuth;
125}
126
127impl Frame for ServerConfirmsAuth {
128    const TAG: FrameType = FrameType::ServerConfirmsAuth;
129}
130
131impl Frame for ServerDeniesAuth {
132    const TAG: FrameType = FrameType::ServerDeniesAuth;
133}
134
135#[stack_error(derive, add_meta)]
136#[allow(missing_docs)]
137#[non_exhaustive]
138pub enum Error {
139    #[error(transparent)]
140    Websocket {
141        #[cfg(not(wasm_browser))]
142        #[error(from, std_err)]
143        source: tokio_websockets::Error,
144        #[cfg(wasm_browser)]
145        #[error(from, std_err)]
146        source: ws_stream_wasm::WsErr,
147    },
148    #[error("Handshake stream ended prematurely")]
149    UnexpectedEnd {},
150    #[error(transparent)]
151    FrameTypeError {
152        #[error(from)]
153        source: FrameTypeError,
154    },
155    #[error("The relay denied our authentication ({reason})")]
156    ServerDeniedAuth { reason: String },
157    #[error("Unexpected tag, got {frame_type:?}, but expected one of {expected_types:?}")]
158    UnexpectedFrameType {
159        frame_type: FrameType,
160        expected_types: Vec<FrameType>,
161    },
162    #[error("Handshake failed while deserializing {frame_type:?} frame")]
163    DeserializationError {
164        frame_type: FrameType,
165        #[error(std_err)]
166        source: postcard::Error,
167    },
168    #[cfg(feature = "server")]
169    /// Failed to deserialize client auth header
170    ClientAuthHeaderInvalid { value: HeaderValue },
171}
172
173#[cfg(feature = "server")]
174#[stack_error(derive, add_meta)]
175pub(crate) enum VerificationError {
176    #[error("Couldn't export TLS keying material on our end")]
177    NoKeyingMaterial,
178    #[error(
179        "Client didn't extract the same keying material, the suffix mismatched: expected {expected:X?} but got {actual:X?}"
180    )]
181    MismatchedSuffix {
182        expected: [u8; 16],
183        actual: [u8; 16],
184    },
185    #[error(
186        "Client signature {signature:X?} for message {message:X?} invalid for public key {public_key}"
187    )]
188    SignatureInvalid {
189        source: iroh_base::SignatureError,
190        message: Vec<u8>,
191        signature: [u8; 64],
192        public_key: PublicKey,
193    },
194}
195
196impl ServerChallenge {
197    /// Generates a new challenge.
198    #[cfg(feature = "server")]
199    pub(crate) fn new<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
200        let mut challenge = [0u8; 16];
201        rng.fill_bytes(&mut challenge);
202        Self { challenge }
203    }
204
205    /// The actual message bytes to sign (and verify against) for this challenge.
206    fn message_to_sign(&self) -> [u8; 32] {
207        // We're signing a key instead of the direct challenge.
208        // This gives us domain separation protecting from multiple possible attacks,
209        // but especially this one:
210        // Assume a malicious relay. If the protocol required the client to sign the
211        // challenge directly, this would allow the relay to obtain an arbitrary 16-byte
212        // signature, if it maliciously choses the challenge instead of generating it
213        // randomly.
214        // Deriving a key to sign instead mitigates this attack.
215        blake3::derive_key(DOMAIN_SEP_CHALLENGE, &self.challenge)
216    }
217}
218
219impl ClientAuth {
220    /// Generates a signature for the given challenge from the server.
221    pub(crate) fn new(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self {
222        Self {
223            public_key: secret_key.public(),
224            signature: secret_key.sign(&challenge.message_to_sign()).to_bytes(),
225        }
226    }
227
228    /// Verifies this client's authentication given the challenge this was sent in response to.
229    #[cfg(feature = "server")]
230    pub(crate) fn verify(&self, challenge: &ServerChallenge) -> Result<(), Box<VerificationError>> {
231        let message = challenge.message_to_sign();
232        self.public_key
233            .verify(&message, &Signature::from_bytes(&self.signature))
234            .map_err(|err| {
235                e!(VerificationError::SignatureInvalid {
236                    source: err,
237                    message: message.to_vec(),
238                    signature: self.signature,
239                    public_key: self.public_key
240                })
241            })
242            .map_err(Box::new)
243    }
244}
245
246#[cfg(not(wasm_browser))]
247impl KeyMaterialClientAuth {
248    /// Generates a client's authentication, similar to [`ClientAuth`], but by using TLS keying material
249    /// instead of a received challenge.
250    pub(crate) fn new(secret_key: &SecretKey, io: &impl ExportKeyingMaterial) -> Option<Self> {
251        let public_key = secret_key.public();
252        let key_material = io.export_keying_material(
253            [0u8; 32],
254            DOMAIN_SEP_TLS_EXPORT_LABEL,
255            Some(secret_key.public().as_bytes()),
256        )?;
257        // We split the export and only sign the first 16 bytes, and
258        // pass through the last 16 bytes. See also the note in [Self::verify].
259        let (message, suffix) = key_material.split_at(16);
260        Some(Self {
261            public_key,
262            signature: secret_key.sign(message).to_bytes(),
263            key_material_suffix: suffix.try_into().expect("hardcoded length"),
264        })
265    }
266
267    /// Generate the base64url-nopad-encoded header value.
268    pub(crate) fn into_header_value(self) -> HeaderValue {
269        HeaderValue::from_str(
270            &data_encoding::BASE64URL_NOPAD
271                .encode(&postcard::to_allocvec(&self).expect("encoding never fails")),
272        )
273        .expect("BASE64URL_NOPAD encoding contained invisible ascii characters")
274    }
275
276    /// Verifies this client auth on the server side using the same key material.
277    ///
278    /// This might return false for a couple of reasons:
279    /// 1. The exported keying material might not be the same between both ends of the TLS session
280    ///    (e.g. there's an HTTPS proxy in between that doesn't think/care about the TLS keying material exporter).
281    ///    This situation is detected when the key material suffix mismatches.
282    /// 2. The signature itself doesn't verify.
283    #[cfg(feature = "server")]
284    pub(crate) fn verify(
285        &self,
286        io: &impl ExportKeyingMaterial,
287    ) -> Result<(), Box<VerificationError>> {
288        let key_material = io
289            .export_keying_material(
290                [0u8; 32],
291                DOMAIN_SEP_TLS_EXPORT_LABEL,
292                Some(self.public_key.as_bytes()),
293            )
294            .ok_or_else(|| e!(VerificationError::NoKeyingMaterial))?;
295        // We split the export and only sign the first 16 bytes, and
296        // pass through the last 16 bytes.
297        // Passing on the suffix helps the verifying end figure out what
298        // went wrong: If there's a suffix mismatch, then the exported keying
299        // material on both ends wasn't the same - so perhaps there was a
300        // TLS proxy in between or similar.
301        // If the suffix does match, but the signature doesn't verify, then
302        // there must be something wrong with the client's secret key or signature.
303        let (message, suffix) = key_material.split_at(16);
304        let suffix: [u8; 16] = suffix.try_into().expect("hardcoded length");
305        ensure!(
306            suffix == self.key_material_suffix,
307            VerificationError::MismatchedSuffix {
308                expected: self.key_material_suffix,
309                actual: suffix
310            }
311        );
312        // NOTE: We don't blake3-hash here as we do it in [`ServerChallenge::message_to_sign`],
313        // because we already have a domain separation string and keyed hashing step in
314        // the TLS export keying material above.
315        self.public_key
316            .verify(message, &Signature::from_bytes(&self.signature))
317            .map_err(|err| {
318                e!(VerificationError::SignatureInvalid {
319                    source: err,
320                    message: message.to_vec(),
321                    public_key: self.public_key,
322                    signature: self.signature
323                })
324            })
325            .map_err(Box::new)
326    }
327}
328
329/// Runs the client side of the handshake protocol.
330///
331/// See the module docs for details on the protocol.
332/// This is already after having potentially transferred a [`KeyMaterialClientAuth`],
333/// but before having received a response for whether that worked or not.
334///
335/// This requires access to the client's secret key to sign a challenge.
336pub(crate) async fn clientside(
337    io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
338    secret_key: &SecretKey,
339) -> Result<ServerConfirmsAuth, Error> {
340    let (tag, frame) = read_frame(io, &[ServerChallenge::TAG, ServerConfirmsAuth::TAG]).await?;
341
342    let (tag, frame) = if tag == ServerChallenge::TAG {
343        let challenge: ServerChallenge = deserialize_frame(frame)?;
344
345        let client_info = ClientAuth::new(secret_key, &challenge);
346        write_frame(io, client_info).await?;
347
348        read_frame(io, &[ServerConfirmsAuth::TAG, ServerDeniesAuth::TAG]).await?
349    } else {
350        (tag, frame)
351    };
352
353    match tag {
354        FrameType::ServerConfirmsAuth => {
355            let confirmation: ServerConfirmsAuth = deserialize_frame(frame)?;
356            Ok(confirmation)
357        }
358        FrameType::ServerDeniesAuth => {
359            let denial: ServerDeniesAuth = deserialize_frame(frame)?;
360            Err(e!(Error::ServerDeniedAuth {
361                reason: denial.reason
362            }))
363        }
364        _ => unreachable!(),
365    }
366}
367
368/// This represents successful authentication for the client with the `client_key` public key
369/// via the authentication [`Mechanism`] `mechanism`.
370///
371/// You must call [`SuccessfulAuthentication::authorize_if`] to finish the protocol.
372#[cfg(feature = "server")]
373#[derive(Debug)]
374#[must_use = "the protocol is not finished unless `authorize_if` is called"]
375pub(crate) struct SuccessfulAuthentication {
376    pub(crate) client_key: PublicKey,
377    pub(crate) mechanism: Mechanism,
378}
379
380/// The mechanism that was used for authentication.
381#[cfg(feature = "server")]
382#[derive(Debug, Clone, Copy, PartialEq, Eq)]
383pub(crate) enum Mechanism {
384    /// Authentication was performed by verifying a signature of a challenge we sent
385    SignedChallenge,
386    /// Authentication was performed by verifying a signature of shared extracted TLS keying material
387    SignedKeyMaterial,
388}
389
390/// Runs the server side of the handshaking protocol.
391///
392/// See the module documentation for an overview of the handshaking protocol.
393///
394/// This takes `rng` to generate cryptographic randomness for the authentication challenge.
395///
396/// This also takes the `client_auth_header`, if present, to perform authentication without
397/// requiring sending a challenge, saving a round-trip, if possible.
398///
399/// If this fails, the protocol falls back to doing a normal extra round trip with a challenge.
400///
401/// The return value [`SuccessfulAuthentication`] still needs to be resolved by calling
402/// [`SuccessfulAuthentication::authorize_if`] to finish the whole authorization protocol
403/// (otherwise the client won't be notified about auth success or failure).
404#[cfg(feature = "server")]
405pub(crate) async fn serverside(
406    io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
407    client_auth_header: Option<HeaderValue>,
408) -> Result<SuccessfulAuthentication, Error> {
409    if let Some(client_auth_header) = client_auth_header {
410        let client_auth_bytes = data_encoding::BASE64URL_NOPAD
411            .decode(client_auth_header.as_ref())
412            .map_err(|_| {
413                e!(Error::ClientAuthHeaderInvalid {
414                    value: client_auth_header.clone()
415                })
416            })?;
417
418        let client_auth: KeyMaterialClientAuth =
419            postcard::from_bytes(&client_auth_bytes).map_err(|_| {
420                e!(Error::ClientAuthHeaderInvalid {
421                    value: client_auth_header.clone()
422                })
423            })?;
424
425        if client_auth.verify(io).is_ok() {
426            trace!(?client_auth.public_key, "authentication succeeded via keying material");
427            return Ok(SuccessfulAuthentication {
428                client_key: client_auth.public_key,
429                mechanism: Mechanism::SignedKeyMaterial,
430            });
431        }
432        // Verification not succeeding is part of normal operation: The TLS exporter isn't required to match.
433        // We'll fall back to verification that takes another round trip more time.
434    }
435
436    let challenge = ServerChallenge::new(&mut rand::rng());
437    write_frame(io, &challenge).await?;
438
439    let (_, frame) = read_frame(io, &[ClientAuth::TAG]).await?;
440    let client_auth: ClientAuth = deserialize_frame(frame)?;
441
442    if let Err(err) = client_auth.verify(&challenge) {
443        trace!(?client_auth.public_key, ?err, "authentication failed");
444        let denial = ServerDeniesAuth {
445            reason: "signature invalid".into(),
446        };
447        write_frame(io, denial.clone()).await?;
448        Err(e!(Error::ServerDeniedAuth {
449            reason: denial.reason
450        }))
451    } else {
452        trace!(?client_auth.public_key, "authentication succeeded via challenge");
453        Ok(SuccessfulAuthentication {
454            client_key: client_auth.public_key,
455            mechanism: Mechanism::SignedChallenge,
456        })
457    }
458}
459
460#[cfg(feature = "server")]
461impl SuccessfulAuthentication {
462    pub async fn authorize_if(
463        self,
464        is_authorized: bool,
465        io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
466    ) -> Result<PublicKey, Error> {
467        if is_authorized {
468            trace!("authorizing client");
469            write_frame(io, ServerConfirmsAuth).await?;
470            Ok(self.client_key)
471        } else {
472            trace!("denying client auth");
473            let denial = ServerDeniesAuth {
474                reason: "not authorized".into(),
475            };
476            write_frame(io, denial.clone()).await?;
477            Err(e!(Error::ServerDeniedAuth {
478                reason: denial.reason
479            }))
480        }
481    }
482}
483
484async fn write_frame<F: serde::Serialize + Frame>(
485    io: &mut impl BytesStreamSink,
486    frame: F,
487) -> Result<(), Error> {
488    let mut bytes = BytesMut::new();
489    trace!(frame_type = ?F::TAG, "Writing frame");
490    F::TAG.write_to(&mut bytes);
491    let bytes = postcard::to_io(&frame, bytes.writer())
492        .expect("serialization failed") // buffer can't become "full" without being a critical failure, datastructures shouldn't ever fail serialization
493        .into_inner()
494        .freeze();
495    io.send(bytes).await?;
496    io.flush().await?;
497    Ok(())
498}
499
500async fn read_frame(
501    io: &mut impl BytesStreamSink,
502    expected_types: &[FrameType],
503) -> Result<(FrameType, Bytes), Error> {
504    let mut payload = io
505        .try_next()
506        .await?
507        .ok_or_else(|| e!(Error::UnexpectedEnd))?;
508
509    let frame_type = FrameType::from_bytes(&mut payload)?;
510    trace!(?frame_type, "Reading frame");
511    ensure!(
512        expected_types.contains(&frame_type),
513        Error::UnexpectedFrameType {
514            frame_type,
515            expected_types: expected_types.to_vec()
516        }
517    );
518
519    Ok((frame_type, payload))
520}
521
522fn deserialize_frame<F: Frame + serde::de::DeserializeOwned>(frame: Bytes) -> Result<F, Error> {
523    postcard::from_bytes(&frame).map_err(|err| {
524        e!(Error::DeserializationError {
525            frame_type: F::TAG,
526            source: err
527        })
528    })
529}
530
531#[cfg(all(test, feature = "server"))]
532mod tests {
533    use bytes::BytesMut;
534    use iroh_base::{PublicKey, SecretKey};
535    use n0_error::{Result, StackResultExt, StdResultExt};
536    use n0_future::{Sink, SinkExt, Stream, TryStreamExt};
537    use rand::SeedableRng;
538    use tokio_util::codec::{Framed, LengthDelimitedCodec};
539    use tracing::{Instrument, info_span};
540    use tracing_test::traced_test;
541
542    use super::{
543        ClientAuth, KeyMaterialClientAuth, Mechanism, ServerChallenge, ServerConfirmsAuth,
544    };
545    use crate::ExportKeyingMaterial;
546
547    struct TestKeyingMaterial<IO> {
548        shared_secret: Option<u64>,
549        inner: IO,
550    }
551
552    trait WithTlsSharedSecret: Sized {
553        fn with_shared_secret(self, shared_secret: Option<u64>) -> TestKeyingMaterial<Self>;
554    }
555
556    impl<T: Sized> WithTlsSharedSecret for T {
557        fn with_shared_secret(self, shared_secret: Option<u64>) -> TestKeyingMaterial<Self> {
558            TestKeyingMaterial {
559                shared_secret,
560                inner: self,
561            }
562        }
563    }
564
565    impl<IO> ExportKeyingMaterial for TestKeyingMaterial<IO> {
566        fn export_keying_material<T: AsMut<[u8]>>(
567            &self,
568            mut output: T,
569            label: &[u8],
570            context: Option<&[u8]>,
571        ) -> Option<T> {
572            // we simulate something like exporting keying material using blake3
573
574            let label_key = blake3::hash(label);
575            let context_key = blake3::keyed_hash(label_key.as_bytes(), context.unwrap_or(&[]));
576            let mut hasher = blake3::Hasher::new_keyed(context_key.as_bytes());
577            hasher.update(&self.shared_secret?.to_le_bytes());
578            hasher.finalize_xof().fill(output.as_mut());
579
580            Some(output)
581        }
582    }
583
584    impl<V, IO: Stream<Item = V> + Unpin> Stream for TestKeyingMaterial<IO> {
585        type Item = V;
586
587        fn poll_next(
588            mut self: std::pin::Pin<&mut Self>,
589            cx: &mut std::task::Context<'_>,
590        ) -> std::task::Poll<Option<Self::Item>> {
591            std::pin::Pin::new(&mut self.inner).poll_next(cx)
592        }
593    }
594
595    impl<V, E, IO: Sink<V, Error = E> + Unpin> Sink<V> for TestKeyingMaterial<IO> {
596        type Error = E;
597
598        fn poll_ready(
599            mut self: std::pin::Pin<&mut Self>,
600            cx: &mut std::task::Context<'_>,
601        ) -> std::task::Poll<Result<(), Self::Error>> {
602            std::pin::Pin::new(&mut self.inner).poll_ready(cx)
603        }
604
605        fn start_send(mut self: std::pin::Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
606            std::pin::Pin::new(&mut self.inner).start_send(item)
607        }
608
609        fn poll_flush(
610            mut self: std::pin::Pin<&mut Self>,
611            cx: &mut std::task::Context<'_>,
612        ) -> std::task::Poll<Result<(), Self::Error>> {
613            std::pin::Pin::new(&mut self.inner).poll_flush(cx)
614        }
615
616        fn poll_close(
617            mut self: std::pin::Pin<&mut Self>,
618            cx: &mut std::task::Context<'_>,
619        ) -> std::task::Poll<Result<(), Self::Error>> {
620            std::pin::Pin::new(&mut self.inner).poll_close(cx)
621        }
622    }
623
624    async fn simulate_handshake(
625        secret_key: &SecretKey,
626        client_shared_secret: Option<u64>,
627        server_shared_secret: Option<u64>,
628        restricted_to: Option<PublicKey>,
629    ) -> (Result<ServerConfirmsAuth>, Result<(PublicKey, Mechanism)>) {
630        let (client, server) = tokio::io::duplex(1024);
631
632        let mut client_io = Framed::new(client, LengthDelimitedCodec::new())
633            .map_ok(BytesMut::freeze)
634            .map_err(tokio_websockets::Error::Io)
635            .sink_map_err(tokio_websockets::Error::Io)
636            .with_shared_secret(client_shared_secret);
637        let mut server_io = Framed::new(server, LengthDelimitedCodec::new())
638            .map_ok(BytesMut::freeze)
639            .map_err(tokio_websockets::Error::Io)
640            .sink_map_err(tokio_websockets::Error::Io)
641            .with_shared_secret(server_shared_secret);
642
643        let client_auth_header = KeyMaterialClientAuth::new(secret_key, &client_io)
644            .map(KeyMaterialClientAuth::into_header_value);
645
646        n0_future::future::zip(
647            async {
648                super::clientside(&mut client_io, secret_key)
649                    .await
650                    .context("clientside")
651            }
652            .instrument(info_span!("clientside")),
653            async {
654                let auth_n = super::serverside(&mut server_io, client_auth_header)
655                    .await
656                    .context("serverside")?;
657                let mechanism = auth_n.mechanism;
658                let is_authorized = restricted_to.is_none_or(|key| key == auth_n.client_key);
659                let key = auth_n.authorize_if(is_authorized, &mut server_io).await?;
660                Ok((key, mechanism))
661            }
662            .instrument(info_span!("serverside")),
663        )
664        .await
665    }
666
667    #[tokio::test]
668    #[traced_test]
669    async fn test_handshake_via_shared_secrets() -> Result {
670        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
671
672        let secret_key = SecretKey::generate(&mut rng);
673        let (client, server) = simulate_handshake(&secret_key, Some(42), Some(42), None).await;
674        client?;
675        let (public_key, auth) = server?;
676        assert_eq!(public_key, secret_key.public());
677        assert_eq!(auth, Mechanism::SignedKeyMaterial); // it got verified via shared key material
678        Ok(())
679    }
680
681    #[tokio::test]
682    #[traced_test]
683    async fn test_handshake_via_challenge() -> Result {
684        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
685
686        let secret_key = SecretKey::generate(&mut rng);
687        let (client, server) = simulate_handshake(&secret_key, None, None, None).await;
688        client?;
689        let (public_key, auth) = server?;
690        assert_eq!(public_key, secret_key.public());
691        assert_eq!(auth, Mechanism::SignedChallenge);
692        Ok(())
693    }
694
695    #[tokio::test]
696    #[traced_test]
697    async fn test_handshake_mismatching_shared_secrets() -> Result {
698        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
699
700        let secret_key = SecretKey::generate(&mut rng);
701        // mismatching shared secrets *might* happen with HTTPS proxies that don't also middle-man the shared secret
702        let (client, server) = simulate_handshake(&secret_key, Some(10), Some(99), None).await;
703        client?;
704        let (public_key, auth) = server?;
705        assert_eq!(public_key, secret_key.public());
706        assert_eq!(auth, Mechanism::SignedChallenge);
707        Ok(())
708    }
709
710    #[tokio::test]
711    #[traced_test]
712    async fn test_handshake_challenge_fallback() -> Result {
713        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
714        let secret_key = SecretKey::generate(&mut rng);
715        // clients might not have access to shared secrets
716        let (client, server) = simulate_handshake(&secret_key, None, Some(99), None).await;
717        client?;
718        let (public_key, auth) = server?;
719        assert_eq!(public_key, secret_key.public());
720        assert_eq!(auth, Mechanism::SignedChallenge);
721        Ok(())
722    }
723
724    #[tokio::test]
725    #[traced_test]
726    async fn test_handshake_with_auth_positive() -> Result {
727        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
728        let secret_key = SecretKey::generate(&mut rng);
729        let public_key = secret_key.public();
730        let (client, server) = simulate_handshake(&secret_key, None, None, Some(public_key)).await;
731        client?;
732        let (public_key, _) = server?;
733        assert_eq!(public_key, secret_key.public());
734        Ok(())
735    }
736
737    #[tokio::test]
738    #[traced_test]
739    async fn test_handshake_with_auth_negative() -> Result {
740        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
741        let secret_key = SecretKey::generate(&mut rng);
742        let public_key = secret_key.public();
743        let wrong_secret_key = SecretKey::generate(&mut rng);
744        let (client, server) =
745            simulate_handshake(&wrong_secret_key, None, None, Some(public_key)).await;
746        assert!(client.is_err());
747        assert!(server.is_err());
748        Ok(())
749    }
750
751    #[tokio::test]
752    #[traced_test]
753    async fn test_handshake_via_shared_secret_with_auth_negative() -> Result {
754        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
755        let secret_key = SecretKey::generate(&mut rng);
756        let public_key = secret_key.public();
757        let wrong_secret_key = SecretKey::generate(&mut rng);
758        let (client, server) =
759            simulate_handshake(&wrong_secret_key, Some(42), Some(42), Some(public_key)).await;
760        assert!(client.is_err());
761        assert!(server.is_err());
762        Ok(())
763    }
764
765    #[test]
766    fn test_client_auth_roundtrip() -> Result {
767        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
768        let secret_key = SecretKey::generate(&mut rng);
769        let challenge = ServerChallenge::new(&mut rng);
770        let client_auth = ClientAuth::new(&secret_key, &challenge);
771
772        let bytes = postcard::to_allocvec(&client_auth).anyerr()?;
773        let decoded: ClientAuth = postcard::from_bytes(&bytes).anyerr()?;
774
775        assert_eq!(client_auth.public_key, decoded.public_key);
776        assert_eq!(client_auth.signature, decoded.signature);
777
778        Ok(())
779    }
780
781    #[test]
782    fn test_km_client_auth_roundtrip() -> Result {
783        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
784        let secret_key = SecretKey::generate(&mut rng);
785        let client_auth = KeyMaterialClientAuth::new(
786            &secret_key,
787            &TestKeyingMaterial {
788                inner: (),
789                shared_secret: Some(42),
790            },
791        )
792        .anyerr()?;
793
794        let bytes = postcard::to_allocvec(&client_auth).anyerr()?;
795        let decoded: KeyMaterialClientAuth = postcard::from_bytes(&bytes).anyerr()?;
796
797        assert_eq!(client_auth.public_key, decoded.public_key);
798        assert_eq!(client_auth.signature, decoded.signature);
799
800        Ok(())
801    }
802
803    #[test]
804    fn test_challenge_verification() -> Result {
805        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
806        let secret_key = SecretKey::generate(&mut rng);
807        let challenge = ServerChallenge::new(&mut rng);
808        let client_auth = ClientAuth::new(&secret_key, &challenge);
809        assert!(client_auth.verify(&challenge).is_ok());
810
811        Ok(())
812    }
813
814    #[test]
815    fn test_key_material_verification() -> Result {
816        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
817        let secret_key = SecretKey::generate(&mut rng);
818        let io = TestKeyingMaterial {
819            inner: (),
820            shared_secret: Some(42),
821        };
822        let client_auth = KeyMaterialClientAuth::new(&secret_key, &io).anyerr()?;
823        assert!(client_auth.verify(&io).is_ok());
824
825        Ok(())
826    }
827}