use core::num::NonZeroU8;
use crate::crypto::AEAD_CANON_KEY_LEN;
use crate::dm::{
ArrayAttributeRead, ArrayAttributeWrite, Cluster, Dataver, InvokeContext, ReadContext,
WriteContext,
};
use crate::error::{Error, ErrorCode};
use crate::fabric::{
FabricPersist, GroupKeyMapping, MAX_GROUPS_PER_FABRIC, MAX_GROUP_KEYS_PER_FABRIC,
};
use crate::group_keys::{GroupEpochKeyEntry, GroupKeySet};
use crate::tlv::{Nullable, Octets, TLVArray, TLVBuilderParent};
use crate::with;
pub use crate::dm::clusters::decl::group_key_management::*;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct GrpKeyMgmtHandler {
dataver: Dataver,
}
impl GrpKeyMgmtHandler {
pub const fn new(dataver: Dataver) -> Self {
Self { dataver }
}
pub const fn adapt(self) -> HandlerAdaptor<Self> {
HandlerAdaptor(self)
}
}
impl ClusterHandler for GrpKeyMgmtHandler {
const CLUSTER: Cluster<'static> = FULL_CLUSTER.with_attrs(with!(required));
fn dataver(&self) -> u32 {
self.dataver.get()
}
fn dataver_changed(&self) {
self.dataver.changed();
}
fn group_key_map<P: TLVBuilderParent>(
&self,
ctx: impl ReadContext,
builder: ArrayAttributeRead<GroupKeyMapStructArrayBuilder<P>, GroupKeyMapStructBuilder<P>>,
) -> Result<P, Error> {
let attr = ctx.attr();
ctx.exchange().with_state(|state| {
let mut entries = state
.fabrics
.iter()
.filter(|fabric| !attr.fab_filter || fabric.fab_idx().get() == attr.fab_idx)
.flat_map(|fabric| {
fabric
.groups()
.key_map_iter()
.map(move |entry| (fabric.fab_idx(), entry))
});
match builder {
ArrayAttributeRead::ReadAll(mut builder) => {
for (fab_idx, entry) in entries {
builder = builder
.push()?
.group_id(entry.group_id)?
.group_key_set_id(entry.group_key_set_id)?
.fabric_index(Some(fab_idx.get()))?
.end()?;
}
builder.end()
}
ArrayAttributeRead::ReadOne(index, builder) => {
let Some((fab_idx, entry)) = entries.nth(index as usize) else {
return Err(ErrorCode::ConstraintError.into());
};
builder
.group_id(entry.group_id)?
.group_key_set_id(entry.group_key_set_id)?
.fabric_index(Some(fab_idx.get()))?
.end()
}
ArrayAttributeRead::ReadNone(builder) => builder.end(),
}
})
}
fn group_table<P: TLVBuilderParent>(
&self,
ctx: impl ReadContext,
builder: ArrayAttributeRead<
GroupInfoMapStructArrayBuilder<P>,
GroupInfoMapStructBuilder<P>,
>,
) -> Result<P, Error> {
let attr = ctx.attr();
ctx.exchange().with_state(|state| {
let mut entries = state
.fabrics
.iter()
.filter(|fabric| !attr.fab_filter || fabric.fab_idx().get() == attr.fab_idx)
.flat_map(|fabric| {
fabric
.groups()
.iter()
.map(move |entry| (fabric.fab_idx(), entry))
});
match builder {
ArrayAttributeRead::ReadAll(mut builder) => {
for (fab_idx, entry) in entries {
let mut endpoints_builder =
builder.push()?.group_id(entry.group_id)?.endpoints()?;
for &ep in entry.endpoints.iter() {
endpoints_builder = endpoints_builder.push(&ep)?;
}
builder = endpoints_builder
.end()?
.group_name(Some(entry.group_name.as_str()))?
.fabric_index(Some(fab_idx.get()))?
.end()?;
}
builder.end()
}
ArrayAttributeRead::ReadOne(index, builder) => {
let Some((fab_idx, entry)) = entries.nth(index as usize) else {
return Err(ErrorCode::ConstraintError.into());
};
let mut endpoints_builder = builder.group_id(entry.group_id)?.endpoints()?;
for &ep in entry.endpoints.iter() {
endpoints_builder = endpoints_builder.push(&ep)?;
}
endpoints_builder
.end()?
.group_name(Some(entry.group_name.as_str()))?
.fabric_index(Some(fab_idx.get()))?
.end()
}
ArrayAttributeRead::ReadNone(builder) => builder.end(),
}
})
}
fn max_groups_per_fabric(&self, _ctx: impl ReadContext) -> Result<u16, Error> {
Ok(MAX_GROUPS_PER_FABRIC as _)
}
fn max_group_keys_per_fabric(&self, _ctx: impl ReadContext) -> Result<u16, Error> {
Ok(MAX_GROUP_KEYS_PER_FABRIC as u16 + 1)
}
fn set_group_key_map(
&self,
ctx: impl WriteContext,
value: ArrayAttributeWrite<TLVArray<'_, GroupKeyMapStruct<'_>>, GroupKeyMapStruct<'_>>,
) -> Result<(), Error> {
let fab_idx = NonZeroU8::new(ctx.attr().fab_idx).ok_or(ErrorCode::UnsupportedAccess)?;
let mut persist = FabricPersist::new(ctx.kv());
ctx.exchange().with_state(|state| {
let fabric = state.fabrics.fabric_mut(fab_idx)?;
match value {
ArrayAttributeWrite::Replace(list) => {
let mut count: usize = 0;
for entry in &list {
count += 1;
if count > MAX_GROUP_KEYS_PER_FABRIC {
return Err(ErrorCode::Failure.into());
}
let entry = entry?;
if entry.group_key_set_id()? == 0 {
return Err(ErrorCode::ConstraintError.into());
}
}
let entries = list.into_iter().filter_map(|entry| {
let entry = entry.ok()?;
Some(GroupKeyMapping {
group_id: entry.group_id().ok()?,
group_key_set_id: entry.group_key_set_id().ok()?,
})
});
fabric.groups_mut().key_map_replace(entries)?;
}
ArrayAttributeWrite::Add(entry) => {
if entry.group_key_set_id()? == 0 {
return Err(ErrorCode::ConstraintError.into());
}
fabric.groups_mut().key_map_add(GroupKeyMapping {
group_id: entry.group_id().map_err(|_| ErrorCode::InvalidCommand)?,
group_key_set_id: entry
.group_key_set_id()
.map_err(|_| ErrorCode::InvalidCommand)?,
})?;
}
_ => {
return Err(ErrorCode::InvalidAction.into());
}
}
if !state.failsafe.is_armed_for(fab_idx.get()) {
persist.store(fabric)?;
}
ctx.exchange().matter().transport().notify_groups_changed();
Ok(())
})?;
persist.run()
}
fn handle_key_set_write(
&self,
ctx: impl InvokeContext,
request: KeySetWriteRequest<'_>,
) -> Result<(), Error> {
let fab_idx = ctx.exchange().accessor()?.fab_idx()?;
let key_set = request.group_key_set()?;
let group_key_set_id = key_set.group_key_set_id()?;
let group_key_security_policy = key_set.group_key_security_policy()?;
if group_key_set_id == 0 {
return Err(ErrorCode::InvalidCommand.into());
}
let epoch_key_0 = key_set.epoch_key_0()?;
let epoch_start_time_0 = key_set.epoch_start_time_0()?;
let epoch_key_1 = key_set.epoch_key_1()?;
let epoch_start_time_1 = key_set.epoch_start_time_1()?;
let epoch_key_2 = key_set.epoch_key_2()?;
let epoch_start_time_2 = key_set.epoch_start_time_2()?;
let Some(epoch_key_0_val) = epoch_key_0.as_opt_ref() else {
return Err(ErrorCode::InvalidCommand.into());
};
let Some(&epoch_start_time_0_val) = epoch_start_time_0.as_opt_ref() else {
return Err(ErrorCode::InvalidCommand.into());
};
if epoch_start_time_0_val == 0 {
return Err(ErrorCode::InvalidCommand.into());
}
if epoch_key_0_val.0.len() != AEAD_CANON_KEY_LEN {
return Err(ErrorCode::ConstraintError.into());
}
let has_epoch_key_1 = epoch_key_1.as_opt_ref().is_some();
let has_epoch_start_time_1 = epoch_start_time_1.as_opt_ref().is_some();
if has_epoch_key_1 != has_epoch_start_time_1 {
return Err(ErrorCode::InvalidCommand.into());
}
let mut entry = GroupKeySet {
group_key_set_id,
group_key_security_policy: group_key_security_policy as u8,
..Default::default()
};
let mut key0 = GroupEpochKeyEntry {
epoch_key: Default::default(),
epoch_start_time: epoch_start_time_0_val,
};
key0.epoch_key
.try_load_from_slice(epoch_key_0_val.0)
.map_err(|_| ErrorCode::ConstraintError)?;
entry
.epoch_keys
.push(key0)
.map_err(|_| Error::from(ErrorCode::ConstraintError))?;
if has_epoch_key_1 {
let epoch_key_1_val = epoch_key_1.as_opt_ref().unwrap();
let &epoch_start_time_1_val = epoch_start_time_1.as_opt_ref().unwrap();
if epoch_key_1_val.0.len() != AEAD_CANON_KEY_LEN {
return Err(ErrorCode::ConstraintError.into());
}
if epoch_start_time_1_val <= epoch_start_time_0_val {
return Err(ErrorCode::InvalidCommand.into());
}
let mut key1 = GroupEpochKeyEntry {
epoch_key: Default::default(),
epoch_start_time: epoch_start_time_1_val,
};
key1.epoch_key
.try_load_from_slice(epoch_key_1_val.0)
.map_err(|_| ErrorCode::ConstraintError)?;
entry
.epoch_keys
.push(key1)
.map_err(|_| Error::from(ErrorCode::ConstraintError))?;
let has_epoch_key_2 = epoch_key_2.as_opt_ref().is_some();
let has_epoch_start_time_2 = epoch_start_time_2.as_opt_ref().is_some();
if has_epoch_key_2 != has_epoch_start_time_2 {
return Err(ErrorCode::InvalidCommand.into());
}
if has_epoch_key_2 {
let epoch_key_2_val = epoch_key_2.as_opt_ref().unwrap();
let &epoch_start_time_2_val = epoch_start_time_2.as_opt_ref().unwrap();
if epoch_key_2_val.0.len() != AEAD_CANON_KEY_LEN {
return Err(ErrorCode::ConstraintError.into());
}
if epoch_start_time_2_val <= epoch_start_time_1_val {
return Err(ErrorCode::InvalidCommand.into());
}
let mut key2 = GroupEpochKeyEntry {
epoch_key: Default::default(),
epoch_start_time: epoch_start_time_2_val,
};
key2.epoch_key
.try_load_from_slice(epoch_key_2_val.0)
.map_err(|_| ErrorCode::ConstraintError)?;
entry
.epoch_keys
.push(key2)
.map_err(|_| Error::from(ErrorCode::ConstraintError))?;
}
} else {
let has_epoch_key_2 = epoch_key_2.as_opt_ref().is_some();
let has_epoch_start_time_2 = epoch_start_time_2.as_opt_ref().is_some();
if has_epoch_key_2 || has_epoch_start_time_2 {
return Err(ErrorCode::InvalidCommand.into());
}
}
let mut persist = FabricPersist::new(ctx.kv());
ctx.exchange().with_state(|state| {
let fabric = state.fabrics.fabric_mut(fab_idx)?;
fabric.groups_mut().key_set_add(entry)?;
if !state.failsafe.is_armed_for(fab_idx.get()) {
persist.store(fabric)?;
}
ctx.exchange().matter().transport().notify_groups_changed();
Ok(())
})?;
ctx.notify_own_cluster_changed();
persist.run()
}
fn handle_key_set_read<P: TLVBuilderParent>(
&self,
ctx: impl InvokeContext,
request: KeySetReadRequest<'_>,
response: KeySetReadResponseBuilder<P>,
) -> Result<P, Error> {
let fab_idx = ctx.exchange().accessor()?.fab_idx()?;
let group_key_set_id = request.group_key_set_id()?;
ctx.exchange().with_state(|state| {
let fabric = state.fabrics.fabric(fab_idx)?;
if group_key_set_id == 0 {
return response
.group_key_set()?
.group_key_set_id(0)?
.group_key_security_policy(GroupKeySecurityPolicyEnum::TrustFirst)?
.epoch_key_0(Nullable::<Octets<'_>>::none())?
.epoch_start_time_0(Nullable::some(0))?
.epoch_key_1(Nullable::<Octets<'_>>::none())?
.epoch_start_time_1(Nullable::none())?
.epoch_key_2(Nullable::<Octets<'_>>::none())?
.epoch_start_time_2(Nullable::none())?
.end()?
.end();
}
let entry = fabric
.groups()
.key_set_get(group_key_set_id)
.ok_or(ErrorCode::NotFound)?;
response
.group_key_set()?
.group_key_set_id(group_key_set_id)?
.group_key_security_policy(
unsafe {
core::mem::transmute::<u8, GroupKeySecurityPolicyEnum>(
entry.group_key_security_policy,
)
},
)?
.epoch_key_0(Nullable::<Octets<'_>>::none())?
.epoch_start_time_0(Nullable::some(entry.epoch_keys[0].epoch_start_time))?
.epoch_key_1(Nullable::<Octets<'_>>::none())?
.epoch_start_time_1(if let Some(k) = entry.epoch_keys.get(1) {
Nullable::some(k.epoch_start_time)
} else {
Nullable::none()
})?
.epoch_key_2(Nullable::<Octets<'_>>::none())?
.epoch_start_time_2(if let Some(k) = entry.epoch_keys.get(2) {
Nullable::some(k.epoch_start_time)
} else {
Nullable::none()
})?
.end()?
.end()
})
}
fn handle_key_set_remove(
&self,
ctx: impl InvokeContext,
request: KeySetRemoveRequest<'_>,
) -> Result<(), Error> {
let fab_idx = ctx.exchange().accessor()?.fab_idx()?;
let group_key_set_id = request.group_key_set_id()?;
if group_key_set_id == 0 {
return Err(ErrorCode::InvalidCommand.into());
}
let mut persist = FabricPersist::new(ctx.kv());
ctx.exchange().with_state(|state| {
let fabric = state.fabrics.fabric_mut(fab_idx)?;
fabric.groups_mut().key_set_remove(group_key_set_id)?;
if !state.failsafe.is_armed_for(fab_idx.get()) {
persist.store(fabric)?;
}
ctx.exchange().matter().transport().notify_groups_changed();
Ok(())
})?;
ctx.notify_own_cluster_changed();
persist.run()
}
fn handle_key_set_read_all_indices<P: TLVBuilderParent>(
&self,
ctx: impl InvokeContext,
response: KeySetReadAllIndicesResponseBuilder<P>,
) -> Result<P, Error> {
let fab_idx = ctx.exchange().accessor()?.fab_idx()?;
ctx.exchange().with_state(|state| {
let fabric = state.fabrics.fabric(fab_idx)?;
let mut ids = response.group_key_set_i_ds()?;
ids = ids.push(&0u16)?;
for entry in fabric.groups().key_set_iter() {
ids = ids.push(&entry.group_key_set_id)?;
}
ids.end()?.end()
})
}
}