iroh-relay 0.98.0

Iroh's relay server and client
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
//! Implements the handshake protocol that authenticates and authorizes clients connecting to the relays.
//!
//! The purpose of the handshake is to
//! 1. Inform the relay of the client's EndpointId
//! 2. Check that the connecting client owns the secret key for its EndpointId ("is authentic"/"authentication")
//! 3. Possibly check that the client has access to this relay, if the relay requires authorization.
//!
//! Additional complexity comes from the fact that there's two ways that clients can authenticate with
//! relays.
//!
//! One way is via an explicitly sent challenge:
//!
//! 1. Once a websocket connection is opened, a client receives a challenge (the `ServerChallenge` frame)
//! 2. The client sends back what is essentially a signature of that challenge with their secret key
//!    that matches the EndpointId they have, as well as the EndpointId (the `ClientAuth` frame)
//!
//! The second way is very similar to the [Concealed HTTP Auth RFC], and involves send a header that
//! contains a signature of some shared keying material extracted from TLS ([RFC 5705]).
//!
//! The second way can save a full round trip, because the challenge doesn't have to be sent to the client
//! first, however, it won't always work, as it relies on the keying material extraction feature of TLS,
//! which is not available in browsers (but might be in the future?) and might break when there's an
//! HTTPS proxy that doesn't properly deal with this TLS feature.
//!
//! [Concealed HTTP Auth RFC]: https://datatracker.ietf.org/doc/rfc9729/
//! [RFC 5705]: https://datatracker.ietf.org/doc/html/rfc5705
use bytes::{BufMut, Bytes, BytesMut};
use data_encoding::BASE32HEX_NOPAD as HEX;
#[cfg(not(wasm_browser))]
use http::HeaderValue;
#[cfg(feature = "server")]
use iroh_base::Signature;
use iroh_base::{PublicKey, SecretKey};
use n0_error::{e, ensure, stack_error};
use n0_future::{SinkExt, TryStreamExt};
#[cfg(feature = "server")]
use rand::CryptoRng;
use tracing::trace;

use super::{
    common::{FrameType, FrameTypeError},
    streams::BytesStreamSink,
};
use crate::ExportKeyingMaterial;

/// Domain separation string for the [`ServerChallenge`] signature
const DOMAIN_SEP_CHALLENGE: &str = "iroh-relay handshake v1 challenge signature";

/// Domain separation label for [`KeyMaterialClientAuth`]'s use of [`ExportKeyingMaterial`]
#[cfg(not(wasm_browser))]
const DOMAIN_SEP_TLS_EXPORT_LABEL: &[u8] = b"iroh-relay handshake v1";

/// Authentication message from the client.
#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
#[cfg_attr(wasm_browser, allow(unused))]
pub(crate) struct KeyMaterialClientAuth {
    /// The client's public key
    pub(crate) public_key: PublicKey,
    /// A signature of (a hash of) extracted key material.
    #[serde(with = "serde_bytes")]
    #[debug("{}", HEX.encode(signature))]
    pub(crate) signature: [u8; 64],
    /// Part of the extracted key material.
    ///
    /// Allows making sure we have the same underlying key material.
    #[debug("{}", HEX.encode(key_material_suffix))]
    pub(crate) key_material_suffix: [u8; 16],
}

/// A challenge for the client to sign with their secret key for EndpointId authentication.
#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
pub(crate) struct ServerChallenge {
    /// The challenge to sign.
    /// Must be randomly generated with an RNG that is safe to use for crypto.
    #[debug("{}", HEX.encode(challenge))]
    pub(crate) challenge: [u8; 16],
}

/// Authentication message from the client.
///
/// Used when authentication via [`KeyMaterialClientAuth`] didn't work.
#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
pub(crate) struct ClientAuth {
    /// The client's public key, a.k.a. the `EndpointId`
    pub(crate) public_key: PublicKey,
    /// A signature of (a hash of) the [`ServerChallenge`].
    ///
    /// This is what provides the authentication.
    #[serde(with = "serde_bytes")]
    #[debug("{}", HEX.encode(signature))]
    pub(crate) signature: [u8; 64],
}

/// Confirmation of successful connection.
#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
pub(crate) struct ServerConfirmsAuth;

