use core::{cell::UnsafeCell, fmt::Debug};
use portable_atomic::{AtomicU32, Ordering};
use critical_section::CriticalSection;
use embassy_time_driver::{Driver, time_driver_impl};
use embassy_time_queue_utils::Queue;
use crate::interrupt;
use crate::sys::watchdog::mcwdt::{
AnyUnlockedMultiWdt, CascadeMode, Config, Control, InterruptSet, MatchAction, MultiWdt,
SubCounter,
};
pub fn init<const N: u8>(wdt: MultiWdt<N>)
where
MultiWdt<N>: McwdtInterrupt,
{
DRIVER
.counter_hi
.compare_exchange(u32::MAX, u32::MAX - 1, Ordering::Relaxed, Ordering::Relaxed)
.expect("Time driver already initialized");
let mut wdt = wdt.into_unlocked();
wdt.set_config(Config {
mode_16bit_0: MatchAction::Interrupt,
mode_16bit_1: MatchAction::Interrupt,
interrupt_bit_32bit: Some(31),
cascade_0_1: CascadeMode::CarryOnMatch,
..Default::default()
});
wdt.set_control(Control {
reset_16bit_0: true,
enable_16bit_0: true,
reset_16bit_1: true,
enable_16bit_1: true,
reset_32bit: true,
enable_32bit: true,
});
wdt.wait_control();
wdt.set_enabled_interrupts(InterruptSet::new());
wdt.clear_interrupts(InterruptSet::all());
unsafe {
DRIVER
.mcwdt
.get()
.as_mut_unchecked()
.replace(wdt.into_erased());
}
DRIVER.counter_hi.store(0, Ordering::Release);
MultiWdt::<N>::enable();
}
struct TimeDriver {
counter_hi: AtomicU32,
mcwdt: UnsafeCell<Option<AnyUnlockedMultiWdt<'static>>>,
queue: UnsafeCell<Queue>,
}
unsafe impl Send for TimeDriver {}
unsafe impl Sync for TimeDriver {}
time_driver_impl!(static DRIVER: TimeDriver = TimeDriver::new());
impl TimeDriver {
const fn new() -> TimeDriver {
TimeDriver {
counter_hi: AtomicU32::new(u32::MAX),
mcwdt: UnsafeCell::new(None),
queue: UnsafeCell::new(Queue::new()),
}
}
unsafe fn wdt(&'_ self) -> AnyUnlockedMultiWdt<'_> {
unsafe {
self.mcwdt
.get()
.as_mut_unchecked()
.as_mut()
.unwrap_unchecked()
.reborrow()
}
}
}
impl Driver for TimeDriver {
fn now(&self) -> u64 {
unsafe {
let mut counter_hi = self.counter_hi.load(Ordering::Acquire);
if counter_hi & (1 << 31) != 0 {
panic!(
"Time driver not initializedd\n help: call TimeDriver::init() before using time functions"
);
}
let counter_lo = self.wdt().counter_32bit();
counter_hi += (counter_hi & 1) ^ (counter_lo >> 31);
counter_hi >>= 1;
((counter_hi as u64) << 32) | counter_lo as u64
}
}
fn schedule_wake(&self, at: u64, waker: &core::task::Waker) {
unsafe {
critical_section::with(|cs| {
if at & (1 << 63) != 0 {
return;
}
let queue = &mut *self.queue.get();
if !queue.schedule_wake(at, waker) {
return;
}
self.update_alarm(cs);
});
}
}
}
impl TimeDriver {
unsafe fn handle_interrupt(&self) {
unsafe {
let interrupts = DRIVER.wdt().masked_interrupts();
self.wdt().clear_interrupts(InterruptSet::all());
self.wdt().requested_interrupts();
if interrupts.contains(SubCounter::_16bit_0)
|| interrupts.contains(SubCounter::_16bit_1)
{
critical_section::with(|cs| self.update_alarm(cs));
} else if interrupts.contains(SubCounter::_32bit) {
self.rollover_interrupt();
}
}
}
unsafe fn rollover_interrupt(&self) {
self.counter_hi.fetch_add(1, Ordering::Release);
}
fn update_alarm(&self, _cs: CriticalSection) {
let mut wdt = unsafe { self.wdt() };
let queue = unsafe { &mut *self.queue.get() };
let mut now = self.now();
let mut timestamp = queue.next_expiration(now);
loop {
if timestamp >= (1 << 63) {
wdt.set_enabled_interrupts(SubCounter::_32bit);
break;
}
if timestamp < now + 0x2_0000 {
wdt.set_match_16bit((timestamp as u16).saturating_sub(1), 0);
wdt.clear_interrupts(SubCounter::_16bit_0);
wdt.set_enabled_interrupts(
InterruptSet::new()
.insert(SubCounter::_16bit_0)
.insert(SubCounter::_32bit),
);
now = self.now();
if timestamp < now + 3 {
timestamp = now + 3;
continue;
} else {
break;
}
} else {
wdt.set_control(Control {
enable_16bit_0: true,
enable_16bit_1: true,
enable_32bit: true,
reset_16bit_1: true,
..Default::default()
});
let target = timestamp - 2;
let periods = (target - now - 3) >> 16;
let match0 = target as u16;
let match1 = periods.try_into().unwrap_or(u16::MAX);
wdt.set_match_16bit(match0, match1);
wdt.clear_interrupts(SubCounter::_16bit_1);
wdt.set_enabled_interrupts(
InterruptSet::new()
.insert(SubCounter::_16bit_1)
.insert(SubCounter::_32bit),
);
now = self.now();
let new_periods = (target - now - 3) >> 16;
if periods != new_periods {
continue;
} else {
break;
}
}
}
}
}
impl Debug for TimeDriver {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TimeDriver")
.field("now", &self.now())
.finish()
}
}
#[doc(hidden)]
pub trait McwdtInterrupt {
fn enable();
}
macro_rules! impl_interrupt {
($n:literal) => {
paste::paste! {
impl McwdtInterrupt for MultiWdt<$n> {
fn enable() {
interrupt::unmask(interrupt::Interrupt::[<SRSS_INTERRUPT_MCWDT_ $n>]);
psoc_macros::require_interrupt!([<SRSS_INTERRUPT_MCWDT_ $n>]);
}
}
#[psoc_macros::optional_interrupt]
fn [<SRSS_INTERRUPT_MCWDT_ $n>]() {
unsafe { DRIVER.handle_interrupt() }
}
}
};
}
#[cfg(mcwdt0)]
impl_interrupt!(0);
#[cfg(mcwdt1)]
impl_interrupt!(1);