1const SIGMA: [u8; 16] = *b"expand 32-byte k";
13const TAU: [u8; 16] = *b"expand 16-byte k";
14
15#[inline]
16fn load_u32_le(bytes: &[u8]) -> u32 {
17 let mut tmp = [0u8; 4];
18 tmp.copy_from_slice(bytes);
19 u32::from_le_bytes(tmp)
20}
21
22#[inline]
23fn quarter_round(y0: &mut u32, y1: &mut u32, y2: &mut u32, y3: &mut u32) {
24 *y1 ^= y0.wrapping_add(*y3).rotate_left(7);
27 *y2 ^= y1.wrapping_add(*y0).rotate_left(9);
28 *y3 ^= y2.wrapping_add(*y1).rotate_left(13);
29 *y0 ^= y3.wrapping_add(*y2).rotate_left(18);
30}
31
32#[inline]
33fn salsa20_block(state: &[u32; 16]) -> [u8; 64] {
34 let mut x = *state;
35
36 for _ in 0..10 {
37 let (mut y0, mut y4, mut y8, mut y12) = (x[0], x[4], x[8], x[12]);
38 quarter_round(&mut y0, &mut y4, &mut y8, &mut y12);
39 (x[0], x[4], x[8], x[12]) = (y0, y4, y8, y12);
40
41 let (mut y5, mut y9, mut y13, mut y1) = (x[5], x[9], x[13], x[1]);
42 quarter_round(&mut y5, &mut y9, &mut y13, &mut y1);
43 (x[5], x[9], x[13], x[1]) = (y5, y9, y13, y1);
44
45 let (mut y10, mut y14, mut y2, mut y6) = (x[10], x[14], x[2], x[6]);
46 quarter_round(&mut y10, &mut y14, &mut y2, &mut y6);
47 (x[10], x[14], x[2], x[6]) = (y10, y14, y2, y6);
48
49 let (mut y15, mut y3, mut y7, mut y11) = (x[15], x[3], x[7], x[11]);
50 quarter_round(&mut y15, &mut y3, &mut y7, &mut y11);
51 (x[15], x[3], x[7], x[11]) = (y15, y3, y7, y11);
52
53 let (mut y0, mut y1, mut y2, mut y3) = (x[0], x[1], x[2], x[3]);
54 quarter_round(&mut y0, &mut y1, &mut y2, &mut y3);
55 (x[0], x[1], x[2], x[3]) = (y0, y1, y2, y3);
56
57 let (mut y5, mut y6, mut y7, mut y4) = (x[5], x[6], x[7], x[4]);
58 quarter_round(&mut y5, &mut y6, &mut y7, &mut y4);
59 (x[5], x[6], x[7], x[4]) = (y5, y6, y7, y4);
60
61 let (mut y10, mut y11, mut y8, mut y9) = (x[10], x[11], x[8], x[9]);
62 quarter_round(&mut y10, &mut y11, &mut y8, &mut y9);
63 (x[10], x[11], x[8], x[9]) = (y10, y11, y8, y9);
64
65 let (mut y15, mut y12, mut y13, mut y14) = (x[15], x[12], x[13], x[14]);
66 quarter_round(&mut y15, &mut y12, &mut y13, &mut y14);
67 (x[15], x[12], x[13], x[14]) = (y15, y12, y13, y14);
68 }
69
70 let mut out = [0u8; 64];
71 for i in 0..16 {
72 out[4 * i..4 * i + 4].copy_from_slice(&x[i].wrapping_add(state[i]).to_le_bytes());
73 }
74 out
75}
76
77#[inline]
78fn key_setup(key: &[u8], nonce: [u8; 8], counter: u64) -> [u32; 16] {
79 assert!(
80 key.len() == 16 || key.len() == 32,
81 "Salsa20 key length must be 16 or 32 bytes, got {}",
82 key.len()
83 );
84
85 let constants = if key.len() == 32 { &SIGMA } else { &TAU };
88 let k0 = &key[..16];
89 let k1 = if key.len() == 32 {
90 &key[16..32]
91 } else {
92 &key[..16]
93 };
94
95 let counter_bytes = counter.to_le_bytes();
96 let counter_low = u32::from_le_bytes([
97 counter_bytes[0],
98 counter_bytes[1],
99 counter_bytes[2],
100 counter_bytes[3],
101 ]);
102 let counter_high = u32::from_le_bytes([
103 counter_bytes[4],
104 counter_bytes[5],
105 counter_bytes[6],
106 counter_bytes[7],
107 ]);
108
109 [
110 load_u32_le(&constants[0..4]),
111 load_u32_le(&k0[0..4]),
112 load_u32_le(&k0[4..8]),
113 load_u32_le(&k0[8..12]),
114 load_u32_le(&k0[12..16]),
115 load_u32_le(&constants[4..8]),
116 load_u32_le(&nonce[0..4]),
117 load_u32_le(&nonce[4..8]),
118 counter_low,
119 counter_high,
120 load_u32_le(&constants[8..12]),
121 load_u32_le(&k1[0..4]),
122 load_u32_le(&k1[4..8]),
123 load_u32_le(&k1[8..12]),
124 load_u32_le(&k1[12..16]),
125 load_u32_le(&constants[12..16]),
126 ]
127}
128
129pub struct Salsa20 {
135 state: [u32; 16],
136 block: [u8; 64],
137 offset: usize,
138}
139
140impl Salsa20 {
141 #[must_use]
143 pub fn new(key: &[u8; 32], nonce: &[u8; 8]) -> Self {
144 Self::with_key_bytes(key, nonce)
145 }
146
147 #[must_use]
149 pub fn with_key_bytes(key: &[u8], nonce: &[u8; 8]) -> Self {
150 Self::with_counter(key, nonce, 0)
151 }
152
153 #[must_use]
155 pub fn with_counter(key: &[u8], nonce: &[u8; 8], counter: u64) -> Self {
156 Self {
157 state: key_setup(key, *nonce, counter),
158 block: [0u8; 64],
159 offset: 64,
160 }
161 }
162
163 pub fn new_wiping(key: &mut [u8; 32], nonce: &mut [u8; 8]) -> Self {
165 let out = Self::new(key, nonce);
166 crate::ct::zeroize_slice(key.as_mut_slice());
167 crate::ct::zeroize_slice(nonce.as_mut_slice());
168 out
169 }
170
171 pub fn with_key_bytes_wiping(key: &mut [u8], nonce: &mut [u8; 8]) -> Self {
173 let out = Self::with_key_bytes(key, nonce);
174 crate::ct::zeroize_slice(key);
175 crate::ct::zeroize_slice(nonce.as_mut_slice());
176 out
177 }
178
179 #[inline]
180 fn refill(&mut self) {
181 self.block = salsa20_block(&self.state);
182 self.offset = 0;
183 self.state[8] = self.state[8].wrapping_add(1);
184 if self.state[8] == 0 {
185 self.state[9] = self.state[9].wrapping_add(1);
186 }
187 }
188
189 pub fn apply_keystream(&mut self, buf: &mut [u8]) {
191 let mut done = 0usize;
192 while done < buf.len() {
193 if self.offset == 64 {
194 self.refill();
195 }
196 let take = core::cmp::min(64 - self.offset, buf.len() - done);
197 for i in 0..take {
198 buf[done + i] ^= self.block[self.offset + i];
199 }
200 self.offset += take;
201 done += take;
202 }
203 }
204
205 pub fn fill(&mut self, buf: &mut [u8]) {
207 self.apply_keystream(buf);
208 }
209
210 pub fn keystream_block(&mut self) -> [u8; 64] {
212 let mut out = [0u8; 64];
213 self.apply_keystream(&mut out);
214 out
215 }
216
217 pub fn set_counter(&mut self, counter: u64) {
219 let counter_bytes = counter.to_le_bytes();
220 self.state[8] = u32::from_le_bytes([
221 counter_bytes[0],
222 counter_bytes[1],
223 counter_bytes[2],
224 counter_bytes[3],
225 ]);
226 self.state[9] = u32::from_le_bytes([
227 counter_bytes[4],
228 counter_bytes[5],
229 counter_bytes[6],
230 counter_bytes[7],
231 ]);
232 crate::ct::zeroize_slice(self.block.as_mut_slice());
233 self.offset = 64;
234 }
235}
236
237impl Drop for Salsa20 {
238 fn drop(&mut self) {
239 crate::ct::zeroize_slice(self.state.as_mut_slice());
240 crate::ct::zeroize_slice(self.block.as_mut_slice());
241 self.offset = 0;
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 fn hex(bytes: &[u8]) -> String {
250 let mut out = String::with_capacity(bytes.len() * 2);
251 for b in bytes {
252 use core::fmt::Write;
253 let _ = write!(&mut out, "{b:02x}");
254 }
255 out
256 }
257
258 #[test]
259 fn salsa20_128bit_estream_vector0_first_block() {
260 let mut key = [0u8; 16];
261 key[0] = 0x80;
262 let nonce = [0u8; 8];
263 let mut s = Salsa20::with_key_bytes(&key, &nonce);
264 let block = s.keystream_block();
265 assert_eq!(
266 hex(&block),
267 "4dfa5e481da23ea09a31022050859936".to_owned()
268 + "da52fcee218005164f267cb65f5cfd7f"
269 + "2b4f97e0ff16924a52df269515110a07"
270 + "f9e460bc65ef95da58f740b7d1dbb0aa"
271 );
272 }
273
274 #[test]
275 fn salsa20_256bit_estream_vector0_first_block() {
276 let mut key = [0u8; 32];
277 key[0] = 0x80;
278 let nonce = [0u8; 8];
279 let mut s = Salsa20::new(&key, &nonce);
280 let block = s.keystream_block();
281 assert_eq!(
282 hex(&block),
283 "e3be8fdd8beca2e3ea8ef9475b29a6e7".to_owned()
284 + "003951e1097a5c38d23b7a5fad9f6844"
285 + "b22c97559e2723c7cbbd3fe4fc8d9a07"
286 + "44652a83e72a9c461876af4d7ef1a117"
287 );
288 }
289
290 #[test]
291 fn salsa20_roundtrip_xor() {
292 let key = [0x42u8; 32];
293 let nonce = [0x24u8; 8];
294 let msg = *b"the same function encrypts and decrypts with xor.....";
295
296 let mut enc = Salsa20::new(&key, &nonce);
297 let mut ct = msg;
298 enc.apply_keystream(&mut ct);
299
300 let mut dec = Salsa20::new(&key, &nonce);
301 dec.apply_keystream(&mut ct);
302
303 assert_eq!(ct, msg);
304 }
305
306 #[test]
307 fn salsa20_chunked_stream_matches_one_shot() {
308 let key = [0x11u8; 32];
309 let nonce = [0x22u8; 8];
310
311 let mut one = Salsa20::new(&key, &nonce);
312 let mut full = [0u8; 96];
313 one.fill(&mut full);
314
315 let mut two = Salsa20::new(&key, &nonce);
316 let mut split = [0u8; 96];
317 two.fill(&mut split[..17]);
318 two.fill(&mut split[17..81]);
319 two.fill(&mut split[81..]);
320
321 assert_eq!(full, split);
322 }
323}