use std::ops::BitOr;
use super::*;
type Result<T> = std::result::Result<T, ErrorCode>;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[must_use]
#[repr(transparent)]
pub struct Access(Perm);
impl Access {
pub const NONE: Self = Self(Perm::empty());
pub const READ: Self = Self(Perm::READ);
pub const WRITE: Self = Self(Perm::WRITE);
pub const READ_WRITE: Self = Self(Perm::READ_WRITE);
#[inline]
pub(crate) fn copy(self, sec: hci::ConnSec) -> Self {
let mut p = self.0.access_type();
if sec.contains(hci::ConnSec::AUTHN) {
p.insert(Perm::AUTHN);
}
if sec.contains(hci::ConnSec::AUTHZ) {
p.insert(Perm::AUTHZ);
}
Self(p).key_len(sec.intersection(hci::ConnSec::KEY_LEN).bits())
}
#[inline]
pub const fn authn(self) -> Self {
Self(self.0.union(Perm::AUTHN))
}
#[inline]
pub const fn authz(self) -> Self {
Self(self.0.union(Perm::AUTHZ))
}
#[inline]
pub const fn encrypt(self) -> Self {
self.key_len(Perm::KEY_MAX)
}
#[inline]
const fn key_len(self, n: u8) -> Self {
assert!(
Perm::KEY_MIN <= n && n <= Perm::KEY_MAX && n % 8 == 0 || n == 0,
"invalid encryption key length"
);
let n = n.saturating_sub(Perm::KEY_OFF) & Perm::KEY_LEN.bits();
Self(Perm::from_bits_retain(
self.0.difference(Perm::KEY_LEN).bits() | n,
))
}
#[inline]
pub const fn typ(self) -> Self {
Self(self.0.access_type())
}
#[inline]
#[must_use]
const fn index(self) -> usize {
self.0.access_type().bits() as usize
}
}
impl BitOr for Access {
type Output = Perms;
#[inline]
fn bitor(self, rhs: Self) -> Self::Output {
Perms::allow(self, rhs)
}
}
#[derive(Clone, Copy, Debug)]
#[must_use]
pub struct Request {
pub(crate) op: Opcode,
pub(crate) ac: Access,
}
#[derive(Clone, Copy, Debug, Default)]
#[must_use]
#[repr(transparent)]
pub struct Perms([Perm; 4]);
impl Perms {
#[inline]
pub const fn new(allow: Access) -> Self {
let mut ps = Self([Perm::empty(); 4]);
ps.0[allow.index()] = allow.0;
ps
}
#[inline]
pub const fn allow(a: Access, b: Access) -> Self {
let (i, j) = (a.index(), b.index());
assert!(i != j, "access type must be different");
let mut ps = Self([Perm::empty(); 4]);
ps.0[i] = a.0;
ps.0[j] = b.0;
ps
}
#[inline(always)]
#[must_use]
pub const fn contains(self, rw: Access) -> bool {
use ErrorCode::*;
!matches!(self.test(rw), Err(ReadNotPermitted | WriteNotPermitted))
}
pub const fn test(self, req: Access) -> Result<()> {
const RW: usize = Access(Perm::READ_WRITE).index();
let mut op = req.index();
if !self.0[op].is_set() {
op = RW;
}
let exact = self.0[op].test(req.0);
if exact.is_ok() || op == RW {
return exact;
}
let rw = self.0[RW].test(req.0);
match rw {
Ok(_) | Err(ErrorCode::EncryptionKeySizeTooShort) => rw,
_ => exact,
}
}
}
impl From<Access> for Perms {
#[inline]
fn from(v: Access) -> Self {
Self::new(v)
}
}
bitflags::bitflags! {
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[must_use]
#[repr(transparent)]
struct Perm: u8 {
const READ = 1 << 0;
const WRITE = 1 << 1;
const READ_WRITE = Self::READ.bits() | Self::WRITE.bits();
const AUTHN = 1 << 2;
const ENCRYPT = 1 << 3;
const KEY_LEN = 0xF << 3;
const AUTHZ = 1 << 7;
}
}
impl Perm {
const KEY_MIN: u8 = 56;
const KEY_MAX: u8 = 128;
const KEY_OFF: u8 = Self::KEY_MIN - 8;
#[inline]
#[must_use]
const fn is_set(self) -> bool {
self.intersects(Self::READ_WRITE)
}
#[inline]
const fn access_type(self) -> Self {
self.intersection(Self::READ_WRITE)
}
#[inline]
const fn security(self) -> Self {
self.difference(Self::READ_WRITE.union(Self::KEY_LEN))
.union(if self.intersects(Self::KEY_LEN) {
Self::ENCRYPT
} else {
Self::empty()
})
}
#[inline]
#[must_use]
const fn key_len(self) -> u8 {
match self.intersection(Self::KEY_LEN).bits() {
0 => 0,
n => n + Self::KEY_OFF,
}
}
const fn test(self, req: Self) -> Result<()> {
use ErrorCode::*;
let want = req.access_type();
let fail = want.intersection(self.access_type().symmetric_difference(want));
if !fail.is_empty() || want.is_empty() {
return Err(match fail {
Self::READ => ReadNotPermitted,
Self::WRITE => WriteNotPermitted,
_ => RequestNotSupported,
});
}
let need = self.security();
let fail = need.intersection(req.security().symmetric_difference(need));
if !fail.is_empty() {
Err(if fail.contains(Self::AUTHZ) {
InsufficientAuthorization
} else if fail.contains(Self::AUTHN) {
InsufficientAuthentication
} else {
InsufficientEncryption
})
} else if req.key_len() < self.key_len() {
Err(EncryptionKeySizeTooShort)
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn access() {
use ErrorCode::*;
fn test(perm: Access, req: Access, want: Result<()>) {
assert_eq!(perm.0.test(req.0), want);
}
let (ro, wo, rw) = (Access::READ, Access::WRITE, Access::READ_WRITE);
test(ro, ro, Ok(()));
test(ro, wo, Err(WriteNotPermitted));
test(ro, rw, Err(WriteNotPermitted));
test(wo, ro, Err(ReadNotPermitted));
test(wo, wo, Ok(()));
test(wo, rw, Err(ReadNotPermitted));
test(rw, ro, Ok(()));
test(rw, wo, Ok(()));
test(rw, rw, Ok(()));
test(rw, Access::NONE, Err(RequestNotSupported));
test(ro.authn(), ro.authn().authz(), Ok(()));
test(wo.authn(), wo.authz(), Err(InsufficientAuthentication));
test(
rw.authn().authz(),
ro.authn(),
Err(InsufficientAuthorization),
);
test(
rw.authn().encrypt(),
wo.authn().authz(),
Err(InsufficientEncryption),
);
test(ro.key_len(0), ro.key_len(0), Ok(()));
test(ro.key_len(56), ro.authn().encrypt(), Ok(()));
test(
rw.key_len(80),
rw.authn().authz().key_len(56),
Err(EncryptionKeySizeTooShort),
);
}
#[test]
fn perms() {
use ErrorCode::*;
fn test(ps: Perms, req: Access, want: Result<()>) {
assert_eq!(ps.test(req), want);
}
let (ro, wo, rw) = (Access::READ, Access::WRITE, Access::READ_WRITE);
let ps = Access::READ.authn() | Access::READ_WRITE.authz().encrypt();
test(ps, Access(Perm::default()), Err(RequestNotSupported));
test(ps, ro, Err(InsufficientAuthentication));
test(ps, ro.authz(), Err(InsufficientAuthentication));
test(ps, ro.authz().key_len(120), Err(EncryptionKeySizeTooShort));
test(ps, ro.authn(), Ok(()));
test(ps, ro.authz().encrypt(), Ok(()));
test(ps, wo, Err(InsufficientAuthorization));
test(ps, wo.authz(), Err(InsufficientEncryption));
test(ps, wo.authz().encrypt(), Ok(()));
test(ps, rw.encrypt(), Err(InsufficientAuthorization));
test(ps, rw.authz().encrypt(), Ok(()));
}
}