use async_channel::{Receiver, Sender};
use futures_lite::io::{AsyncRead, AsyncWrite};
use futures_lite::stream::Stream;
use futures_timer::Delay;
use log::*;
use std::collections::VecDeque;
use std::convert::TryInto;
use std::fmt;
use std::future::Future;
use std::io::{self, Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use crate::builder::{Builder, Options};
use crate::channels::{Channel, ChannelMap};
use crate::constants::DEFAULT_KEEPALIVE;
use crate::extension::{Extension, Extensions};
use crate::message::{ChannelMessage, EncodeError, Frame, FrameType, Message};
use crate::noise::{Handshake, HandshakeResult};
use crate::reader::ReadState;
use crate::schema::*;
use crate::util::map_channel_err;
use crate::util::pretty_hash;
use crate::writer::WriteState;
macro_rules! return_error {
($msg:expr) => {
if let Err(e) = $msg {
return Poll::Ready(Err(e));
}
};
}
const CHANNEL_CAP: usize = 1000;
const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64);
pub type RemotePublicKey = [u8; 32];
pub type DiscoveryKey = [u8; 32];
pub type Key = [u8; 32];
#[non_exhaustive]
#[derive(PartialEq)]
pub enum Event {
Handshake(RemotePublicKey),
DiscoveryKey(DiscoveryKey),
Channel(Channel),
Close(DiscoveryKey),
}
#[derive(Debug)]
pub enum Command {
Open(Key),
Close(DiscoveryKey),
}
impl fmt::Debug for Event {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Event::Handshake(remote_key) => {
write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
}
Event::DiscoveryKey(discovery_key) => {
write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
}
Event::Channel(channel) => {
write!(f, "Channel({})", &pretty_hash(channel.discovery_key()))
}
Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum State {
NotInitialized,
Handshake(Option<Handshake>),
Established,
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
State::NotInitialized => write!(f, "NotInitialized"),
State::Handshake(_) => write!(f, "Handshaking"),
State::Established => write!(f, "Established"),
}
}
}
#[derive(Debug)]
pub struct Protocol<IO> {
write_state: WriteState,
read_state: ReadState,
io: IO,
state: State,
options: Options,
handshake: Option<HandshakeResult>,
channels: ChannelMap,
command_rx: Receiver<Command>,
command_tx: CommandTx,
outbound_rx: Receiver<ChannelMessage>,
outbound_tx: Sender<ChannelMessage>,
keepalive: Delay,
queued_events: VecDeque<Event>,
extensions: Extensions,
}
impl<IO> Protocol<IO>
where
IO: AsyncWrite + AsyncRead + Send + Unpin + 'static,
{
pub fn new(io: IO, options: Options) -> Self {
let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP);
let (outbound_tx, outbound_rx) = async_channel::bounded(1);
Protocol {
io,
read_state: ReadState::new(),
write_state: WriteState::new(),
options,
state: State::NotInitialized,
channels: ChannelMap::new(),
handshake: None,
extensions: Extensions::new(outbound_tx.clone(), 0),
command_rx,
command_tx: CommandTx(command_tx),
outbound_tx,
outbound_rx,
keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)),
queued_events: VecDeque::new(),
}
}
pub fn with_defaults(io: IO, is_initiator: bool) -> Self {
let options = Options::new(is_initiator);
Protocol::new(io, options)
}
pub fn builder(is_initiator: bool) -> Builder {
Builder::new(is_initiator)
}
pub fn is_initiator(&self) -> bool {
self.options.is_initiator
}
pub fn public_key(&self) -> Option<&[u8]> {
match &self.handshake {
None => None,
Some(handshake) => Some(handshake.local_pubkey.as_slice()),
}
}
pub fn remote_public_key(&self) -> Option<&[u8]> {
match &self.handshake {
None => None,
Some(handshake) => Some(handshake.remote_pubkey.as_slice()),
}
}
pub fn commands(&self) -> CommandTx {
self.command_tx.clone()
}
pub async fn command(&mut self, command: Command) -> Result<()> {
self.command_tx.send(command).await
}
pub async fn register_extension(&mut self, name: impl ToString) -> Extension {
self.extensions.register(name.to_string()).await
}
pub async fn open(&mut self, key: Key) -> Result<()> {
self.command_tx.open(key).await
}
pub fn channels(&self) -> impl Iterator<Item = &DiscoveryKey> {
self.channels.iter().map(|c| c.discovery_key())
}
pub fn release(self) -> IO {
self.io
}
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Event>> {
let this = self.get_mut();
if let State::NotInitialized = this.state {
return_error!(this.init());
}
if let Some(event) = this.queued_events.pop_front() {
return Poll::Ready(Ok(event));
}
return_error!(this.poll_inbound_read(cx));
if let State::Established = this.state {
return_error!(this.poll_commands(cx));
}
this.poll_keepalive(cx);
return_error!(this.poll_outbound_write(cx));
if let Some(event) = this.queued_events.pop_front() {
Poll::Ready(Ok(event))
} else {
Poll::Pending
}
}
fn init(&mut self) -> Result<()> {
debug!(
"protocol init, state {:?}, options {:?}",
self.state, self.options
);
match self.state {
State::NotInitialized => {}
_ => return Ok(()),
};
self.state = if self.options.noise {
let mut handshake = Handshake::new(self.options.is_initiator)?;
if let Some(buf) = handshake.start()? {
self.queue_frame_direct(buf.to_vec()).unwrap();
}
self.read_state.set_frame_type(FrameType::Raw);
State::Handshake(Some(handshake))
} else {
self.read_state.set_frame_type(FrameType::Message);
State::Established
};
Ok(())
}
fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> {
while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) {
self.on_command(command)?;
}
Ok(())
}
fn poll_keepalive(&mut self, cx: &mut Context<'_>) {
if Pin::new(&mut self.keepalive).poll(cx).is_ready() {
self.write_state.queue_frame(Frame::Raw(vec![0u8; 0]));
self.keepalive.reset(KEEPALIVE_DURATION);
}
}
fn on_outbound_message(&mut self, message: &ChannelMessage) {
if let ChannelMessage {
channel,
message: Message::Close(_),
} = message
{
self.close_local(*channel);
}
}
fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> {
loop {
let msg = self.read_state.poll_reader(cx, &mut self.io);
match msg {
Poll::Ready(Ok(message)) => {
self.on_inbound_frame(message)?;
}
Poll::Ready(Err(e)) => return Err(e),
Poll::Pending => return Ok(()),
}
}
}
fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> {
loop {
if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) {
return Err(e);
}
if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) {
return Ok(());
}
match Pin::new(&mut self.outbound_rx).poll_next(cx) {
Poll::Ready(Some(message)) => {
self.on_outbound_message(&message);
let frame = Frame::Message(message);
self.write_state.park_frame(frame);
}
Poll::Ready(None) => unreachable!("Channel closed before end"),
Poll::Pending => return Ok(()),
}
}
}
fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> {
match frame {
Frame::Raw(buf) => match self.state {
State::Handshake(_) => self.on_handshake_message(buf),
_ => unreachable!("May not receive raw frames outside of handshake state"),
},
Frame::Message(channel_message) => match self.state {
State::Established => self.on_inbound_message(channel_message),
_ => unreachable!("May not receive message frames when not established"),
},
}
}
fn on_handshake_message(&mut self, buf: Vec<u8>) -> Result<()> {
let mut handshake = match &mut self.state {
State::Handshake(handshake) => handshake.take().unwrap(),
_ => unreachable!("May not call on_handshake_message when not in Handshake state"),
};
if let Some(response_buf) = handshake.read(&buf)? {
self.queue_frame_direct(response_buf.to_vec()).unwrap();
}
if !handshake.complete() {
self.state = State::Handshake(Some(handshake));
} else {
let result = handshake.into_result()?;
if self.options.encrypted {
self.read_state.upgrade_with_handshake(&result)?;
self.write_state.upgrade_with_handshake(&result)?;
}
self.read_state.set_frame_type(FrameType::Message);
let remote_public_key = parse_key(&result.remote_pubkey)?;
log::debug!(
"handshake complete, remote_key {}",
pretty_hash(&remote_public_key)
);
self.handshake = Some(result);
self.state = State::Established;
self.queue_event(Event::Handshake(remote_public_key));
}
Ok(())
}
fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> {
log::debug!("[{}] recv {:?}", self.is_initiator(), channel_message);
let (remote_id, message) = channel_message.into_split();
match remote_id {
0 => match message {
Message::Options(msg) => self.extensions.on_remote_update(msg.extensions),
Message::Extension(msg) => self.extensions.on_message(msg),
_ => {}
},
_ => match message {
Message::Open(msg) => self.on_open(remote_id, msg)?,
Message::Close(msg) => self.on_close(remote_id, msg)?,
_ => self
.channels
.forward_inbound_message(remote_id as usize, message)?,
},
}
Ok(())
}
fn on_command(&mut self, command: Command) -> Result<()> {
match command {
Command::Open(key) => self.command_open(key),
_ => Ok(()),
}
}
fn command_open(&mut self, key: Key) -> Result<()> {
let channel_handle = self.channels.attach_local(key);
let local_id = channel_handle.local_id().unwrap();
let discovery_key = *channel_handle.discovery_key();
if channel_handle.is_connected() {
self.accept_channel(local_id)?;
}
let capability = self.capability(&key);
let message = Message::Open(Open {
discovery_key: discovery_key.to_vec(),
capability,
});
let channel_message = ChannelMessage::new(local_id as u64, message);
self.write_state
.queue_frame(Frame::Message(channel_message));
Ok(())
}
fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> {
let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?;
let channel_handle =
self.channels
.attach_remote(discovery_key, ch as usize, msg.capability);
if channel_handle.is_connected() {
let local_id = channel_handle.local_id().unwrap();
self.accept_channel(local_id)?;
} else {
self.queue_event(Event::DiscoveryKey(discovery_key));
}
Ok(())
}
fn queue_event(&mut self, event: Event) {
self.queued_events.push_back(event);
}
fn queue_frame_direct(&mut self, body: Vec<u8>) -> std::result::Result<bool, EncodeError> {
let frame = Frame::Raw(body);
self.write_state.try_queue_direct(&frame)
}
fn accept_channel(&mut self, local_id: usize) -> Result<()> {
let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?;
self.verify_remote_capability(remote_capability.cloned(), key)?;
let channel = self.channels.accept(local_id, self.outbound_tx.clone())?;
self.queue_event(Event::Channel(channel));
Ok(())
}
fn close_local(&mut self, local_id: u64) {
if let Some(channel) = self.channels.get_local(local_id as usize) {
let discovery_key = *channel.discovery_key();
self.channels.remove(&discovery_key);
self.queue_event(Event::Close(discovery_key));
}
}
fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> {
if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) {
let discovery_key = *channel_handle.discovery_key();
self.channels
.forward_inbound_message(remote_id as usize, Message::Close(msg))?;
self.channels.remove(&discovery_key);
self.queue_event(Event::Close(discovery_key));
}
Ok(())
}
fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
match self.handshake.as_ref() {
Some(handshake) => handshake.capability(key),
None => None,
}
}
fn verify_remote_capability(&self, capability: Option<Vec<u8>>, key: &[u8]) -> Result<()> {
match self.handshake.as_ref() {
Some(handshake) => handshake.verify_remote_capability(capability, key),
None => Err(Error::new(
ErrorKind::PermissionDenied,
"Missing handshake state for capability verification",
)),
}
}
}
impl<IO> Stream for Protocol<IO>
where
IO: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
type Item = Result<Event>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Protocol::poll_next(self, cx).map(Some)
}
}
#[derive(Clone, Debug)]
pub struct CommandTx(Sender<Command>);
impl CommandTx {
pub async fn send(&mut self, command: Command) -> Result<()> {
self.0.send(command).await.map_err(map_channel_err)
}
pub async fn open(&mut self, key: Key) -> Result<()> {
self.send(Command::Open(key)).await
}
pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
self.send(Command::Close(discovery_key)).await
}
}
fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> {
key.try_into()
.map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long"))
}