use coap_message::{MessageOption as _, ReadableMessage};
pub trait Scope: Sized + core::fmt::Debug {
fn request_is_allowed<M: ReadableMessage>(&self, request: &M) -> bool;
}
impl Scope for core::convert::Infallible {
fn request_is_allowed<M: ReadableMessage>(&self, _request: &M) -> bool {
match *self {}
}
}
#[derive(Debug, Copy, Clone)]
pub struct InvalidScope;
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug)]
pub struct AllowAll;
impl Scope for AllowAll {
fn request_is_allowed<M: ReadableMessage>(&self, _request: &M) -> bool {
true
}
}
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug)]
pub struct DenyAll;
impl Scope for DenyAll {
fn request_is_allowed<M: ReadableMessage>(&self, _request: &M) -> bool {
false
}
}
const AIF_SCOPE_MAX_LEN: usize = 64;
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug, Clone)]
pub struct AifValue([u8; AIF_SCOPE_MAX_LEN]);
impl AifValue {
pub fn parse(bytes: &[u8]) -> Result<Self, InvalidScope> {
let mut buffer = [0; AIF_SCOPE_MAX_LEN];
buffer
.get_mut(..bytes.len())
.ok_or(InvalidScope)?
.copy_from_slice(bytes);
let mut decoder = minicbor::Decoder::new(bytes);
for item in decoder
.array_iter::<(&str, u32)>()
.map_err(|_| InvalidScope)?
{
let (path, _mask) = item.map_err(|_| InvalidScope)?;
if !path.starts_with('/') {
return Err(InvalidScope);
}
}
Ok(Self(buffer))
}
}
impl Scope for AifValue {
fn request_is_allowed<M: ReadableMessage>(&self, request: &M) -> bool {
let code: u8 = request.code().into();
let (codebit, false) = 1u32.overflowing_shl(
u32::from(code)
.checked_sub(1)
.expect("Request codes are != 0"),
) else {
return false;
};
let mut decoder = minicbor::Decoder::new(&self.0);
'outer: for item in decoder.array_iter::<(&str, u32)>().unwrap() {
let (path, perms) = item.unwrap();
if perms & codebit == 0 {
continue;
}
let mut pathopts = request
.options()
.filter(|o| o.number() == coap_numbers::option::URI_PATH)
.peekable();
if path == "/" && pathopts.peek().is_none() {
return true;
}
assert!(path.starts_with('/'), "Invalid AIF");
let mut remainder = &path[1..];
while !remainder.is_empty() {
let (next_part, next_remainder) = match remainder.split_once('/') {
Some((next_part, next_remainder)) => (next_part, next_remainder),
None => (remainder, ""),
};
let Some(this_opt) = pathopts.next() else {
continue 'outer;
};
if this_opt.value() != next_part.as_bytes() {
continue 'outer;
}
remainder = next_remainder;
}
if pathopts.next().is_none() {
return true;
}
}
false
}
}
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug, Clone)]
pub enum UnionScope {
AifValue(AifValue),
AllowAll,
DenyAll,
}
impl Scope for UnionScope {
fn request_is_allowed<M: ReadableMessage>(&self, request: &M) -> bool {
match self {
UnionScope::AifValue(v) => v.request_is_allowed(request),
UnionScope::AllowAll => AllowAll.request_is_allowed(request),
UnionScope::DenyAll => DenyAll.request_is_allowed(request),
}
}
}
impl From<AifValue> for UnionScope {
fn from(value: AifValue) -> Self {
UnionScope::AifValue(value)
}
}
impl From<AllowAll> for UnionScope {
fn from(_value: AllowAll) -> Self {
UnionScope::AllowAll
}
}
impl From<DenyAll> for UnionScope {
fn from(_value: DenyAll) -> Self {
UnionScope::DenyAll
}
}
impl From<core::convert::Infallible> for UnionScope {
fn from(value: core::convert::Infallible) -> Self {
match value {}
}
}