#![feature(async_await, await_macro)]
extern crate futures;
extern crate shs_core;
use futures::io::{
AsyncRead,
AsyncReadExt,
AsyncWrite,
AsyncWriteExt,
};
use core::mem::size_of;
use ssb_crypto::{NetworkKey, NonceGen, PublicKey, SecretKey};
use shs_core::{*, messages::*};
pub use shs_core::HandshakeError;
pub async fn client<S>(mut stream: S,
net_key: NetworkKey,
pk: PublicKey,
sk: SecretKey,
server_pk: PublicKey)
-> Result<HandshakeOutcome, HandshakeError>
where S: AsyncRead + AsyncWrite + Unpin
{
let r = await!(try_client_side(&mut stream, net_key, pk, sk, server_pk));
if r.is_err() {
await!(stream.close()).unwrap_or(());
}
r
}
async fn try_client_side<S>(mut stream: S,
net_key: NetworkKey,
pk: PublicKey,
sk: SecretKey,
server_pk: PublicKey)
-> Result<HandshakeOutcome, HandshakeError>
where S: AsyncRead + AsyncWrite + Unpin
{
let pk = ClientPublicKey(pk);
let sk = ClientSecretKey(sk);
let server_pk = ServerPublicKey(server_pk);
let (eph_pk, eph_sk) = client::generate_eph_keypair();
let hello = ClientHello::new(&eph_pk, &net_key);
await!(stream.write_all(&hello.as_slice()))?;
await!(stream.flush())?;
let server_eph_pk = {
let mut buf = [0u8; size_of::<ServerHello>()];
await!(stream.read_exact(&mut buf))?;
let server_hello = ServerHello::from_slice(&buf)?;
server_hello.verify(&net_key)?
};
let shared_a = SharedA::client_side(&eph_sk, &server_eph_pk)?;
let shared_b = SharedB::client_side(&eph_sk, &server_pk)?;
let shared_c = SharedC::client_side(&sk, &server_eph_pk)?;
let client_auth = ClientAuth::new(&sk, &pk, &server_pk, &net_key, &shared_a, &shared_b);
await!(stream.write_all(client_auth.as_slice()))?;
await!(stream.flush())?;
let mut buf = [0u8; 80];
await!(stream.read_exact(&mut buf))?;
let server_acc = ServerAccept::from_buffer(buf.to_vec())?;
server_acc.open_and_verify(&sk, &pk, &server_pk,
&net_key, &shared_a,
&shared_b, &shared_c)?;
Ok(HandshakeOutcome {
read_key: server_to_client_key(&pk, &net_key, &shared_a, &shared_b, &shared_c),
read_noncegen: NonceGen::new(&eph_pk.0, &net_key),
write_key: client_to_server_key(&server_pk, &net_key, &shared_a, &shared_b, &shared_c),
write_noncegen: NonceGen::new(&server_eph_pk.0, &net_key),
})
}
pub async fn server<S>(mut stream: S,
net_key: NetworkKey,
pk: PublicKey,
sk: SecretKey)
-> Result<HandshakeOutcome, HandshakeError>
where S: AsyncRead + AsyncWrite + Unpin
{
let r = await!(try_server_side(&mut stream, net_key, pk, sk));
if r.is_err() {
await!(stream.close()).unwrap_or(());
}
r
}
async fn try_server_side<S>(mut stream: S,
net_key: NetworkKey,
pk: PublicKey,
sk: SecretKey)
-> Result<HandshakeOutcome, HandshakeError>
where S: AsyncRead + AsyncWrite + Unpin
{
let pk = ServerPublicKey(pk);
let sk = ServerSecretKey(sk);
let (eph_pk, eph_sk) = server::generate_eph_keypair();
let client_eph_pk = {
let mut buf = [0u8; 64];
await!(stream.read_exact(&mut buf))?;
let client_hello = ClientHello::from_slice(&buf)?;
client_hello.verify(&net_key)?
};
let hello = ServerHello::new(&eph_pk, &net_key);
await!(stream.write_all(hello.as_slice()))?;
await!(stream.flush())?;
let shared_a = SharedA::server_side(&eph_sk, &client_eph_pk)?;
let shared_b = SharedB::server_side(&sk, &client_eph_pk)?;
let (client_sig, client_pk) = {
let mut buf = [0u8; 112];
await!(stream.read_exact(&mut buf))?;
let client_auth = ClientAuth::from_buffer(buf.to_vec())?;
client_auth.open_and_verify(&pk, &net_key, &shared_a, &shared_b)?
};
let shared_c = SharedC::server_side(&eph_sk, &client_pk)?;
let server_acc = ServerAccept::new(&sk, &client_pk, &net_key, &client_sig,
&shared_a, &shared_b, &shared_c);
await!(stream.write_all(server_acc.as_slice()))?;
await!(stream.flush())?;
Ok(HandshakeOutcome {
read_key: client_to_server_key(&pk, &net_key, &shared_a, &shared_b, &shared_c),
read_noncegen: NonceGen::new(&eph_pk.0, &net_key),
write_key: server_to_client_key(&client_pk, &net_key, &shared_a, &shared_b, &shared_c),
write_noncegen: NonceGen::new(&client_eph_pk.0, &net_key),
})
}
#[cfg(test)]
mod tests {
use super::*;
use core::task::Context;
use core::pin::Pin;
use std::io::{self, ErrorKind};
use futures::{join, Poll};
use futures::executor::block_on;
extern crate async_ringbuffer;
extern crate pin_utils;
use pin_utils::unsafe_pinned;
use ssb_crypto::{generate_longterm_keypair, NetworkKey, PublicKey};
struct Duplex<R, W> {
r: R,
w: W,
}
impl<R, W> Duplex<R, W> {
unsafe_pinned!(r: R);
unsafe_pinned!(w: W);
}
impl<R, W> AsyncRead for Duplex<R, W>
where
R: AsyncRead + Unpin,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, io::Error>> {
self.r().poll_read(cx, buf)
}
}
impl<R, W> AsyncWrite for Duplex<R, W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.w().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.w().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.w().poll_close(cx)
}
}
type DuplexRingbufStream = Duplex<async_ringbuffer::Reader, async_ringbuffer::Writer>;
fn make_streams() -> (DuplexRingbufStream, DuplexRingbufStream) {
let (c2s_w, c2s_r) = async_ringbuffer::ring_buffer(1024);
let (s2c_w, s2c_r) = async_ringbuffer::ring_buffer(1024);
(Duplex { r: s2c_r, w: c2s_w }, Duplex { r: c2s_r, w: s2c_w })
}
#[test]
fn basic() {
let (mut c_stream, mut s_stream) = make_streams();
let (s_pk, s_sk) = generate_longterm_keypair();
let (c_pk, c_sk) = generate_longterm_keypair();
let net_key = NetworkKey::SSB_MAIN_NET;
let client_side = client(&mut c_stream, net_key.clone(), c_pk, c_sk, s_pk.clone());
let server_side = server(&mut s_stream, net_key.clone(), s_pk, s_sk);
let (c_out, s_out) = block_on(async {
join!(client_side, server_side)
});
let mut c_out = c_out.unwrap();
let mut s_out = s_out.unwrap();
assert_eq!(c_out.write_key, s_out.read_key);
assert_eq!(c_out.read_key, s_out.write_key);
assert_eq!(c_out.write_noncegen.next(),
s_out.read_noncegen.next());
assert_eq!(c_out.read_noncegen.next(),
s_out.write_noncegen.next());
}
fn is_eof_err<T>(r: &Result<T, HandshakeError>) -> bool {
match r {
Err(HandshakeError::Io(e)) => e.kind() == ErrorKind::UnexpectedEof,
_ => false,
}
}
#[test]
fn server_rejects_wrong_netkey() {
let (mut c_stream, mut s_stream) = make_streams();
let (s_pk, s_sk) = generate_longterm_keypair();
let (c_pk, c_sk) = generate_longterm_keypair();
let client_side = client(&mut c_stream, NetworkKey::random(), c_pk, c_sk, s_pk.clone());
let server_side = server(&mut s_stream, NetworkKey::random(), s_pk, s_sk);
let (c_out, s_out) = block_on(async {
join!(client_side, server_side)
});
assert!(is_eof_err(&c_out));
match s_out {
Err(HandshakeError::ClientHelloVerifyFailed) => {},
_ => panic!(),
};
}
#[test]
fn reject_wrong_server_pk() {
test_handshake_with_bad_server_pk(
PublicKey::from_slice(&[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]).unwrap());
let (pk, _sk) = generate_longterm_keypair();
test_handshake_with_bad_server_pk(pk);
}
fn test_handshake_with_bad_server_pk(bad_pk: PublicKey) {
let (mut c_stream, mut s_stream) = make_streams();
let (s_pk, s_sk) = generate_longterm_keypair();
let (c_pk, c_sk) = generate_longterm_keypair();
let net_key = NetworkKey::SSB_MAIN_NET;
let client_side = client(&mut c_stream, net_key.clone(), c_pk, c_sk, bad_pk);
let server_side = server(&mut s_stream, net_key.clone(), s_pk, s_sk);
let (c_out, s_out) = block_on(async {
join!(client_side, server_side)
});
assert!(c_out.is_err());
assert!(s_out.is_err());
}
}