1use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
10
11use crate::{
12 ieee80211::{FrameLayout, WifiFrame},
13 wfb::{parse_forwarder_packet, WfbPacket, CRYPTO_BOX_NONCE_LEN},
14 ChannelId,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
19pub struct DiversitySourceId(u16);
20
21impl DiversitySourceId {
22 pub const fn new(value: u16) -> Self {
24 Self(value)
25 }
26
27 pub const fn get(self) -> u16 {
29 self.0
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum DiversityDecision {
36 Accept,
38 Duplicate,
40 Passthrough,
42}
43
44impl DiversityDecision {
45 pub const fn should_forward(self) -> bool {
47 !matches!(self, Self::Duplicate)
48 }
49}
50
51#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
53pub struct DiversitySourceStats {
54 pub observed: u64,
56 pub accepted: u64,
58 pub duplicates: u64,
60 pub passthrough: u64,
62}
63
64#[derive(Debug, Clone, Default, PartialEq, Eq)]
66pub struct DiversityStats {
67 pub accepted: u64,
69 pub duplicates: u64,
71 pub passthrough: u64,
73 pub cached_packets: usize,
75 pub sources: BTreeMap<DiversitySourceId, DiversitySourceStats>,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
80enum PacketIdentity {
81 Data {
82 channel_id: ChannelId,
83 session_generation: u64,
84 data_nonce: u64,
85 },
86 Session {
87 channel_id: ChannelId,
88 nonce: [u8; CRYPTO_BOX_NONCE_LEN],
89 },
90}
91
92#[derive(Debug, Clone)]
94pub struct DiversityCombiner {
95 capacity: usize,
96 seen: HashSet<PacketIdentity>,
97 insertion_order: VecDeque<PacketIdentity>,
98 session_generations: HashMap<ChannelId, u64>,
99 current_sessions: HashMap<ChannelId, [u8; CRYPTO_BOX_NONCE_LEN]>,
100 session_order: VecDeque<ChannelId>,
101 stats: DiversityStats,
102}
103
104impl Default for DiversityCombiner {
105 fn default() -> Self {
106 Self::new(8_192)
107 }
108}
109
110impl DiversityCombiner {
111 pub fn new(capacity: usize) -> Self {
113 Self {
114 capacity: capacity.max(1),
115 seen: HashSet::with_capacity(capacity.min(8_192)),
116 insertion_order: VecDeque::with_capacity(capacity.min(8_192)),
117 session_generations: HashMap::new(),
118 current_sessions: HashMap::new(),
119 session_order: VecDeque::new(),
120 stats: DiversityStats::default(),
121 }
122 }
123
124 pub fn observe_frame(
130 &mut self,
131 source: DiversitySourceId,
132 frame: &[u8],
133 layout: FrameLayout,
134 ) -> DiversityDecision {
135 self.source_mut(source).observed += 1;
136 let Some(identity) = self.packet_identity(frame, layout) else {
137 self.stats.passthrough += 1;
138 self.source_mut(source).passthrough += 1;
139 return DiversityDecision::Passthrough;
140 };
141
142 if self.seen.contains(&identity) {
143 self.stats.duplicates += 1;
144 self.source_mut(source).duplicates += 1;
145 return DiversityDecision::Duplicate;
146 }
147
148 if let PacketIdentity::Session { channel_id, nonce } = identity {
149 let changed = self
150 .current_sessions
151 .get(&channel_id)
152 .map(|current| current != &nonce)
153 .unwrap_or(true);
154 if changed {
155 if !self.current_sessions.contains_key(&channel_id) {
156 self.remember_session_channel(channel_id);
157 }
158 let generation = self.session_generations.entry(channel_id).or_default();
159 *generation = generation.wrapping_add(1);
160 self.current_sessions.insert(channel_id, nonce);
161 }
162 }
163
164 self.remember(identity);
165 self.stats.accepted += 1;
166 self.source_mut(source).accepted += 1;
167 DiversityDecision::Accept
168 }
169
170 pub fn stats(&self) -> DiversityStats {
172 let mut stats = self.stats.clone();
173 stats.cached_packets = self.seen.len();
174 stats
175 }
176
177 pub fn reset(&mut self) {
179 self.seen.clear();
180 self.insertion_order.clear();
181 self.session_generations.clear();
182 self.current_sessions.clear();
183 self.session_order.clear();
184 self.stats = DiversityStats::default();
185 }
186
187 fn packet_identity(&self, frame: &[u8], layout: FrameLayout) -> Option<PacketIdentity> {
188 let frame = WifiFrame::parse(frame, layout).ok()?;
189 let channel_id = frame.channel_id()?;
190 match parse_forwarder_packet(frame.payload()).ok()? {
191 WfbPacket::Data { data_nonce, .. } => Some(PacketIdentity::Data {
192 channel_id,
193 session_generation: self
194 .session_generations
195 .get(&channel_id)
196 .copied()
197 .unwrap_or(0),
198 data_nonce,
199 }),
200 WfbPacket::SessionKey { session_nonce, .. } => {
201 let nonce = session_nonce.try_into().ok()?;
202 Some(PacketIdentity::Session { channel_id, nonce })
203 }
204 }
205 }
206
207 fn remember(&mut self, identity: PacketIdentity) {
208 if self.seen.insert(identity) {
209 self.insertion_order.push_back(identity);
210 }
211 while self.insertion_order.len() > self.capacity {
212 if let Some(expired) = self.insertion_order.pop_front() {
213 self.seen.remove(&expired);
214 }
215 }
216 }
217
218 fn source_mut(&mut self, source: DiversitySourceId) -> &mut DiversitySourceStats {
219 self.stats.sources.entry(source).or_default()
220 }
221
222 fn remember_session_channel(&mut self, channel_id: ChannelId) {
223 while self.session_order.len() >= self.capacity {
224 if let Some(expired) = self.session_order.pop_front() {
225 self.current_sessions.remove(&expired);
226 self.session_generations.remove(&expired);
227 }
228 }
229 self.session_order.push_back(channel_id);
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::{
237 fec::FecCode,
238 ieee80211::build_wfb_header,
239 wfb::{PlainAssembler, CHACHA20_POLY1305_TAG_LEN, MAX_FEC_PAYLOAD, WPACKET_HDR_LEN},
240 PayloadPipeline, PayloadPipelineEvent, WfbKeypair, WfbTransmitter, WfbTxKeypair,
241 };
242 use crypto_box::SecretKey;
243
244 fn data_frame(channel: ChannelId, data_nonce: u64) -> Vec<u8> {
245 let mut frame = Vec::from(build_wfb_header(channel, [0, 0]));
246 frame.push(1);
247 frame.extend_from_slice(&data_nonce.to_be_bytes());
248 frame.resize(frame.len() + WPACKET_HDR_LEN + CHACHA20_POLY1305_TAG_LEN, 0);
249 frame.extend_from_slice(&[0; 4]);
250 frame
251 }
252
253 fn session_frame(channel: ChannelId, marker: u8) -> Vec<u8> {
254 let mut frame = Vec::from(build_wfb_header(channel, [0, 0]));
255 frame.push(2);
256 frame.extend_from_slice(&[marker; CRYPTO_BOX_NONCE_LEN]);
257 frame.resize(frame.len() + crate::wfb::WSESSION_DATA_LEN + 16, 0);
258 frame.extend_from_slice(&[0; 4]);
259 frame
260 }
261
262 fn plain(payload: &[u8]) -> Vec<u8> {
263 let mut out = vec![0];
264 out.extend_from_slice(&(payload.len() as u16).to_be_bytes());
265 out.extend_from_slice(payload);
266 out.resize(MAX_FEC_PAYLOAD, 0);
267 out
268 }
269
270 fn linked_keypairs() -> (WfbTxKeypair, WfbKeypair) {
271 let transmitter = SecretKey::from([3u8; 32]);
272 let receiver = SecretKey::from([9u8; 32]);
273 (
274 WfbTxKeypair {
275 tx_secretkey: transmitter.to_bytes(),
276 rx_publickey: receiver.public_key().to_bytes(),
277 },
278 WfbKeypair {
279 rx_secretkey: receiver.to_bytes(),
280 tx_publickey: transmitter.public_key().to_bytes(),
281 },
282 )
283 }
284
285 fn wrap_forwarder_packet(channel: ChannelId, packet: &[u8]) -> Vec<u8> {
286 let mut frame = Vec::from(build_wfb_header(channel, [0, 0]));
287 frame.extend_from_slice(packet);
288 frame.extend_from_slice(&[0; 4]);
289 frame
290 }
291
292 #[test]
293 fn first_valid_radio_wins_without_delaying_the_packet() {
294 let mut combiner = DiversityCombiner::default();
295 let frame = data_frame(ChannelId::default_video(), 42);
296
297 assert_eq!(
298 combiner.observe_frame(DiversitySourceId::new(1), &frame, FrameLayout::WithFcs),
299 DiversityDecision::Accept
300 );
301 assert_eq!(
302 combiner.observe_frame(DiversitySourceId::new(2), &frame, FrameLayout::WithFcs),
303 DiversityDecision::Duplicate
304 );
305 let stats = combiner.stats();
306 assert_eq!(stats.accepted, 1);
307 assert_eq!(stats.duplicates, 1);
308 assert_eq!(stats.sources[&DiversitySourceId::new(1)].accepted, 1);
309 assert_eq!(stats.sources[&DiversitySourceId::new(2)].duplicates, 1);
310 }
311
312 #[test]
313 fn a_new_session_can_reuse_data_nonces() {
314 let mut combiner = DiversityCombiner::default();
315 let channel = ChannelId::default_video();
316 let data = data_frame(channel, 7);
317
318 assert_eq!(
319 combiner.observe_frame(
320 DiversitySourceId::new(0),
321 &session_frame(channel, 1),
322 FrameLayout::WithFcs,
323 ),
324 DiversityDecision::Accept
325 );
326 assert_eq!(
327 combiner.observe_frame(DiversitySourceId::new(0), &data, FrameLayout::WithFcs),
328 DiversityDecision::Accept
329 );
330 assert_eq!(
331 combiner.observe_frame(DiversitySourceId::new(1), &data, FrameLayout::WithFcs),
332 DiversityDecision::Duplicate
333 );
334 assert_eq!(
335 combiner.observe_frame(
336 DiversitySourceId::new(1),
337 &session_frame(channel, 2),
338 FrameLayout::WithFcs,
339 ),
340 DiversityDecision::Accept
341 );
342 assert_eq!(
343 combiner.observe_frame(DiversitySourceId::new(1), &data, FrameLayout::WithFcs),
344 DiversityDecision::Accept
345 );
346 }
347
348 #[test]
349 fn session_tracking_is_bounded_with_the_packet_cache() {
350 let mut combiner = DiversityCombiner::new(1);
351 let first = ChannelId::new(1);
352 let second = ChannelId::new(2);
353
354 combiner.observe_frame(
355 DiversitySourceId::new(0),
356 &session_frame(first, 1),
357 FrameLayout::WithFcs,
358 );
359 combiner.observe_frame(
360 DiversitySourceId::new(0),
361 &session_frame(second, 2),
362 FrameLayout::WithFcs,
363 );
364
365 assert_eq!(combiner.current_sessions.len(), 1);
366 assert!(!combiner.current_sessions.contains_key(&first));
367 assert!(combiner.current_sessions.contains_key(&second));
368 }
369
370 #[test]
371 fn fragments_from_two_radios_recover_one_shared_fec_block() {
372 let channel = ChannelId::default_video();
373 let primary = [plain(b"first"), plain(b"second"), plain(b"third")];
374 let parity = FecCode::new(3, 5)
375 .unwrap()
376 .encode(&primary, MAX_FEC_PAYLOAD)
377 .unwrap();
378 let arrivals = [
379 (DiversitySourceId::new(0), 0, primary[0].as_slice()),
380 (DiversitySourceId::new(1), 0, primary[0].as_slice()),
381 (DiversitySourceId::new(0), 2, primary[2].as_slice()),
382 (DiversitySourceId::new(1), 3, parity[0].as_slice()),
383 ];
384 let mut combiner = DiversityCombiner::default();
385 let mut assembler = PlainAssembler::new(3, 5).unwrap();
386 let mut output = Vec::new();
387
388 for (source, nonce, fragment) in arrivals {
389 let frame = data_frame(channel, nonce);
390 if combiner
391 .observe_frame(source, &frame, FrameLayout::WithFcs)
392 .should_forward()
393 {
394 output.extend(assembler.push_decrypted_fragment(nonce, fragment).unwrap());
395 }
396 }
397
398 assert_eq!(
399 output
400 .into_iter()
401 .map(|packet| packet.payload)
402 .collect::<Vec<_>>(),
403 vec![b"first".to_vec(), b"second".to_vec(), b"third".to_vec()]
404 );
405 assert_eq!(combiner.stats().duplicates, 1);
406 assert_eq!(assembler.recovered_packets, 1);
407 }
408
409 #[test]
410 fn encrypted_fragments_from_two_radios_share_one_pipeline() {
411 let channel = ChannelId::default_video();
412 let (tx_keys, rx_keys) = linked_keypairs();
413 let mut transmitter = WfbTransmitter::new(channel, tx_keys, 42, 2, 3).unwrap();
414 let mut pipeline =
415 PayloadPipeline::with_keypair(channel, FrameLayout::WithFcs, rx_keys, 0).unwrap();
416 let mut combiner = DiversityCombiner::default();
417
418 let session = wrap_forwarder_packet(channel, transmitter.session_forwarder_packet());
419 assert!(combiner
420 .observe_frame(DiversitySourceId::new(0), &session, FrameLayout::WithFcs)
421 .should_forward());
422 let events = pipeline.push_80211_frame(&session).unwrap();
423 assert!(matches!(
424 events.as_slice(),
425 [PayloadPipelineEvent::SessionEstablished {
426 epoch: 42,
427 fec_k: 2,
428 fec_n: 3
429 }]
430 ));
431
432 let missing_primary = transmitter
433 .forwarder_packets_for_payload(b"first", 0)
434 .unwrap();
435 assert_eq!(missing_primary.len(), 1);
436 let second_and_parity = transmitter
437 .forwarder_packets_for_payload(b"second", 0)
438 .unwrap();
439 assert_eq!(second_and_parity.len(), 2);
440 let second = wrap_forwarder_packet(channel, &second_and_parity[0]);
441 let parity = wrap_forwarder_packet(channel, &second_and_parity[1]);
442
443 let arrivals = [
444 (DiversitySourceId::new(0), second.as_slice()),
445 (DiversitySourceId::new(1), second.as_slice()),
446 (DiversitySourceId::new(1), parity.as_slice()),
447 ];
448 let mut payloads = Vec::new();
449 for (source, frame) in arrivals {
450 if combiner
451 .observe_frame(source, frame, FrameLayout::WithFcs)
452 .should_forward()
453 {
454 for event in pipeline.push_80211_frame(frame).unwrap() {
455 if let PayloadPipelineEvent::Payload(payload) = event {
456 payloads.push(payload.data);
457 }
458 }
459 }
460 }
461
462 assert_eq!(payloads, [b"first".to_vec(), b"second".to_vec()]);
463 let stats = combiner.stats();
464 assert_eq!(stats.duplicates, 1);
465 assert_eq!(stats.sources[&DiversitySourceId::new(1)].accepted, 1);
466 }
467}