use ahash::AHashSet;
use mavlink::MavHeader;
use serde::Deserialize;
type HashSet<T> = AHashSet<T>;
#[derive(Debug, Clone, Default, Deserialize)]
pub struct EndpointFilters {
#[serde(default)]
pub allow_msg_id_out: HashSet<u32>,
#[serde(default)]
pub block_msg_id_out: HashSet<u32>,
#[serde(default)]
pub allow_src_comp_out: HashSet<u8>,
#[serde(default)]
pub block_src_comp_out: HashSet<u8>,
#[serde(default)]
pub allow_src_sys_out: HashSet<u8>,
#[serde(default)]
pub block_src_sys_out: HashSet<u8>,
#[serde(default)]
pub allow_msg_id_in: HashSet<u32>,
#[serde(default)]
pub block_msg_id_in: HashSet<u32>,
#[serde(default)]
pub allow_src_comp_in: HashSet<u8>,
#[serde(default)]
pub block_src_comp_in: HashSet<u8>,
#[serde(default)]
pub allow_src_sys_in: HashSet<u8>,
#[serde(default)]
pub block_src_sys_in: HashSet<u8>,
}
impl EndpointFilters {
pub fn check_incoming(&self, header: &MavHeader, msg_id: u32) -> bool {
Self::check(
header,
msg_id,
&self.allow_msg_id_in,
&self.block_msg_id_in,
&self.allow_src_comp_in,
&self.block_src_comp_in,
&self.allow_src_sys_in,
&self.block_src_sys_in,
)
}
pub fn check_outgoing(&self, header: &MavHeader, msg_id: u32) -> bool {
Self::check(
header,
msg_id,
&self.allow_msg_id_out,
&self.block_msg_id_out,
&self.allow_src_comp_out,
&self.block_src_comp_out,
&self.allow_src_sys_out,
&self.block_src_sys_out,
)
}
#[allow(clippy::too_many_arguments)]
fn check(
header: &MavHeader,
msg_id: u32,
allow_msg: &HashSet<u32>,
block_msg: &HashSet<u32>,
allow_comp: &HashSet<u8>,
block_comp: &HashSet<u8>,
allow_sys: &HashSet<u8>,
block_sys: &HashSet<u8>,
) -> bool {
if !allow_msg.is_empty() && !allow_msg.contains(&msg_id) {
return false;
}
if block_msg.contains(&msg_id) {
return false;
}
if !allow_comp.is_empty() && !allow_comp.contains(&header.component_id) {
return false;
}
if block_comp.contains(&header.component_id) {
return false;
}
if !allow_sys.is_empty() && !allow_sys.contains(&header.system_id) {
return false;
}
if block_sys.contains(&header.system_id) {
return false;
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use mavlink::MavHeader;
#[test]
fn test_filter_logic() {
let filters = EndpointFilters {
allow_msg_id_out: HashSet::from([0]), ..Default::default()
};
let header = MavHeader::default();
assert!(filters.check_outgoing(&header, 0));
assert!(!filters.check_outgoing(&header, 1));
}
#[test]
fn test_filter_block() {
let filters = EndpointFilters {
block_msg_id_out: HashSet::from([30]), ..Default::default()
};
let header = MavHeader::default();
assert!(filters.check_outgoing(&header, 0)); assert!(!filters.check_outgoing(&header, 30)); }
}