/// Denial of connection. The client couldn't be verified as authentic.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub(crate) struct ServerDeniesAuth {
    reason: String,
}

/// Trait for getting the frame type tag for a frame.
///
/// Used only in the handshake, as the frame we expect next
/// is fairly stateful.
/// Not used in the send/recv protocol, as any frame is
/// allowed to happen at any time there.
trait Frame {
    /// The frame type this frame is identified by and prefixed with
    const TAG: FrameType;
}

impl<T: Frame> Frame for &T {
    const TAG: FrameType = T::TAG;
}

impl Frame for ServerChallenge {
    const TAG: FrameType = FrameType::ServerChallenge;
}

impl Frame for ClientAuth {
    const TAG: FrameType = FrameType::ClientAuth;
}

impl Frame for ServerConfirmsAuth {
    const TAG: FrameType = FrameType::ServerConfirmsAuth;
}

impl Frame for ServerDeniesAuth {
    const TAG: FrameType = FrameType::ServerDeniesAuth;
}

#[stack_error(derive, add_meta)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum Error {
    #[error(transparent)]
    Websocket {
        #[cfg(not(wasm_browser))]
        #[error(from, std_err)]
        source: tokio_websockets::Error,
        #[cfg(wasm_browser)]
        #[error(from, std_err)]
        source: ws_stream_wasm::WsErr,
    },
    #[error("Handshake stream ended prematurely")]
    UnexpectedEnd {},
    #[error(transparent)]
    FrameTypeError {
        #[error(from)]
        source: FrameTypeError,
    },
    #[error("The relay denied our authentication ({reason})")]
    ServerDeniedAuth { reason: String },
    #[error("Unexpected tag, got {frame_type:?}, but expected one of {expected_types:?}")]
    UnexpectedFrameType {
        frame_type: FrameType,
        expected_types: Vec<FrameType>,
    },
    #[error("Handshake failed while deserializing {frame_type:?} frame")]
    DeserializationError {
        frame_type: FrameType,
        #[error(std_err)]
        source: postcard::Error,
    },
    #[cfg(feature = "server")]
    /// Failed to deserialize client auth header
    ClientAuthHeaderInvalid { value: HeaderValue },
}

#[cfg(feature = "server")]
#[stack_error(derive, add_meta)]
pub(crate) enum VerificationError {
    #[error("Couldn't export TLS keying material on our end")]
    NoKeyingMaterial,
    #[error(
        "Client didn't extract the same keying material, the suffix mismatched: expected {expected:X?} but got {actual:X?}"
    )]
    MismatchedSuffix {
        expected: [u8; 16],
        actual: [u8; 16],
    },
    #[error(
        "Client signature {signature:X?} for message {message:X?} invalid for public key {public_key}"
    )]
    SignatureInvalid {
        source: iroh_base::SignatureError,
        message: Vec<u8>,
        signature: [u8; 64],
        public_key: PublicKey,
    },
}

impl ServerChallenge {
    /// Generates a new challenge.
    #[cfg(feature = "server")]
    pub(crate) fn new<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
        let mut challenge = [0u8; 16];
        rng.fill_bytes(&mut challenge);
        Self { challenge }
    }

    /// The actual message bytes to sign (and verify against) for this challenge.
    fn message_to_sign(&self) -> [u8; 32] {
        // We're signing a key instead of the direct challenge.
        // This gives us domain separation protecting from multiple possible attacks,
        // but especially this one:
        // Assume a malicious relay. If the protocol required the client to sign the
        // challenge directly, this would allow the relay to obtain an arbitrary 16-byte
        // signature, if it maliciously choses the challenge instead of generating it
        // randomly.
        // Deriving a key to sign instead mitigates this attack.
        blake3::derive_key(DOMAIN_SEP_CHALLENGE, &self.challenge)
    }
}

