1use std::collections::HashMap;
2
3use crate::channel::ChannelId;
4use crate::ieee80211::{FrameLayout, WifiFrame};
5use crate::pipeline::{
6 MockPayloadPipeline, PayloadPipeline, PayloadPipelineEvent, RecoveredPayload,
7};
8use crate::wfb::{FecCounters, WfbError, WfbKeypair};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct PayloadRouteId(u64);
13
14impl PayloadRouteId {
15 pub const fn new(raw: u64) -> Self {
17 Self(raw)
18 }
19
20 pub const fn raw(self) -> u64 {
22 self.0
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub struct PayloadRuntimeKey {
32 channel_id: ChannelId,
33 key_slot: u64,
34}
35
36impl PayloadRuntimeKey {
37 pub const fn new(channel_id: ChannelId, key_slot: u64) -> Self {
39 Self {
40 channel_id,
41 key_slot,
42 }
43 }
44
45 pub const fn channel_id(self) -> ChannelId {
47 self.channel_id
48 }
49
50 pub const fn key_slot(self) -> u64 {
52 self.key_slot
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub enum PayloadRouteEvent {
59 IgnoredFrame,
61 SessionEstablished {
63 runtime: PayloadRuntimeKey,
65 route_ids: Vec<PayloadRouteId>,
67 epoch: u64,
69 fec_k: usize,
71 fec_n: usize,
73 },
74 Payload {
76 runtime: PayloadRuntimeKey,
78 route_ids: Vec<PayloadRouteId>,
80 payload: RecoveredPayload,
82 },
83}
84
85#[derive(Debug, PartialEq, Eq)]
87pub enum PayloadRouteError {
88 UnknownRuntime(PayloadRuntimeKey),
90 Wfb(WfbError),
92}
93
94impl std::fmt::Display for PayloadRouteError {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 Self::UnknownRuntime(key) => write!(
98 f,
99 "unknown payload runtime for channel 0x{:08x} key slot {}",
100 key.channel_id().raw(),
101 key.key_slot()
102 ),
103 Self::Wfb(err) => std::fmt::Display::fmt(err, f),
104 }
105 }
106}
107
108impl std::error::Error for PayloadRouteError {}
109
110impl From<WfbError> for PayloadRouteError {
111 fn from(err: WfbError) -> Self {
112 Self::Wfb(err)
113 }
114}
115
116#[derive(Debug, Clone)]
117enum PayloadChannelPipeline {
118 Real(Box<PayloadPipeline>),
119 Mock(MockPayloadPipeline),
120}
121
122impl PayloadChannelPipeline {
123 fn channel_id(&self) -> ChannelId {
124 match self {
125 Self::Real(pipeline) => pipeline.channel_id(),
126 Self::Mock(pipeline) => pipeline.channel_id(),
127 }
128 }
129
130 fn fec_counters(&self) -> FecCounters {
131 match self {
132 Self::Real(pipeline) => pipeline.fec_counters(),
133 Self::Mock(pipeline) => pipeline.fec_counters(),
134 }
135 }
136
137 fn accepts_80211_frame(&self, frame: &[u8]) -> bool {
138 match self {
139 Self::Real(pipeline) => pipeline.accepts_80211_frame(frame),
140 Self::Mock(_) => false,
141 }
142 }
143
144 fn push_matched_payload(
145 &mut self,
146 payload: &[u8],
147 ) -> Result<Vec<PayloadPipelineEvent>, WfbError> {
148 match self {
149 Self::Real(pipeline) => pipeline.push_matched_payload(payload),
150 Self::Mock(_) => Ok(vec![PayloadPipelineEvent::IgnoredFrame]),
151 }
152 }
153
154 fn push_decrypted_80211_frame(
155 &mut self,
156 frame: &[u8],
157 decrypted_fragment: &[u8],
158 ) -> Result<Vec<PayloadPipelineEvent>, WfbError> {
159 match self {
160 Self::Real(pipeline) => pipeline.push_decrypted_80211_frame(frame, decrypted_fragment),
161 Self::Mock(_) => Ok(vec![PayloadPipelineEvent::IgnoredFrame]),
162 }
163 }
164
165 fn push_decrypted_fragment(
166 &mut self,
167 data_nonce: u64,
168 decrypted_fragment: &[u8],
169 ) -> Result<Vec<PayloadPipelineEvent>, WfbError> {
170 match self {
171 Self::Real(pipeline) => {
172 pipeline.push_decrypted_fragment(data_nonce, decrypted_fragment)
173 }
174 Self::Mock(pipeline) => Ok(pipeline.push_payload(data_nonce, decrypted_fragment)),
175 }
176 }
177
178 fn push_mock_payload(&mut self, packet_seq: u64, data: &[u8]) -> Vec<PayloadPipelineEvent> {
179 match self {
180 Self::Real(_) => vec![PayloadPipelineEvent::IgnoredFrame],
181 Self::Mock(pipeline) => pipeline.push_payload(packet_seq, data),
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
192pub struct PayloadChannelRuntime {
193 pipeline: PayloadChannelPipeline,
194 route_ids: Vec<PayloadRouteId>,
195}
196
197impl PayloadChannelRuntime {
198 fn real(pipeline: PayloadPipeline, route_id: PayloadRouteId) -> Self {
199 Self {
200 pipeline: PayloadChannelPipeline::Real(Box::new(pipeline)),
201 route_ids: vec![route_id],
202 }
203 }
204
205 fn mock(channel_id: ChannelId, route_id: PayloadRouteId) -> Self {
206 Self {
207 pipeline: PayloadChannelPipeline::Mock(MockPayloadPipeline::new(channel_id)),
208 route_ids: vec![route_id],
209 }
210 }
211
212 pub fn channel_id(&self) -> ChannelId {
214 self.pipeline.channel_id()
215 }
216
217 pub fn route_ids(&self) -> &[PayloadRouteId] {
219 self.route_ids.as_slice()
220 }
221
222 fn push_route_id(&mut self, route_id: PayloadRouteId) {
223 push_route_id(&mut self.route_ids, route_id);
224 }
225}
226
227#[derive(Debug, Clone)]
233pub struct PayloadRouteManager {
234 frame_layout: FrameLayout,
235 runtimes: HashMap<PayloadRuntimeKey, PayloadChannelRuntime>,
236}
237
238impl PayloadRouteManager {
239 pub fn new(frame_layout: FrameLayout) -> Self {
241 Self {
242 frame_layout,
243 runtimes: HashMap::new(),
244 }
245 }
246
247 pub const fn frame_layout(&self) -> FrameLayout {
249 self.frame_layout
250 }
251
252 pub fn runtime_count(&self) -> usize {
254 self.runtimes.len()
255 }
256
257 pub fn add_plain_route(
261 &mut self,
262 route_id: PayloadRouteId,
263 channel_id: ChannelId,
264 key_slot: u64,
265 fec_k: usize,
266 fec_n: usize,
267 ) -> Result<PayloadRuntimeKey, PayloadRouteError> {
268 let key = PayloadRuntimeKey::new(channel_id, key_slot);
269 if let Some(runtime) = self.runtimes.get_mut(&key) {
270 runtime.push_route_id(route_id);
271 return Ok(key);
272 }
273
274 let pipeline = PayloadPipeline::new(channel_id, self.frame_layout, fec_k, fec_n)?;
275 self.runtimes
276 .insert(key, PayloadChannelRuntime::real(pipeline, route_id));
277 Ok(key)
278 }
279
280 pub fn add_keyed_route(
284 &mut self,
285 route_id: PayloadRouteId,
286 channel_id: ChannelId,
287 key_slot: u64,
288 keypair: WfbKeypair,
289 minimum_epoch: u64,
290 ) -> Result<PayloadRuntimeKey, PayloadRouteError> {
291 let key = PayloadRuntimeKey::new(channel_id, key_slot);
292 if let Some(runtime) = self.runtimes.get_mut(&key) {
293 runtime.push_route_id(route_id);
294 return Ok(key);
295 }
296
297 let pipeline =
298 PayloadPipeline::with_keypair(channel_id, self.frame_layout, keypair, minimum_epoch)?;
299 self.runtimes
300 .insert(key, PayloadChannelRuntime::real(pipeline, route_id));
301 Ok(key)
302 }
303
304 pub fn add_direct_route(
311 &mut self,
312 route_id: PayloadRouteId,
313 channel_id: ChannelId,
314 key_slot: u64,
315 ) -> PayloadRuntimeKey {
316 let key = PayloadRuntimeKey::new(channel_id, key_slot);
317 if let Some(runtime) = self.runtimes.get_mut(&key) {
318 runtime.push_route_id(route_id);
319 return key;
320 }
321
322 self.runtimes
323 .insert(key, PayloadChannelRuntime::mock(channel_id, route_id));
324 key
325 }
326
327 pub fn add_mock_route(
331 &mut self,
332 route_id: PayloadRouteId,
333 channel_id: ChannelId,
334 key_slot: u64,
335 ) -> PayloadRuntimeKey {
336 self.add_direct_route(route_id, channel_id, key_slot)
337 }
338
339 pub fn route_ids(&self, key: PayloadRuntimeKey) -> Option<&[PayloadRouteId]> {
341 self.runtimes
342 .get(&key)
343 .map(PayloadChannelRuntime::route_ids)
344 }
345
346 pub fn fec_counters(&self, key: PayloadRuntimeKey) -> Option<FecCounters> {
348 self.runtimes
349 .get(&key)
350 .map(|runtime| runtime.pipeline.fec_counters())
351 }
352
353 pub fn accepts_80211_frame(&self, key: PayloadRuntimeKey, frame: &[u8]) -> bool {
355 self.runtimes
356 .get(&key)
357 .map(|runtime| runtime.pipeline.accepts_80211_frame(frame))
358 .unwrap_or(false)
359 }
360
361 pub fn push_80211_frame(
363 &mut self,
364 frame: &[u8],
365 ) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
366 let Ok(frame_view) = WifiFrame::parse(frame, self.frame_layout) else {
367 return Ok(vec![PayloadRouteEvent::IgnoredFrame]);
368 };
369 let Some(channel_id) = frame_view.channel_id() else {
370 return Ok(vec![PayloadRouteEvent::IgnoredFrame]);
371 };
372
373 let mut matched = false;
374 let mut route_events = Vec::new();
375 let mut first_error = None;
376
377 for (key, runtime) in self
378 .runtimes
379 .iter_mut()
380 .filter(|(key, _)| key.channel_id() == channel_id)
381 {
382 matched = true;
383 match runtime.pipeline.push_matched_payload(frame_view.payload()) {
384 Ok(events) => {
385 route_events.extend(map_pipeline_events(*key, runtime.route_ids(), events));
386 }
387 Err(err) => {
388 if first_error.is_none() {
389 first_error = Some(err);
390 }
391 }
392 }
393 }
394
395 if !matched {
396 return Ok(vec![PayloadRouteEvent::IgnoredFrame]);
397 }
398 if route_events.is_empty() {
399 if let Some(err) = first_error {
400 return Err(err.into());
401 }
402 }
403 Ok(route_events)
404 }
405
406 pub fn push_decrypted_80211_frame(
408 &mut self,
409 key: PayloadRuntimeKey,
410 frame: &[u8],
411 decrypted_fragment: &[u8],
412 ) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
413 let runtime = self
414 .runtimes
415 .get_mut(&key)
416 .ok_or(PayloadRouteError::UnknownRuntime(key))?;
417 let events = runtime
418 .pipeline
419 .push_decrypted_80211_frame(frame, decrypted_fragment)?;
420 Ok(map_pipeline_events(key, runtime.route_ids(), events))
421 }
422
423 pub fn push_decrypted_fragment(
425 &mut self,
426 key: PayloadRuntimeKey,
427 data_nonce: u64,
428 decrypted_fragment: &[u8],
429 ) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
430 let runtime = self
431 .runtimes
432 .get_mut(&key)
433 .ok_or(PayloadRouteError::UnknownRuntime(key))?;
434 let events = runtime
435 .pipeline
436 .push_decrypted_fragment(data_nonce, decrypted_fragment)?;
437 Ok(map_pipeline_events(key, runtime.route_ids(), events))
438 }
439
440 pub fn push_direct_payload(
442 &mut self,
443 key: PayloadRuntimeKey,
444 packet_seq: u64,
445 data: &[u8],
446 ) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
447 let runtime = self
448 .runtimes
449 .get_mut(&key)
450 .ok_or(PayloadRouteError::UnknownRuntime(key))?;
451 let events = runtime.pipeline.push_mock_payload(packet_seq, data);
452 Ok(map_pipeline_events(key, runtime.route_ids(), events))
453 }
454
455 pub fn push_mock_payload(
459 &mut self,
460 key: PayloadRuntimeKey,
461 packet_seq: u64,
462 data: &[u8],
463 ) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
464 self.push_direct_payload(key, packet_seq, data)
465 }
466}
467
468fn push_route_id(route_ids: &mut Vec<PayloadRouteId>, route_id: PayloadRouteId) {
469 if !route_ids.contains(&route_id) {
470 route_ids.push(route_id);
471 }
472}
473
474fn map_pipeline_events(
475 runtime: PayloadRuntimeKey,
476 route_ids: &[PayloadRouteId],
477 events: Vec<PayloadPipelineEvent>,
478) -> Vec<PayloadRouteEvent> {
479 events
480 .into_iter()
481 .map(|event| match event {
482 PayloadPipelineEvent::IgnoredFrame => PayloadRouteEvent::IgnoredFrame,
483 PayloadPipelineEvent::SessionEstablished {
484 epoch,
485 fec_k,
486 fec_n,
487 } => PayloadRouteEvent::SessionEstablished {
488 runtime,
489 route_ids: route_ids.to_vec(),
490 epoch,
491 fec_k,
492 fec_n,
493 },
494 PayloadPipelineEvent::Payload(payload) => PayloadRouteEvent::Payload {
495 runtime,
496 route_ids: route_ids.to_vec(),
497 payload,
498 },
499 })
500 .collect()
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 fn plain(payload: &[u8]) -> Vec<u8> {
508 let mut out = Vec::new();
509 out.push(0);
510 out.extend_from_slice(&(payload.len() as u16).to_be_bytes());
511 out.extend_from_slice(payload);
512 out
513 }
514
515 #[test]
516 fn routes_share_one_runtime_per_channel_and_key_slot() {
517 let mut manager = PayloadRouteManager::new(FrameLayout::WithFcs);
518 let channel = ChannelId::default_video();
519 let runtime = manager
520 .add_plain_route(PayloadRouteId::new(1), channel, 0, 1, 1)
521 .unwrap();
522 let same_runtime = manager
523 .add_plain_route(PayloadRouteId::new(2), channel, 0, 1, 1)
524 .unwrap();
525
526 assert_eq!(runtime, same_runtime);
527 assert_eq!(manager.runtime_count(), 1);
528
529 let events = manager
530 .push_decrypted_fragment(runtime, 0, &plain(b"rtp bytes"))
531 .unwrap();
532 assert_eq!(
533 events,
534 vec![PayloadRouteEvent::Payload {
535 runtime,
536 route_ids: vec![PayloadRouteId::new(1), PayloadRouteId::new(2)],
537 payload: RecoveredPayload {
538 channel_id: channel,
539 packet_seq: 0,
540 data: b"rtp bytes".to_vec(),
541 },
542 }]
543 );
544 }
545
546 #[test]
547 fn different_channels_get_different_runtimes() {
548 let mut manager = PayloadRouteManager::new(FrameLayout::WithFcs);
549 let video = ChannelId::default_video();
550 let telemetry = ChannelId::from_link_port(
551 crate::channel::DEFAULT_LINK_ID,
552 crate::RadioPort::TelemetryRx,
553 );
554
555 let video_runtime = manager
556 .add_plain_route(PayloadRouteId::new(1), video, 0, 1, 1)
557 .unwrap();
558 let telemetry_runtime = manager
559 .add_plain_route(PayloadRouteId::new(2), telemetry, 0, 1, 1)
560 .unwrap();
561
562 assert_ne!(video_runtime, telemetry_runtime);
563 assert_eq!(manager.runtime_count(), 2);
564 }
565}