pub use self::error::SecioError;
use futures::stream::MapErr as StreamMapErr;
use futures::prelude::*;
use tet_libp2p_core::{PeerId, PublicKey, identity, upgrade::{UpgradeInfo, InboundUpgrade, OutboundUpgrade}};
use log::debug;
use rw_stream_sink::RwStreamSink;
use std::{io, iter, pin::Pin, task::Context, task::Poll};
mod algo_support;
mod codec;
mod error;
mod exchange;
mod handshake;
mod structs_proto {
include!(concat!(env!("OUT_DIR"), "/spipe.pb.rs"));
}
mod stream_cipher;
pub use crate::algo_support::Digest;
pub use crate::exchange::KeyAgreement;
pub use crate::stream_cipher::Cipher;
#[derive(Clone)]
pub struct SecioConfig {
pub(crate) key: identity::Keypair,
pub(crate) agreements_prop: Option<String>,
pub(crate) ciphers_prop: Option<String>,
pub(crate) digests_prop: Option<String>,
pub(crate) max_frame_len: usize
}
impl SecioConfig {
pub fn new(kp: identity::Keypair) -> Self {
SecioConfig {
key: kp,
agreements_prop: None,
ciphers_prop: None,
digests_prop: None,
max_frame_len: 8 * 1024 * 1024
}
}
pub fn key_agreements<'a, I>(mut self, xs: I) -> Self
where
I: IntoIterator<Item=&'a KeyAgreement>
{
self.agreements_prop = Some(algo_support::key_agreements_proposition(xs));
self
}
pub fn ciphers<'a, I>(mut self, xs: I) -> Self
where
I: IntoIterator<Item=&'a Cipher>
{
self.ciphers_prop = Some(algo_support::ciphers_proposition(xs));
self
}
pub fn digests<'a, I>(mut self, xs: I) -> Self
where
I: IntoIterator<Item=&'a Digest>
{
self.digests_prop = Some(algo_support::digests_proposition(xs));
self
}
pub fn max_frame_len(mut self, n: usize) -> Self {
self.max_frame_len = n;
self
}
fn handshake<T>(self, socket: T) -> impl Future<Output = Result<(PeerId, SecioOutput<T>), SecioError>>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
debug!("Starting secio upgrade");
SecioMiddleware::handshake(socket, self)
.map_ok(|(stream_sink, pubkey, ephemeral)| {
let mapped = stream_sink.map_err(map_err as fn(_) -> _);
let peer = pubkey.clone().into_peer_id();
let io = SecioOutput {
stream: RwStreamSink::new(mapped),
remote_key: pubkey,
ephemeral_public_key: ephemeral
};
(peer, io)
})
}
}
pub struct SecioOutput<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
pub stream: RwStreamSink<StreamMapErr<SecioMiddleware<S>, fn(SecioError) -> io::Error>>,
pub remote_key: PublicKey,
pub ephemeral_public_key: Vec<u8>,
}
impl UpgradeInfo for SecioConfig {
type Info = &'static [u8];
type InfoIter = iter::Once<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
iter::once(b"/secio/1.0.0")
}
}
impl<T> InboundUpgrade<T> for SecioConfig
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
type Output = (PeerId, SecioOutput<T>);
type Error = SecioError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn upgrade_inbound(self, socket: T, _: Self::Info) -> Self::Future {
Box::pin(self.handshake(socket))
}
}
impl<T> OutboundUpgrade<T> for SecioConfig
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
type Output = (PeerId, SecioOutput<T>);
type Error = SecioError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn upgrade_outbound(self, socket: T, _: Self::Info) -> Self::Future {
Box::pin(self.handshake(socket))
}
}
impl<S> AsyncRead for SecioOutput<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
-> Poll<Result<usize, io::Error>>
{
AsyncRead::poll_read(Pin::new(&mut self.stream), cx, buf)
}
}
impl<S> AsyncWrite for SecioOutput<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
-> Poll<Result<usize, io::Error>>
{
AsyncWrite::poll_write(Pin::new(&mut self.stream), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>)
-> Poll<Result<(), io::Error>>
{
AsyncWrite::poll_flush(Pin::new(&mut self.stream), cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>)
-> Poll<Result<(), io::Error>>
{
AsyncWrite::poll_close(Pin::new(&mut self.stream), cx)
}
}
fn map_err(err: SecioError) -> io::Error {
debug!("error during secio handshake {:?}", err);
io::Error::new(io::ErrorKind::InvalidData, err)
}
pub struct SecioMiddleware<S> {
inner: codec::FullCodec<S>,
}
impl<S> SecioMiddleware<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
pub fn handshake(socket: S, config: SecioConfig)
-> impl Future<Output = Result<(SecioMiddleware<S>, PublicKey, Vec<u8>), SecioError>>
{
handshake::handshake(socket, config).map_ok(|(inner, pubkey, ephemeral)| {
let inner = SecioMiddleware { inner };
(inner, pubkey, ephemeral)
})
}
}
impl<S> Sink<Vec<u8>> for SecioMiddleware<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Sink::poll_ready(Pin::new(&mut self.inner), cx)
}
fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
Sink::start_send(Pin::new(&mut self.inner), item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Sink::poll_flush(Pin::new(&mut self.inner), cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Sink::poll_close(Pin::new(&mut self.inner), cx)
}
}
impl<S> Stream for SecioMiddleware<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static
{
type Item = Result<Vec<u8>, SecioError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Stream::poll_next(Pin::new(&mut self.inner), cx)
}
}