impl ClientAuth {
    /// Generates a signature for the given challenge from the server.
    pub(crate) fn new(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self {
        Self {
            public_key: secret_key.public(),
            signature: secret_key.sign(&challenge.message_to_sign()).to_bytes(),
        }
    }

    /// Verifies this client's authentication given the challenge this was sent in response to.
    #[cfg(feature = "server")]
    pub(crate) fn verify(&self, challenge: &ServerChallenge) -> Result<(), Box<VerificationError>> {
        let message = challenge.message_to_sign();
        self.public_key
            .verify(&message, &Signature::from_bytes(&self.signature))
            .map_err(|err| {
                e!(VerificationError::SignatureInvalid {
                    source: err,
                    message: message.to_vec(),
                    signature: self.signature,
                    public_key: self.public_key
                })
            })
            .map_err(Box::new)
    }
}

#[cfg(not(wasm_browser))]
impl KeyMaterialClientAuth {
    /// Generates a client's authentication, similar to [`ClientAuth`], but by using TLS keying material
    /// instead of a received challenge.
    pub(crate) fn new(secret_key: &SecretKey, io: &impl ExportKeyingMaterial) -> Option<Self> {
        let public_key = secret_key.public();
        let key_material = io.export_keying_material(
            [0u8; 32],
            DOMAIN_SEP_TLS_EXPORT_LABEL,
            Some(secret_key.public().as_bytes()),
        )?;
        // We split the export and only sign the first 16 bytes, and
        // pass through the last 16 bytes. See also the note in [Self::verify].
        let (message, suffix) = key_material.split_at(16);
        Some(Self {
            public_key,
            signature: secret_key.sign(message).to_bytes(),
            key_material_suffix: suffix.try_into().expect("hardcoded length"),
        })
    }

    /// Generate the base64url-nopad-encoded header value.
    pub(crate) fn into_header_value(self) -> HeaderValue {
        HeaderValue::from_str(
            &data_encoding::BASE64URL_NOPAD
                .encode(&postcard::to_allocvec(&self).expect("encoding never fails")),
        )
        .expect("BASE64URL_NOPAD encoding contained invisible ascii characters")
    }

    /// Verifies this client auth on the server side using the same key material.
    ///
    /// This might return false for a couple of reasons:
    /// 1. The exported keying material might not be the same between both ends of the TLS session
    ///    (e.g. there's an HTTPS proxy in between that doesn't think/care about the TLS keying material exporter).
    ///    This situation is detected when the key material suffix mismatches.
    /// 2. The signature itself doesn't verify.
    #[cfg(feature = "server")]
    pub(crate) fn verify(
        &self,
        io: &impl ExportKeyingMaterial,
    ) -> Result<(), Box<VerificationError>> {
        let key_material = io
            .export_keying_material(
                [0u8; 32],
                DOMAIN_SEP_TLS_EXPORT_LABEL,
                Some(self.public_key.as_bytes()),
            )
            .ok_or_else(|| e!(VerificationError::NoKeyingMaterial))?;
        // We split the export and only sign the first 16 bytes, and
        // pass through the last 16 bytes.
        // Passing on the suffix helps the verifying end figure out what
        // went wrong: If there's a suffix mismatch, then the exported keying
        // material on both ends wasn't the same - so perhaps there was a
        // TLS proxy in between or similar.
        // If the suffix does match, but the signature doesn't verify, then
        // there must be something wrong with the client's secret key or signature.
        let (message, suffix) = key_material.split_at(16);
        let suffix: [u8; 16] = suffix.try_into().expect("hardcoded length");
        ensure!(
            suffix == self.key_material_suffix,
            VerificationError::MismatchedSuffix {
                expected: self.key_material_suffix,
                actual: suffix
            }
        );
        // NOTE: We don't blake3-hash here as we do it in [`ServerChallenge::message_to_sign`],
        // because we already have a domain separation string and keyed hashing step in
        // the TLS export keying material above.
        self.public_key
            .verify(message, &Signature::from_bytes(&self.signature))
            .map_err(|err| {
                e!(VerificationError::SignatureInvalid {
                    source: err,
                    message: message.to_vec(),
                    public_key: self.public_key,
                    signature: self.signature
                })
            })
            .map_err(Box::new)
    }
}

/// Runs the client side of the handshake protocol.
///
/// See the module docs for details on the protocol.
/// This is already after having potentially transferred a [`KeyMaterialClientAuth`],
/// but before having received a response for whether that worked or not.
///
/// This requires access to the client's secret key to sign a challenge.
pub(crate) async fn clientside(
    io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
    secret_key: &SecretKey,
) -> Result<ServerConfirmsAuth, Error> {
    let (tag, frame) = read_frame(io, &[ServerChallenge::TAG, ServerConfirmsAuth::TAG]).await?;

    let (tag, frame) = if tag == ServerChallenge::TAG {
        let challenge: ServerChallenge = deserialize_frame(frame)?;

        let client_info = ClientAuth::new(secret_key, &challenge);
        write_frame(io, client_info).await?;

        read_frame(io, &[ServerConfirmsAuth::TAG, ServerDeniesAuth::TAG]).await?
    } else {
        (tag, frame)
    };

    match tag {
        FrameType::ServerConfirmsAuth => {
            let confirmation: ServerConfirmsAuth = deserialize_frame(frame)?;
            Ok(confirmation)
        }
        FrameType::ServerDeniesAuth => {
            let denial: ServerDeniesAuth = deserialize_frame(frame)?;
            Err(e!(Error::ServerDeniedAuth {
                reason: denial.reason
            }))
        }
        _ => unreachable!(),
    }
}

/// This represents successful authentication for the client with the `client_key` public key
/// via the authentication [`Mechanism`] `mechanism`.
///
/// You must call [`SuccessfulAuthentication::authorize_if`] to finish the protocol.
#[cfg(feature = "server")]
/// Result of a successful authentication handshake.
///
/// This struct represents a client that has successfully authenticated itself to the relay
/// server. The authorization must still be confirmed by calling [`Self::authorize_if`] to
/// complete the protocol and notify the client of success or failure.
#[derive(Debug)]
#[must_use = "the protocol is not finished unless `authorize_if` is called"]
pub struct SuccessfulAuthentication {
    /// The authenticated client's public key.
    pub client_key: PublicKey,
    /// The authentication mechanism that was used.
    pub mechanism: Mechanism,
}

/// The mechanism that was used for authentication.
#[cfg(feature = "server")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mechanism {
    /// Authentication was performed by verifying a signature of a challenge we sent
    SignedChallenge,
    /// Authentication was performed by verifying a signature of shared extracted TLS keying material
    SignedKeyMaterial,
}

