use core::any::type_name;
use core::marker::PhantomData;
use embedded_time::duration::Milliseconds;
use embedded_time::Instant;
use no_std_net::SocketAddr;
use tinyvec::ArrayVec;
use toad_array::Array;
use toad_len::Len;
use toad_map::{InsertError, Map};
use toad_msg::Id;
use toad_stem::Stem;
use super::{Step, _try, log};
use crate::config::Config;
use crate::net::Addrd;
use crate::platform;
use crate::platform::PlatformTypes;
use crate::req::Req;
use crate::resp::Resp;
use crate::time::Stamped;
pub trait IdsBySocketAddr<P: PlatformTypes>: Map<SocketAddrWithDefault, Self::Ids> {
type Ids: Array<Item = Stamped<P::Clock, IdWithDefault>>;
}
#[cfg(feature = "alloc")]
impl<P: platform::PlatformTypes, A: Array<Item = Stamped<P::Clock, IdWithDefault>>>
IdsBySocketAddr<P> for std_alloc::collections::BTreeMap<SocketAddrWithDefault, A>
{
type Ids = A;
}
impl<P: platform::PlatformTypes,
A: Array<Item = Stamped<P::Clock, IdWithDefault>>,
const N: usize> IdsBySocketAddr<P> for ArrayVec<[(SocketAddrWithDefault, A); N]>
{
type Ids = A;
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
#[non_exhaustive]
pub struct SocketAddrWithDefault(pub SocketAddr);
impl Default for SocketAddrWithDefault {
fn default() -> Self {
use no_std_net::*;
Self(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)))
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
pub struct IdWithDefault(pub Id);
impl Default for IdWithDefault {
fn default() -> Self {
Self(Id(0))
}
}
#[derive(Debug)]
pub struct ProvisionIds<P, Inner, SeenIds> {
inner: Inner,
seen: Stem<SeenIds>,
__p: PhantomData<P>,
}
impl<P, Inner, SeenIds> Default for ProvisionIds<P, Inner, SeenIds>
where Inner: Default,
SeenIds: Default
{
fn default() -> Self {
Self { inner: Default::default(),
seen: Default::default(),
__p: PhantomData }
}
}
impl<P, Inner, Ids> ProvisionIds<P, Inner, Ids>
where Ids: IdsBySocketAddr<P>,
P: PlatformTypes
{
fn prune(effs: &mut P::Effects, seen: &mut Ids, now: Instant<P::Clock>, config: Config) {
for (_, ids) in seen.iter_mut() {
ids.sort_by_key(|t| t.time());
let ix_of_first_id_to_keep = ids.iter()
.enumerate()
.find(|(_, id)| {
now.checked_duration_since(&id.time())
< Some(Milliseconds(config.exchange_lifetime_millis()).into())
})
.map(|(ix, _)| ix);
match ix_of_first_id_to_keep {
| Some(keep_at) if keep_at == 0 => (),
| Some(keep_at) => {
log!(ProvisionIds::prune,
effs,
log::Level::Trace,
"removing {} old irrelevant ids",
keep_at);
for ix in 0..keep_at {
ids.remove(ix);
}
},
| None => {
*ids = Default::default();
},
}
}
}
fn new_addr(effs: &mut P::Effects, seen: &mut Ids, addr: SocketAddr) {
log!(ProvisionIds::new_addr,
effs,
log::Level::Trace,
"haven't seen {:?} before",
addr);
match seen.insert(SocketAddrWithDefault(addr), Default::default()) {
| Ok(_) => (),
| Err(InsertError::CapacityExhausted) => {
let mut to_remove: Option<Stamped<P::Clock, SocketAddrWithDefault>> = None;
for (addr, ids) in seen.iter_mut() {
if ids.is_empty() {
to_remove = Some(Stamped(*addr, Instant::new(0)));
break;
}
ids.sort_by_key(|t| t.time());
let newest_id_time = ids[ids.len() - 1].time();
ids.sort();
if to_remove.is_none() || Some(newest_id_time) < to_remove.map(|t| t.time()) {
to_remove = Some(Stamped(*addr, newest_id_time));
}
}
seen.remove(&to_remove.unwrap().discard_timestamp());
},
| Err(InsertError::Exists(_)) => unreachable!(),
};
}
fn next(effs: &mut P::Effects,
seen: &mut Ids,
config: Config,
time: Instant<P::Clock>,
addr: SocketAddr)
-> Id {
match seen.get_mut(&SocketAddrWithDefault(addr)) {
| None => {
Self::new_addr(effs, seen, addr);
Self::next(effs, seen, config, time, addr)
},
| Some(ids) => {
ids.sort_unstable();
let smallest = || ids[0].data().0 .0;
let biggest = || ids[ids.len() - 1].data().0 .0;
let next = if ids.is_empty() {
Id(1)
} else if biggest() < u16::MAX {
Id(biggest() + 1)
} else if smallest() > 1 {
Id(smallest() - 1)
} else {
let mut ahead = ids.iter();
ahead.next();
let (Stamped(IdWithDefault(Id(before_gap)), _), _) =
ids.iter()
.zip(ahead)
.find(|(Stamped(IdWithDefault(Id(cur)), _), Stamped(IdWithDefault(Id(next)), _))| {
next - cur > 1
})
.unwrap();
Id(before_gap + 1)
};
log!(ProvisionIds::next,
effs,
log::Level::Debug,
"Generated new {:?}",
next);
Self::seen(effs, seen, config, time, addr, next);
next
},
}
}
fn seen(effs: &mut P::Effects,
seen: &mut Ids,
config: Config,
now: Instant<P::Clock>,
addr: SocketAddr,
id: Id) {
Self::prune(effs, seen, now, config);
match seen.get_mut(&SocketAddrWithDefault(addr)) {
| None => {
Self::new_addr(effs, seen, addr);
Self::seen(effs, seen, config, now, addr, id)
},
| Some(ids) => {
if ids.is_full() {
log!(ProvisionIds::seen, effs, log::Level::Warn, "Id buffer {} has reached capacity of {}. Forgetting the oldest Id to make room for {:?}", type_name::<Ids>(), Ids::CAPACITY.unwrap_or(usize::MAX), id);
ids.sort_by_key(|s| s.time());
ids.remove(0);
ids.sort();
}
log!(ProvisionIds::seen,
effs,
log::Level::Trace,
"Saw new {:?}",
id);
ids.push(Stamped(IdWithDefault(id), now));
},
}
}
}
macro_rules! common {
($self:expr, $effs:expr, $snap:expr, $req_or_resp:expr) => {{
let r = $req_or_resp;
$self.seen.map_mut(|s| {
Self::seen($effs,
s,
$snap.config,
$snap.time,
r.addr(),
r.data().msg().id)
});
Some(Ok(r))
}};
}
impl<P, E: super::Error, Inner, Ids> Step<P> for ProvisionIds<P, Inner, Ids>
where P: PlatformTypes,
Inner: Step<P, PollReq = Addrd<Req<P>>, PollResp = Addrd<Resp<P>>, Error = E>,
Ids: IdsBySocketAddr<P>
{
type PollReq = Addrd<Req<P>>;
type PollResp = Addrd<Resp<P>>;
type Error = E;
type Inner = Inner;
fn inner(&self) -> &Inner {
&self.inner
}
fn poll_req(&self,
snap: &crate::platform::Snapshot<P>,
effects: &mut <P as PlatformTypes>::Effects)
-> super::StepOutput<Self::PollReq, Self::Error> {
let req = self.inner.poll_req(snap, effects);
let req = _try!(Option<nb::Result>; req);
common!(self, effects, snap, req)
}
fn poll_resp(&self,
snap: &crate::platform::Snapshot<P>,
effects: &mut <P as PlatformTypes>::Effects,
token: toad_msg::Token,
addr: SocketAddr)
-> super::StepOutput<Self::PollResp, Self::Error> {
let resp = self.inner.poll_resp(snap, effects, token, addr);
let resp = _try!(Option<nb::Result>; resp);
common!(self, effects, snap, resp)
}
fn before_message_sent(&self,
snap: &platform::Snapshot<P>,
effs: &mut P::Effects,
msg: &mut Addrd<platform::Message<P>>)
-> Result<(), Self::Error> {
self.inner.before_message_sent(snap, effs, msg)?;
if msg.data().id == Id(0) {
let id = self.seen
.map_mut(|s| Self::next(effs, s, snap.config, snap.time, msg.addr()));
msg.data_mut().id = id;
}
Ok(())
}
}
#[cfg(test)]
mod test {
use std::collections::BTreeMap;
use embedded_time::duration::Microseconds;
use super::*;
use crate::step::test::test_step;
use crate::test::{self, ClockMock, Platform as P};
type InnerPollReq = Addrd<Req<test::Platform>>;
type InnerPollResp = Addrd<Resp<test::Platform>>;
type ProvisionIds<S> = super::ProvisionIds<P,
S,
BTreeMap<SocketAddrWithDefault,
Vec<Stamped<ClockMock, IdWithDefault>>>>;
fn test_msg(id: Id) -> Addrd<test::Message> {
use toad_msg::*;
Addrd(test::Message { id,
ty: Type::Con,
ver: Default::default(),
code: Code::new(0, 0),
opts: Default::default(),
payload: Payload(vec![]),
token: Token(Default::default()) },
test::dummy_addr())
}
test_step!(
GIVEN ProvisionIds::<Dummy> where Dummy: {Step<PollReq = InnerPollReq, PollResp = InnerPollResp, Error = ()>};
WHEN inner_errors [
(inner.poll_req => { Some(Err(nb::Error::Other(()))) }),
(inner.poll_resp => { Some(Err(nb::Error::Other(()))) })
]
THEN this_should_error [
(poll_req(_, _) should satisfy { |out| assert_eq!(out, Some(Err(nb::Error::Other(())))) }),
(poll_resp(_, _, _, _) should satisfy { |out| assert_eq!(out, Some(Err(nb::Error::Other(())))) })
]
);
test_step!(
GIVEN ProvisionIds::<Dummy> where Dummy: {Step<PollReq = InnerPollReq, PollResp = InnerPollResp, Error = ()>};
WHEN inner_blocks [
(inner.poll_req => { Some(Err(nb::Error::WouldBlock)) }),
(inner.poll_resp => { Some(Err(nb::Error::WouldBlock)) })
]
THEN this_should_block [
(poll_req(_, _) should satisfy { |out| assert_eq!(out, Some(Err(nb::Error::WouldBlock))) }),
(poll_resp(_, _, _, _) should satisfy { |out| assert_eq!(out, Some(Err(nb::Error::WouldBlock))) })
]
);
test_step!(
GIVEN ProvisionIds::<Dummy> where Dummy: {Step<PollReq = InnerPollReq, PollResp = InnerPollResp, Error = ()>};
WHEN message_sent_with_id_zero []
THEN this_should_assign_nonzero_id [
(before_message_sent(_, _, test_msg(Id(0))) should be ok with { |msg| assert!(matches!(msg.data().id, Id(n) if n > 0)) })
]
);
test_step!(
GIVEN ProvisionIds::<Dummy> where Dummy: {Step<PollReq = InnerPollReq, PollResp = InnerPollResp, Error = ()>};
WHEN req_or_resp_recvd_with_id_zero [
(inner.poll_req => { Some(Ok(test_msg(Id(0)).map(Req::from))) }),
(inner.poll_resp => { Some(Ok(test_msg(Id(0)).map(Resp::from))) })
]
THEN id_should_be_respected [
(poll_req(_, _) should satisfy { |out| assert!(matches!(out.unwrap().unwrap().data().as_ref().id, Id(0))) }),
(poll_resp(_, _, _, _) should satisfy { |out| assert!(matches!(out.unwrap().unwrap().data().as_ref().id, Id(0))) })
]
);
#[test]
fn seen_should_remove_oldest_addr_when_new_addr_would_exceed_capacity() {
type Ids = ArrayVec<[Stamped<ClockMock, IdWithDefault>; 16]>;
type IdsByAddr = ArrayVec<[(SocketAddrWithDefault, Ids); 2]>;
type Step = super::ProvisionIds<P, (), IdsByAddr>;
let mut effs = Vec::<test::Effect>::new();
let step = Step::default();
let cfg = Config::default();
step.seen.map_mut(|s| {
Step::seen(&mut effs,
s,
cfg,
ClockMock::instant(0),
test::dummy_addr(),
Id(1));
Step::seen(&mut effs,
s,
cfg,
ClockMock::instant(1),
test::dummy_addr_2(),
Id(1));
Step::seen(&mut effs,
s,
cfg,
ClockMock::instant(2),
test::dummy_addr(),
Id(2));
Step::seen(&mut effs,
s,
cfg,
ClockMock::instant(3),
test::dummy_addr_3(),
Id(1));
});
let mut addrs: Vec<_> = step.seen.map_ref(|s| s.iter().map(|(k, _)| k.0).collect());
addrs.sort();
assert_eq!(addrs, vec![test::dummy_addr(), test::dummy_addr_3()]);
}
#[test]
fn seen_should_remove_empty_addr_when_new_addr_would_exceed_capacity() {
type Ids = ArrayVec<[Stamped<ClockMock, IdWithDefault>; 16]>;
type IdsByAddr = ArrayVec<[(SocketAddrWithDefault, Ids); 2]>;
type Step = super::ProvisionIds<P, (), IdsByAddr>;
let mut effs = Vec::<test::Effect>::new();
let step = Step::default();
let cfg = Config::default();
step.seen.map_mut(|seen| {
Map::insert(seen,
SocketAddrWithDefault(test::dummy_addr()),
Default::default()).unwrap();
Step::seen(&mut effs,
seen,
cfg,
ClockMock::instant(1),
test::dummy_addr_2(),
Id(1));
Step::seen(&mut effs,
seen,
cfg,
ClockMock::instant(3),
test::dummy_addr_3(),
Id(1));
});
let mut addrs: Vec<_> = step.seen.map_ref(|s| s.iter().map(|(k, _)| k.0).collect());
addrs.sort();
assert_eq!(addrs, vec![test::dummy_addr_2(), test::dummy_addr_3()]);
}
#[test]
fn seen_should_remove_oldest_id_when_about_to_exceed_capacity() {
type Ids = ArrayVec<[Stamped<ClockMock, IdWithDefault>; 2]>;
type IdsByAddr = ArrayVec<[(SocketAddrWithDefault, Ids); 1]>;
type Step = super::ProvisionIds<P, (), IdsByAddr>;
let mut effs = Vec::<test::Effect>::new();
let step = Step::default();
let cfg = Config::default();
step.seen.map_mut(|seen| {
Step::seen(&mut effs,
seen,
cfg,
ClockMock::instant(0),
test::dummy_addr(),
Id(0));
Step::seen(&mut effs,
seen,
cfg,
ClockMock::instant(1),
test::dummy_addr(),
Id(1));
Step::seen(&mut effs,
seen,
cfg,
ClockMock::instant(2),
test::dummy_addr(),
Id(2));
});
let ids: Vec<_> = step.seen.map_ref(|s| {
s.get(&SocketAddrWithDefault(test::dummy_addr()))
.unwrap()
.into_iter()
.map(|Stamped(IdWithDefault(id), _)| *id)
.collect()
});
assert_eq!(ids, vec![Id(1), Id(2)]);
}
#[test]
fn seen_should_prune_ids_older_than_exchange_lifetime() {
type Step = ProvisionIds<()>;
let mut effs = Vec::<test::Effect>::new();
let step = Step::default();
let cfg = Config::default();
let exchange_lifetime_micros = cfg.exchange_lifetime_millis() * 1_000;
assert_eq!(Microseconds::try_from(ClockMock::instant(1).duration_since_epoch()),
Ok(Microseconds(1u64)));
step.seen.map_mut(|s| {
Step::seen(&mut effs,
s,
cfg,
ClockMock::instant(0),
test::dummy_addr(),
Id(1));
Step::seen(&mut effs,
s,
cfg,
ClockMock::instant(1),
test::dummy_addr(),
Id(2));
Step::seen(&mut effs,
s,
cfg,
ClockMock::instant(exchange_lifetime_micros + 1_000),
test::dummy_addr(),
Id(3));
});
let ids: Vec<_> = step.seen.map_ref(|s| {
s.get(&SocketAddrWithDefault(test::dummy_addr()))
.unwrap()
.iter()
.map(|Stamped(IdWithDefault(id), _)| *id)
.collect()
});
assert_eq!(ids, vec![Id(3)]);
}
#[test]
fn next_should_generate_largest_plus_one_when_largest_lt_max() {
type Step = ProvisionIds<()>;
let mut effs = Vec::<test::Effect>::new();
let step = Step::default();
let time = ClockMock::instant(0);
step.seen.map_mut(|seen| {
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(22));
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(1));
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(2));
let generated = Step::next(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr());
assert_eq!(generated, Id(23))
});
}
#[test]
fn next_should_generate_smallest_minus_one_when_largest_is_max() {
type Step = ProvisionIds<()>;
let mut effs = Vec::<test::Effect>::new();
let step = Step::default();
let time = ClockMock::instant(0);
step.seen.map_mut(|seen| {
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(2));
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(u16::MAX));
let generated = Step::next(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr());
assert_eq!(generated, Id(1))
});
}
#[test]
fn next_should_generate_in_gap_when_smallest_1_and_largest_max() {
type Step = ProvisionIds<()>;
let mut effs = Vec::<test::Effect>::new();
let step = Step::default();
let time = ClockMock::instant(0);
step.seen.map_mut(|seen| {
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(1));
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(2));
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(3));
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(5));
Step::seen(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr(),
Id(u16::MAX));
let generated = Step::next(&mut effs,
seen,
Default::default(),
time,
test::dummy_addr());
assert_eq!(generated, Id(4))
});
}
#[test]
fn next_should_generate_initial_id() {
type Step = ProvisionIds<()>;
let step = Step::default();
let id = step.seen.map_mut(|s| {
Step::next(&mut vec![],
s,
Default::default(),
ClockMock::instant(0),
test::dummy_addr())
});
assert_eq!(id, Id(1))
}
}