use core::num::NonZeroU8;
use crate::crypto::Crypto;
use crate::dm::{
ArrayAttributeRead, ArrayAttributeWrite, AttrChangeNotifier, Cluster, Dataver, InvokeContext,
ReadContext, WriteContext,
};
use crate::dm::{AttrId, EndptId, NodeId};
use crate::error::{Error, ErrorCode};
use crate::fabric::MAX_FABRICS;
use crate::persist::{KvBlobStore, Persist, OTA_PROVIDERS_KEY};
use crate::tlv::{FromTLV, Nullable, Octets, TLVArray, TLVBuilderParent, TLVElement, ToTLV};
use crate::transport::exchange::Exchange;
use crate::utils::cell::RefCell;
use crate::utils::init::{init, Init};
use crate::utils::storage::Vec;
use crate::utils::sync::blocking::Mutex;
use crate::utils::sync::Notification;
use crate::with;
use crate::Matter;
pub use crate::dm::clusters::decl::ota_software_update_requestor::*;
use crate::dm::clusters::decl::ota_software_update_provider::{
ApplyUpdateActionEnum, DownloadProtocolEnum, OtaSoftwareUpdateProviderClient,
QueryImageResponse,
};
use crate::dm::clusters::ota_prov::OtaApplyOutcome;
const ANNOUNCED_PROVIDERS: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromTLV, ToTLV)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Provider {
pub fab_idx: NonZeroU8,
pub node_id: NodeId,
pub endpoint: EndptId,
}
impl Provider {
pub async fn query<C, F, R>(
&self,
matter: &Matter<'_>,
crypto: C,
protocols: &[DownloadProtocolEnum],
current_version: Option<u32>,
requestor_can_consent: bool,
f: F,
) -> Result<R, Error>
where
C: Crypto,
F: FnOnce(&QueryImageResponse<'_>) -> Result<R, Error>,
{
let dev = matter.dev_det();
let version = current_version.unwrap_or(dev.sw_ver);
let location = matter.with_state(|state| state.basic_info_settings.location.clone());
let location = location.as_deref();
let exchange = Exchange::initiate(matter, crypto, self.fab_idx, self.node_id).await?;
let handle = exchange
.ota_software_update_provider()
.query_image(self.endpoint, |b| {
let mut protos = b
.vendor_id(dev.vid)?
.product_id(dev.pid)?
.software_version(version)?
.protocols_supported()?;
for proto in protocols {
protos = protos.push(proto)?;
}
protos
.end()?
.hardware_version(Some(dev.hw_ver))?
.location(location)?
.requestor_can_consent(Some(requestor_can_consent))?
.metadata_for_provider(None)?
.end()
})
.await?;
let result = {
let response = handle.response()?;
f(&response)
};
handle.complete().await?;
result
}
pub async fn apply_update<C>(
&self,
matter: &Matter<'_>,
crypto: C,
update_token: &[u8],
new_version: u32,
) -> Result<OtaApplyOutcome, Error>
where
C: Crypto,
{
let exchange = Exchange::initiate(matter, crypto, self.fab_idx, self.node_id).await?;
let handle = exchange
.ota_software_update_provider()
.apply_update_request(self.endpoint, |b| {
b.update_token(Octets(update_token))?
.new_version(new_version)?
.end()
})
.await?;
let outcome = {
let response = handle.response()?;
let delay_secs = response.delayed_action_time()?;
match response.action()? {
ApplyUpdateActionEnum::Proceed => OtaApplyOutcome::Proceed { delay_secs },
ApplyUpdateActionEnum::AwaitNextAction => OtaApplyOutcome::Await { delay_secs },
ApplyUpdateActionEnum::Discontinue => OtaApplyOutcome::Discontinue,
}
};
handle.complete().await?;
Ok(outcome)
}
pub async fn notify_applied<C>(
&self,
matter: &Matter<'_>,
crypto: C,
update_token: &[u8],
software_version: u32,
) -> Result<(), Error>
where
C: Crypto,
{
let exchange = Exchange::initiate(matter, crypto, self.fab_idx, self.node_id).await?;
exchange
.ota_software_update_provider()
.notify_update_applied(self.endpoint, |b| {
b.update_token(Octets(update_token))?
.software_version(software_version)?
.end()
})
.await
}
}
pub struct Providers {
state: Mutex<RefCell<ProvidersState>>,
changed: Notification,
}
struct ProvidersState {
default: Vec<Provider, MAX_FABRICS>,
announced: Vec<Provider, ANNOUNCED_PROVIDERS>,
}
impl ProvidersState {
fn init() -> impl Init<Self> {
init!(Self {
default <- Vec::init(),
announced <- Vec::init(),
})
}
}
impl Providers {
pub const fn new() -> Self {
Self {
state: Mutex::new(RefCell::new(ProvidersState {
default: Vec::new(),
announced: Vec::new(),
})),
changed: Notification::new(),
}
}
pub fn init() -> impl Init<Self> {
init!(Self {
state <- Mutex::init(RefCell::init(ProvidersState::init())),
changed: Notification::new(),
})
}
pub async fn load_persist<S: KvBlobStore>(
&self,
mut store: S,
buf: &mut [u8],
) -> Result<(), Error> {
let Some(data) = store.load(OTA_PROVIDERS_KEY, buf)? else {
self.state.lock(|cell| cell.borrow_mut().default.clear());
return Ok(());
};
let loaded = Vec::<Provider, MAX_FABRICS>::from_tlv(&TLVElement::new(data))?;
self.state.lock(|cell| cell.borrow_mut().default = loaded);
info!("Loaded OTA provider entries from storage");
Ok(())
}
fn store_persist<C: WriteContext>(&self, ctx: &C) -> Result<(), Error> {
let mut persist = Persist::new(ctx.kv());
self.state.lock(|cell| {
let state = cell.borrow();
persist.store_tlv(OTA_PROVIDERS_KEY, &state.default)
})?;
persist.run()
}
pub fn len(&self) -> usize {
self.state.lock(|cell| cell.borrow().default.len())
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get(&self, index: usize) -> Option<Provider> {
self.state
.lock(|cell| cell.borrow().default.get(index).copied())
}
pub fn announced_len(&self) -> usize {
self.state.lock(|cell| cell.borrow().announced.len())
}
pub fn announced(&self, index: usize) -> Option<Provider> {
self.state
.lock(|cell| cell.borrow().announced.get(index).copied())
}
pub fn clear_announced(&self) {
self.state.lock(|cell| cell.borrow_mut().announced.clear());
}
pub fn take_announced(&self) -> Vec<Provider, ANNOUNCED_PROVIDERS> {
self.state.lock(|cell| {
let mut state = cell.borrow_mut();
let taken = state.announced.clone();
state.announced.clear();
taken
})
}
pub async fn wait_changed(&self) {
self.changed.wait().await;
}
fn replace_default<C: WriteContext>(
&self,
ctx: &C,
fab_idx: NonZeroU8,
provider: Option<Provider>,
) -> Result<(), Error> {
self.state.lock(|cell| {
let mut state = cell.borrow_mut();
state.default.retain(|p| p.fab_idx != fab_idx);
if let Some(provider) = provider {
state
.default
.push(provider)
.map_err(|_| ErrorCode::ResourceExhausted)?;
}
Ok::<_, Error>(())
})?;
self.changed.notify();
self.store_persist(ctx)
}
fn add_default<C: WriteContext>(
&self,
ctx: &C,
fab_idx: NonZeroU8,
provider: Provider,
) -> Result<(), Error> {
self.state.lock(|cell| {
let mut state = cell.borrow_mut();
if state.default.iter().any(|p| p.fab_idx == fab_idx) {
return Err(ErrorCode::ConstraintError.into());
}
state
.default
.push(provider)
.map_err(|_| ErrorCode::ResourceExhausted)?;
Ok::<_, Error>(())
})?;
self.changed.notify();
self.store_persist(ctx)
}
fn add_announced(&self, provider: Provider) {
self.state.lock(|cell| {
let mut state = cell.borrow_mut();
state
.announced
.retain(|p| !(p.fab_idx == provider.fab_idx && p.node_id == provider.node_id));
if state.announced.is_full() {
state.announced.remove(0);
}
let _ = state.announced.push(provider);
});
self.changed.notify();
}
fn render<P: TLVBuilderParent>(
&self,
fab_filter: Option<NonZeroU8>,
builder: ArrayAttributeRead<ProviderLocationArrayBuilder<P>, ProviderLocationBuilder<P>>,
) -> Result<P, Error> {
self.state.lock(|cell| {
let state = cell.borrow();
let mut iter = state
.default
.iter()
.filter(|p| fab_filter.is_none_or(|f| p.fab_idx == f));
match builder {
ArrayAttributeRead::ReadAll(mut array) => {
for p in iter {
array = array
.push()?
.provider_node_id(p.node_id)?
.endpoint(p.endpoint)?
.fabric_index(Some(p.fab_idx.get()))?
.end()?;
}
array.end()
}
ArrayAttributeRead::ReadOne(index, item) => {
let Some(p) = iter.nth(index as usize) else {
return Err(ErrorCode::ConstraintError.into());
};
item.provider_node_id(p.node_id)?
.endpoint(p.endpoint)?
.fabric_index(Some(p.fab_idx.get()))?
.end()
}
ArrayAttributeRead::ReadNone(array) => array.end(),
}
})
}
}
impl Default for Providers {
fn default() -> Self {
Self::new()
}
}
pub struct OtaState {
endpoint_id: EndptId,
reported: Mutex<RefCell<Reported>>,
}
struct Reported {
update_state: UpdateStateEnum,
progress: Option<u8>,
update_possible: bool,
}
impl OtaState {
pub const fn new(endpoint_id: EndptId) -> Self {
Self {
endpoint_id,
reported: Mutex::new(RefCell::new(Reported {
update_state: UpdateStateEnum::Idle,
progress: None,
update_possible: true,
})),
}
}
pub fn set_update_possible(&self, notifier: &dyn AttrChangeNotifier, possible: bool) {
self.reported
.lock(|cell| cell.borrow_mut().update_possible = possible);
self.notify(notifier, AttributeId::UpdatePossible as _);
}
fn update_state(&self) -> UpdateStateEnum {
self.reported.lock(|cell| cell.borrow().update_state)
}
fn progress(&self) -> Option<u8> {
self.reported.lock(|cell| cell.borrow().progress)
}
fn update_possible(&self) -> bool {
self.reported.lock(|cell| cell.borrow().update_possible)
}
fn report(
&self,
notifier: &dyn AttrChangeNotifier,
state: UpdateStateEnum,
progress: Option<u8>,
) {
self.reported.lock(|cell| {
let mut reported = cell.borrow_mut();
reported.update_state = state;
reported.progress = progress;
});
notifier.notify_cluster_changed(self.endpoint_id, FULL_CLUSTER.id);
}
fn notify(&self, notifier: &dyn AttrChangeNotifier, attr_id: AttrId) {
notifier.notify_attr_changed(self.endpoint_id, FULL_CLUSTER.id, attr_id);
}
pub fn initiate_update<'a>(&'a self, notifier: &'a dyn AttrChangeNotifier) -> OtaUpdate<'a> {
OtaUpdate {
state: self,
notifier,
done: false,
}
}
}
pub struct OtaUpdate<'a> {
state: &'a OtaState,
notifier: &'a dyn AttrChangeNotifier,
done: bool,
}
impl OtaUpdate<'_> {
pub fn querying(&self) {
self.state
.report(self.notifier, UpdateStateEnum::Querying, None);
}
pub fn downloading(&self, percent: Option<u8>) {
self.state
.report(self.notifier, UpdateStateEnum::Downloading, percent);
}
pub fn applying(&self) {
self.state
.report(self.notifier, UpdateStateEnum::Applying, None);
}
pub fn report(&self, state: UpdateStateEnum, progress: Option<u8>) {
self.state.report(self.notifier, state, progress);
}
pub fn complete(mut self) {
self.state
.report(self.notifier, UpdateStateEnum::Idle, None);
self.done = true;
}
}
impl Drop for OtaUpdate<'_> {
fn drop(&mut self) {
if !self.done {
self.state
.report(self.notifier, UpdateStateEnum::Idle, None);
}
}
}
pub struct OtaRequestorHandler<'a> {
dataver: Dataver,
providers: &'a Providers,
state: &'a OtaState,
}
impl<'a> OtaRequestorHandler<'a> {
pub const fn new(dataver: Dataver, providers: &'a Providers, state: &'a OtaState) -> Self {
Self {
dataver,
providers,
state,
}
}
pub const fn adapt(self) -> HandlerAdaptor<Self> {
HandlerAdaptor(self)
}
}
impl ClusterHandler for OtaRequestorHandler<'_> {
const CLUSTER: Cluster<'static> = FULL_CLUSTER.with_attrs(with!(required));
fn dataver(&self) -> u32 {
self.dataver.get()
}
fn dataver_changed(&self) {
self.dataver.changed();
}
fn default_ota_providers<P: TLVBuilderParent>(
&self,
ctx: impl ReadContext,
builder: ArrayAttributeRead<ProviderLocationArrayBuilder<P>, ProviderLocationBuilder<P>>,
) -> Result<P, Error> {
let attr = ctx.attr();
let fab_filter = if attr.fab_filter {
Some(NonZeroU8::new(attr.fab_idx).ok_or(ErrorCode::UnsupportedAccess)?)
} else {
None
};
self.providers.render(fab_filter, builder)
}
fn update_possible(&self, _ctx: impl ReadContext) -> Result<bool, Error> {
Ok(self.state.update_possible())
}
fn update_state(&self, _ctx: impl ReadContext) -> Result<UpdateStateEnum, Error> {
Ok(self.state.update_state())
}
fn update_state_progress(&self, _ctx: impl ReadContext) -> Result<Nullable<u8>, Error> {
Ok(self
.state
.progress()
.map(Nullable::some)
.unwrap_or_else(Nullable::none))
}
fn set_default_ota_providers(
&self,
ctx: impl WriteContext,
value: ArrayAttributeWrite<TLVArray<'_, ProviderLocation<'_>>, ProviderLocation<'_>>,
) -> Result<(), Error> {
let fab_idx = NonZeroU8::new(ctx.attr().fab_idx).ok_or(ErrorCode::UnsupportedAccess)?;
let to_provider = |loc: &ProviderLocation<'_>| -> Result<Provider, Error> {
Ok(Provider {
fab_idx,
node_id: loc.provider_node_id()?,
endpoint: loc.endpoint()?,
})
};
match value {
ArrayAttributeWrite::Replace(list) => {
let mut iter = list.iter();
let first = iter.next().transpose()?;
if iter.next().is_some() {
return Err(ErrorCode::ConstraintError.into());
}
let parsed = first.map(|loc| to_provider(&loc)).transpose()?;
self.providers.replace_default(&ctx, fab_idx, parsed)?;
}
ArrayAttributeWrite::Add(loc) => {
self.providers
.add_default(&ctx, fab_idx, to_provider(&loc)?)?;
}
ArrayAttributeWrite::Update(_, _) | ArrayAttributeWrite::Remove(_) => {
return Err(ErrorCode::InvalidAction.into());
}
}
ctx.notify_changed();
Ok(())
}
fn handle_announce_ota_provider(
&self,
ctx: impl InvokeContext,
request: AnnounceOTAProviderRequest<'_>,
) -> Result<(), Error> {
let fab_idx = NonZeroU8::new(ctx.cmd().fab_idx).ok_or(ErrorCode::UnsupportedAccess)?;
let provider = Provider {
fab_idx,
node_id: request.provider_node_id()?,
endpoint: request.endpoint()?,
};
self.providers.add_announced(provider);
Ok(())
}
}
pub fn parse_bdx_url(url: &str) -> Result<(NodeId, &str), Error> {
let rest = url.strip_prefix("bdx://").ok_or(ErrorCode::InvalidData)?;
let (node, fd) = rest.split_once('/').ok_or(ErrorCode::InvalidData)?;
if fd.is_empty() {
return Err(ErrorCode::InvalidData.into());
}
let node_id = u64::from_str_radix(node, 16).map_err(|_| ErrorCode::InvalidData)?;
Ok((node_id, fd))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_bdx_url_extracts_node_and_fd() {
let (node, fd) = parse_bdx_url("bdx://00112233AABBCCDD/my-firmware.ota").unwrap();
assert_eq!(node, 0x0011_2233_AABB_CCDD);
assert_eq!(fd, "my-firmware.ota");
assert!(parse_bdx_url("https://example.com/x").is_err());
assert!(parse_bdx_url("bdx://nodeid-no-slash").is_err());
assert!(parse_bdx_url("bdx://zzzz/fd").is_err());
}
#[test]
fn announced_dedup_evict_and_clear() {
let providers = Providers::new();
let provider = |node| Provider {
fab_idx: NonZeroU8::new(1).unwrap(),
node_id: node,
endpoint: 0,
};
providers.add_announced(provider(0xaa));
providers.add_announced(provider(0xaa));
assert_eq!(providers.announced_len(), 1);
for n in 0..(ANNOUNCED_PROVIDERS as u64 + 2) {
providers.add_announced(provider(0x100 + n));
}
assert_eq!(providers.announced_len(), ANNOUNCED_PROVIDERS);
providers.clear_announced();
assert_eq!(providers.announced_len(), 0);
}
}