/// Runs the server side of the handshaking protocol.
///
/// See the module documentation for an overview of the handshaking protocol.
///
/// This takes `rng` to generate cryptographic randomness for the authentication challenge.
///
/// This also takes the `client_auth_header`, if present, to perform authentication without
/// requiring sending a challenge, saving a round-trip, if possible.
///
/// If this fails, the protocol falls back to doing a normal extra round trip with a challenge.
///
/// The return value [`SuccessfulAuthentication`] still needs to be resolved by calling
/// [`SuccessfulAuthentication::authorize_if`] to finish the whole authorization protocol
/// (otherwise the client won't be notified about auth success or failure).
#[cfg(feature = "server")]
pub async fn serverside(
    io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
    client_auth_header: Option<HeaderValue>,
) -> Result<SuccessfulAuthentication, Error> {
    if let Some(client_auth_header) = client_auth_header {
        let client_auth_bytes = data_encoding::BASE64URL_NOPAD
            .decode(client_auth_header.as_ref())
            .map_err(|_| {
                e!(Error::ClientAuthHeaderInvalid {
                    value: client_auth_header.clone()
                })
            })?;

        let client_auth: KeyMaterialClientAuth =
            postcard::from_bytes(&client_auth_bytes).map_err(|_| {
                e!(Error::ClientAuthHeaderInvalid {
                    value: client_auth_header.clone()
                })
            })?;

        if client_auth.verify(io).is_ok() {
            trace!(?client_auth.public_key, "authentication succeeded via keying material");
            return Ok(SuccessfulAuthentication {
                client_key: client_auth.public_key,
                mechanism: Mechanism::SignedKeyMaterial,
            });
        }
        // Verification not succeeding is part of normal operation: The TLS exporter isn't required to match.
        // We'll fall back to verification that takes another round trip more time.
    }

    let challenge = ServerChallenge::new(&mut rand::rng());
    write_frame(io, &challenge).await?;

    let (_, frame) = read_frame(io, &[ClientAuth::TAG]).await?;
    let client_auth: ClientAuth = deserialize_frame(frame)?;

    if let Err(err) = client_auth.verify(&challenge) {
        trace!(?client_auth.public_key, ?err, "authentication failed");
        let denial = ServerDeniesAuth {
            reason: "signature invalid".into(),
        };
        write_frame(io, denial.clone()).await?;
        Err(e!(Error::ServerDeniedAuth {
            reason: denial.reason
        }))
    } else {
        trace!(?client_auth.public_key, "authentication succeeded via challenge");
        Ok(SuccessfulAuthentication {
            client_key: client_auth.public_key,
            mechanism: Mechanism::SignedChallenge,
        })
    }
}

