use super::*;
use crate::proto::pq_ratchet as pqrpb;
use crate::{Error, SerializedMessage, Version};
use num_enum::IntoPrimitive;
impl States {
pub fn into_pb(self) -> pqrpb::V1State {
pqrpb::V1State {
inner_state: Some(match self {
Self::KeysUnsampled(state) => {
pqrpb::v1_state::InnerState::KeysUnsampled(state.into_pb())
}
Self::KeysSampled(state) => {
pqrpb::v1_state::InnerState::KeysSampled(state.into_pb())
}
Self::HeaderSent(state) => pqrpb::v1_state::InnerState::HeaderSent(state.into_pb()),
Self::Ct1Received(state) => {
pqrpb::v1_state::InnerState::Ct1Received(state.into_pb())
}
Self::EkSentCt1Received(state) => {
pqrpb::v1_state::InnerState::EkSentCt1Received(state.into_pb())
}
Self::NoHeaderReceived(state) => {
pqrpb::v1_state::InnerState::NoHeaderReceived(state.into_pb())
}
Self::HeaderReceived(state) => {
pqrpb::v1_state::InnerState::HeaderReceived(state.into_pb())
}
Self::Ct1Sampled(state) => pqrpb::v1_state::InnerState::Ct1Sampled(state.into_pb()),
Self::EkReceivedCt1Sampled(state) => {
pqrpb::v1_state::InnerState::EkReceivedCt1Sampled(state.into_pb())
}
Self::Ct1Acknowledged(state) => {
pqrpb::v1_state::InnerState::Ct1Acknowledged(state.into_pb())
}
Self::Ct2Sampled(state) => pqrpb::v1_state::InnerState::Ct2Sampled(state.into_pb()),
}),
}
}
pub fn from_pb(pb: pqrpb::V1State) -> Result<Self, Error> {
Ok(match pb.inner_state {
Some(pqrpb::v1_state::InnerState::KeysUnsampled(pb)) => {
Self::KeysUnsampled(send_ek::KeysUnsampled::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::KeysSampled(pb)) => {
Self::KeysSampled(send_ek::KeysSampled::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::HeaderSent(pb)) => {
Self::HeaderSent(send_ek::HeaderSent::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::Ct1Received(pb)) => {
Self::Ct1Received(send_ek::Ct1Received::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::EkSentCt1Received(pb)) => {
Self::EkSentCt1Received(send_ek::EkSentCt1Received::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::NoHeaderReceived(pb)) => {
Self::NoHeaderReceived(send_ct::NoHeaderReceived::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::HeaderReceived(pb)) => {
Self::HeaderReceived(send_ct::HeaderReceived::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::Ct1Sampled(pb)) => {
Self::Ct1Sampled(send_ct::Ct1Sampled::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::EkReceivedCt1Sampled(pb)) => {
Self::EkReceivedCt1Sampled(send_ct::EkReceivedCt1Sampled::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::Ct1Acknowledged(pb)) => {
Self::Ct1Acknowledged(send_ct::Ct1Acknowledged::from_pb(pb)?)
}
Some(pqrpb::v1_state::InnerState::Ct2Sampled(pb)) => {
Self::Ct2Sampled(send_ct::Ct2Sampled::from_pb(pb)?)
}
_ => {
return Err(Error::StateDecode);
}
})
}
}
#[derive(IntoPrimitive)]
#[repr(u8)]
enum MessageType {
None = 0,
Hdr = 1,
Ek = 2,
EkCt1Ack = 3,
Ct1Ack = 4,
Ct1 = 5,
Ct2 = 6,
}
#[hax_lib::opaque]
impl TryFrom<u8> for MessageType {
type Error = String;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(MessageType::None),
1 => Ok(MessageType::Hdr),
2 => Ok(MessageType::Ek),
3 => Ok(MessageType::EkCt1Ack),
4 => Ok(MessageType::Ct1Ack),
5 => Ok(MessageType::Ct1),
6 => Ok(MessageType::Ct2),
_ => Err("Expected a number between 0 and 6".to_owned()),
}
}
}
impl MessageType {
fn from_payload(mp: &MessagePayload) -> Self {
match mp {
MessagePayload::None => Self::None,
MessagePayload::Hdr(_) => Self::Hdr,
MessagePayload::Ek(_) => Self::Ek,
MessagePayload::EkCt1Ack(_) => Self::EkCt1Ack,
MessagePayload::Ct1Ack(_) => Self::Ct1Ack,
MessagePayload::Ct1(_) => Self::Ct1,
MessagePayload::Ct2(_) => Self::Ct2,
}
}
}
fn encode_varint(mut a: u64, into: &mut SerializedMessage) {
for _i in 0..10 {
let byte = (a & 0x7F) as u8;
if a < 0x80 {
into.push(byte);
break;
} else {
into.push(0x80 | byte);
a >>= 7;
}
}
}
#[hax_lib::opaque] fn decode_varint(from: &SerializedMessage, at: &mut usize) -> Result<u64, Error> {
let mut out = 0u64;
let mut shift = 0;
while *at < from.len() {
let byte = from[*at];
out |= ((byte as u64) & 0x7f) << shift;
*at += 1;
if byte & 0x80 == 0 {
return Ok(out);
}
shift += 7;
}
Err(Error::MsgDecode)
}
#[hax_lib::fstar::verification_status(lax)]
fn encode_chunk(c: &Chunk, into: &mut SerializedMessage) {
encode_varint(c.index as u64, into);
into.extend_from_slice(&c.data[..]);
}
#[hax_lib::fstar::verification_status(lax)]
fn decode_chunk(from: &SerializedMessage, at: &mut usize) -> Result<Chunk, Error> {
let index = decode_varint(from, at)?;
let start = *at;
*at += 32;
if *at > from.len() || index > 65535 {
return Err(Error::MsgDecode);
}
Ok(Chunk {
index: index as u16,
data: from[start..*at].try_into().expect("correct size"),
})
}
impl Message {
pub fn serialize(&self, index: u32) -> SerializedMessage {
hax_lib::fstar!("admit()");
let mut into = Vec::with_capacity(40);
into.push(Version::V1.into());
encode_varint(self.epoch, &mut into);
encode_varint(index as u64, &mut into);
into.push(MessageType::from_payload(&self.payload).into());
encode_chunk(
match &self.payload {
MessagePayload::Hdr(ref chunk) => chunk,
MessagePayload::Ek(ref chunk) => chunk,
MessagePayload::EkCt1Ack(ref chunk) => chunk,
MessagePayload::Ct1(ref chunk) => chunk,
MessagePayload::Ct2(ref chunk) => chunk,
_ => {
return into;
}
},
&mut into,
);
into
}
pub fn deserialize(from: &SerializedMessage) -> Result<(Self, u32, usize), Error> {
hax_lib::fstar!("admit()");
if from.is_empty() || from[0] != Version::V1.into() {
return Err(Error::MsgDecode);
}
let mut at = 1usize;
let epoch = decode_varint(from, &mut at)? as Epoch;
let index: u32 = decode_varint(from, &mut at)?
.try_into()
.map_err(|_| Error::MsgDecode)?;
let msg_type = MessageType::try_from(from[at]).map_err(|_| Error::MsgDecode)?;
at += 1;
let payload = match msg_type {
MessageType::None => MessagePayload::None,
MessageType::Ct1Ack => MessagePayload::Ct1Ack(true),
MessageType::Hdr => MessagePayload::Hdr(decode_chunk(from, &mut at)?),
MessageType::Ek => MessagePayload::Ek(decode_chunk(from, &mut at)?),
MessageType::EkCt1Ack => MessagePayload::EkCt1Ack(decode_chunk(from, &mut at)?),
MessageType::Ct1 => MessagePayload::Ct1(decode_chunk(from, &mut at)?),
MessageType::Ct2 => MessagePayload::Ct2(decode_chunk(from, &mut at)?),
};
Ok((Self { epoch, payload }, index, at))
}
}
#[cfg(test)]
mod test {
use super::{decode_varint, encode_varint};
use rand::RngCore;
use rand::TryRngCore;
use rand_core::OsRng;
#[test]
fn encoding_varint() {
let mut v = vec![];
encode_varint(0x012C, &mut v);
assert_eq!(&v, &[0xAC, 0x02][..]);
}
#[test]
fn decoding_varint() {
let v = vec![0xFF, 0xAC, 0x02, 0xFF];
let mut at = 1usize;
assert_eq!(0x012C, decode_varint(&v, &mut at).unwrap());
assert_eq!(at, 3);
}
#[test]
fn roundtrip_varint() {
let mut rng = OsRng.unwrap_err();
for _i in 0..10000 {
let u = rng.next_u64();
let mut v = vec![];
encode_varint(u, &mut v);
let mut at = 0usize;
assert_eq!(u, decode_varint(&v, &mut at).unwrap());
assert_eq!(at, v.len());
}
}
}