1#![allow(non_snake_case)]
24
25#[cfg(feature = "safe_api")]
26use crate::errors::UnknownCryptoError;
27
28#[cfg(test)]
29#[cfg(feature = "safe_api")]
30pub trait TestingRandom {
31 fn gen() -> Self;
33}
34
35#[cfg(feature = "safe_api")]
36pub fn StreamCipherTestRunner<Encryptor, Decryptor, Key, Nonce>(
38 encryptor: Encryptor,
39 decryptor: Decryptor,
40 key: Key,
41 nonce: Nonce,
42 counter: u32,
43 input: &[u8],
44 expected_ct: Option<&[u8]>,
45) where
46 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
47 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
48{
49 if !input.is_empty() {
50 encrypt_decrypt_out_length(&encryptor, &decryptor, &key, &nonce, input);
51 encrypt_decrypt_equals_expected(
52 &encryptor,
53 &decryptor,
54 &key,
55 &nonce,
56 counter,
57 input,
58 expected_ct,
59 );
60 }
61
62 encrypt_decrypt_input_empty(&encryptor, &decryptor, &key, &nonce);
63 initial_counter_overflow_err(&encryptor, &decryptor, &key, &nonce);
64 initial_counter_max_ok(&encryptor, &decryptor, &key, &nonce);
65}
66
67#[cfg(feature = "safe_api")]
68fn counter_increase_times(a: f32) -> u32 {
72 if a <= 64f32 {
74 return 0;
75 }
76
77 let check_with_floor = (a / 64f32).floor();
78 let actual = a / 64f32;
79
80 assert!(actual >= check_with_floor);
81 if actual > check_with_floor {
84 (actual.ceil() as u32) - 1
85 } else {
86 (actual as u32) - 1
87 }
88}
89
90#[cfg(feature = "safe_api")]
91fn return_if_counter_will_overflow<Encryptor, Decryptor, Key, Nonce>(
92 encryptor: &Encryptor,
93 decryptor: &Decryptor,
94 key: &Key,
95 nonce: &Nonce,
96 counter: u32,
97 input: &[u8],
98) -> bool
99where
100 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
101 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
102{
103 assert!(!input.is_empty());
104 let mut dst_out = vec![0u8; input.len()];
105
106 let enc_res = encryptor(key, nonce, counter, &[0u8; 0], &mut dst_out).is_err();
108 let dec_res = decryptor(key, nonce, counter, &[0u8; 0], &mut dst_out).is_err();
109
110 enc_res && dec_res
111}
112
113#[cfg(feature = "safe_api")]
114fn encrypt_decrypt_input_empty<Encryptor, Decryptor, Key, Nonce>(
115 encryptor: &Encryptor,
116 decryptor: &Decryptor,
117 key: &Key,
118 nonce: &Nonce,
119) where
120 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
121 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
122{
123 let mut dst_out = [0u8; 64];
124 assert!(encryptor(key, nonce, 0, &[0u8; 0], &mut dst_out).is_err());
125 assert!(decryptor(key, nonce, 0, &[0u8; 0], &mut dst_out).is_err());
126}
127
128#[cfg(feature = "safe_api")]
129fn encrypt_decrypt_out_length<Encryptor, Decryptor, Key, Nonce>(
130 encryptor: &Encryptor,
131 decryptor: &Decryptor,
132 key: &Key,
133 nonce: &Nonce,
134 input: &[u8],
135) where
136 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
137 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
138{
139 assert!(!input.is_empty());
140
141 let mut dst_out_empty = vec![0u8; 0];
142 assert!(encryptor(key, nonce, 0, input, &mut dst_out_empty).is_err());
143 assert!(decryptor(key, nonce, 0, input, &mut dst_out_empty).is_err());
144
145 let mut dst_out_less = vec![0u8; input.len() - 1];
146 assert!(encryptor(key, nonce, 0, input, &mut dst_out_less).is_err());
147 assert!(decryptor(key, nonce, 0, input, &mut dst_out_less).is_err());
148
149 let mut dst_out_exact = vec![0u8; input.len()];
150 assert!(encryptor(key, nonce, 0, input, &mut dst_out_exact).is_ok());
151 assert!(decryptor(key, nonce, 0, input, &mut dst_out_exact).is_ok());
152
153 let mut dst_out_greater = vec![0u8; input.len() + 1];
154 assert!(encryptor(key, nonce, 0, input, &mut dst_out_greater).is_ok());
155 assert!(decryptor(key, nonce, 0, input, &mut dst_out_greater).is_ok());
156}
157
158#[cfg(feature = "safe_api")]
159fn encrypt_decrypt_equals_expected<Encryptor, Decryptor, Key, Nonce>(
161 encryptor: &Encryptor,
162 decryptor: &Decryptor,
163 key: &Key,
164 nonce: &Nonce,
165 counter: u32,
166 input: &[u8],
167 expected_ct: Option<&[u8]>,
168) where
169 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
170 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
171{
172 assert!(!input.is_empty());
173
174 if counter_increase_times(input.len() as f32)
177 .checked_add(counter)
178 .is_none()
179 {
180 assert!(return_if_counter_will_overflow(
181 encryptor, decryptor, key, nonce, counter, input
182 ));
183
184 return;
185 }
186
187 let mut dst_out_ct = vec![0u8; input.len()];
188 encryptor(key, nonce, counter, input, &mut dst_out_ct).unwrap();
189 if let Some(expected_result) = expected_ct {
190 assert_eq!(expected_result, &dst_out_ct[..]);
191 }
192
193 let mut dst_out_pt = vec![0u8; input.len()];
194 decryptor(key, nonce, counter, &dst_out_ct, &mut dst_out_pt).unwrap();
195 assert_eq!(input, &dst_out_pt[..]);
196 if let Some(expected_result) = expected_ct {
197 decryptor(key, nonce, counter, expected_result, &mut dst_out_pt).unwrap();
198 assert_eq!(input, &dst_out_pt[..]);
199 }
200}
201
202#[cfg(feature = "safe_api")]
203fn initial_counter_overflow_err<Encryptor, Decryptor, Key, Nonce>(
205 encryptor: &Encryptor,
206 decryptor: &Decryptor,
207 key: &Key,
208 nonce: &Nonce,
209) where
210 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
211 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
212{
213 let mut dst_out = [0u8; 128];
214 assert!(encryptor(
215 key,
216 nonce,
217 u32::MAX,
218 &[0u8; 65], &mut dst_out
220 )
221 .is_err());
222 assert!(decryptor(
223 key,
224 nonce,
225 u32::MAX,
226 &[0u8; 65], &mut dst_out
228 )
229 .is_err());
230}
231
232#[cfg(feature = "safe_api")]
233fn initial_counter_max_ok<Encryptor, Decryptor, Key, Nonce>(
235 encryptor: &Encryptor,
236 decryptor: &Decryptor,
237 key: &Key,
238 nonce: &Nonce,
239) where
240 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
241 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
242{
243 let mut dst_out = [0u8; 64];
244 assert!(encryptor(
245 key,
246 nonce,
247 u32::MAX,
248 &[0u8; 64], &mut dst_out
250 )
251 .is_ok());
252 assert!(decryptor(
253 key,
254 nonce,
255 u32::MAX,
256 &[0u8; 64], &mut dst_out
258 )
259 .is_ok());
260}
261
262#[cfg(test)]
263#[cfg(feature = "safe_api")]
264pub fn test_diff_params_diff_output<Encryptor, Decryptor, Key, Nonce>(
267 encryptor: &Encryptor,
268 decryptor: &Decryptor,
269) where
270 Key: TestingRandom + PartialEq<Key>,
271 Nonce: TestingRandom + PartialEq<Nonce>,
272 Encryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
273 Decryptor: Fn(&Key, &Nonce, u32, &[u8], &mut [u8]) -> Result<(), UnknownCryptoError>,
274{
275 let input = &[0u8; 16];
276
277 let sk1 = Key::gen();
278 let sk2 = Key::gen();
279 assert!(sk1 != sk2);
280
281 let n1 = Nonce::gen();
282 let n2 = Nonce::gen();
283 assert!(n1 != n2);
284
285 let c1 = 0u32;
286 let c2 = 1u32;
287
288 let mut dst_out_ct = vec![0u8; input.len()];
289 let mut dst_out_pt = vec![0u8; input.len()];
290
291 encryptor(&sk1, &n1, c1, input, &mut dst_out_ct).unwrap();
293 decryptor(&sk2, &n1, c1, &dst_out_ct, &mut dst_out_pt).unwrap();
294 assert_ne!(&dst_out_pt[..], input);
295
296 encryptor(&sk1, &n1, c1, input, &mut dst_out_ct).unwrap();
298 decryptor(&sk1, &n2, c1, &dst_out_ct, &mut dst_out_pt).unwrap();
299 assert_ne!(&dst_out_pt[..], input);
300
301 encryptor(&sk1, &n1, c1, input, &mut dst_out_ct).unwrap();
303 decryptor(&sk1, &n1, c2, &dst_out_ct, &mut dst_out_pt).unwrap();
304 assert_ne!(&dst_out_pt[..], input);
305}