#[cfg(feature = "server")]
impl SuccessfulAuthentication {
    /// Completes the authorization protocol by notifying the client of success or failure.
    ///
    /// After a client has been successfully authenticated via [`serverside`], the server must
    /// decide whether to authorize the client (allow access) or deny it. This method sends
    /// the authorization decision to the client and completes the handshake protocol.
    ///
    /// # Arguments
    /// * `is_authorized` - Whether to grant access to the authenticated client
    /// * `io` - The WebSocket stream to send the authorization response on
    ///
    /// # Returns
    /// * `Ok(PublicKey)` - The client's public key if authorization was granted
    /// * `Err(Error)` - If authorization was denied or communication failed
    pub async fn authorize_if(
        self,
        is_authorized: bool,
        io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
    ) -> Result<PublicKey, Error> {
        if is_authorized {
            trace!("authorizing client");
            write_frame(io, ServerConfirmsAuth).await?;
            Ok(self.client_key)
        } else {
            trace!("denying client auth");
            let denial = ServerDeniesAuth {
                reason: "not authorized".into(),
            };
            write_frame(io, denial.clone()).await?;
            Err(e!(Error::ServerDeniedAuth {
                reason: denial.reason
            }))
        }
    }
}

async fn write_frame<F: serde::Serialize + Frame>(
    io: &mut impl BytesStreamSink,
    frame: F,
) -> Result<(), Error> {
    let mut bytes = BytesMut::new();
    trace!(frame_type = ?F::TAG, "Writing frame");
    F::TAG.write_to(&mut bytes);
    let bytes = postcard::to_io(&frame, bytes.writer())
        .expect("serialization failed") // buffer can't become "full" without being a critical failure, datastructures shouldn't ever fail serialization
        .into_inner()
        .freeze();
    io.send(bytes).await?;
    io.flush().await?;
    Ok(())
}

async fn read_frame(
    io: &mut impl BytesStreamSink,
    expected_types: &[FrameType],
) -> Result<(FrameType, Bytes), Error> {
    let mut payload = io
        .try_next()
        .await?
        .ok_or_else(|| e!(Error::UnexpectedEnd))?;

    let frame_type = FrameType::from_bytes(&mut payload)?;
    trace!(?frame_type, "Reading frame");
    ensure!(
        expected_types.contains(&frame_type),
        Error::UnexpectedFrameType {
            frame_type,
            expected_types: expected_types.to_vec()
        }
    );

    Ok((frame_type, payload))
}

fn deserialize_frame<F: Frame + serde::de::DeserializeOwned>(frame: Bytes) -> Result<F, Error> {
    postcard::from_bytes(&frame).map_err(|err| {
        e!(Error::DeserializationError {
            frame_type: F::TAG,
            source: err
        })
    })
}

#[cfg(all(test, feature = "server"))]
mod tests {
    use bytes::BytesMut;
    use iroh_base::{PublicKey, SecretKey};
    use n0_error::{Result, StackResultExt, StdResultExt};
    use n0_future::{Sink, SinkExt, Stream, TryStreamExt};
    use n0_tracing_test::traced_test;
    use rand::{RngExt, SeedableRng};
    use tokio_util::codec::{Framed, LengthDelimitedCodec};
    use tracing::{Instrument, info_span};

    use super::{
        ClientAuth, KeyMaterialClientAuth, Mechanism, ServerChallenge, ServerConfirmsAuth,
    };
    use crate::ExportKeyingMaterial;

    struct TestKeyingMaterial<IO> {
        shared_secret: Option<u64>,
        inner: IO,
    }

    trait WithTlsSharedSecret: Sized {
        fn with_shared_secret(self, shared_secret: Option<u64>) -> TestKeyingMaterial<Self>;
    }

