use std::io::{self, Read, Write};
use anyhow::{Context, Result, anyhow, bail};
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use sha2::{Digest, Sha256};
use snow::{TransportState, params::NoiseParams};
const NOISE_PROTOCOL_NAME: &str = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
const AEAD_TAG_LEN: usize = 16;
const HANDSHAKE_MESSAGE_MAX: usize = 1024;
fn derive_psk(token: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(token);
let digest = hasher.finalize();
let mut psk = [0u8; 32];
psk.copy_from_slice(&digest);
psk
}
pub async fn perform_client_handshake<R, W>(
reader: &mut R,
writer: &mut W,
token: &[u8],
) -> Result<TransportState>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let params: NoiseParams = NOISE_PROTOCOL_NAME
.parse()
.context("failed to parse Noise protocol definition")?;
let psk = derive_psk(token);
let builder = snow::Builder::new(params).psk(0, &psk);
let mut state = builder
.build_initiator()
.context("failed to build Noise initiator")?;
let mut buffer = vec![0u8; HANDSHAKE_MESSAGE_MAX];
let len = state
.write_message(&[], &mut buffer)
.context("failed to write Noise handshake message")?;
send_frame(writer, &buffer[..len]).await?;
writer
.flush()
.await
.context("failed to flush Noise handshake message")?;
let message = read_frame(reader).await?;
let frame = message.ok_or_else(|| anyhow!("peer closed connection during Noise handshake"))?;
state
.read_message(&frame, &mut [])
.context("failed to read Noise handshake response")?;
state
.into_transport_mode()
.context("failed to construct Noise transport state")
}
pub async fn perform_server_handshake<R, W>(
reader: &mut R,
writer: &mut W,
token: &[u8],
) -> Result<TransportState>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let params: NoiseParams = NOISE_PROTOCOL_NAME
.parse()
.context("failed to parse Noise protocol definition")?;
let psk = derive_psk(token);
let builder = snow::Builder::new(params).psk(0, &psk);
let mut state = builder
.build_responder()
.context("failed to build Noise responder")?;
let message = read_frame(reader).await?;
let frame = message.ok_or_else(|| anyhow!("peer closed connection during Noise handshake"))?;
state
.read_message(&frame, &mut [])
.context("failed to process Noise handshake message")?;
let mut buffer = vec![0u8; HANDSHAKE_MESSAGE_MAX];
let len = state
.write_message(&[], &mut buffer)
.context("failed to write Noise handshake response")?;
send_frame(writer, &buffer[..len]).await?;
writer
.flush()
.await
.context("failed to flush Noise handshake response")?;
state
.into_transport_mode()
.context("failed to construct Noise transport state")
}
pub async fn write_encrypted_frame<W>(
writer: &mut W,
state: &mut TransportState,
plaintext: &[u8],
) -> Result<()>
where
W: AsyncWrite + Unpin,
{
if plaintext.len() > MAX_FRAME_LEN {
bail!(
"plaintext message exceeds maximum Noise frame length ({} bytes)",
MAX_FRAME_LEN
);
}
let mut buffer = vec![0u8; plaintext.len() + AEAD_TAG_LEN];
let len = state
.write_message(plaintext, &mut buffer)
.context("failed to encrypt Noise frame")?;
send_frame(writer, &buffer[..len]).await?;
writer
.flush()
.await
.context("failed to flush encrypted Noise frame")?;
Ok(())
}
pub async fn read_encrypted_frame<R>(
reader: &mut R,
state: &mut TransportState,
) -> Result<Option<Vec<u8>>>
where
R: AsyncRead + Unpin,
{
let frame = match read_frame(reader).await? {
Some(frame) => frame,
None => return Ok(None),
};
if frame.len() > MAX_FRAME_LEN + AEAD_TAG_LEN {
bail!(
"encrypted Noise frame exceeds maximum length ({} bytes)",
MAX_FRAME_LEN + AEAD_TAG_LEN
);
}
let mut buffer = vec![0u8; frame.len()];
let len = state
.read_message(&frame, &mut buffer)
.context("failed to decrypt Noise frame")?;
buffer.truncate(len);
Ok(Some(buffer))
}
async fn send_frame<W>(writer: &mut W, payload: &[u8]) -> Result<()>
where
W: AsyncWrite + Unpin,
{
let len = payload.len();
if len > u32::MAX as usize {
bail!("frame payload exceeds u32 length");
}
let mut header = [0u8; 4];
header.copy_from_slice(&(len as u32).to_be_bytes());
writer
.write_all(&header)
.await
.context("failed to write Noise frame header")?;
writer
.write_all(payload)
.await
.context("failed to write Noise frame payload")?;
Ok(())
}
async fn read_frame<R>(reader: &mut R) -> Result<Option<Vec<u8>>>
where
R: AsyncRead + Unpin,
{
let mut header = [0u8; 4];
match reader.read_exact(&mut header).await {
Ok(()) => {}
Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(err) => {
return Err(anyhow!("failed to read Noise frame header: {}", err));
}
}
let len = u32::from_be_bytes(header) as usize;
if len > MAX_FRAME_LEN + AEAD_TAG_LEN {
bail!(
"frame length {} exceeds allowed maximum {}",
len,
MAX_FRAME_LEN + AEAD_TAG_LEN
);
}
let mut payload = vec![0u8; len];
if len > 0 {
if let Err(err) = reader.read_exact(&mut payload).await {
if err.kind() == io::ErrorKind::UnexpectedEof {
return Ok(None);
}
return Err(anyhow!("failed to read Noise frame payload: {}", err));
}
}
Ok(Some(payload))
}
pub fn perform_client_handshake_blocking<S>(stream: &mut S, token: &[u8]) -> Result<TransportState>
where
S: Read + Write,
{
let params: NoiseParams = NOISE_PROTOCOL_NAME
.parse()
.context("failed to parse Noise protocol definition")?;
let psk = derive_psk(token);
let builder = snow::Builder::new(params).psk(0, &psk);
let mut state = builder
.build_initiator()
.context("failed to build Noise initiator")?;
let mut buffer = vec![0u8; HANDSHAKE_MESSAGE_MAX];
let len = state
.write_message(&[], &mut buffer)
.context("failed to write Noise handshake message")?;
send_frame_blocking(stream, &buffer[..len])?;
stream.flush()?;
let frame = read_frame_blocking(stream)?
.ok_or_else(|| anyhow!("peer closed connection during Noise handshake"))?;
state
.read_message(&frame, &mut [])
.context("failed to read Noise handshake response")?;
state
.into_transport_mode()
.context("failed to construct Noise transport state")
}
pub fn write_encrypted_frame_blocking<W: Write>(
writer: &mut W,
state: &mut TransportState,
plaintext: &[u8],
) -> Result<()> {
if plaintext.len() > MAX_FRAME_LEN {
bail!(
"plaintext message exceeds maximum Noise frame length ({} bytes)",
MAX_FRAME_LEN
);
}
let mut buffer = vec![0u8; plaintext.len() + AEAD_TAG_LEN];
let len = state
.write_message(plaintext, &mut buffer)
.context("failed to encrypt Noise frame")?;
send_frame_blocking(writer, &buffer[..len])?;
writer.flush()?;
Ok(())
}
pub fn read_encrypted_frame_blocking<R: Read>(
reader: &mut R,
state: &mut TransportState,
) -> Result<Option<Vec<u8>>> {
let frame = match read_frame_blocking(reader)? {
Some(frame) => frame,
None => return Ok(None),
};
if frame.len() > MAX_FRAME_LEN + AEAD_TAG_LEN {
bail!(
"encrypted Noise frame exceeds maximum length ({} bytes)",
MAX_FRAME_LEN + AEAD_TAG_LEN
);
}
let mut buffer = vec![0u8; frame.len()];
let len = state
.read_message(&frame, &mut buffer)
.context("failed to decrypt Noise frame")?;
buffer.truncate(len);
Ok(Some(buffer))
}
fn send_frame_blocking<W: Write>(writer: &mut W, payload: &[u8]) -> Result<()> {
let len = payload.len();
if len > u32::MAX as usize {
bail!("frame payload exceeds u32 length");
}
let mut header = [0u8; 4];
header.copy_from_slice(&(len as u32).to_be_bytes());
writer
.write_all(&header)
.context("failed to write Noise frame header")?;
writer
.write_all(payload)
.context("failed to write Noise frame payload")?;
Ok(())
}
fn read_frame_blocking<R: Read>(reader: &mut R) -> Result<Option<Vec<u8>>> {
let mut header = [0u8; 4];
match reader.read_exact(&mut header) {
Ok(()) => {}
Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(err) => {
return Err(anyhow!("failed to read Noise frame header: {}", err));
}
}
let len = u32::from_be_bytes(header) as usize;
if len > MAX_FRAME_LEN + AEAD_TAG_LEN {
bail!(
"frame length {} exceeds allowed maximum {}",
len,
MAX_FRAME_LEN + AEAD_TAG_LEN
);
}
let mut payload = vec![0u8; len];
if len > 0 {
if let Err(err) = reader.read_exact(&mut payload) {
if err.kind() == io::ErrorKind::UnexpectedEof {
return Ok(None);
}
return Err(anyhow!("failed to read Noise frame payload: {}", err));
}
}
Ok(Some(payload))
}