use crate::error::{DecodeError, PlainTextError};
use crate::proto::Exchange;
use crate::PlainText2Config;
use asynchronous_codec::{Framed, FramedParts};
use bytes::{Bytes, BytesMut};
use futures::prelude::*;
use libp2p_identity::{PeerId, PublicKey};
use log::{debug, trace};
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use unsigned_varint::codec::UviBytes;
struct HandshakeContext<T> {
config: PlainText2Config,
state: T,
}
struct Local {
exchange_bytes: Vec<u8>,
}
pub(crate) struct Remote {
pub(crate) peer_id: PeerId, pub(crate) public_key: PublicKey,
}
impl HandshakeContext<Local> {
fn new(config: PlainText2Config) -> Self {
let exchange = Exchange {
id: Some(config.local_public_key.to_peer_id().to_bytes()),
pubkey: Some(config.local_public_key.encode_protobuf()),
};
let mut buf = Vec::with_capacity(exchange.get_size());
let mut writer = Writer::new(&mut buf);
exchange
.write_message(&mut writer)
.expect("Encoding to succeed");
Self {
config,
state: Local {
exchange_bytes: buf,
},
}
}
fn with_remote(
self,
exchange_bytes: BytesMut,
) -> Result<HandshakeContext<Remote>, PlainTextError> {
let mut reader = BytesReader::from_bytes(&exchange_bytes);
let prop = Exchange::from_reader(&mut reader, &exchange_bytes).map_err(DecodeError)?;
let public_key = PublicKey::try_decode_protobuf(&prop.pubkey.unwrap_or_default())?;
let peer_id = PeerId::from_bytes(&prop.id.unwrap_or_default())?;
if peer_id != public_key.to_peer_id() {
return Err(PlainTextError::PeerIdMismatch);
}
Ok(HandshakeContext {
config: self.config,
state: Remote {
peer_id,
public_key,
},
})
}
}
pub(crate) async fn handshake<S>(
socket: S,
config: PlainText2Config,
) -> Result<(S, Remote, Bytes), PlainTextError>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let mut framed_socket = Framed::new(socket, UviBytes::default());
trace!("starting handshake");
let context = HandshakeContext::new(config);
trace!("sending exchange to remote");
framed_socket
.send(BytesMut::from(&context.state.exchange_bytes[..]))
.await?;
trace!("receiving the remote's exchange");
let context = match framed_socket.next().await {
Some(p) => context.with_remote(p?)?,
None => {
debug!("unexpected eof while waiting for remote's exchange");
let err = IoError::new(IoErrorKind::BrokenPipe, "unexpected eof");
return Err(err.into());
}
};
trace!(
"received exchange from remote; pubkey = {:?}",
context.state.public_key
);
let FramedParts {
io,
read_buffer,
write_buffer,
..
} = framed_socket.into_parts();
assert!(write_buffer.is_empty());
Ok((io, context.state, read_buffer.freeze()))
}