use std::collections::HashMap;
use crate::channel::ChannelId;
use crate::ieee80211::{FrameLayout, WifiFrame};
use crate::pipeline::{
MockPayloadPipeline, PayloadPipeline, PayloadPipelineEvent, RecoveredPayload,
};
use crate::wfb::{FecCounters, WfbError, WfbKeypair};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PayloadRouteId(u64);
impl PayloadRouteId {
pub const fn new(raw: u64) -> Self {
Self(raw)
}
pub const fn raw(self) -> u64 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PayloadRuntimeKey {
channel_id: ChannelId,
key_slot: u64,
}
impl PayloadRuntimeKey {
pub const fn new(channel_id: ChannelId, key_slot: u64) -> Self {
Self {
channel_id,
key_slot,
}
}
pub const fn channel_id(self) -> ChannelId {
self.channel_id
}
pub const fn key_slot(self) -> u64 {
self.key_slot
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PayloadRouteEvent {
IgnoredFrame,
SessionEstablished {
runtime: PayloadRuntimeKey,
route_ids: Vec<PayloadRouteId>,
epoch: u64,
fec_k: usize,
fec_n: usize,
},
Payload {
runtime: PayloadRuntimeKey,
route_ids: Vec<PayloadRouteId>,
payload: RecoveredPayload,
},
}
#[derive(Debug, PartialEq, Eq)]
pub enum PayloadRouteError {
UnknownRuntime(PayloadRuntimeKey),
Wfb(WfbError),
}
impl std::fmt::Display for PayloadRouteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnknownRuntime(key) => write!(
f,
"unknown payload runtime for channel 0x{:08x} key slot {}",
key.channel_id().raw(),
key.key_slot()
),
Self::Wfb(err) => std::fmt::Display::fmt(err, f),
}
}
}
impl std::error::Error for PayloadRouteError {}
impl From<WfbError> for PayloadRouteError {
fn from(err: WfbError) -> Self {
Self::Wfb(err)
}
}
#[derive(Debug, Clone)]
enum PayloadChannelPipeline {
Real(Box<PayloadPipeline>),
Mock(MockPayloadPipeline),
}
impl PayloadChannelPipeline {
fn channel_id(&self) -> ChannelId {
match self {
Self::Real(pipeline) => pipeline.channel_id(),
Self::Mock(pipeline) => pipeline.channel_id(),
}
}
fn fec_counters(&self) -> FecCounters {
match self {
Self::Real(pipeline) => pipeline.fec_counters(),
Self::Mock(pipeline) => pipeline.fec_counters(),
}
}
fn accepts_80211_frame(&self, frame: &[u8]) -> bool {
match self {
Self::Real(pipeline) => pipeline.accepts_80211_frame(frame),
Self::Mock(_) => false,
}
}
fn push_matched_payload(
&mut self,
payload: &[u8],
) -> Result<Vec<PayloadPipelineEvent>, WfbError> {
match self {
Self::Real(pipeline) => pipeline.push_matched_payload(payload),
Self::Mock(_) => Ok(vec![PayloadPipelineEvent::IgnoredFrame]),
}
}
fn push_decrypted_80211_frame(
&mut self,
frame: &[u8],
decrypted_fragment: &[u8],
) -> Result<Vec<PayloadPipelineEvent>, WfbError> {
match self {
Self::Real(pipeline) => pipeline.push_decrypted_80211_frame(frame, decrypted_fragment),
Self::Mock(_) => Ok(vec![PayloadPipelineEvent::IgnoredFrame]),
}
}
fn push_decrypted_fragment(
&mut self,
data_nonce: u64,
decrypted_fragment: &[u8],
) -> Result<Vec<PayloadPipelineEvent>, WfbError> {
match self {
Self::Real(pipeline) => {
pipeline.push_decrypted_fragment(data_nonce, decrypted_fragment)
}
Self::Mock(pipeline) => Ok(pipeline.push_payload(data_nonce, decrypted_fragment)),
}
}
fn push_mock_payload(&mut self, packet_seq: u64, data: &[u8]) -> Vec<PayloadPipelineEvent> {
match self {
Self::Real(_) => vec![PayloadPipelineEvent::IgnoredFrame],
Self::Mock(pipeline) => pipeline.push_payload(packet_seq, data),
}
}
}
#[derive(Debug, Clone)]
pub struct PayloadChannelRuntime {
pipeline: PayloadChannelPipeline,
route_ids: Vec<PayloadRouteId>,
}
impl PayloadChannelRuntime {
fn real(pipeline: PayloadPipeline, route_id: PayloadRouteId) -> Self {
Self {
pipeline: PayloadChannelPipeline::Real(Box::new(pipeline)),
route_ids: vec![route_id],
}
}
fn mock(channel_id: ChannelId, route_id: PayloadRouteId) -> Self {
Self {
pipeline: PayloadChannelPipeline::Mock(MockPayloadPipeline::new(channel_id)),
route_ids: vec![route_id],
}
}
pub fn channel_id(&self) -> ChannelId {
self.pipeline.channel_id()
}
pub fn route_ids(&self) -> &[PayloadRouteId] {
self.route_ids.as_slice()
}
fn push_route_id(&mut self, route_id: PayloadRouteId) {
push_route_id(&mut self.route_ids, route_id);
}
}
#[derive(Debug, Clone)]
pub struct PayloadRouteManager {
frame_layout: FrameLayout,
runtimes: HashMap<PayloadRuntimeKey, PayloadChannelRuntime>,
}
impl PayloadRouteManager {
pub fn new(frame_layout: FrameLayout) -> Self {
Self {
frame_layout,
runtimes: HashMap::new(),
}
}
pub const fn frame_layout(&self) -> FrameLayout {
self.frame_layout
}
pub fn runtime_count(&self) -> usize {
self.runtimes.len()
}
pub fn add_plain_route(
&mut self,
route_id: PayloadRouteId,
channel_id: ChannelId,
key_slot: u64,
fec_k: usize,
fec_n: usize,
) -> Result<PayloadRuntimeKey, PayloadRouteError> {
let key = PayloadRuntimeKey::new(channel_id, key_slot);
if let Some(runtime) = self.runtimes.get_mut(&key) {
runtime.push_route_id(route_id);
return Ok(key);
}
let pipeline = PayloadPipeline::new(channel_id, self.frame_layout, fec_k, fec_n)?;
self.runtimes
.insert(key, PayloadChannelRuntime::real(pipeline, route_id));
Ok(key)
}
pub fn add_keyed_route(
&mut self,
route_id: PayloadRouteId,
channel_id: ChannelId,
key_slot: u64,
keypair: WfbKeypair,
minimum_epoch: u64,
) -> Result<PayloadRuntimeKey, PayloadRouteError> {
let key = PayloadRuntimeKey::new(channel_id, key_slot);
if let Some(runtime) = self.runtimes.get_mut(&key) {
runtime.push_route_id(route_id);
return Ok(key);
}
let pipeline =
PayloadPipeline::with_keypair(channel_id, self.frame_layout, keypair, minimum_epoch)?;
self.runtimes
.insert(key, PayloadChannelRuntime::real(pipeline, route_id));
Ok(key)
}
pub fn add_direct_route(
&mut self,
route_id: PayloadRouteId,
channel_id: ChannelId,
key_slot: u64,
) -> PayloadRuntimeKey {
let key = PayloadRuntimeKey::new(channel_id, key_slot);
if let Some(runtime) = self.runtimes.get_mut(&key) {
runtime.push_route_id(route_id);
return key;
}
self.runtimes
.insert(key, PayloadChannelRuntime::mock(channel_id, route_id));
key
}
pub fn add_mock_route(
&mut self,
route_id: PayloadRouteId,
channel_id: ChannelId,
key_slot: u64,
) -> PayloadRuntimeKey {
self.add_direct_route(route_id, channel_id, key_slot)
}
pub fn route_ids(&self, key: PayloadRuntimeKey) -> Option<&[PayloadRouteId]> {
self.runtimes
.get(&key)
.map(PayloadChannelRuntime::route_ids)
}
pub fn fec_counters(&self, key: PayloadRuntimeKey) -> Option<FecCounters> {
self.runtimes
.get(&key)
.map(|runtime| runtime.pipeline.fec_counters())
}
pub fn accepts_80211_frame(&self, key: PayloadRuntimeKey, frame: &[u8]) -> bool {
self.runtimes
.get(&key)
.map(|runtime| runtime.pipeline.accepts_80211_frame(frame))
.unwrap_or(false)
}
pub fn push_80211_frame(
&mut self,
frame: &[u8],
) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
let Ok(frame_view) = WifiFrame::parse(frame, self.frame_layout) else {
return Ok(vec![PayloadRouteEvent::IgnoredFrame]);
};
let Some(channel_id) = frame_view.channel_id() else {
return Ok(vec![PayloadRouteEvent::IgnoredFrame]);
};
let mut matched = false;
let mut route_events = Vec::new();
let mut first_error = None;
for (key, runtime) in self
.runtimes
.iter_mut()
.filter(|(key, _)| key.channel_id() == channel_id)
{
matched = true;
match runtime.pipeline.push_matched_payload(frame_view.payload()) {
Ok(events) => {
route_events.extend(map_pipeline_events(*key, runtime.route_ids(), events));
}
Err(err) => {
if first_error.is_none() {
first_error = Some(err);
}
}
}
}
if !matched {
return Ok(vec![PayloadRouteEvent::IgnoredFrame]);
}
if route_events.is_empty() {
if let Some(err) = first_error {
return Err(err.into());
}
}
Ok(route_events)
}
pub fn push_decrypted_80211_frame(
&mut self,
key: PayloadRuntimeKey,
frame: &[u8],
decrypted_fragment: &[u8],
) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
let runtime = self
.runtimes
.get_mut(&key)
.ok_or(PayloadRouteError::UnknownRuntime(key))?;
let events = runtime
.pipeline
.push_decrypted_80211_frame(frame, decrypted_fragment)?;
Ok(map_pipeline_events(key, runtime.route_ids(), events))
}
pub fn push_decrypted_fragment(
&mut self,
key: PayloadRuntimeKey,
data_nonce: u64,
decrypted_fragment: &[u8],
) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
let runtime = self
.runtimes
.get_mut(&key)
.ok_or(PayloadRouteError::UnknownRuntime(key))?;
let events = runtime
.pipeline
.push_decrypted_fragment(data_nonce, decrypted_fragment)?;
Ok(map_pipeline_events(key, runtime.route_ids(), events))
}
pub fn push_direct_payload(
&mut self,
key: PayloadRuntimeKey,
packet_seq: u64,
data: &[u8],
) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
let runtime = self
.runtimes
.get_mut(&key)
.ok_or(PayloadRouteError::UnknownRuntime(key))?;
let events = runtime.pipeline.push_mock_payload(packet_seq, data);
Ok(map_pipeline_events(key, runtime.route_ids(), events))
}
pub fn push_mock_payload(
&mut self,
key: PayloadRuntimeKey,
packet_seq: u64,
data: &[u8],
) -> Result<Vec<PayloadRouteEvent>, PayloadRouteError> {
self.push_direct_payload(key, packet_seq, data)
}
}
fn push_route_id(route_ids: &mut Vec<PayloadRouteId>, route_id: PayloadRouteId) {
if !route_ids.contains(&route_id) {
route_ids.push(route_id);
}
}
fn map_pipeline_events(
runtime: PayloadRuntimeKey,
route_ids: &[PayloadRouteId],
events: Vec<PayloadPipelineEvent>,
) -> Vec<PayloadRouteEvent> {
events
.into_iter()
.map(|event| match event {
PayloadPipelineEvent::IgnoredFrame => PayloadRouteEvent::IgnoredFrame,
PayloadPipelineEvent::SessionEstablished {
epoch,
fec_k,
fec_n,
} => PayloadRouteEvent::SessionEstablished {
runtime,
route_ids: route_ids.to_vec(),
epoch,
fec_k,
fec_n,
},
PayloadPipelineEvent::Payload(payload) => PayloadRouteEvent::Payload {
runtime,
route_ids: route_ids.to_vec(),
payload,
},
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn plain(payload: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
out.push(0);
out.extend_from_slice(&(payload.len() as u16).to_be_bytes());
out.extend_from_slice(payload);
out
}
#[test]
fn routes_share_one_runtime_per_channel_and_key_slot() {
let mut manager = PayloadRouteManager::new(FrameLayout::WithFcs);
let channel = ChannelId::default_video();
let runtime = manager
.add_plain_route(PayloadRouteId::new(1), channel, 0, 1, 1)
.unwrap();
let same_runtime = manager
.add_plain_route(PayloadRouteId::new(2), channel, 0, 1, 1)
.unwrap();
assert_eq!(runtime, same_runtime);
assert_eq!(manager.runtime_count(), 1);
let events = manager
.push_decrypted_fragment(runtime, 0, &plain(b"rtp bytes"))
.unwrap();
assert_eq!(
events,
vec![PayloadRouteEvent::Payload {
runtime,
route_ids: vec![PayloadRouteId::new(1), PayloadRouteId::new(2)],
payload: RecoveredPayload {
channel_id: channel,
packet_seq: 0,
data: b"rtp bytes".to_vec(),
},
}]
);
}
#[test]
fn different_channels_get_different_runtimes() {
let mut manager = PayloadRouteManager::new(FrameLayout::WithFcs);
let video = ChannelId::default_video();
let telemetry = ChannelId::from_link_port(
crate::channel::DEFAULT_LINK_ID,
crate::RadioPort::TelemetryRx,
);
let video_runtime = manager
.add_plain_route(PayloadRouteId::new(1), video, 0, 1, 1)
.unwrap();
let telemetry_runtime = manager
.add_plain_route(PayloadRouteId::new(2), telemetry, 0, 1, 1)
.unwrap();
assert_ne!(video_runtime, telemetry_runtime);
assert_eq!(manager.runtime_count(), 2);
}
}