    impl<T: Sized> WithTlsSharedSecret for T {
        fn with_shared_secret(self, shared_secret: Option<u64>) -> TestKeyingMaterial<Self> {
            TestKeyingMaterial {
                shared_secret,
                inner: self,
            }
        }
    }

    impl<IO> ExportKeyingMaterial for TestKeyingMaterial<IO> {
        fn export_keying_material<T: AsMut<[u8]>>(
            &self,
            mut output: T,
            label: &[u8],
            context: Option<&[u8]>,
        ) -> Option<T> {
            // we simulate something like exporting keying material using blake3

            let label_key = blake3::hash(label);
            let context_key = blake3::keyed_hash(label_key.as_bytes(), context.unwrap_or(&[]));
            let mut hasher = blake3::Hasher::new_keyed(context_key.as_bytes());
            hasher.update(&self.shared_secret?.to_le_bytes());
            hasher.finalize_xof().fill(output.as_mut());

            Some(output)
        }
    }

    impl<V, IO: Stream<Item = V> + Unpin> Stream for TestKeyingMaterial<IO> {
        type Item = V;

        fn poll_next(
            mut self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Option<Self::Item>> {
            std::pin::Pin::new(&mut self.inner).poll_next(cx)
        }
    }

    impl<V, E, IO: Sink<V, Error = E> + Unpin> Sink<V> for TestKeyingMaterial<IO> {
        type Error = E;

        fn poll_ready(
            mut self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Result<(), Self::Error>> {
            std::pin::Pin::new(&mut self.inner).poll_ready(cx)
        }

        fn start_send(mut self: std::pin::Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
            std::pin::Pin::new(&mut self.inner).start_send(item)
        }

        fn poll_flush(
            mut self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Result<(), Self::Error>> {
            std::pin::Pin::new(&mut self.inner).poll_flush(cx)
        }

        fn poll_close(
            mut self: std::pin::Pin<&mut Self>,
            cx: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Result<(), Self::Error>> {
            std::pin::Pin::new(&mut self.inner).poll_close(cx)
        }
    }

    async fn simulate_handshake(
        secret_key: &SecretKey,
        client_shared_secret: Option<u64>,
        server_shared_secret: Option<u64>,
        restricted_to: Option<PublicKey>,
    ) -> (Result<ServerConfirmsAuth>, Result<(PublicKey, Mechanism)>) {
        let (client, server) = tokio::io::duplex(1024);

        let mut client_io = Framed::new(client, LengthDelimitedCodec::new())
            .map_ok(BytesMut::freeze)
            .map_err(tokio_websockets::Error::Io)
            .sink_map_err(tokio_websockets::Error::Io)
            .with_shared_secret(client_shared_secret);
        let mut server_io = Framed::new(server, LengthDelimitedCodec::new())
            .map_ok(BytesMut::freeze)
            .map_err(tokio_websockets::Error::Io)
            .sink_map_err(tokio_websockets::Error::Io)
            .with_shared_secret(server_shared_secret);

        let client_auth_header = KeyMaterialClientAuth::new(secret_key, &client_io)
            .map(KeyMaterialClientAuth::into_header_value);

        n0_future::future::zip(
            async {
                super::clientside(&mut client_io, secret_key)
                    .await
                    .context("clientside")
            }
            .instrument(info_span!("clientside")),
            async {
                let auth_n = super::serverside(&mut server_io, client_auth_header)
                    .await
                    .context("serverside")?;
                let mechanism = auth_n.mechanism;
                let is_authorized = restricted_to.is_none_or(|key| key == auth_n.client_key);
                let key = auth_n.authorize_if(is_authorized, &mut server_io).await?;
                Ok((key, mechanism))
            }
            .instrument(info_span!("serverside")),
        )
        .await
    }

    #[tokio::test]
    #[traced_test]
    async fn test_handshake_via_shared_secrets() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);

        let secret_key = SecretKey::from_bytes(&rng.random());
        let (client, server) = simulate_handshake(&secret_key, Some(42), Some(42), None).await;
        client?;
        let (public_key, auth) = server?;
        assert_eq!(public_key, secret_key.public());
        assert_eq!(auth, Mechanism::SignedKeyMaterial); // it got verified via shared key material
        Ok(())
    }

    #[tokio::test]
    #[traced_test]
    async fn test_handshake_via_challenge() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);

