1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use super::DataType;

use super::Error;
use super::Result;

use cfg_if::cfg_if;
use std::convert::TryFrom;
use std::io::Cursor;

use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use zeroize::Zeroize;

#[cfg(feature = "fuzz")]
use arbitrary::Arbitrary;

const SIGNATURE: u16 = 0x0C0D;

pub trait HeaderType {
    cfg_if! {
        if #[cfg(feature = "fuzz")] {
            type Version: Into<u16> + TryFrom<u16> + Clone + Default + Zeroize + std::fmt::Debug + Arbitrary;
            type Subtype: Into<u16> + TryFrom<u16> + Clone + Default + Zeroize + std::fmt::Debug + Arbitrary;
        }
        else {
            type Version: Into<u16> + TryFrom<u16> + Clone + Default + Zeroize + std::fmt::Debug;
            type Subtype: Into<u16> + TryFrom<u16> + Clone + Default + Zeroize + std::fmt::Debug;
        }
    }

    fn data_type() -> DataType;

    fn default_version() -> Self::Version {
        Default::default()
    }

    fn subtype() -> Self::Subtype {
        Default::default()
    }
}

// Default values, used for len()
impl HeaderType for () {
    type Version = super::CiphertextVersion;
    type Subtype = super::CiphertextSubtype;

    fn data_type() -> DataType {
        super::DataType::None
    }
}

#[derive(Clone, Debug)]
#[cfg_attr(feature = "fuzz", derive(Arbitrary))]
pub struct Header<M>
where
    M: HeaderType,
{
    pub signature: u16,
    pub data_type: DataType,
    pub data_subtype: M::Subtype,
    pub version: M::Version,
}

impl<M> TryFrom<&[u8]> for Header<M>
where
    M: HeaderType,
{
    type Error = crate::error::Error;
    fn try_from(data: &[u8]) -> Result<Self> {
        let mut data_cursor = Cursor::new(data);
        let signature = data_cursor.read_u16::<LittleEndian>()?;
        let data_type = data_cursor.read_u16::<LittleEndian>()?;
        let data_subtype = data_cursor.read_u16::<LittleEndian>()?;
        let version = data_cursor.read_u16::<LittleEndian>()?;

        if signature != SIGNATURE {
            return Err(Error::InvalidSignature);
        }

        let data_type = match DataType::try_from(data_type) {
            Ok(d) => d,
            Err(_) => return Err(Error::UnknownType),
        };

        let data_subtype = match M::Subtype::try_from(data_subtype) {
            Ok(d) => d,
            Err(_) => return Err(Error::UnknownSubtype),
        };

        let version = match M::Version::try_from(version) {
            Ok(d) => d,
            Err(_) => return Err(Error::UnknownVersion),
        };

        if data_type != M::data_type() {
            return Err(Error::InvalidData);
        };

        Ok(Header {
            signature,
            data_type,
            data_subtype,
            version,
        })
    }
}

impl<M> From<Header<M>> for Vec<u8>
where
    M: HeaderType,
{
    fn from(header: Header<M>) -> Vec<u8> {
        let mut data = Vec::with_capacity(8);
        data.write_u16::<LittleEndian>(header.signature).unwrap();
        data.write_u16::<LittleEndian>(header.data_type.into())
            .unwrap();
        data.write_u16::<LittleEndian>(header.data_subtype.into())
            .unwrap();
        data.write_u16::<LittleEndian>(header.version.into())
            .unwrap();
        data
    }
}

impl<M> Default for Header<M>
where
    M: HeaderType,
{
    fn default() -> Self {
        Header {
            signature: SIGNATURE,
            data_type: M::data_type(),
            data_subtype: M::subtype(),
            version: M::default_version(),
        }
    }
}

impl Header<()> {
    pub fn len() -> usize {
        8
    }
}