mod payload_proto {
include!(concat!(env!("OUT_DIR"), "/payload.proto.rs"));
}
use crate::error::NoiseError;
use crate::io::{framed::NoiseFramed, NoiseOutput};
use crate::protocol::{KeypairIdentity, Protocol, PublicKey};
use crate::LegacyConfig;
use bytes::Bytes;
use futures::prelude::*;
use futures::task;
use libp2prs_core::identity;
use prost::Message;
use std::{io, pin::Pin, task::Context};
pub enum RemoteIdentity<C> {
Unknown,
StaticDhKey(PublicKey<C>),
IdentityKey(identity::PublicKey),
}
pub enum IdentityExchange {
Mutual,
Send { remote: identity::PublicKey },
Receive,
None { remote: identity::PublicKey },
}
#[allow(clippy::type_complexity)]
pub struct Handshake<T, C>(Pin<Box<dyn Future<Output = Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError>> + Send>>);
impl<T, C> Future for Handshake<T, C> {
type Output = Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> task::Poll<Self::Output> {
Pin::new(&mut self.0).poll(ctx)
}
}
pub fn rt1_initiator<T, C>(
io: T,
session: Result<snow::HandshakeState, NoiseError>,
identity: KeypairIdentity,
identity_x: IdentityExchange,
legacy: LegacyConfig,
) -> Handshake<T, C>
where
T: AsyncWrite + AsyncRead + Send + Unpin + 'static,
C: Protocol<C> + AsRef<[u8]>,
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x, legacy)?;
send_identity(&mut state).await?;
recv_identity(&mut state).await?;
state.finish()
}))
}
pub fn rt1_responder<T, C>(
io: T,
session: Result<snow::HandshakeState, NoiseError>,
identity: KeypairIdentity,
identity_x: IdentityExchange,
legacy: LegacyConfig,
) -> Handshake<T, C>
where
T: AsyncWrite + AsyncRead + Send + Unpin + 'static,
C: Protocol<C> + AsRef<[u8]>,
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x, legacy)?;
recv_identity(&mut state).await?;
send_identity(&mut state).await?;
state.finish()
}))
}
pub fn rt15_initiator<T, C>(
io: T,
session: Result<snow::HandshakeState, NoiseError>,
identity: KeypairIdentity,
identity_x: IdentityExchange,
legacy: LegacyConfig,
) -> Handshake<T, C>
where
T: AsyncWrite + AsyncRead + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]>,
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x, legacy)?;
send_empty(&mut state).await?;
recv_identity(&mut state).await?;
send_identity(&mut state).await?;
state.finish()
}))
}
pub fn rt15_responder<T, C>(
io: T,
session: Result<snow::HandshakeState, NoiseError>,
identity: KeypairIdentity,
identity_x: IdentityExchange,
legacy: LegacyConfig,
) -> Handshake<T, C>
where
T: AsyncWrite + AsyncRead + Unpin + Send + 'static,
C: Protocol<C> + AsRef<[u8]>,
{
Handshake(Box::pin(async move {
let mut state = State::new(io, session, identity, identity_x, legacy)?;
recv_empty(&mut state).await?;
send_identity(&mut state).await?;
recv_identity(&mut state).await?;
state.finish()
}))
}
struct State<T> {
io: NoiseFramed<T, snow::HandshakeState>,
identity: KeypairIdentity,
dh_remote_pubkey_sig: Option<Vec<u8>>,
id_remote_pubkey: Option<identity::PublicKey>,
send_identity: bool,
legacy: LegacyConfig,
}
impl<T> State<T> {
fn new(
io: T,
session: Result<snow::HandshakeState, NoiseError>,
identity: KeypairIdentity,
identity_x: IdentityExchange,
legacy: LegacyConfig,
) -> Result<Self, NoiseError> {
let (id_remote_pubkey, send_identity) = match identity_x {
IdentityExchange::Mutual => (None, true),
IdentityExchange::Send { remote } => (Some(remote), true),
IdentityExchange::Receive => (None, false),
IdentityExchange::None { remote } => (Some(remote), false),
};
session.map(|s| State {
identity,
io: NoiseFramed::new(io, s),
dh_remote_pubkey_sig: None,
id_remote_pubkey,
send_identity,
legacy,
})
}
}
impl<T> State<T> {
fn finish<C>(self) -> Result<(RemoteIdentity<C>, NoiseOutput<T>), NoiseError>
where
C: Protocol<C> + AsRef<[u8]>,
{
let (pubkey, io) = self.io.into_transport()?;
let remote = match (self.id_remote_pubkey, pubkey) {
(_, None) => RemoteIdentity::Unknown,
(None, Some(dh_pk)) => RemoteIdentity::StaticDhKey(dh_pk),
(Some(id_pk), Some(dh_pk)) => {
if C::verify(&id_pk, &dh_pk, &self.dh_remote_pubkey_sig) {
RemoteIdentity::IdentityKey(id_pk)
} else {
return Err(NoiseError::InvalidKey);
}
}
};
Ok((remote, io))
}
}
async fn recv<T>(state: &mut State<T>) -> Result<Bytes, NoiseError>
where
T: AsyncRead + Unpin,
{
match state.io.next().await {
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "eof").into()),
Some(Err(e)) => Err(e.into()),
Some(Ok(m)) => Ok(m),
}
}
async fn recv_empty<T>(state: &mut State<T>) -> Result<(), NoiseError>
where
T: AsyncRead + Unpin,
{
let msg = recv(state).await?;
if !msg.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Unexpected handshake payload.").into());
}
Ok(())
}
async fn send_empty<T>(state: &mut State<T>) -> Result<(), NoiseError>
where
T: AsyncWrite + Unpin,
{
state.io.send(&Vec::new()).await?;
Ok(())
}
async fn recv_identity<T>(state: &mut State<T>) -> Result<(), NoiseError>
where
T: AsyncRead + Unpin,
{
let msg = recv(state).await?;
let mut pb_result = payload_proto::NoiseHandshakePayload::decode(&msg[..]);
if pb_result.is_err() && state.legacy.recv_legacy_handshake {
pb_result = pb_result.or_else(|e| {
if msg.len() > 2 {
let mut buf = [0, 0];
buf.copy_from_slice(&msg[..2]);
if usize::from(u16::from_be_bytes(buf)) + 2 == msg.len() {
log::debug!("Attempting fallback legacy protobuf decoding.");
payload_proto::NoiseHandshakePayload::decode(&msg[2..])
} else {
Err(e)
}
} else {
Err(e)
}
});
}
let pb = pb_result?;
if !pb.identity_key.is_empty() {
let pk = identity::PublicKey::from_protobuf_encoding(&pb.identity_key).map_err(|_| NoiseError::InvalidKey)?;
if let Some(ref k) = state.id_remote_pubkey {
if k != &pk {
return Err(NoiseError::InvalidKey);
}
}
state.id_remote_pubkey = Some(pk);
}
if !pb.identity_sig.is_empty() {
state.dh_remote_pubkey_sig = Some(pb.identity_sig);
}
Ok(())
}
async fn send_identity<T>(state: &mut State<T>) -> Result<(), NoiseError>
where
T: AsyncWrite + Unpin,
{
let mut pb = payload_proto::NoiseHandshakePayload::default();
if state.send_identity {
pb.identity_key = state.identity.public.clone().into_protobuf_encoding()
}
if let Some(ref sig) = state.identity.signature {
pb.identity_sig = sig.clone()
}
let mut msg = if state.legacy.send_legacy_handshake {
let mut msg = Vec::with_capacity(2 + pb.encoded_len());
msg.extend_from_slice(&(pb.encoded_len() as u16).to_be_bytes());
msg
} else {
Vec::with_capacity(pb.encoded_len())
};
pb.encode(&mut msg).expect("Vec<u8> provides capacity as needed");
state.io.send(&msg).await?;
Ok(())
}