demand_sv2_connection/
lib.rs

1use binary_sv2::{Deserialize, GetSize, Serialize};
2
3pub mod noise_connection_tokio;
4
5use codec_sv2::{Error as CodecError, HandShakeFrame, HandshakeRole, StandardEitherFrame};
6use const_sv2::{
7    INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE, RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE,
8};
9use futures::lock::Mutex;
10use std::{
11    convert::TryInto,
12    sync::{atomic::AtomicBool, Arc},
13};
14use tokio::sync::mpsc::{error::SendError, Receiver, Sender};
15
16#[derive(Debug)]
17pub enum Error {
18    HandshakeRemoteInvalidMessage,
19    CodecError(CodecError),
20    RecvError,
21    SendError,
22    // This means that a socket that was supposed to be opened have been closed, likley by the peer
23    SocketClosed,
24}
25
26impl From<CodecError> for Error {
27    fn from(e: CodecError) -> Self {
28        Error::CodecError(e)
29    }
30}
31impl<T> From<SendError<T>> for Error {
32    fn from(_: SendError<T>) -> Self {
33        Error::SendError
34    }
35}
36
37trait SetState {
38    async fn set_state(self_: Arc<Mutex<Self>>, state: codec_sv2::State);
39}
40
41async fn initialize_as_downstream<
42    'a,
43    Message: Serialize + Deserialize<'a> + GetSize,
44    T: SetState,
45>(
46    self_: Arc<Mutex<T>>,
47    role: HandshakeRole,
48    sender_outgoing: &mut Sender<StandardEitherFrame<Message>>,
49    receiver_incoming: &mut Receiver<StandardEitherFrame<Message>>,
50) -> Result<(), Error> {
51    let mut state = codec_sv2::State::initialized(role);
52
53    // Create and send first handshake message
54    let first_message = state.step_0()?;
55    sender_outgoing
56        .send(first_message.into())
57        .await
58        .map_err(|_| Error::SendError)?;
59
60    // Receive and deserialize second handshake message
61    let second_message = receiver_incoming.recv().await.ok_or(Error::RecvError)?;
62    let second_message: HandShakeFrame = second_message
63        .try_into()
64        .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
65    let second_message: [u8; INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = second_message
66        .get_payload_when_handshaking()
67        .try_into()
68        .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
69
70    // Create and send thirth handshake message
71    let transport_mode = state.step_2(second_message)?;
72
73    T::set_state(self_, transport_mode).await;
74    while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) {
75        std::hint::spin_loop()
76    }
77    Ok(())
78}
79
80async fn initialize_as_upstream<'a, Message: Serialize + Deserialize<'a> + GetSize, T: SetState>(
81    self_: Arc<Mutex<T>>,
82    role: HandshakeRole,
83    sender_outgoing: &mut Sender<StandardEitherFrame<Message>>,
84    receiver_incoming: &mut Receiver<StandardEitherFrame<Message>>,
85) -> Result<(), Error> {
86    let mut state = codec_sv2::State::initialized(role);
87
88    // Receive and deserialize first handshake message
89    let first_message: HandShakeFrame = receiver_incoming
90        .recv()
91        .await
92        .ok_or(Error::RecvError)?
93        .try_into()
94        .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
95    let first_message: [u8; RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = first_message
96        .get_payload_when_handshaking()
97        .try_into()
98        .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
99
100    // Create and send second handshake message
101    let (second_message, transport_mode) = state.step_1(first_message)?;
102    HANDSHAKE_READY.store(false, std::sync::atomic::Ordering::SeqCst);
103    sender_outgoing.send(second_message.into()).await?;
104
105    // This sets the state to Handshake state - this prompts the task above to move the state
106    // to transport mode so that the next incoming message will be decoded correctly
107    // It is important to do this directly before sending the fourth message
108    T::set_state(self_, transport_mode).await;
109    while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) {
110        std::hint::spin_loop()
111    }
112
113    Ok(())
114}
115
116static HANDSHAKE_READY: AtomicBool = AtomicBool::new(false);
117static TRANSPORT_READY: AtomicBool = AtomicBool::new(false);