use futures::channel::mpsc::{Receiver, Sender};
use futures::future::{Fuse, FutureExt};
use futures::io::{AsyncRead, AsyncWrite};
use futures::io::{BufReader, BufWriter};
use futures::sink::SinkExt;
use futures::stream::{SelectAll, Stream, StreamExt};
use futures_timer::Delay;
use log::*;
use std::collections::VecDeque;
use std::fmt;
use std::io::{Error, ErrorKind, Result};
use std::time::Duration;
use crate::channels::{Channel, Channelizer};
use crate::constants::DEFAULT_KEEPALIVE;
use crate::message::{ChannelMessage, Message};
use crate::noise::{Handshake, HandshakeResult};
use crate::reader::ProtocolReader;
use crate::schema::*;
use crate::util::map_channel_err;
use crate::util::pretty_hash;
use crate::writer::ProtocolWriter;
const CHANNEL_CAP: usize = 1000;
const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64);
pub enum Event {
Handshake(Vec<u8>),
DiscoveryKey(Vec<u8>),
Channel(Channel),
Close(Vec<u8>),
}
impl fmt::Debug for Event {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Event::Handshake(remote_key) => {
write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
}
Event::DiscoveryKey(discovery_key) => {
write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
}
Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
Event::Channel(channel) => write!(f, "{:?}", channel),
}
}
}
#[derive(Debug)]
pub struct ProtocolOptions {
pub is_initiator: bool,
pub noise: bool,
pub encrypted: bool,
}
pub struct ProtocolBuilder(ProtocolOptions);
impl ProtocolBuilder {
pub fn new(is_initiator: bool) -> Self {
Self(ProtocolOptions {
is_initiator,
noise: true,
encrypted: true,
})
}
pub fn initiator() -> Self {
Self::new(true)
}
pub fn responder() -> Self {
Self::new(false)
}
pub fn set_encrypted(mut self, encrypted: bool) -> Self {
self.0.encrypted = encrypted;
self
}
pub fn set_noise(mut self, noise: bool) -> Self {
self.0.noise = noise;
self
}
pub fn connect<S>(self, stream: S) -> Protocol<S, S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + Clone + 'static,
{
Protocol::new(stream.clone(), stream, self.0)
}
pub fn connect_rw<R, W>(self, reader: R, writer: W) -> Protocol<R, W>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
Protocol::new(reader, writer, self.0)
}
#[deprecated(since = "0.0.1", note = "Use connect_rw")]
pub fn build_from_io<R, W>(self, reader: R, writer: W) -> Protocol<R, W>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
self.connect_rw(reader, writer)
}
#[deprecated(since = "0.0.1", note = "Use connect")]
pub fn build_from_stream<S>(self, stream: S) -> Protocol<S, S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + Clone + 'static,
{
self.connect(stream)
}
}
#[allow(clippy::large_enum_variant)]
pub enum State {
NotInitialized,
Handshake(Option<Handshake>),
Established,
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
State::NotInitialized => write!(f, "NotInitialized"),
State::Handshake(_) => write!(f, "Handshaking"),
State::Established => write!(f, "Established"),
}
}
}
type CombinedOutputStream = SelectAll<Box<dyn Stream<Item = ChannelMessage> + Send + Unpin>>;
pub struct Protocol<R, W>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
writer: ProtocolWriter<BufWriter<W>>,
reader: ProtocolReader<BufReader<R>>,
state: State,
options: ProtocolOptions,
handshake: Option<HandshakeResult>,
channels: Channelizer,
error: Option<Error>,
outbound_rx: CombinedOutputStream,
control_rx: Receiver<ControlEvent>,
control_tx: ControlTx,
events: VecDeque<Event>,
keepalive: Option<Fuse<Delay>>,
}
impl<R, W> Protocol<R, W>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
pub fn new(reader: R, writer: W, options: ProtocolOptions) -> Self {
let reader = ProtocolReader::new(BufReader::new(reader));
let writer = ProtocolWriter::new(BufWriter::new(writer));
let (control_tx, control_rx) = futures::channel::mpsc::channel(CHANNEL_CAP);
Protocol {
writer,
reader,
options,
state: State::NotInitialized,
channels: Channelizer::new(),
handshake: None,
error: None,
outbound_rx: SelectAll::new(),
control_rx,
control_tx: ControlTx(control_tx),
events: VecDeque::new(),
keepalive: None,
}
}
pub fn builder(is_initiator: bool) -> ProtocolBuilder {
ProtocolBuilder::new(is_initiator)
}
async fn init(&mut self) -> Result<()> {
trace!(
"protocol init, state {:?}, options {:?}",
self.state,
self.options
);
match self.state {
State::NotInitialized => {}
_ => return Ok(()),
};
self.state = if self.options.noise {
let mut handshake = Handshake::new(self.options.is_initiator)?;
if let Some(buf) = handshake.start()? {
self.writer.send_prefixed(buf).await?;
}
State::Handshake(Some(handshake))
} else {
State::Established
};
self.reset_keepalive();
Ok(())
}
fn reset_keepalive(&mut self) {
let keepalive_duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64);
self.keepalive = Some(Delay::new(keepalive_duration).fuse());
}
pub fn is_initiator(&self) -> bool {
self.options.is_initiator
}
pub async fn loop_next(&mut self) -> Result<Event> {
if let State::NotInitialized = self.state {
self.init().await?;
}
let mut keepalive = if let Some(keepalive) = self.keepalive.take() {
keepalive
} else {
Delay::new(KEEPALIVE_DURATION).fuse()
};
loop {
if let Some(event) = self.events.pop_front() {
return Ok(event);
}
let event = futures::select! {
_ = keepalive => {
self.ping().await?;
keepalive = Delay::new(KEEPALIVE_DURATION).fuse();
None
},
buf = self.reader.select_next_some() => {
self.on_message(buf?).await?
},
channel_message = self.outbound_rx.select_next_some() => {
let event = match channel_message {
ChannelMessage { channel, message: Message::Close(_) } => {
self.close_local(channel).await?
},
_ => None
};
self.send(channel_message).await?;
event
},
ev = self.control_rx.select_next_some() => {
match ev {
ControlEvent::Open(key) => {
self.open(key).await?;
None
}
}
},
};
if let Some(event) = event {
self.keepalive = Some(keepalive);
return Ok(event);
}
}
}
pub fn remote_key(&self) -> Option<&[u8]> {
match &self.handshake {
None => None,
Some(handshake) => Some(handshake.remote_pubkey.as_slice()),
}
}
pub fn destroy(&mut self, error: Error) {
self.error = Some(error)
}
async fn on_message(&mut self, buf: Vec<u8>) -> Result<Option<Event>> {
match self.state {
State::Handshake(_) => self.on_handshake_message(buf).await,
State::Established => self.on_proto_message(buf).await,
State::NotInitialized => panic!("cannot receive messages before starting the protocol"),
}
}
async fn on_handshake_message(&mut self, buf: Vec<u8>) -> Result<Option<Event>> {
let mut handshake = match &mut self.state {
State::Handshake(handshake) => handshake.take().unwrap(),
_ => panic!("cannot call on_handshake_message when not in Handshake state"),
};
if let Some(response_buf) = handshake.read(&buf)? {
self.writer.send_prefixed(response_buf).await?;
}
if !handshake.complete() {
self.state = State::Handshake(Some(handshake));
Ok(None)
} else {
let result = handshake.into_result()?;
if self.options.encrypted {
self.reader.upgrade_with_handshake(&result)?;
self.writer.upgrade_with_handshake(&result)?;
}
let remote_key = result.remote_pubkey.to_vec();
log::trace!(
"handshake complete, remote_key {}",
pretty_hash(&remote_key)
);
self.handshake = Some(result);
self.state = State::Established;
Ok(Some(Event::Handshake(remote_key)))
}
}
async fn on_proto_message(&mut self, buf: Vec<u8>) -> Result<Option<Event>> {
let channel_message = ChannelMessage::decode(buf)?;
log::trace!("recv {:?}", channel_message);
let (remote_id, message) = channel_message.into_split();
match message {
Message::Open(msg) => self.open_remote(remote_id, msg).await,
Message::Close(msg) => self.close_remote(remote_id, msg).await,
Message::Extension(_msg) => unimplemented!(),
_ => {
self.channels.forward(remote_id as usize, message).await?;
Ok(None)
}
}
}
pub async fn open(&mut self, key: Vec<u8>) -> Result<()> {
let inner_channel = self.channels.attach_local(key.clone());
let local_id = inner_channel.local_id.unwrap();
let discovery_key = inner_channel.discovery_key.clone();
if let Some(_remote_id) = inner_channel.remote_id {
let remote_capability = inner_channel.remote_capability.clone();
self.verify_remote_capability(remote_capability, &key)?;
let channel = self.create_channel(local_id).await?;
self.events.push_back(Event::Channel(channel));
}
let capability = self.capability(&key);
let message = Message::Open(Open {
discovery_key,
capability,
});
let channel_message = ChannelMessage::new(local_id as u64, message);
self.outbound_rx.push(Box::new(
futures::future::ready(channel_message).into_stream(),
));
Ok(())
}
async fn open_remote(&mut self, ch: u64, msg: Open) -> Result<Option<Event>> {
let inner_channel = self.channels.attach_remote(
msg.discovery_key.clone(),
ch as usize,
msg.capability.clone(),
);
if let Some(local_id) = inner_channel.local_id {
let key = inner_channel.key.as_ref().unwrap().clone();
self.verify_remote_capability(msg.capability, &key)?;
let channel = self.create_channel(local_id).await?;
Ok(Some(Event::Channel(channel)))
} else {
Ok(Some(Event::DiscoveryKey(msg.discovery_key.clone())))
}
}
async fn create_channel(&mut self, local_id: usize) -> Result<Channel> {
let inner_channel = self.channels.get_local_mut(local_id).unwrap();
let (channel, send_rx) = inner_channel.open().await?;
self.outbound_rx.push(Box::new(send_rx));
Ok(channel)
}
async fn close_local(&mut self, local_id: u64) -> Result<Option<Event>> {
if let Some(channel) = self.channels.get_local_mut(local_id as usize) {
let discovery_key = channel.discovery_key.clone();
channel.recv_close(None).await?;
self.channels.remove(&discovery_key);
Ok(Some(Event::Close(discovery_key)))
} else {
Ok(None)
}
}
async fn close_remote(&mut self, remote_id: u64, msg: Close) -> Result<Option<Event>> {
if let Some(channel) = self.channels.get_remote_mut(remote_id as usize) {
let discovery_key = channel.discovery_key.clone();
channel.recv_close(Some(msg)).await?;
self.channels.remove(&discovery_key);
Ok(Some(Event::Close(discovery_key)))
} else {
Ok(None)
}
}
async fn send(&mut self, channel_message: ChannelMessage) -> Result<()> {
log::trace!("send {:?}", channel_message);
let buf = channel_message.encode()?;
self.writer.send_prefixed(&buf).await
}
async fn ping(&mut self) -> Result<()> {
self.writer.ping().await
}
pub fn release(self) -> (R, W) {
(
self.reader.into_inner().into_inner(),
self.writer.into_inner().into_inner(),
)
}
fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
match self.handshake.as_ref() {
Some(handshake) => handshake.capability(key),
None => None,
}
}
fn verify_remote_capability(&self, capability: Option<Vec<u8>>, key: &[u8]) -> Result<()> {
match self.handshake.as_ref() {
Some(handshake) => handshake.verify_remote_capability(capability, key),
None => Err(Error::new(
ErrorKind::PermissionDenied,
"Missing handshake state for capability verification",
)),
}
}
pub fn into_stream(self) -> stream::ProtocolStream<R, W> {
let control = self.control();
stream::ProtocolStream::new(self, control)
}
pub fn control(&self) -> ControlTx {
self.control_tx.clone()
}
}
#[derive(Debug)]
pub enum ControlEvent {
Open(Vec<u8>),
}
#[derive(Clone)]
pub struct ControlTx(Sender<ControlEvent>);
impl ControlTx {
pub async fn open(&mut self, key: Vec<u8>) -> Result<()> {
self.0
.send(ControlEvent::Open(key))
.await
.map_err(map_channel_err)
}
}
pub use stream::ProtocolStream;
mod stream {
use super::ControlTx;
use crate::{Event, Protocol};
use futures::future::FutureExt;
use futures::io::{AsyncRead, AsyncWrite};
use futures::stream::Stream;
use std::future::Future;
use std::io::Result;
use std::pin::Pin;
use std::task::Poll;
type LoopFuture<R, W> = Pin<Box<dyn Future<Output = (Result<Event>, Protocol<R, W>)> + Send>>;
async fn loop_next<R, W>(mut protocol: Protocol<R, W>) -> (Result<Event>, Protocol<R, W>)
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
let event = protocol.loop_next().await;
(event, protocol)
}
pub struct ProtocolStream<R, W>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
fut: LoopFuture<R, W>,
tx: ControlTx,
}
impl<R, W> ProtocolStream<R, W>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
pub fn new(protocol: Protocol<R, W>, tx: ControlTx) -> Self {
let fut = loop_next(protocol).boxed();
Self { fut, tx }
}
pub async fn open(&mut self, key: Vec<u8>) -> Result<()> {
self.tx.open(key).await
}
}
impl<R, W> Stream for ProtocolStream<R, W>
where
R: AsyncRead + Send + Unpin + 'static,
W: AsyncWrite + Send + Unpin + 'static,
{
type Item = Result<Event>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let fut = Pin::as_mut(&mut self.fut);
match fut.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
let (result, protocol) = result;
self.fut = loop_next(protocol).boxed();
Poll::Ready(Some(result))
}
}
}
}
}