1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
use std::fmt;
use std::str::FromStr;

use derive_more::{Display, Error};
use rand::rngs::OsRng;
use rand::RngCore;

#[derive(Debug, Display, Error)]
pub struct SecretKeyError;

impl From<SecretKeyError> for std::io::Error {
    fn from(_: SecretKeyError) -> Self {
        std::io::Error::new(
            std::io::ErrorKind::InvalidData,
            "not valid secret key format",
        )
    }
}

/// Represents a 16-byte (128-bit) secret key
pub type SecretKey16 = SecretKey<16>;

/// Represents a 24-byte (192-bit) secret key
pub type SecretKey24 = SecretKey<24>;

/// Represents a 32-byte (256-bit) secret key
pub type SecretKey32 = SecretKey<32>;

/// Represents a secret key used with transport encryption and authentication
#[derive(Clone, PartialEq, Eq)]
pub struct SecretKey<const N: usize>([u8; N]);

impl<const N: usize> fmt::Debug for SecretKey<N> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("SecretKey")
            .field(&"**OMITTED**".to_string())
            .finish()
    }
}

impl<const N: usize> Default for SecretKey<N> {
    /// Creates a new secret key of the size `N`
    ///
    /// ### Panic
    ///
    /// Will panic if `N` is less than 1 or greater than `isize::MAX`
    fn default() -> Self {
        Self::generate().unwrap()
    }
}

impl<const N: usize> SecretKey<N> {
    /// Returns byte slice to the key's bytes
    pub fn unprotected_as_bytes(&self) -> &[u8] {
        &self.0
    }

    /// Returns reference to array of key's bytes
    pub fn unprotected_as_byte_array(&self) -> &[u8; N] {
        &self.0
    }

    /// Consumes the secret key and returns the array of key's bytes
    pub fn unprotected_into_byte_array(self) -> [u8; N] {
        self.0
    }

    /// Consumes the secret key and returns the key's bytes as a [`HeapSecretKey`]
    pub fn into_heap_secret_key(self) -> HeapSecretKey {
        HeapSecretKey(self.0.to_vec())
    }

    /// Returns the length of the key
    #[allow(clippy::len_without_is_empty)]
    pub fn len(&self) -> usize {
        N
    }

    /// Generates a new secret key, returning success if key created or
    /// failing if the desired key length is not between 1 and `isize::MAX`
    pub fn generate() -> Result<Self, SecretKeyError> {
        // Limitation described in https://github.com/orion-rs/orion/issues/130
        if N < 1 || N > (isize::MAX as usize) {
            return Err(SecretKeyError);
        }

        let mut key = [0; N];
        OsRng.fill_bytes(&mut key);

        Ok(Self(key))
    }

    /// Creates the key from the given byte slice, returning success if key created
    /// or failing if the byte slice does not match the desired key length
    pub fn from_slice(slice: &[u8]) -> Result<Self, SecretKeyError> {
        if slice.len() != N {
            return Err(SecretKeyError);
        }

        let mut value = [0u8; N];
        value[..N].copy_from_slice(slice);

        Ok(Self(value))
    }
}

impl<const N: usize> From<[u8; N]> for SecretKey<N> {
    fn from(arr: [u8; N]) -> Self {
        Self(arr)
    }
}

impl<const N: usize> FromStr for SecretKey<N> {
    type Err = SecretKeyError;

    /// Parse a str of hex as an N-byte secret key
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let bytes = hex::decode(s).map_err(|_| SecretKeyError)?;
        Self::from_slice(&bytes)
    }
}

impl<const N: usize> fmt::Display for SecretKey<N> {
    /// Display an N-byte secret key as a hex string
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", hex::encode(self.unprotected_as_bytes()))
    }
}

/// Represents a secret key used with transport encryption and authentication that is stored on the
/// heap
#[derive(Clone, PartialEq, Eq)]
pub struct HeapSecretKey(Vec<u8>);

impl fmt::Debug for HeapSecretKey {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("HeapSecretKey")
            .field(&"**OMITTED**".to_string())
            .finish()
    }
}

impl HeapSecretKey {
    /// Returns byte slice to the key's bytes
    pub fn unprotected_as_bytes(&self) -> &[u8] {
        &self.0
    }

    /// Consumes the secret key and returns the key's bytes
    pub fn unprotected_into_bytes(self) -> Vec<u8> {
        self.0.to_vec()
    }

