1use crate::constants::{NODE_ADDRESS_LENGTH, PAYLOAD_KEY_SEED_SIZE, PAYLOAD_KEY_SIZE};
2use crate::header::delays::Delay;
3use crate::payload::key::{PayloadKey, PayloadKeySeed};
4use crate::payload::Payload;
5use crate::route::{Destination, Node, NodeAddressBytes};
6use crate::version::Version;
7use crate::{header, SphinxPacket};
8use crate::{Error, ErrorKind, Result};
9use header::{SphinxHeader, HEADER_SIZE};
10use std::fmt;
11use x25519_dalek::StaticSecret;
12
13#[derive(Debug)]
15enum PayloadKeysMaterial {
16 DerivedKeys(Vec<PayloadKey>),
17 KeySeeds(Vec<PayloadKeySeed>),
18}
19
20impl PayloadKeysMaterial {
21 fn from_bytes(bytes: &[u8]) -> Result<PayloadKeysMaterial> {
22 if bytes.len() < PAYLOAD_KEY_SIZE {
26 if bytes.len() % PAYLOAD_KEY_SEED_SIZE != 0 {
28 return Err(Error::new(
29 ErrorKind::InvalidSURB,
30 "bytes of invalid length provided",
31 ));
32 }
33 let seeds_count = bytes.len() / PAYLOAD_KEY_SEED_SIZE;
34 let mut payload_key_seeds = Vec::with_capacity(seeds_count);
35 for i in 0..seeds_count {
36 let mut payload_key = [0u8; PAYLOAD_KEY_SEED_SIZE];
37 payload_key.copy_from_slice(
38 &bytes[i * PAYLOAD_KEY_SEED_SIZE..(i + 1) * PAYLOAD_KEY_SEED_SIZE],
39 );
40 payload_key_seeds.push(payload_key);
41 }
42 Ok(PayloadKeysMaterial::KeySeeds(payload_key_seeds))
43 } else {
44 if bytes.len() % PAYLOAD_KEY_SIZE != 0 {
46 return Err(Error::new(
47 ErrorKind::InvalidSURB,
48 "bytes of invalid length provided",
49 ));
50 }
51 let key_count = bytes.len() / PAYLOAD_KEY_SIZE;
52 let mut payload_keys = Vec::with_capacity(key_count);
53 for i in 0..key_count {
54 let mut payload_key = [0u8; PAYLOAD_KEY_SIZE];
55 payload_key
56 .copy_from_slice(&bytes[i * PAYLOAD_KEY_SIZE..(i + 1) * PAYLOAD_KEY_SIZE]);
57 payload_keys.push(payload_key);
58 }
59 Ok(PayloadKeysMaterial::DerivedKeys(payload_keys))
60 }
61 }
62}
63
64#[allow(non_snake_case)]
68pub struct SURB {
69 SURB_header: header::SphinxHeader,
70 first_hop_address: NodeAddressBytes,
71 payload_keys_material: PayloadKeysMaterial,
72}
73
74impl fmt::Debug for SURB {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_struct("SURB")
77 .field("SURB_header", &self.SURB_header)
78 .field("first_hop_address", &self.first_hop_address)
79 .field("payload_keys_material", &self.payload_keys_material)
80 .finish()
81 }
82}
83
84pub struct SURBMaterial {
85 surb_route: Vec<Node>,
86 surb_delays: Vec<Delay>,
87 surb_destination: Destination,
88 version: Version,
89}
90
91impl SURBMaterial {
92 pub fn new(route: Vec<Node>, delays: Vec<Delay>, destination: Destination) -> Self {
93 SURBMaterial {
94 surb_route: route,
95 surb_delays: delays,
96 surb_destination: destination,
97 version: Default::default(),
98 }
99 }
100
101 #[allow(non_snake_case)]
102 pub fn construct_SURB(self) -> Result<SURB> {
103 let surb_initial_secret = StaticSecret::random();
104 SURB::new(surb_initial_secret, self)
105 }
106
107 #[must_use]
108 pub fn with_version(mut self, version: Version) -> Self {
109 self.version = version;
110 self
111 }
112}
113
114#[allow(non_snake_case)]
115impl SURB {
116 pub fn new(surb_initial_secret: StaticSecret, surb_material: SURBMaterial) -> Result<Self> {
117 let surb_route = surb_material.surb_route;
118 let surb_delays = surb_material.surb_delays;
119 let surb_destination = surb_material.surb_destination;
120
121 let Some(first_hop) = surb_route.first() else {
125 return Err(Error::new(
126 ErrorKind::InvalidSURB,
127 "tried to create SURB for an empty route",
128 ));
129 };
130
131 if surb_route.len() != surb_delays.len() {
132 return Err(Error::new(ErrorKind::InvalidSURB, format!("creating SURB for contradictory data: route has len {} while there are {} delays generated", surb_route.len(), surb_delays.len())));
133 }
134
135 #[allow(deprecated)]
136 let built_header = header::SphinxHeader::new_versioned(
137 &surb_initial_secret,
138 &surb_route,
139 &surb_delays,
140 &surb_destination,
141 surb_material.version,
142 );
143
144 if surb_material.version.expects_legacy_full_payload_keys() {
145 Ok(SURB {
146 first_hop_address: first_hop.address,
147 payload_keys_material: PayloadKeysMaterial::DerivedKeys(
148 built_header.legacy_full_payload_keys(),
149 ),
150 SURB_header: built_header.into_header(),
151 })
152 } else {
153 Ok(SURB {
154 first_hop_address: first_hop.address,
155 payload_keys_material: PayloadKeysMaterial::KeySeeds(
156 built_header.payload_key_seeds(),
157 ),
158 SURB_header: built_header.into_header(),
159 })
160 }
161 }
162
163 pub fn use_surb(
167 self,
168 plaintext_message: &[u8],
169 payload_size: usize,
170 ) -> Result<(SphinxPacket, NodeAddressBytes)> {
171 let header = self.SURB_header;
172
173 let payload = match self.payload_keys_material {
176 PayloadKeysMaterial::DerivedKeys(keys) => {
177 Payload::encapsulate_message(plaintext_message, keys.as_slice(), payload_size)?
178 }
179 PayloadKeysMaterial::KeySeeds(seeds) => {
180 Payload::encapsulate_message(plaintext_message, &seeds, payload_size)?
181 }
182 };
183
184 Ok((SphinxPacket { header, payload }, self.first_hop_address))
185 }
186
187 pub fn to_bytes(&self) -> Vec<u8> {
188 let initial_bytes = self
189 .SURB_header
190 .to_bytes()
191 .into_iter()
192 .chain(self.first_hop_address.to_bytes());
193
194 match &self.payload_keys_material {
195 PayloadKeysMaterial::DerivedKeys(keys) => initial_bytes
196 .chain(keys.iter().flat_map(|k| k.iter().copied()))
197 .collect(),
198 PayloadKeysMaterial::KeySeeds(seeds) => initial_bytes
199 .chain(seeds.iter().flat_map(|s| s.iter().copied()))
200 .collect(),
201 }
202 }
203
204 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
205 if bytes.len() < HEADER_SIZE + NODE_ADDRESS_LENGTH + PAYLOAD_KEY_SEED_SIZE {
207 return Err(Error::new(
208 ErrorKind::InvalidSURB,
209 "not enough bytes provided to try to recover a SURB",
210 ));
211 }
212
213 let header_bytes = &bytes[..HEADER_SIZE];
214 let first_hop_bytes = &bytes[HEADER_SIZE..HEADER_SIZE + NODE_ADDRESS_LENGTH];
215 let payload_keys_material_bytes = &bytes[HEADER_SIZE + NODE_ADDRESS_LENGTH..];
216
217 let SURB_header = SphinxHeader::from_bytes(header_bytes)?;
218 let first_hop_address = NodeAddressBytes::try_from_byte_slice(first_hop_bytes)?;
219 let payload_keys_material = PayloadKeysMaterial::from_bytes(payload_keys_material_bytes)?;
220
221 Ok(SURB {
222 SURB_header,
223 first_hop_address,
224 payload_keys_material,
225 })
226 }
227
228 pub fn first_hop(&self) -> NodeAddressBytes {
229 self.first_hop_address
230 }
231
232 pub fn materials_count(&self) -> usize {
233 match &self.payload_keys_material {
234 PayloadKeysMaterial::DerivedKeys(keys) => keys.len(),
235 PayloadKeysMaterial::KeySeeds(seeds) => seeds.len(),
236 }
237 }
238
239 pub fn uses_key_seeds(&self) -> bool {
240 matches!(self.payload_keys_material, PayloadKeysMaterial::KeySeeds(_))
241 }
242}
243
244#[cfg(test)]
245mod prepare_and_use_process_surb {
246 use super::*;
247 use crate::constants::NODE_ADDRESS_LENGTH;
248 use crate::header::{delays, HEADER_SIZE};
249 use crate::version::{PAYLOAD_KEYS_SEEDS_VERSION, X25519_WITH_EXPLICIT_PAYLOAD_KEYS_VERSION};
250 use crate::{
251 packet::builder::DEFAULT_PAYLOAD_SIZE,
252 test_utils::fixtures::{destination_fixture, keygen},
253 };
254 use std::time::Duration;
255
256 fn surb_material_fixture() -> SURBMaterial {
257 let (_, node1_pk) = keygen();
258 let node1 = Node {
259 address: NodeAddressBytes::from_bytes([5u8; NODE_ADDRESS_LENGTH]),
260 pub_key: node1_pk,
261 };
262 let (_, node2_pk) = keygen();
263 let node2 = Node {
264 address: NodeAddressBytes::from_bytes([4u8; NODE_ADDRESS_LENGTH]),
265 pub_key: node2_pk,
266 };
267 let (_, node3_pk) = keygen();
268 let node3 = Node {
269 address: NodeAddressBytes::from_bytes([2u8; NODE_ADDRESS_LENGTH]),
270 pub_key: node3_pk,
271 };
272
273 let surb_route = vec![node1, node2, node3];
274 let surb_destination = destination_fixture();
275 let surb_delays =
276 delays::generate_from_average_duration(surb_route.len(), Duration::from_secs(3));
277
278 SURBMaterial::new(surb_route, surb_delays, surb_destination)
279 }
280
281 #[allow(non_snake_case)]
282 fn legacy_SURB_fixture() -> SURB {
283 let surb_initial_secret = StaticSecret::random();
284 let surb_material =
285 surb_material_fixture().with_version(X25519_WITH_EXPLICIT_PAYLOAD_KEYS_VERSION);
286
287 SURB::new(surb_initial_secret, surb_material).unwrap()
288 }
289
290 #[allow(non_snake_case)]
291 fn seeded_SURB_fixture() -> SURB {
292 let surb_initial_secret = StaticSecret::random();
293 let surb_material = surb_material_fixture().with_version(PAYLOAD_KEYS_SEEDS_VERSION);
294
295 SURB::new(surb_initial_secret, surb_material).unwrap()
296 }
297
298 #[test]
299 fn returns_error_if_surb_route_empty() {
300 let surb_route = Vec::new();
301 let surb_destination = destination_fixture();
302 let surb_initial_secret = StaticSecret::random();
303 let surb_delays =
304 delays::generate_from_average_duration(surb_route.len(), Duration::from_secs(3));
305 let expected = ErrorKind::InvalidSURB;
306
307 match SURB::new(
308 surb_initial_secret,
309 SURBMaterial::new(surb_route, surb_delays, surb_destination),
310 ) {
311 Err(err) => assert_eq!(expected, err.kind()),
312 _ => panic!("Should have returned an error when route empty"),
313 };
314 }
315
316 #[test]
317 fn surb_header_has_correct_length() {
318 let pre_surb = legacy_SURB_fixture();
319 assert_eq!(pre_surb.SURB_header.to_bytes().len(), HEADER_SIZE);
320 }
321
322 #[test]
323 fn to_bytes_returns_correct_value() {
324 let pre_surb = legacy_SURB_fixture();
325 let PayloadKeysMaterial::DerivedKeys(keys) = &pre_surb.payload_keys_material else {
326 unreachable!()
327 };
328
329 let pre_surb_bytes = pre_surb.to_bytes();
330 let expected = [
331 pre_surb.SURB_header.to_bytes(),
332 [5u8; NODE_ADDRESS_LENGTH].to_vec(),
333 keys[0].to_vec(),
334 keys[1].to_vec(),
335 keys[2].to_vec(),
336 ]
337 .concat();
338 assert_eq!(pre_surb_bytes, expected);
339
340 let pre_surb = seeded_SURB_fixture();
341 let PayloadKeysMaterial::KeySeeds(seeds) = &pre_surb.payload_keys_material else {
342 unreachable!()
343 };
344
345 let pre_surb_bytes = pre_surb.to_bytes();
346 let expected = [
347 pre_surb.SURB_header.to_bytes(),
348 [5u8; NODE_ADDRESS_LENGTH].to_vec(),
349 seeds[0].to_vec(),
350 seeds[1].to_vec(),
351 seeds[2].to_vec(),
352 ]
353 .concat();
354 assert_eq!(pre_surb_bytes, expected);
355 }
356
357 #[test]
358 fn returns_error_is_payload_too_large() {
359 let pre_surb = legacy_SURB_fixture();
360 let plaintext_message = vec![42u8; 5000];
361 let expected = ErrorKind::InvalidPayload;
362
363 match SURB::use_surb(pre_surb, &plaintext_message, DEFAULT_PAYLOAD_SIZE) {
364 Err(err) => assert_eq!(expected, err.kind()),
365 _ => panic!("Should have returned an error when payload bytes too long"),
366 };
367 }
368
369 #[test]
370 #[allow(non_snake_case)]
371 fn can_be_converted_to_and_from_bytes_with_legacy_keys() {
372 let dummy_SURB = legacy_SURB_fixture();
373 let bytes = dummy_SURB.to_bytes();
374 let recovered_SURB = SURB::from_bytes(&bytes).unwrap();
375
376 assert_eq!(
377 dummy_SURB.first_hop_address,
378 recovered_SURB.first_hop_address
379 );
380
381 let PayloadKeysMaterial::DerivedKeys(original_keys) = &dummy_SURB.payload_keys_material
382 else {
383 unreachable!()
384 };
385
386 let PayloadKeysMaterial::DerivedKeys(recovered_keys) =
387 &recovered_SURB.payload_keys_material
388 else {
389 unreachable!()
390 };
391
392 for i in 0..original_keys.len() {
393 assert_eq!(original_keys[i], recovered_keys[i])
394 }
395
396 assert_eq!(
398 dummy_SURB.SURB_header.to_bytes(),
399 dummy_SURB.SURB_header.to_bytes()
400 );
401 }
402
403 #[test]
404 #[allow(non_snake_case)]
405 fn can_be_converted_to_and_from_bytes_with_key_seeds() {
406 let dummy_SURB = seeded_SURB_fixture();
407 let bytes = dummy_SURB.to_bytes();
408 let recovered_SURB = SURB::from_bytes(&bytes).unwrap();
409
410 assert_eq!(
411 dummy_SURB.first_hop_address,
412 recovered_SURB.first_hop_address
413 );
414
415 let PayloadKeysMaterial::KeySeeds(original_seeds) = &dummy_SURB.payload_keys_material
416 else {
417 unreachable!()
418 };
419
420 let PayloadKeysMaterial::KeySeeds(recovered_seeds) = &recovered_SURB.payload_keys_material
421 else {
422 unreachable!()
423 };
424
425 for i in 0..original_seeds.len() {
426 assert_eq!(original_seeds[i], recovered_seeds[i])
427 }
428
429 assert_eq!(
431 dummy_SURB.SURB_header.to_bytes(),
432 dummy_SURB.SURB_header.to_bytes()
433 );
434 }
435}