use std::convert::TryFrom;
use crate::{
packet::Tag,
Error,
Result
};
use crate::packet::header::BodyLength;
#[derive(Clone, Debug)]
struct CTBCommon {
tag: Tag,
}
#[derive(Clone, Debug)]
pub struct CTBNew {
common: CTBCommon,
}
assert_send_and_sync!(CTBNew);
impl CTBNew {
pub fn new(tag: Tag) -> Self {
CTBNew {
common: CTBCommon {
tag,
},
}
}
pub fn tag(&self) -> Tag {
self.common.tag
}
}
#[derive(Debug)]
#[derive(Clone, Copy, PartialEq)]
pub enum PacketLengthType {
OneOctet,
TwoOctets,
FourOctets,
Indeterminate,
}
assert_send_and_sync!(PacketLengthType);
impl TryFrom<u8> for PacketLengthType {
type Error = anyhow::Error;
fn try_from(u: u8) -> Result<Self> {
match u {
0 => Ok(PacketLengthType::OneOctet),
1 => Ok(PacketLengthType::TwoOctets),
2 => Ok(PacketLengthType::FourOctets),
3 => Ok(PacketLengthType::Indeterminate),
_ => Err(Error::InvalidArgument(
format!("Invalid packet length: {}", u)).into()),
}
}
}
impl From<PacketLengthType> for u8 {
fn from(l: PacketLengthType) -> Self {
match l {
PacketLengthType::OneOctet => 0,
PacketLengthType::TwoOctets => 1,
PacketLengthType::FourOctets => 2,
PacketLengthType::Indeterminate => 3,
}
}
}
#[derive(Clone, Debug)]
pub struct CTBOld {
common: CTBCommon,
length_type: PacketLengthType,
}
assert_send_and_sync!(CTBOld);
impl CTBOld {
pub fn new(tag: Tag, length: BodyLength) -> Result<Self> {
let n: u8 = tag.into();
if n > 15 {
return Err(Error::InvalidArgument(
format!("Only tags 0-15 are supported, got: {:?} ({})",
tag, n)).into());
}
let length_type = match length {
BodyLength::Full(l) => {
match l {
0 ..= 0xFF => PacketLengthType::OneOctet,
0x1_00 ..= 0xFF_FF => PacketLengthType::TwoOctets,
_ => PacketLengthType::FourOctets,
}
},
BodyLength::Partial(_) =>
return Err(Error::InvalidArgument(
"Partial body lengths are not support for old format packets".
into()).into()),
BodyLength::Indeterminate =>
PacketLengthType::Indeterminate,
};
Ok(CTBOld {
common: CTBCommon {
tag,
},
length_type,
})
}
pub fn tag(&self) -> Tag {
self.common.tag
}
pub fn length_type(&self) -> PacketLengthType {
self.length_type
}
}
#[derive(Clone, Debug)]
pub enum CTB {
New(CTBNew),
Old(CTBOld),
}
assert_send_and_sync!(CTB);
impl CTB {
pub fn new(tag: Tag) -> Self {
CTB::New(CTBNew::new(tag))
}
pub fn tag(&self) -> Tag {
match self {
CTB::New(c) => c.tag(),
CTB::Old(c) => c.tag(),
}
}
}
impl TryFrom<u8> for CTB {
type Error = anyhow::Error;
fn try_from(ptag: u8) -> Result<CTB> {
if ptag & 0b1000_0000 == 0 {
return Err(
Error::MalformedPacket(
format!("Malformed CTB: MSB of ptag ({:#010b}) not set{}.",
ptag,
if ptag == b'-' {
" (ptag is a dash, perhaps this is an \
ASCII-armor encoded message)"
} else {
""
})).into());
}
let new_format = ptag & 0b0100_0000 != 0;
let ctb = if new_format {
let tag = ptag & 0b0011_1111;
CTB::New(CTBNew {
common: CTBCommon {
tag: tag.into()
}})
} else {
let tag = (ptag & 0b0011_1100) >> 2;
let length_type = ptag & 0b0000_0011;
CTB::Old(CTBOld {
common: CTBCommon {
tag: tag.into(),
},
length_type: PacketLengthType::try_from(length_type)?,
})
};
Ok(ctb)
}
}
#[test]
fn ctb() {
if let CTB::Old(ctb) = CTB::try_from(0x99).unwrap() {
assert_eq!(ctb.tag(), Tag::PublicKey);
assert_eq!(ctb.length_type, PacketLengthType::TwoOctets);
} else {
panic!("Expected an old format packet.");
}
if let CTB::Old(ctb) = CTB::try_from(0xa3).unwrap() {
assert_eq!(ctb.tag(), Tag::CompressedData);
assert_eq!(ctb.length_type, PacketLengthType::Indeterminate);
} else {
panic!("Expected an old format packet.");
}
if let CTB::New(ctb) = CTB::try_from(0xcb).unwrap() {
assert_eq!(ctb.tag(), Tag::Literal);
} else {
panic!("Expected a new format packet.");
}
}