use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use asynchronous_codec::{Framed, FramedParts};
use bytes::Bytes;
use futures::prelude::*;
use libp2p_identity::{PeerId, PublicKey};
use crate::{
error::{DecodeError, Error},
proto::Exchange,
Config,
};
pub(crate) async fn handshake<S>(socket: S, config: Config) -> Result<(S, PublicKey, Bytes), Error>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let mut framed_socket = Framed::new(socket, quick_protobuf_codec::Codec::<Exchange>::new(100));
tracing::trace!("sending exchange to remote");
framed_socket
.send(Exchange {
id: Some(config.local_public_key.to_peer_id().to_bytes()),
pubkey: Some(config.local_public_key.encode_protobuf()),
})
.await
.map_err(DecodeError)?;
tracing::trace!("receiving the remote's exchange");
let public_key = match framed_socket
.next()
.await
.transpose()
.map_err(DecodeError)?
{
Some(remote) => {
let public_key = PublicKey::try_decode_protobuf(&remote.pubkey.unwrap_or_default())?;
let peer_id = PeerId::from_bytes(&remote.id.unwrap_or_default())?;
if peer_id != public_key.to_peer_id() {
return Err(Error::PeerIdMismatch);
}
public_key
}
None => {
tracing::debug!("unexpected eof while waiting for remote's exchange");
let err = IoError::new(IoErrorKind::BrokenPipe, "unexpected eof");
return Err(err.into());
}
};
tracing::trace!(?public_key, "received exchange from remote");
let FramedParts {
io,
read_buffer,
write_buffer,
..
} = framed_socket.into_parts();
assert!(write_buffer.is_empty());
Ok((io, public_key, read_buffer.freeze()))
}