use binary_sv2::{Deserialize, GetSize, Serialize};
pub mod noise_connection;
pub mod plain_connection;
#[cfg(feature = "sv1")]
pub mod sv1_connection;
use async_channel::{Receiver, RecvError, SendError, Sender};
use codec_sv2::{Error as CodecError, HandShakeFrame, HandshakeRole, StandardEitherFrame};
use const_sv2::{
INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE, RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE,
};
use futures::lock::Mutex;
use std::{
convert::TryInto,
sync::{atomic::AtomicBool, Arc},
};
#[derive(Debug)]
pub enum Error {
HandshakeRemoteInvalidMessage,
CodecError(CodecError),
RecvError,
SendError,
SocketClosed,
}
impl From<CodecError> for Error {
fn from(e: CodecError) -> Self {
Error::CodecError(e)
}
}
impl From<RecvError> for Error {
fn from(_: RecvError) -> Self {
Error::RecvError
}
}
impl<T> From<SendError<T>> for Error {
fn from(_: SendError<T>) -> Self {
Error::SendError
}
}
trait SetState {
async fn set_state(self_: Arc<Mutex<Self>>, state: codec_sv2::State);
}
async fn initialize_as_downstream<
'a,
Message: Serialize + Deserialize<'a> + GetSize,
T: SetState,
>(
self_: Arc<Mutex<T>>,
role: HandshakeRole,
sender_outgoing: Sender<StandardEitherFrame<Message>>,
receiver_incoming: Receiver<StandardEitherFrame<Message>>,
) -> Result<(), Error> {
let mut state = codec_sv2::State::initialized(role);
let first_message = state.step_0()?;
sender_outgoing.send(first_message.into()).await?;
let second_message = receiver_incoming.recv().await?;
let second_message: HandShakeFrame = second_message
.try_into()
.map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
let second_message: [u8; INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = second_message
.get_payload_when_handshaking()
.try_into()
.map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
let transport_mode = state.step_2(second_message)?;
T::set_state(self_, transport_mode).await;
while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) {
std::hint::spin_loop()
}
Ok(())
}
async fn initialize_as_upstream<'a, Message: Serialize + Deserialize<'a> + GetSize, T: SetState>(
self_: Arc<Mutex<T>>,
role: HandshakeRole,
sender_outgoing: Sender<StandardEitherFrame<Message>>,
receiver_incoming: Receiver<StandardEitherFrame<Message>>,
) -> Result<(), Error> {
let mut state = codec_sv2::State::initialized(role);
let first_message: HandShakeFrame = receiver_incoming
.recv()
.await?
.try_into()
.map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
let first_message: [u8; RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = first_message
.get_payload_when_handshaking()
.try_into()
.map_err(|_| Error::HandshakeRemoteInvalidMessage)?;
let (second_message, transport_mode) = state.step_1(first_message)?;
HANDSHAKE_READY.store(false, std::sync::atomic::Ordering::SeqCst);
sender_outgoing.send(second_message.into()).await?;
T::set_state(self_, transport_mode).await;
while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) {
std::hint::spin_loop()
}
Ok(())
}
static HANDSHAKE_READY: AtomicBool = AtomicBool::new(false);
static TRANSPORT_READY: AtomicBool = AtomicBool::new(false);