    /// Returns the length of the key
    #[allow(clippy::len_without_is_empty)]
    pub fn len(&self) -> usize {
        self.0.len()
    }

    /// Generates a random key of `n` bytes in length.
    ///
    /// ### Note
    ///
    /// Will return an error if `n` < 1 or `n` > `isize::MAX`.
    pub fn generate(n: usize) -> Result<Self, SecretKeyError> {
        // Limitation described in https://github.com/orion-rs/orion/issues/130
        if n < 1 || n > (isize::MAX as usize) {
            return Err(SecretKeyError);
        }

        let mut key = Vec::new();
        let mut buf = [0; 32];

        // Continually generate a chunk of bytes and extend our key until we've reached
        // the appropriate length
        while key.len() < n {
            OsRng.fill_bytes(&mut buf);
            key.extend_from_slice(&buf[..std::cmp::min(n - key.len(), 32)]);
        }

        Ok(Self(key))
    }
}

impl From<Vec<u8>> for HeapSecretKey {
    fn from(bytes: Vec<u8>) -> Self {
        Self(bytes)
    }
}

impl<const N: usize> From<[u8; N]> for HeapSecretKey {
    fn from(arr: [u8; N]) -> Self {
        Self::from(arr.to_vec())
    }
}

impl<const N: usize> From<SecretKey<N>> for HeapSecretKey {
    fn from(key: SecretKey<N>) -> Self {
        key.into_heap_secret_key()
    }
}

impl FromStr for HeapSecretKey {
    type Err = SecretKeyError;

    /// Parse a str of hex as secret key on heap
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Ok(Self(hex::decode(s).map_err(|_| SecretKeyError)?))
    }
}

impl fmt::Display for HeapSecretKey {
    /// Display an N-byte secret key as a hex string
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", hex::encode(self.unprotected_as_bytes()))
    }
}

impl<const N: usize> PartialEq<[u8; N]> for HeapSecretKey {
    fn eq(&self, other: &[u8; N]) -> bool {
        self.0.eq(other)
    }
}

impl<const N: usize> PartialEq<HeapSecretKey> for [u8; N] {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(self)
    }
}

impl<const N: usize> PartialEq<HeapSecretKey> for &[u8; N] {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(*self)
    }
}

impl PartialEq<[u8]> for HeapSecretKey {
    fn eq(&self, other: &[u8]) -> bool {
        self.0.eq(other)
    }
}

impl PartialEq<HeapSecretKey> for [u8] {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(self)
    }
}

impl PartialEq<HeapSecretKey> for &[u8] {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(*self)
    }
}

impl PartialEq<String> for HeapSecretKey {
    fn eq(&self, other: &String) -> bool {
        self.0.eq(other.as_bytes())
    }
}

impl PartialEq<HeapSecretKey> for String {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(self)
    }
}

impl PartialEq<HeapSecretKey> for &String {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(*self)
    }
}

impl PartialEq<str> for HeapSecretKey {
    fn eq(&self, other: &str) -> bool {
        self.0.eq(other.as_bytes())
    }
}

impl PartialEq<HeapSecretKey> for str {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(self)
    }
}

impl PartialEq<HeapSecretKey> for &str {
    fn eq(&self, other: &HeapSecretKey) -> bool {
        other.eq(*self)
    }
}

#[cfg(test)]
mod tests {
    use test_log::test;

    use super::*;

    #[test]
    fn secret_key_should_be_able_to_be_generated() {
        SecretKey::<0>::generate().unwrap_err();

        let key = SecretKey::<1>::generate().unwrap();
        assert_eq!(key.len(), 1);

        // NOTE: We aren't going to validate generating isize::MAX or +1 of that size because it
        //       takes a lot of time to do so
        let key = SecretKey::<100>::generate().unwrap();
        assert_eq!(key.len(), 100);
    }

    #[test]
    fn heap_secret_key_should_be_able_to_be_generated() {
        HeapSecretKey::generate(0).unwrap_err();

        let key = HeapSecretKey::generate(1).unwrap();
        assert_eq!(key.len(), 1);

        // NOTE: We aren't going to validate generating isize::MAX or +1 of that size because it
        //       takes a lot of time to do so
        let key = HeapSecretKey::generate(100).unwrap();
        assert_eq!(key.len(), 100);
    }
}