use core::cell::{Cell, RefCell};
use core::future::Future;
use crate::dm::clusters::decl::globals::{
ICECandidateStruct, StreamUsageEnum, WebRTCEndReasonEnum, WebRTCSessionStructArrayBuilder,
WebRTCSessionStructBuilder,
};
use crate::dm::{
ArrayAttributeRead, Cluster, Dataver, EndptId, HandlerContext, InvokeContext, ReadContext,
};
use crate::error::{Error, ErrorCode};
use crate::tlv::{Nullable, TLVArray, TLVBuilderParent};
use crate::transport::exchange::Exchange;
use crate::utils::storage::Vec;
use crate::utils::sync::blocking::Mutex;
use crate::with;
use super::super::decl::web_rtc_transport_provider as decl;
use super::super::decl::web_rtc_transport_requestor::WebRtcTransportRequestorClient;
#[allow(unused_imports)]
pub use crate::dm::clusters::decl::web_rtc_transport_provider::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum WebRtcError {
InvalidInState,
InvalidCommand,
DynamicConstraint,
ResourceExhausted,
Failure,
}
impl From<WebRtcError> for Error {
fn from(e: WebRtcError) -> Self {
match e {
WebRtcError::InvalidInState => ErrorCode::InvalidAction.into(),
WebRtcError::InvalidCommand => ErrorCode::InvalidCommand.into(),
WebRtcError::DynamicConstraint => ErrorCode::DynamicConstraintError.into(),
WebRtcError::ResourceExhausted => ErrorCode::ResourceExhausted.into(),
WebRtcError::Failure => ErrorCode::Failure.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct OfferParams {
pub stream_usage: StreamUsageEnum,
pub originating_endpoint_id: EndptId,
pub video_stream_id: Option<Option<u16>>,
pub audio_stream_id: Option<Option<u16>>,
pub metadata_enabled: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SolicitOutcome {
pub deferred: bool,
pub video_stream_id: Option<u16>,
pub audio_stream_id: Option<u16>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct AnswerOutcome {
pub video_stream_id: Option<u16>,
pub audio_stream_id: Option<u16>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum OutboundWork {
Offer {
session_id: u16,
},
Answer {
session_id: u16,
},
IceCandidates {
session_id: u16,
},
End {
session_id: u16,
reason: WebRTCEndReasonEnum,
},
}
pub trait IceCandidateSink {
fn push(&mut self, candidate: &str) -> Result<(), WebRtcError>;
}
struct VecCandidateSink<'a, const CAND_LEN: usize, const MAX_CAND: usize> {
buf: &'a mut Vec<heapless::String<CAND_LEN>, MAX_CAND>,
}
impl<const CAND_LEN: usize, const MAX_CAND: usize> IceCandidateSink
for VecCandidateSink<'_, CAND_LEN, MAX_CAND>
{
fn push(&mut self, candidate: &str) -> Result<(), WebRtcError> {
let mut s = heapless::String::<CAND_LEN>::new();
s.push_str(candidate)
.map_err(|_| WebRtcError::ResourceExhausted)?;
self.buf
.push(s)
.map_err(|_| WebRtcError::ResourceExhausted)?;
Ok(())
}
}
pub trait WebRtcHooks {
async fn on_solicit_offer(
&self,
session_id: u16,
params: &OfferParams,
) -> Result<SolicitOutcome, WebRtcError>;
async fn on_offer(
&self,
session_id: u16,
sdp: &str,
params: &OfferParams,
) -> Result<AnswerOutcome, WebRtcError>;
async fn on_answer(&self, session_id: u16, sdp: &str) -> Result<(), WebRtcError>;
async fn on_ice_candidates(
&self,
session_id: u16,
candidates: &TLVArray<'_, ICECandidateStruct<'_>>,
) -> Result<(), WebRtcError>;
async fn on_end_session(
&self,
session_id: u16,
reason: WebRTCEndReasonEnum,
) -> Result<(), WebRtcError>;
async fn next_outbound(&self) -> OutboundWork {
core::future::pending().await
}
async fn take_ice_candidates(
&self,
_session_id: u16,
_out: &mut dyn IceCandidateSink,
) -> Result<(), WebRtcError> {
Err(WebRtcError::InvalidInState)
}
async fn take_answer_sdp(
&self,
_session_id: u16,
_sdp_out: &mut [u8],
) -> Result<usize, WebRtcError> {
Err(WebRtcError::InvalidInState)
}
async fn take_offer_sdp(
&self,
_session_id: u16,
_sdp_out: &mut [u8],
) -> Result<usize, WebRtcError> {
Err(WebRtcError::InvalidInState)
}
}
impl<T> WebRtcHooks for &T
where
T: WebRtcHooks,
{
fn on_solicit_offer(
&self,
session_id: u16,
params: &OfferParams,
) -> impl Future<Output = Result<SolicitOutcome, WebRtcError>> {
(*self).on_solicit_offer(session_id, params)
}
fn on_offer(
&self,
session_id: u16,
sdp: &str,
params: &OfferParams,
) -> impl Future<Output = Result<AnswerOutcome, WebRtcError>> {
(*self).on_offer(session_id, sdp, params)
}
fn on_answer(
&self,
session_id: u16,
sdp: &str,
) -> impl Future<Output = Result<(), WebRtcError>> {
(*self).on_answer(session_id, sdp)
}
fn on_ice_candidates(
&self,
session_id: u16,
candidates: &TLVArray<'_, ICECandidateStruct<'_>>,
) -> impl Future<Output = Result<(), WebRtcError>> {
(*self).on_ice_candidates(session_id, candidates)
}
fn on_end_session(
&self,
session_id: u16,
reason: WebRTCEndReasonEnum,
) -> impl Future<Output = Result<(), WebRtcError>> {
(*self).on_end_session(session_id, reason)
}
fn next_outbound(&self) -> impl Future<Output = OutboundWork> {
(*self).next_outbound()
}
fn take_ice_candidates(
&self,
session_id: u16,
out: &mut dyn IceCandidateSink,
) -> impl Future<Output = Result<(), WebRtcError>> {
(*self).take_ice_candidates(session_id, out)
}
fn take_answer_sdp(
&self,
session_id: u16,
sdp_out: &mut [u8],
) -> impl Future<Output = Result<usize, WebRtcError>> {
(*self).take_answer_sdp(session_id, sdp_out)
}
fn take_offer_sdp(
&self,
session_id: u16,
sdp_out: &mut [u8],
) -> impl Future<Output = Result<usize, WebRtcError>> {
(*self).take_offer_sdp(session_id, sdp_out)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
enum SessionState {
AwaitingDeferredOffer,
AwaitingAnswer,
Established,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
struct SessionEntry {
id: u16,
fab_idx: u8,
peer_node_id: u64,
peer_endpoint_id: EndptId,
stream_usage: StreamUsageEnum,
video_stream_id: Option<u16>,
audio_stream_id: Option<u16>,
metadata_enabled: bool,
state: SessionState,
}
pub struct WebRtcProvHandler<
H: WebRtcHooks,
const N_SESSIONS: usize,
const SDP_LEN: usize,
const OUT_LEN: usize,
const CAND_LEN: usize,
const MAX_CAND: usize,
> {
dataver: Dataver,
endpoint_id: EndptId,
hooks: H,
sessions: Mutex<RefCell<Vec<SessionEntry, N_SESSIONS>>>,
next_id: Mutex<Cell<u16>>,
}
impl<
H: WebRtcHooks,
const N_SESSIONS: usize,
const SDP_LEN: usize,
const OUT_LEN: usize,
const CAND_LEN: usize,
const MAX_CAND: usize,
> WebRtcProvHandler<H, N_SESSIONS, SDP_LEN, OUT_LEN, CAND_LEN, MAX_CAND>
{
pub const CLUSTER: Cluster<'static> = decl::FULL_CLUSTER
.with_revision(2)
.with_attrs(with!(required))
.with_cmds(with!(
decl::CommandId::SolicitOffer
| decl::CommandId::ProvideOffer
| decl::CommandId::ProvideAnswer
| decl::CommandId::ProvideICECandidates
| decl::CommandId::EndSession
));
pub const fn new(dataver: Dataver, endpoint_id: EndptId, hooks: H) -> Self {
Self {
dataver,
endpoint_id,
hooks,
sessions: Mutex::new(RefCell::new(Vec::new())),
next_id: Mutex::new(Cell::new(1)),
}
}
pub const fn adapt(self) -> decl::HandlerAsyncAdaptor<Self> {
decl::HandlerAsyncAdaptor(self)
}
pub fn remove_fabric_sessions(&self, fab_idx: u8) {
let changed = self.sessions.lock(|cell| {
let mut sessions = cell.borrow_mut();
let before = sessions.len();
sessions.retain(|s| s.fab_idx != fab_idx);
before != sessions.len()
});
if changed {
self.dataver.changed();
}
}
pub const fn endpoint_id(&self) -> EndptId {
self.endpoint_id
}
fn allocate_id(&self) -> u16 {
self.sessions.lock(|cell| {
let sessions = cell.borrow();
self.next_id.lock(|n| loop {
let candidate = n.get();
let next = if candidate == u16::MAX {
1
} else {
candidate + 1
};
n.set(next);
if candidate != 0 && !sessions.iter().any(|s| s.id == candidate) {
return candidate;
}
})
})
}
fn session_copy(&self, id: u16) -> Option<SessionEntry> {
self.sessions
.lock(|cell| cell.borrow().iter().find(|s| s.id == id).copied())
}
fn upsert_session(&self, entry: SessionEntry) -> Result<(), Error> {
self.sessions.lock(|cell| {
let mut sessions = cell.borrow_mut();
if let Some(existing) = sessions.iter_mut().find(|s| s.id == entry.id) {
*existing = entry;
Ok(())
} else {
sessions
.push(entry)
.map_err(|_| Error::from(ErrorCode::ResourceExhausted))
}
})
}
fn remove_session(&self, id: u16) {
self.sessions.lock(|cell| {
let mut sessions = cell.borrow_mut();
sessions.retain(|s| s.id != id);
});
}
fn set_state(&self, id: u16, state: SessionState) {
self.sessions.lock(|cell| {
if let Some(s) = cell.borrow_mut().iter_mut().find(|s| s.id == id) {
s.state = state;
}
});
}
fn check_peer(&self, s: &SessionEntry, fab_idx: u8, peer: u64) -> Result<(), Error> {
if s.fab_idx != fab_idx || s.peer_node_id != peer {
Err(ErrorCode::NotFound.into())
} else {
Ok(())
}
}
async fn push_outbound(
&self,
ctx: &impl HandlerContext,
work: OutboundWork,
) -> Result<(), Error> {
let session_id = match work {
OutboundWork::Offer { session_id } => session_id,
OutboundWork::Answer { session_id } => session_id,
OutboundWork::IceCandidates { session_id } => session_id,
OutboundWork::End { session_id, .. } => session_id,
};
let Some(session) = self.session_copy(session_id) else {
return Ok(());
};
let mut sdp_buf = [0u8; SDP_LEN];
let sdp_len = match &work {
OutboundWork::Offer { session_id } => self
.hooks
.take_offer_sdp(*session_id, &mut sdp_buf)
.await
.map_err(Error::from)?,
OutboundWork::Answer { session_id } => self
.hooks
.take_answer_sdp(*session_id, &mut sdp_buf)
.await
.map_err(Error::from)?,
_ => 0,
};
let sdp = core::str::from_utf8(&sdp_buf[..sdp_len])
.map_err(|_| Error::from(ErrorCode::Invalid))?;
if let OutboundWork::Offer { session_id } = &work {
self.set_state(*session_id, SessionState::AwaitingAnswer);
}
let fab_idx = core::num::NonZeroU8::new(session.fab_idx).ok_or(ErrorCode::Invalid)?;
let exchange =
Exchange::initiate(ctx.matter(), ctx.crypto(), fab_idx, session.peer_node_id).await?;
let endpoint = session.peer_endpoint_id;
match &work {
OutboundWork::Offer { session_id } => {
let session_id = *session_id;
exchange
.web_rtc_transport_requestor()
.offer(endpoint, |req| {
req.web_rtc_session_id(session_id)?
.sdp(sdp)?
.ice_servers()?
.none()
.ice_transport_policy(None)?
.end()
})
.await?;
}
OutboundWork::Answer { session_id } => {
let session_id = *session_id;
exchange
.web_rtc_transport_requestor()
.answer(endpoint, |req| {
req.web_rtc_session_id(session_id)?.sdp(sdp)?.end()
})
.await?;
}
OutboundWork::IceCandidates { session_id } => {
let session_id = *session_id;
let mut ice_buf: crate::utils::storage::Vec<heapless::String<CAND_LEN>, MAX_CAND> =
crate::utils::storage::Vec::new();
{
let mut sink = VecCandidateSink { buf: &mut ice_buf };
self.hooks
.take_ice_candidates(session_id, &mut sink)
.await
.map_err(Error::from)?;
}
exchange
.web_rtc_transport_requestor()
.ice_candidates(endpoint, |req| {
let req = req.web_rtc_session_id(session_id)?;
let mut arr = req.ice_candidates()?;
for cand in ice_buf.iter() {
arr = arr
.push()?
.candidate(cand.as_str())?
.sdp_mid(Nullable::none())?
.sdpm_line_index(Nullable::none())?
.end()?;
}
arr.end()?.end()
})
.await?;
}
OutboundWork::End { session_id, reason } => {
let session_id = *session_id;
let reason = *reason;
exchange
.web_rtc_transport_requestor()
.end(endpoint, |req| {
req.web_rtc_session_id(session_id)?.reason(reason)?.end()
})
.await?;
}
}
Ok(())
}
}
impl<
H: WebRtcHooks,
const N_SESSIONS: usize,
const SDP_LEN: usize,
const OUT_LEN: usize,
const CAND_LEN: usize,
const MAX_CAND: usize,
> decl::ClusterAsyncHandler
for WebRtcProvHandler<H, N_SESSIONS, SDP_LEN, OUT_LEN, CAND_LEN, MAX_CAND>
{
const CLUSTER: Cluster<'static> = Self::CLUSTER;
fn dataver(&self) -> u32 {
self.dataver.get()
}
fn dataver_changed(&self) {
self.dataver.changed();
}
async fn run(&self, ctx: impl HandlerContext) -> Result<(), Error> {
loop {
let work = self.hooks.next_outbound().await;
if let Err(err) = self.push_outbound(&ctx, work).await {
warn!("webrtc_prov: outbound push failed: {}", err);
}
}
}
async fn current_sessions<P: TLVBuilderParent>(
&self,
ctx: impl ReadContext,
builder: ArrayAttributeRead<
WebRTCSessionStructArrayBuilder<P>,
WebRTCSessionStructBuilder<P>,
>,
) -> Result<P, Error> {
let attr = ctx.attr();
let mut snapshot: Vec<SessionEntry, N_SESSIONS> = Vec::new();
self.sessions.lock(|cell| {
for s in cell.borrow().iter() {
if !attr.fab_filter || s.fab_idx == attr.fab_idx {
let _ = snapshot.push(*s);
}
}
});
match builder {
ArrayAttributeRead::ReadAll(mut arr) => {
for s in &snapshot {
arr = encode_session_struct(arr.push()?, s)?;
}
arr.end()
}
ArrayAttributeRead::ReadOne(index, b) => {
let s = snapshot
.get(index as usize)
.ok_or(Error::from(ErrorCode::ConstraintError))?;
encode_session_struct(b, s)
}
ArrayAttributeRead::ReadNone(b) => b.end(),
}
}
async fn handle_solicit_offer<P: TLVBuilderParent>(
&self,
ctx: impl InvokeContext,
request: decl::SolicitOfferRequest<'_>,
response: decl::SolicitOfferResponseBuilder<P>,
) -> Result<P, Error> {
let cmd = ctx.cmd();
let fab_idx = cmd.fab_idx;
let peer_node_id = exchange_peer_node_id(ctx.exchange())?;
let params = OfferParams {
stream_usage: request.stream_usage()?,
originating_endpoint_id: request.originating_endpoint_id()?,
video_stream_id: request.video_stream_id()?.map(|n| n.into_option()),
audio_stream_id: request.audio_stream_id()?.map(|n| n.into_option()),
metadata_enabled: request.metadata_enabled()?.unwrap_or(false),
};
let session_id = self.allocate_id();
let outcome = self
.hooks
.on_solicit_offer(session_id, ¶ms)
.await
.map_err(Error::from)?;
let state = if outcome.deferred {
SessionState::AwaitingDeferredOffer
} else {
SessionState::AwaitingAnswer
};
self.upsert_session(SessionEntry {
id: session_id,
fab_idx,
peer_node_id,
peer_endpoint_id: params.originating_endpoint_id,
stream_usage: params.stream_usage,
video_stream_id: outcome.video_stream_id,
audio_stream_id: outcome.audio_stream_id,
metadata_enabled: params.metadata_enabled,
state,
})?;
ctx.notify_own_attr_changed(AttributeId::CurrentSessions as _);
response
.web_rtc_session_id(session_id)?
.deferred_offer(outcome.deferred)?
.video_stream_id(wrap_opt_u16_nullable(outcome.video_stream_id))?
.audio_stream_id(wrap_opt_u16_nullable(outcome.audio_stream_id))?
.end()
}
async fn handle_provide_offer<P: TLVBuilderParent>(
&self,
ctx: impl InvokeContext,
request: decl::ProvideOfferRequest<'_>,
response: decl::ProvideOfferResponseBuilder<P>,
) -> Result<P, Error> {
let cmd = ctx.cmd();
let fab_idx = cmd.fab_idx;
let peer_node_id = exchange_peer_node_id(ctx.exchange())?;
let sdp = request.sdp()?;
if sdp.len() > SDP_LEN {
return Err(ErrorCode::ConstraintError.into());
}
let params = OfferParams {
stream_usage: request.stream_usage()?,
originating_endpoint_id: request.originating_endpoint_id()?,
video_stream_id: request.video_stream_id()?.map(|n| n.into_option()),
audio_stream_id: request.audio_stream_id()?.map(|n| n.into_option()),
metadata_enabled: request.metadata_enabled()?.unwrap_or(false),
};
let session_id = match request.web_rtc_session_id()?.into_option() {
None => self.allocate_id(),
Some(id) => {
let s = self
.session_copy(id)
.ok_or(Error::from(ErrorCode::NotFound))?;
self.check_peer(&s, fab_idx, peer_node_id)?;
id
}
};
let outcome = self
.hooks
.on_offer(session_id, sdp, ¶ms)
.await
.map_err(Error::from)?;
self.upsert_session(SessionEntry {
id: session_id,
fab_idx,
peer_node_id,
peer_endpoint_id: params.originating_endpoint_id,
stream_usage: params.stream_usage,
video_stream_id: outcome.video_stream_id,
audio_stream_id: outcome.audio_stream_id,
metadata_enabled: params.metadata_enabled,
state: SessionState::Established,
})?;
ctx.notify_own_attr_changed(AttributeId::CurrentSessions as _);
response
.web_rtc_session_id(session_id)?
.video_stream_id(wrap_opt_u16_nullable(outcome.video_stream_id))?
.audio_stream_id(wrap_opt_u16_nullable(outcome.audio_stream_id))?
.end()
}
async fn handle_provide_answer(
&self,
ctx: impl InvokeContext,
request: decl::ProvideAnswerRequest<'_>,
) -> Result<(), Error> {
let cmd = ctx.cmd();
let fab_idx = cmd.fab_idx;
let peer_node_id = exchange_peer_node_id(ctx.exchange())?;
let session_id = request.web_rtc_session_id()?;
let sdp = request.sdp()?;
if sdp.len() > SDP_LEN {
return Err(ErrorCode::ConstraintError.into());
}
let session = self
.session_copy(session_id)
.ok_or(Error::from(ErrorCode::NotFound))?;
self.check_peer(&session, fab_idx, peer_node_id)?;
match session.state {
SessionState::AwaitingAnswer | SessionState::AwaitingDeferredOffer => {}
SessionState::Established => return Err(ErrorCode::InvalidAction.into()),
}
self.hooks
.on_answer(session_id, sdp)
.await
.map_err(Error::from)?;
self.set_state(session_id, SessionState::Established);
Ok(())
}
async fn handle_provide_ice_candidates(
&self,
ctx: impl InvokeContext,
request: decl::ProvideICECandidatesRequest<'_>,
) -> Result<(), Error> {
let cmd = ctx.cmd();
let fab_idx = cmd.fab_idx;
let peer_node_id = exchange_peer_node_id(ctx.exchange())?;
let session_id = request.web_rtc_session_id()?;
let session = self
.session_copy(session_id)
.ok_or(Error::from(ErrorCode::NotFound))?;
self.check_peer(&session, fab_idx, peer_node_id)?;
let candidates = request.ice_candidates()?;
self.hooks
.on_ice_candidates(session_id, &candidates)
.await
.map_err(Error::from)
}
async fn handle_end_session(
&self,
ctx: impl InvokeContext,
request: decl::EndSessionRequest<'_>,
) -> Result<(), Error> {
let cmd = ctx.cmd();
let fab_idx = cmd.fab_idx;
let peer_node_id = exchange_peer_node_id(ctx.exchange())?;
let session_id = request.web_rtc_session_id()?;
let reason = request.reason()?;
let session = self
.session_copy(session_id)
.ok_or(Error::from(ErrorCode::NotFound))?;
self.check_peer(&session, fab_idx, peer_node_id)?;
let _ = self.hooks.on_end_session(session_id, reason).await;
self.remove_session(session_id);
ctx.notify_own_attr_changed(AttributeId::CurrentSessions as _);
Ok(())
}
}
fn exchange_peer_node_id(exchange: &Exchange<'_>) -> Result<u64, Error> {
exchange.with_state(|state| {
let sess = exchange.id().session(&mut state.sessions);
sess.get_peer_node_id().ok_or(ErrorCode::Invalid.into())
})
}
fn wrap_opt_u16_nullable(v: Option<u16>) -> Option<Nullable<u16>> {
Some(match v {
Some(x) => Nullable::some(x),
None => Nullable::none(),
})
}
fn encode_session_struct<P: TLVBuilderParent>(
b: WebRTCSessionStructBuilder<P>,
s: &SessionEntry,
) -> Result<P, Error> {
let video = match s.video_stream_id {
Some(x) => Nullable::some(x),
None => Nullable::none(),
};
let audio = match s.audio_stream_id {
Some(x) => Nullable::some(x),
None => Nullable::none(),
};
b.id(s.id)?
.peer_node_id(s.peer_node_id)?
.peer_endpoint_id(s.peer_endpoint_id)?
.stream_usage(s.stream_usage)?
.video_stream_id(video)?
.audio_stream_id(audio)?
.metadata_enabled(s.metadata_enabled)?
.video_streams()?
.none()
.audio_streams()?
.none()
.fabric_index(Some(s.fab_idx))?
.end()
}