        let secret_key = SecretKey::from_bytes(&rng.random());
        let (client, server) = simulate_handshake(&secret_key, None, None, None).await;
        client?;
        let (public_key, auth) = server?;
        assert_eq!(public_key, secret_key.public());
        assert_eq!(auth, Mechanism::SignedChallenge);
        Ok(())
    }

    #[tokio::test]
    #[traced_test]
    async fn test_handshake_mismatching_shared_secrets() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);

        let secret_key = SecretKey::from_bytes(&rng.random());
        // mismatching shared secrets *might* happen with HTTPS proxies that don't also middle-man the shared secret
        let (client, server) = simulate_handshake(&secret_key, Some(10), Some(99), None).await;
        client?;
        let (public_key, auth) = server?;
        assert_eq!(public_key, secret_key.public());
        assert_eq!(auth, Mechanism::SignedChallenge);
        Ok(())
    }

    #[tokio::test]
    #[traced_test]
    async fn test_handshake_challenge_fallback() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        // clients might not have access to shared secrets
        let (client, server) = simulate_handshake(&secret_key, None, Some(99), None).await;
        client?;
        let (public_key, auth) = server?;
        assert_eq!(public_key, secret_key.public());
        assert_eq!(auth, Mechanism::SignedChallenge);
        Ok(())
    }

    #[tokio::test]
    #[traced_test]
    async fn test_handshake_with_auth_positive() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        let public_key = secret_key.public();
        let (client, server) = simulate_handshake(&secret_key, None, None, Some(public_key)).await;
        client?;
        let (public_key, _) = server?;
        assert_eq!(public_key, secret_key.public());
        Ok(())
    }

    #[tokio::test]
    #[traced_test]
    async fn test_handshake_with_auth_negative() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        let public_key = secret_key.public();
        let wrong_secret_key = SecretKey::from_bytes(&rng.random());
        let (client, server) =
            simulate_handshake(&wrong_secret_key, None, None, Some(public_key)).await;
        assert!(client.is_err());
        assert!(server.is_err());
        Ok(())
    }

    #[tokio::test]
    #[traced_test]
    async fn test_handshake_via_shared_secret_with_auth_negative() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        let public_key = secret_key.public();
        let wrong_secret_key = SecretKey::from_bytes(&rng.random());
        let (client, server) =
            simulate_handshake(&wrong_secret_key, Some(42), Some(42), Some(public_key)).await;
        assert!(client.is_err());
        assert!(server.is_err());
        Ok(())
    }

    #[test]
    fn test_client_auth_roundtrip() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        let challenge = ServerChallenge::new(&mut rng);
        let client_auth = ClientAuth::new(&secret_key, &challenge);

        let bytes = postcard::to_allocvec(&client_auth).anyerr()?;
        let decoded: ClientAuth = postcard::from_bytes(&bytes).anyerr()?;

        assert_eq!(client_auth.public_key, decoded.public_key);
        assert_eq!(client_auth.signature, decoded.signature);

        Ok(())
    }

    #[test]
    fn test_km_client_auth_roundtrip() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        let client_auth = KeyMaterialClientAuth::new(
            &secret_key,
            &TestKeyingMaterial {
                inner: (),
                shared_secret: Some(42),
            },
        )
        .anyerr()?;

        let bytes = postcard::to_allocvec(&client_auth).anyerr()?;
        let decoded: KeyMaterialClientAuth = postcard::from_bytes(&bytes).anyerr()?;

        assert_eq!(client_auth.public_key, decoded.public_key);
        assert_eq!(client_auth.signature, decoded.signature);

        Ok(())
    }

    #[test]
    fn test_challenge_verification() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        let challenge = ServerChallenge::new(&mut rng);
        let client_auth = ClientAuth::new(&secret_key, &challenge);
        assert!(client_auth.verify(&challenge).is_ok());

        Ok(())
    }

    #[test]
    fn test_key_material_verification() -> Result {
        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
        let secret_key = SecretKey::from_bytes(&rng.random());
        let io = TestKeyingMaterial {
            inner: (),
            shared_secret: Some(42),
        };
        let client_auth = KeyMaterialClientAuth::new(&secret_key, &io).anyerr()?;
        assert!(client_auth.verify(&io).is_ok());

        Ok(())
    }
}