use crate::{constants::SALT_LEN, utils::Session};
use codec::DistantCodec;
use derive_more::{Display, Error, From};
use futures::SinkExt;
use k256::{ecdh::EphemeralSecret, EncodedPoint, PublicKey};
use orion::{
aead::{self, SecretKey},
auth::{self, Tag},
errors::UnknownCryptoError,
kdf::{self, Salt},
pwhash::Password,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io,
net::{tcp, TcpStream},
};
use tokio_stream::StreamExt;
use tokio_util::codec::{Framed, FramedRead, FramedWrite};
mod codec;
#[derive(Debug, Display, Error, From)]
pub enum TransportError {
#[from(ignore)]
AuthError(UnknownCryptoError),
#[from(ignore)]
EncryptError(UnknownCryptoError),
IoError(io::Error),
SerializeError(serde_cbor::Error),
}
pub struct Transport {
conn: Framed<TcpStream, DistantCodec>,
auth_key: Arc<SecretKey>,
crypt_key: Arc<SecretKey>,
}
impl Transport {
pub async fn from_handshake(stream: TcpStream, auth_key: Arc<SecretKey>) -> io::Result<Self> {
let addr_str = stream
.peer_addr()
.map(|x| x.to_string())
.unwrap_or_else(|_| String::from("???"));
log::trace!("Beginning handshake @ {}", addr_str);
let mut conn = Framed::new(stream, DistantCodec);
let private_key = EphemeralSecret::random(&mut rand::rngs::OsRng);
let public_key = EncodedPoint::from(private_key.public_key());
let salt = Salt::generate(SALT_LEN).map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
let mut data = Vec::new();
data.extend_from_slice(salt.as_ref());
data.extend_from_slice(public_key.as_bytes());
conn.send(&data)
.await
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
let data = conn.next().await.ok_or_else(|| {
io::Error::new(
io::ErrorKind::UnexpectedEof,
"Stream ended before handshake completed",
)
})??;
let (salt_bytes, other_public_key_bytes) = data.split_at(SALT_LEN);
let other_salt = Salt::from_slice(salt_bytes)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
let other_public_key = PublicKey::from_sec1_bytes(other_public_key_bytes)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
let shared_secret = private_key.diffie_hellman(&other_public_key);
let password = Password::from_slice(shared_secret.as_bytes())
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
let mixed_salt = Salt::from_slice(
&salt
.as_ref()
.iter()
.zip(other_salt.as_ref().iter())
.map(|(x, y)| x ^ y)
.collect::<Vec<u8>>(),
)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?;
let derived_key = kdf::derive_key(&password, &mixed_salt, 3, 1 << 16, 32)
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))?;
let crypt_key = Arc::new(derived_key);
log::trace!("Completed handshake @ {}", addr_str);
Ok(Self {
conn,
auth_key,
crypt_key,
})
}
pub async fn connect(session: Session) -> io::Result<Self> {
let stream = TcpStream::connect(session.to_socket_addr().await?).await?;
Self::from_handshake(stream, Arc::new(session.auth_key)).await
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.conn.get_ref().peer_addr()
}
pub fn into_split(self) -> (TransportReadHalf, TransportWriteHalf) {
let auth_key = self.auth_key;
let crypt_key = self.crypt_key;
let parts = self.conn.into_parts();
let (read_half, write_half) = parts.io.into_split();
let mut f_read = FramedRead::new(read_half, parts.codec);
*f_read.read_buffer_mut() = parts.read_buf;
let mut f_write = FramedWrite::new(write_half, parts.codec);
*f_write.write_buffer_mut() = parts.write_buf;
let t_read = TransportReadHalf {
conn: f_read,
auth_key: Arc::clone(&auth_key),
crypt_key: Arc::clone(&crypt_key),
};
let t_write = TransportWriteHalf {
conn: f_write,
auth_key,
crypt_key,
};
(t_read, t_write)
}
}
pub struct TransportWriteHalf {
conn: FramedWrite<tcp::OwnedWriteHalf, DistantCodec>,
auth_key: Arc<SecretKey>,
crypt_key: Arc<SecretKey>,
}
impl TransportWriteHalf {
pub async fn send<T: Serialize>(&mut self, data: T) -> Result<(), TransportError> {
let data = serde_cbor::to_vec(&data)?;
let data = aead::seal(&self.crypt_key, &data).map_err(TransportError::EncryptError)?;
let tag = auth::authenticate(&self.auth_key, &data).map_err(TransportError::AuthError)?;
let mut out: Vec<u8> = Vec::new();
out.push(tag.unprotected_as_bytes().len() as u8);
out.extend_from_slice(tag.unprotected_as_bytes());
out.extend(data);
self.conn.send(&out).await.map_err(TransportError::from)
}
}
pub struct TransportReadHalf {
conn: FramedRead<tcp::OwnedReadHalf, DistantCodec>,
auth_key: Arc<SecretKey>,
crypt_key: Arc<SecretKey>,
}
impl TransportReadHalf {
pub async fn receive<T: DeserializeOwned>(&mut self) -> Result<Option<T>, TransportError> {
if let Some(data) = self.conn.next().await {
let mut data = data?;
if data.is_empty() {
return Err(TransportError::from(io::Error::new(
io::ErrorKind::InvalidData,
"Received data is empty",
)));
}
let tag_len = data[0];
let tag =
Tag::from_slice(&data[1..=tag_len as usize]).map_err(TransportError::AuthError)?;
let data = data.split_off(tag_len as usize + 1);
auth::authenticate_verify(&tag, &self.auth_key, &data)
.map_err(TransportError::AuthError)?;
let data = aead::open(&self.crypt_key, &data).map_err(TransportError::EncryptError)?;
let data = serde_cbor::from_slice(&data)?;
Ok(Some(data))
} else {
Ok(None)
}
}
}