demand_sv2_connection/
lib.rs1use binary_sv2::{Deserialize, GetSize, Serialize};
2
3pub mod noise_connection_tokio;
4
5use codec_sv2::{Error as CodecError, HandShakeFrame, HandshakeRole, StandardEitherFrame};
6use const_sv2::{
7 INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE, RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE,
8};
9use futures::lock::Mutex;
10use std::{
11 convert::TryInto,
12 sync::{atomic::AtomicBool, Arc},
13};
14use tokio::sync::mpsc::{error::SendError, Receiver, Sender};
15
16#[derive(Debug)]
17pub enum Error {
18 HandshakeRemoteInvalidMessage,
19 CodecError(CodecError),
20 RecvError,
21 SendError,
22 SocketClosed,
24}
25
26impl From<CodecError> for Error {
27 fn from(e: CodecError) -> Self {
28 Error::CodecError(e)
29 }
30}
31impl<T> From<SendError<T>> for Error {
32 fn from(_: SendError<T>) -> Self {
33 Error::SendError
34 }
35}
36
37trait SetState {
38 async fn set_state(self_: Arc<Mutex<Self>>, state: codec_sv2::State);
39}
40
41async fn initialize_as_downstream<
42 'a,
43 Message: Serialize + Deserialize<'a> + GetSize,
44 T: SetState,
45>(
46 self_: Arc<Mutex<T>>,
47 role: HandshakeRole,
48 sender_outgoing: &mut Sender<StandardEitherFrame<Message>>,
49 receiver_incoming: &mut Receiver<StandardEitherFrame<Message>>,
50) -> Result<(), Error> {
51 let mut state = codec_sv2::State::initialized(role);
52
53 let first_message = state.step_0()?;
55 sender_outgoing
56 .send(first_message.into())
57 .await
58 .map_err(|_| Error::SendError)?;
59
60 let second_message = receiver_incoming.recv().await.ok_or(Error::RecvError)?;
62 let second_message: HandShakeFrame = second_message
63 .try_into()
64 .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
65 let second_message: [u8; INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = second_message
66 .get_payload_when_handshaking()
67 .try_into()
68 .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
69
70 let transport_mode = state.step_2(second_message)?;
72
73 T::set_state(self_, transport_mode).await;
74 while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) {
75 std::hint::spin_loop()
76 }
77 Ok(())
78}
79
80async fn initialize_as_upstream<'a, Message: Serialize + Deserialize<'a> + GetSize, T: SetState>(
81 self_: Arc<Mutex<T>>,
82 role: HandshakeRole,
83 sender_outgoing: &mut Sender<StandardEitherFrame<Message>>,
84 receiver_incoming: &mut Receiver<StandardEitherFrame<Message>>,
85) -> Result<(), Error> {
86 let mut state = codec_sv2::State::initialized(role);
87
88 let first_message: HandShakeFrame = receiver_incoming
90 .recv()
91 .await
92 .ok_or(Error::RecvError)?
93 .try_into()
94 .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
95 let first_message: [u8; RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = first_message
96 .get_payload_when_handshaking()
97 .try_into()
98 .map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
99
100 let (second_message, transport_mode) = state.step_1(first_message)?;
102 HANDSHAKE_READY.store(false, std::sync::atomic::Ordering::SeqCst);
103 sender_outgoing.send(second_message.into()).await?;
104
105 T::set_state(self_, transport_mode).await;
109 while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) {
110 std::hint::spin_loop()
111 }
112
113 Ok(())
114}
115
116static HANDSHAKE_READY: AtomicBool = AtomicBool::new(false);
117static TRANSPORT_READY: AtomicBool = AtomicBool::new(false);