Skip to main content

orion/test_framework/
streamcipher_interface.rs

1// MIT License
2
3// Copyright (c) 2019-2026 The orion Developers
4
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11
12// The above copyright notice and this permission notice shall be included in
13// all copies or substantial portions of the Software.
14
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23#![allow(non_snake_case)]
24
25#[cfg(feature = "safe_api")]
26use crate::errors::UnknownCryptoError;
27
28#[cfg(test)]
29#[cfg(feature = "safe_api")]
30pub trait TestingRandom {
31    /// Randomly generate self.
32    fn gen() -> Self;
33}
34
35#[cfg(feature = "safe_api")]
36/// Test runner for stream ciphers.
37pub fn StreamCipherTestRunner<Encryptor, Decryptor, Key, Nonce>(
38    encryptor: Encryptor,
39    decryptor: Decryptor,
40    key: Key,
41    nonce: Nonce,
42    counter: u32,
43    input: &[u8],
44    expected_ct: Option<&[u8]>,
45) where
46    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
47    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
48{
49    if !input.is_empty() {
50        encrypt_decrypt_out_length(&encryptor, &decryptor, &key, &nonce, input);
51        encrypt_decrypt_equals_expected(
52            &encryptor,
53            &decryptor,
54            &key,
55            &nonce,
56            counter,
57            input,
58            expected_ct,
59        );
60    }
61
62    encrypt_decrypt_input_empty(&encryptor, &decryptor, &key, &nonce);
63    initial_counter_overflow_err(&encryptor, &decryptor, &key, &nonce);
64    initial_counter_max_ok(&encryptor, &decryptor, &key, &nonce);
65}
66
67#[cfg(feature = "safe_api")]
68/// Given an input length `a` find out how many times
69/// the initial counter on encrypt()/decrypt() would
70/// increase.
71fn counter_increase_times(a: f32) -> u32 {
72    // Otherwise a overflowing subtraction would happen
73    if a <= 64f32 {
74        return 0;
75    }
76
77    let check_with_floor = (a / 64f32).floor();
78    let actual = a / 64f32;
79
80    assert!(actual >= check_with_floor);
81    // Subtract one because the first 64 in length
82    // the counter does not increase
83    if actual > check_with_floor {
84        (actual.ceil() as u32) - 1
85    } else {
86        (actual as u32) - 1
87    }
88}
89
90#[cfg(feature = "safe_api")]
91fn return_if_counter_will_overflow<Encryptor, Decryptor, Key, Nonce>(
92    encryptor: &Encryptor,
93    decryptor: &Decryptor,
94    key: &Key,
95    nonce: &Nonce,
96    counter: u32,
97    input: &[u8],
98) -> bool
99where
100    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
101    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
102{
103    assert!(!input.is_empty());
104    let mut dst_out = vec![0u8; input.len()];
105
106    // Overflow will occur and the operation should fail
107    let enc_res = encryptor(key, nonce, counter, &[0u8; 0], &mut dst_out).is_err();
108    let dec_res = decryptor(key, nonce, counter, &[0u8; 0], &mut dst_out).is_err();
109
110    enc_res && dec_res
111}
112
113#[cfg(feature = "safe_api")]
114fn encrypt_decrypt_input_empty<Encryptor, Decryptor, Key, Nonce>(
115    encryptor: &Encryptor,
116    decryptor: &Decryptor,
117    key: &Key,
118    nonce: &Nonce,
119) where
120    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
121    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
122{
123    let mut dst_out = [0u8; 64];
124    assert!(encryptor(key, nonce, 0, &[0u8; 0], &mut dst_out).is_err());
125    assert!(decryptor(key, nonce, 0, &[0u8; 0], &mut dst_out).is_err());
126}
127
128#[cfg(feature = "safe_api")]
129fn encrypt_decrypt_out_length<Encryptor, Decryptor, Key, Nonce>(
130    encryptor: &Encryptor,
131    decryptor: &Decryptor,
132    key: &Key,
133    nonce: &Nonce,
134    input: &[u8],
135) where
136    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
137    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
138{
139    assert!(!input.is_empty());
140
141    let mut dst_out_empty = vec![0u8; 0];
142    assert!(encryptor(key, nonce, 0, input, &mut dst_out_empty).is_err());
143    assert!(decryptor(key, nonce, 0, input, &mut dst_out_empty).is_err());
144
145    let mut dst_out_less = vec![0u8; input.len() - 1];
146    assert!(encryptor(key, nonce, 0, input, &mut dst_out_less).is_err());
147    assert!(decryptor(key, nonce, 0, input, &mut dst_out_less).is_err());
148
149    let mut dst_out_exact = vec![0u8; input.len()];
150    assert!(encryptor(key, nonce, 0, input, &mut dst_out_exact).is_ok());
151    assert!(decryptor(key, nonce, 0, input, &mut dst_out_exact).is_ok());
152
153    let mut dst_out_greater = vec![0u8; input.len() + 1];
154    assert!(encryptor(key, nonce, 0, input, &mut dst_out_greater).is_ok());
155    assert!(decryptor(key, nonce, 0, input, &mut dst_out_greater).is_ok());
156}
157
158#[cfg(feature = "safe_api")]
159/// Test that encrypting and decrypting produces expected plaintext/ciphertext.
160fn encrypt_decrypt_equals_expected<Encryptor, Decryptor, Key, Nonce>(
161    encryptor: &Encryptor,
162    decryptor: &Decryptor,
163    key: &Key,
164    nonce: &Nonce,
165    counter: u32,
166    input: &[u8],
167    expected_ct: Option<&[u8]>,
168) where
169    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
170    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
171{
172    assert!(!input.is_empty());
173
174    // Check if the counter would overflow. If yes, ensure that both encryptor and
175    // decryptor returned errors.
176    if counter_increase_times(input.len() as f32)
177        .checked_add(counter)
178        .is_none()
179    {
180        assert!(return_if_counter_will_overflow(
181            encryptor, decryptor, key, nonce, counter, input
182        ));
183
184        return;
185    }
186
187    let mut dst_out_ct = vec![0u8; input.len()];
188    encryptor(key, nonce, counter, input, &mut dst_out_ct).unwrap();
189    if let Some(expected_result) = expected_ct {
190        assert_eq!(expected_result, &dst_out_ct[..]);
191    }
192
193    let mut dst_out_pt = vec![0u8; input.len()];
194    decryptor(key, nonce, counter, &dst_out_ct, &mut dst_out_pt).unwrap();
195    assert_eq!(input, &dst_out_pt[..]);
196    if let Some(expected_result) = expected_ct {
197        decryptor(key, nonce, counter, expected_result, &mut dst_out_pt).unwrap();
198        assert_eq!(input, &dst_out_pt[..]);
199    }
200}
201
202#[cfg(feature = "safe_api")]
203/// Test that an initial counter will not overflow the internal.
204fn initial_counter_overflow_err<Encryptor, Decryptor, Key, Nonce>(
205    encryptor: &Encryptor,
206    decryptor: &Decryptor,
207    key: &Key,
208    nonce: &Nonce,
209) where
210    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
211    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
212{
213    let mut dst_out = [0u8; 128];
214    assert!(encryptor(
215        key,
216        nonce,
217        u32::MAX,
218        &[0u8; 65], //  CHACHA_BLOCKSIZE + 1 one to trigger internal block counter addition.
219        &mut dst_out
220    )
221    .is_err());
222    assert!(decryptor(
223        key,
224        nonce,
225        u32::MAX,
226        &[0u8; 65], //  CHACHA_BLOCKSIZE + 1 one to trigger internal block counter addition.
227        &mut dst_out
228    )
229    .is_err());
230}
231
232#[cfg(feature = "safe_api")]
233/// Test that processing one block does not fail on the largest possible initial block counter.
234fn initial_counter_max_ok<Encryptor, Decryptor, Key, Nonce>(
235    encryptor: &Encryptor,
236    decryptor: &Decryptor,
237    key: &Key,
238    nonce: &Nonce,
239) where
240    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
241    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
242{
243    let mut dst_out = [0u8; 64];
244    assert!(encryptor(
245        key,
246        nonce,
247        u32::MAX,
248        &[0u8; 64], // Only needs to process one keystream
249        &mut dst_out
250    )
251    .is_ok());
252    assert!(decryptor(
253        key,
254        nonce,
255        u32::MAX,
256        &[0u8; 64], // Only needs to process one keystream
257        &mut dst_out
258    )
259    .is_ok());
260}
261
262#[cfg(test)]
263#[cfg(feature = "safe_api")]
264/// Test that encrypting using different secret-key/nonce/initial-counter combinations yields different
265/// ciphertexts.
266pub fn test_diff_params_diff_output<Encryptor, Decryptor, Key, Nonce>(
267    encryptor: &Encryptor,
268    decryptor: &Decryptor,
269) where
270    Key: TestingRandom + PartialEq<Key>,
271    Nonce: TestingRandom + PartialEq<Nonce>,
272    Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
273    Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
274{
275    let input = &[0u8; 16];
276
277    let sk1 = Key::gen();
278    let sk2 = Key::gen();
279    assert!(sk1 != sk2);
280
281    let n1 = Nonce::gen();
282    let n2 = Nonce::gen();
283    assert!(n1 != n2);
284
285    let c1 = 0u32;
286    let c2 = 1u32;
287
288    let mut dst_out_ct = vec![0u8; input.len()];
289    let mut dst_out_pt = vec![0u8; input.len()];
290
291    // Different secret key
292    encryptor(&sk1, &n1, c1, input, &mut dst_out_ct).unwrap();
293    decryptor(&sk2, &n1, c1, &dst_out_ct, &mut dst_out_pt).unwrap();
294    assert_ne!(&dst_out_pt[..], input);
295
296    // Different nonce
297    encryptor(&sk1, &n1, c1, input, &mut dst_out_ct).unwrap();
298    decryptor(&sk1, &n2, c1, &dst_out_ct, &mut dst_out_pt).unwrap();
299    assert_ne!(&dst_out_pt[..], input);
300
301    // Different initial counter
302    encryptor(&sk1, &n1, c1, input, &mut dst_out_ct).unwrap();
303    decryptor(&sk1, &n1, c2, &dst_out_ct, &mut dst_out_pt).unwrap();
304    assert_ne!(&dst_out_pt[..], input);
305}