#[cfg(feature = "noise_sv2")]
use binary_sv2::Deserialize;
#[cfg(feature = "noise_sv2")]
use binary_sv2::GetSize;
use binary_sv2::Serialize;
pub use buffer_sv2::AeadBuffer;
use core::marker::PhantomData;
#[cfg(feature = "noise_sv2")]
use framing_sv2::framing::HandShakeFrame;
use framing_sv2::{
framing::{Frame, Sv2Frame},
header::Header,
};
#[cfg(feature = "noise_sv2")]
use framing_sv2::{ENCRYPTED_SV2_FRAME_HEADER_SIZE, SV2_FRAME_CHUNK_SIZE, SV2_FRAME_HEADER_SIZE};
#[cfg(feature = "noise_sv2")]
use noise_sv2::NoiseCodec;
#[cfg(feature = "noise_sv2")]
use noise_sv2::NOISE_FRAME_HEADER_SIZE;
#[cfg(feature = "noise_sv2")]
use crate::error::Error;
use crate::error::Result;
use crate::Error::MissingBytes;
#[cfg(feature = "noise_sv2")]
use crate::State;
#[cfg(not(feature = "with_buffer_pool"))]
use buffer_sv2::{Buffer as IsBuffer, BufferFromSystemMemory as Buffer};
#[cfg(feature = "with_buffer_pool")]
use buffer_sv2::{Buffer as IsBuffer, BufferFromSystemMemory, BufferPool};
#[cfg(feature = "with_buffer_pool")]
type Buffer = BufferPool<BufferFromSystemMemory>;
pub type StandardEitherFrame<T> = Frame<T, <Buffer as IsBuffer>::Slice>;
pub type StandardSv2Frame<T> = Sv2Frame<T, <Buffer as IsBuffer>::Slice>;
#[cfg(feature = "noise_sv2")]
pub type StandardNoiseDecoder<T> = WithNoise<Buffer, T>;
pub type StandardDecoder<T> = WithoutNoise<Buffer, T>;
#[cfg(feature = "noise_sv2")]
#[derive(Debug)]
pub struct WithNoise<B: IsBuffer, T: Serialize + binary_sv2::GetSize> {
frame: PhantomData<T>,
missing_noise_b: usize,
noise_buffer: B,
sv2_buffer: B,
}
#[cfg(feature = "noise_sv2")]
impl<'a, T: Serialize + GetSize + Deserialize<'a>, B: IsBuffer + AeadBuffer> WithNoise<B, T> {
#[inline]
pub fn next_frame(&mut self, state: &mut State) -> Result<Frame<T, B::Slice>> {
match state {
State::HandShake(_) => unreachable!(),
State::NotInitialized(msg_len) => {
let hint = *msg_len - self.noise_buffer.as_ref().len();
match hint {
0 => {
self.missing_noise_b = NOISE_FRAME_HEADER_SIZE;
Ok(self.while_handshaking())
}
_ => {
self.missing_noise_b = hint;
Err(Error::MissingBytes(hint))
}
}
}
State::Transport(noise_codec) => {
let hint = if IsBuffer::len(&self.sv2_buffer) < SV2_FRAME_HEADER_SIZE {
let len = IsBuffer::len(&self.noise_buffer);
let src = self.noise_buffer.get_data_by_ref(len);
if src.len() < ENCRYPTED_SV2_FRAME_HEADER_SIZE {
ENCRYPTED_SV2_FRAME_HEADER_SIZE - src.len()
} else {
0
}
} else {
let src = self.sv2_buffer.get_data_by_ref(SV2_FRAME_HEADER_SIZE);
let header = Header::from_bytes(src)?;
header.encrypted_len() - IsBuffer::len(&self.noise_buffer)
};
match hint {
0 => {
self.missing_noise_b = ENCRYPTED_SV2_FRAME_HEADER_SIZE;
self.decode_noise_frame(noise_codec)
}
_ => {
self.missing_noise_b = hint;
Err(Error::MissingBytes(hint))
}
}
}
}
}
pub fn writable_len(&self) -> usize {
self.missing_noise_b
}
#[inline]
pub fn writable(&mut self) -> &mut [u8] {
self.noise_buffer.get_writable(self.missing_noise_b)
}
pub fn droppable(&self) -> bool {
self.noise_buffer.is_droppable() && self.sv2_buffer.is_droppable()
}
fn while_handshaking(&mut self) -> Frame<T, B::Slice> {
let src = self.noise_buffer.get_data_owned().as_mut().to_vec();
#[cfg(feature = "with_buffer_pool")]
let frame = HandShakeFrame::from_bytes_unchecked(src.into());
#[cfg(not(feature = "with_buffer_pool"))]
let frame = HandShakeFrame::from_bytes_unchecked(src);
frame.into()
}
#[inline]
fn decode_noise_frame(&mut self, noise_codec: &mut NoiseCodec) -> Result<Frame<T, B::Slice>> {
match (
IsBuffer::len(&self.noise_buffer),
IsBuffer::len(&self.sv2_buffer),
) {
(ENCRYPTED_SV2_FRAME_HEADER_SIZE, 0) => {
let src = self.noise_buffer.get_data_owned();
let decrypted_header = self
.sv2_buffer
.get_writable(ENCRYPTED_SV2_FRAME_HEADER_SIZE);
decrypted_header.copy_from_slice(src.as_ref());
self.sv2_buffer.as_ref();
noise_codec.decrypt(&mut self.sv2_buffer)?;
let header =
Header::from_bytes(self.sv2_buffer.get_data_by_ref(SV2_FRAME_HEADER_SIZE))?;
self.missing_noise_b = header.encrypted_len();
Err(Error::MissingBytes(header.encrypted_len()))
}
_ => {
let encrypted_payload = self.noise_buffer.get_data_owned();
let encrypted_payload_len = encrypted_payload.as_ref().len();
let mut start = 0;
let mut end = if encrypted_payload_len < SV2_FRAME_CHUNK_SIZE {
encrypted_payload_len
} else {
SV2_FRAME_CHUNK_SIZE
};
let mut decrypted_len = SV2_FRAME_HEADER_SIZE;
while start < encrypted_payload_len {
let decrypted_payload = self.sv2_buffer.get_writable(end - start);
decrypted_payload.copy_from_slice(&encrypted_payload.as_ref()[start..end]);
self.sv2_buffer.danger_set_start(decrypted_len);
noise_codec.decrypt(&mut self.sv2_buffer)?;
start = end;
end = (start + SV2_FRAME_CHUNK_SIZE).min(encrypted_payload_len);
decrypted_len += self.sv2_buffer.as_ref().len();
}
self.sv2_buffer.danger_set_start(0);
let src = self.sv2_buffer.get_data_owned();
let frame = Sv2Frame::<T, B::Slice>::from_bytes_unchecked(src);
Ok(frame.into())
}
}
}
}
#[cfg(feature = "noise_sv2")]
impl<T: Serialize + binary_sv2::GetSize> WithNoise<Buffer, T> {
pub fn new() -> Self {
Self {
frame: PhantomData,
missing_noise_b: 0,
noise_buffer: Buffer::new(2_usize.pow(16) * 5),
sv2_buffer: Buffer::new(2_usize.pow(16) * 5),
}
}
}
#[cfg(feature = "noise_sv2")]
impl<T: Serialize + binary_sv2::GetSize> Default for WithNoise<Buffer, T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct WithoutNoise<B: IsBuffer, T: Serialize + binary_sv2::GetSize> {
frame: PhantomData<T>,
missing_b: usize,
buffer: B,
}
impl<T: Serialize + binary_sv2::GetSize, B: IsBuffer> WithoutNoise<B, T> {
#[inline]
pub fn next_frame(&mut self) -> Result<Sv2Frame<T, B::Slice>> {
let len = self.buffer.len();
let src = self.buffer.get_data_by_ref(len);
let hint = Sv2Frame::<T, B::Slice>::size_hint(src) as usize;
match hint {
0 => {
self.missing_b = Header::SIZE;
let src = self.buffer.get_data_owned();
let frame = Sv2Frame::<T, B::Slice>::from_bytes_unchecked(src);
Ok(frame)
}
_ => {
self.missing_b = hint;
Err(MissingBytes(self.missing_b))
}
}
}
pub fn writable(&mut self) -> &mut [u8] {
self.buffer.get_writable(self.missing_b)
}
}
impl<T: Serialize + binary_sv2::GetSize> WithoutNoise<Buffer, T> {
pub fn new() -> Self {
Self {
frame: PhantomData,
missing_b: Header::SIZE,
buffer: Buffer::new(2_usize.pow(16) * 5),
}
}
}
impl<T: Serialize + binary_sv2::GetSize> Default for WithoutNoise<Buffer, T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use binary_sv2::{self, Serialize};
#[derive(Serialize)]
pub struct TestMessage {}
#[test]
fn unencrypted_writable_with_missing_b_initialized_as_header_size() {
let mut decoder = StandardDecoder::<TestMessage>::new();
let actual = decoder.writable();
let expect = [0u8; Header::SIZE];
assert_eq!(actual, expect);
}
}
#[cfg(test)]
mod prop_tests {
use crate::{decoder::Buffer, encoder::Encoder, StandardDecoder};
#[cfg(feature = "noise_sv2")]
use crate::{HandshakeRole, NoiseEncoder, StandardNoiseDecoder, State};
use binary_sv2::{self, Deserialize, Serialize};
use buffer_sv2::Buffer as IsBuffer;
#[cfg(feature = "noise_sv2")]
use framing_sv2::framing::Frame;
use framing_sv2::framing::Sv2Frame;
#[cfg(feature = "noise_sv2")]
use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey};
#[cfg(feature = "noise_sv2")]
use noise_sv2::{ELLSWIFT_ENCODING_SIZE, INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE};
use quickcheck::{Arbitrary, Gen, TestResult};
use quickcheck_macros::quickcheck;
#[cfg(feature = "noise_sv2")]
use std::convert::TryInto;
#[cfg(feature = "noise_sv2")]
use std::time::Duration;
#[cfg(feature = "noise_sv2")]
const AUTHORITY_PUBLIC_K: &str = "9auqWEzQDVyd2oe1JVGFLMLHZtCo2FFqZwtKA5gd9xbuEu7PH72";
#[cfg(feature = "noise_sv2")]
const AUTHORITY_PRIVATE_K: &str = "mkDLTBBRxdBv998612qipDYoTK3YUrqLe8uWw7gu3iXbSrn2n";
type Slice = <Buffer as IsBuffer>::Slice;
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
struct TestMessage {
value: u16,
}
impl Arbitrary for TestMessage {
fn arbitrary(g: &mut Gen) -> Self {
TestMessage {
value: u16::arbitrary(g),
}
}
}
fn decode_frame(
decoder: &mut StandardDecoder<TestMessage>,
encoded_bytes: &[u8],
chunk_size: Option<usize>,
) -> Option<framing_sv2::framing::Sv2Frame<TestMessage, Slice>> {
let mut offset = 0;
while offset < encoded_bytes.len() {
let writable = decoder.writable();
let available = encoded_bytes.len() - offset;
let to_copy = match chunk_size {
Some(c) => core::cmp::min(core::cmp::min(writable.len(), c), available),
None => core::cmp::min(writable.len(), available),
};
writable[..to_copy].copy_from_slice(&encoded_bytes[offset..offset + to_copy]);
offset += to_copy;
match decoder.next_frame() {
Ok(frame) => return Some(frame),
Err(crate::Error::MissingBytes(_)) => continue,
Err(_) => return None,
}
}
None
}
#[quickcheck]
fn prop_encode_decode_roundtrip(msg: TestMessage, msg_type: u8, ext_type: u16) -> TestResult {
let original_msg = msg.clone();
let frame =
match Sv2Frame::<TestMessage, Slice>::from_message(msg, msg_type, ext_type, false) {
Some(f) => f,
None => return TestResult::discard(),
};
let expected_ext_type = frame.get_header().unwrap().ext_type();
let mut encoder = Encoder::<TestMessage>::new();
let encoded = match encoder.encode(frame) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let mut decoder = StandardDecoder::<TestMessage>::new();
match decode_frame(&mut decoder, encoded.as_ref(), None) {
Some(mut decoded_frame) => {
let header = match decoded_frame.get_header() {
Some(h) => h,
None => return TestResult::failed(),
};
let actual_msg_type = header.msg_type();
let actual_ext_type = header.ext_type();
let decoded_msg: TestMessage = match binary_sv2::from_bytes(decoded_frame.payload())
{
Ok(m) => m,
Err(_) => return TestResult::failed(),
};
TestResult::from_bool(
decoded_msg == original_msg
&& actual_msg_type == msg_type
&& actual_ext_type == expected_ext_type,
)
}
None => TestResult::failed(),
}
}
#[quickcheck]
fn prop_decoder_handles_partial_data(
msg: TestMessage,
msg_type: u8,
chunk_size: u8,
) -> TestResult {
if chunk_size == 0 {
return TestResult::discard();
}
let frame = match Sv2Frame::<TestMessage, Slice>::from_message(msg, msg_type, 0, false) {
Some(f) => f,
None => return TestResult::discard(),
};
let mut encoder = Encoder::<TestMessage>::new();
let encoded = match encoder.encode(frame) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let mut decoder = StandardDecoder::<TestMessage>::new();
let encoded_bytes: &[u8] = encoded.as_ref();
let chunk_size = (chunk_size as usize).max(1);
let mut offset = 0;
let mut missing_bytes_count = 0;
while offset < encoded_bytes.len() {
let writable = decoder.writable();
let to_copy = core::cmp::min(
core::cmp::min(writable.len(), chunk_size),
encoded_bytes.len() - offset,
);
writable[..to_copy].copy_from_slice(&encoded_bytes[offset..offset + to_copy]);
offset += to_copy;
match decoder.next_frame() {
Ok(_) => return TestResult::passed(),
Err(crate::Error::MissingBytes(n)) => {
missing_bytes_count += 1;
assert!(n > 0);
}
Err(_) => return TestResult::failed(),
}
}
TestResult::from_bool(missing_bytes_count > 0)
}
#[quickcheck]
fn prop_decoder_multiple_frames(
msg1: TestMessage,
msg2: TestMessage,
msg_type: u8,
) -> TestResult {
let frame1 =
match Sv2Frame::<TestMessage, Slice>::from_message(msg1.clone(), msg_type, 0, false) {
Some(f) => f,
None => return TestResult::discard(),
};
let frame2 =
match Sv2Frame::<TestMessage, Slice>::from_message(msg2.clone(), msg_type, 0, false) {
Some(f) => f,
None => return TestResult::discard(),
};
let mut encoder = Encoder::<TestMessage>::new();
let encoded1 = match encoder.encode(frame1) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let encoded2 = match encoder.encode(frame2) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let mut decoder = StandardDecoder::<TestMessage>::new();
let decoded_msg1 = match decode_frame(&mut decoder, encoded1.as_ref(), None) {
Some(mut f) => match binary_sv2::from_bytes::<TestMessage>(f.payload()) {
Ok(m) => m,
Err(_) => return TestResult::failed(),
},
None => return TestResult::failed(),
};
let decoded_msg2 = match decode_frame(&mut decoder, encoded2.as_ref(), None) {
Some(mut f) => match binary_sv2::from_bytes::<TestMessage>(f.payload()) {
Ok(m) => m,
Err(_) => return TestResult::failed(),
},
None => return TestResult::failed(),
};
TestResult::from_bool(decoded_msg1 == msg1 && decoded_msg2 == msg2)
}
#[cfg(feature = "noise_sv2")]
fn make_transport_state_pair() -> (State, State) {
let pub_k: Secp256k1PublicKey = AUTHORITY_PUBLIC_K.to_string().try_into().unwrap();
let pub_k_bytes = pub_k.into_bytes();
let priv_k: Secp256k1SecretKey = AUTHORITY_PRIVATE_K.to_string().try_into().unwrap();
let priv_k_bytes = priv_k.into_bytes();
let initiator = noise_sv2::Initiator::from_raw_k(pub_k_bytes).unwrap();
let responder = noise_sv2::Responder::from_authority_kp(
&pub_k_bytes,
&priv_k_bytes,
Duration::from_secs(3600),
)
.unwrap();
let mut sender_state = State::initialized(HandshakeRole::Initiator(initiator));
let mut receiver_state = State::initialized(HandshakeRole::Responder(responder));
let msg0 = sender_state.step_0().unwrap();
let msg0: [u8; ELLSWIFT_ENCODING_SIZE] =
msg0.get_payload_when_handshaking().try_into().unwrap();
let (msg1, receiver_transport) = receiver_state.step_1(msg0).unwrap();
let msg1: [u8; INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE] =
msg1.get_payload_when_handshaking().try_into().unwrap();
let sender_transport = sender_state.step_2(msg1).unwrap();
let sender_state = match sender_transport {
State::Transport(c) => State::with_transport_mode(c),
_ => unreachable!(),
};
let receiver_state = match receiver_transport {
State::Transport(c) => State::with_transport_mode(c),
_ => unreachable!(),
};
(sender_state, receiver_state)
}
#[cfg(feature = "noise_sv2")]
fn decode_noise_frame(
decoder: &mut StandardNoiseDecoder<TestMessage>,
state: &mut State,
encoded: &[u8],
) -> Option<Sv2Frame<TestMessage, Slice>> {
let mut offset = 0;
loop {
let writable = decoder.writable();
let available = encoded.len().saturating_sub(offset);
let n = writable.len().min(available);
writable[..n].copy_from_slice(&encoded[offset..offset + n]);
offset += n;
match decoder.next_frame(state) {
Ok(Frame::Sv2(frame)) => return Some(frame),
Ok(_) => return None,
Err(crate::Error::MissingBytes(_)) => {}
Err(_) => return None,
}
}
}
#[cfg(feature = "noise_sv2")]
#[quickcheck]
fn prop_noise_encode_decode_roundtrip(
msg: TestMessage,
msg_type: u8,
ext_type: u16,
) -> TestResult {
let (mut sender_state, mut receiver_state) = make_transport_state_pair();
let original = msg.clone();
let sv2_frame =
match Sv2Frame::<TestMessage, Slice>::from_message(msg, msg_type, ext_type, false) {
Some(f) => f,
None => return TestResult::discard(),
};
let expected_ext = sv2_frame.get_header().unwrap().ext_type();
let frame = Frame::Sv2(sv2_frame);
let mut encoder = NoiseEncoder::<TestMessage>::new();
let encrypted = match encoder.encode(frame, &mut sender_state) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let mut decoder = StandardNoiseDecoder::<TestMessage>::new();
let encrypted_bytes: &[u8] = encrypted.as_ref();
match decode_noise_frame(&mut decoder, &mut receiver_state, encrypted_bytes) {
Some(mut decoded) => {
let header = match decoded.get_header() {
Some(h) => h,
None => return TestResult::failed(),
};
let decoded_msg: TestMessage = match binary_sv2::from_bytes(decoded.payload()) {
Ok(m) => m,
Err(_) => return TestResult::failed(),
};
TestResult::from_bool(
decoded_msg == original
&& header.msg_type() == msg_type
&& header.ext_type() == expected_ext,
)
}
None => TestResult::failed(),
}
}
#[cfg(feature = "noise_sv2")]
#[quickcheck]
fn prop_noise_decoder_handles_partial_data(msg: TestMessage, msg_type: u8) -> TestResult {
let frame = match Sv2Frame::<TestMessage, Slice>::from_message(msg, msg_type, 0, false) {
Some(f) => Frame::Sv2(f),
None => return TestResult::discard(),
};
let (mut sender_state, mut receiver_state) = make_transport_state_pair();
let mut encoder = NoiseEncoder::<TestMessage>::new();
let encrypted = match encoder.encode(frame, &mut sender_state) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let mut decoder = StandardNoiseDecoder::<TestMessage>::new();
let encoded_bytes: &[u8] = encrypted.as_ref();
let mut offset = 0;
let mut missing_bytes_count = 0;
loop {
let writable = decoder.writable();
let n = writable
.len()
.min(encoded_bytes.len().saturating_sub(offset));
writable[..n].copy_from_slice(&encoded_bytes[offset..offset + n]);
offset += n;
match decoder.next_frame(&mut receiver_state) {
Ok(_) => return TestResult::from_bool(missing_bytes_count > 0),
Err(crate::Error::MissingBytes(n)) => {
missing_bytes_count += 1;
assert!(n > 0);
}
Err(_) => return TestResult::failed(),
}
}
}
#[cfg(feature = "noise_sv2")]
#[quickcheck]
fn prop_noise_decoder_multiple_frames(
msg1: TestMessage,
msg2: TestMessage,
msg_type: u8,
) -> TestResult {
let (mut sender_state, mut receiver_state) = make_transport_state_pair();
let frame1 =
match Sv2Frame::<TestMessage, Slice>::from_message(msg1.clone(), msg_type, 0, false) {
Some(f) => Frame::Sv2(f),
None => return TestResult::discard(),
};
let frame2 =
match Sv2Frame::<TestMessage, Slice>::from_message(msg2.clone(), msg_type, 0, false) {
Some(f) => Frame::Sv2(f),
None => return TestResult::discard(),
};
let mut encoder = NoiseEncoder::<TestMessage>::new();
let enc1 = match encoder.encode(frame1, &mut sender_state) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let enc2 = match encoder.encode(frame2, &mut sender_state) {
Ok(e) => e,
Err(_) => return TestResult::failed(),
};
let mut decoder = StandardNoiseDecoder::<TestMessage>::new();
let decoded_msg1 =
match decode_noise_frame(&mut decoder, &mut receiver_state, enc1.as_ref()) {
Some(mut f) => match binary_sv2::from_bytes::<TestMessage>(f.payload()) {
Ok(m) => m,
Err(_) => return TestResult::failed(),
},
None => return TestResult::failed(),
};
let decoded_msg2 =
match decode_noise_frame(&mut decoder, &mut receiver_state, enc2.as_ref()) {
Some(mut f) => match binary_sv2::from_bytes::<TestMessage>(f.payload()) {
Ok(m) => m,
Err(_) => return TestResult::failed(),
},
None => return TestResult::failed(),
};
TestResult::from_bool(decoded_msg1 == msg1 && decoded_msg2 == msg2)
}
}