use core::{fmt, hash, marker::PhantomData};
use super::{
state, task, timeout, utils,
wait::{WaitPayload, WaitQueue},
BadIdError, DrainSemaphoreError, GetSemaphoreError, Id, Kernel, PollSemaphoreError, Port,
SignalSemaphoreError, WaitSemaphoreError, WaitSemaphoreTimeoutError,
};
use crate::{time::Duration, utils::Init};
#[doc(include = "../common.md")]
pub type SemaphoreValue = usize;
#[doc(include = "../common.md")]
#[repr(transparent)]
pub struct Semaphore<System>(Id, PhantomData<System>);
impl<System> Clone for Semaphore<System> {
fn clone(&self) -> Self {
Self(self.0, self.1)
}
}
impl<System> Copy for Semaphore<System> {}
impl<System> PartialEq for Semaphore<System> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<System> Eq for Semaphore<System> {}
impl<System> hash::Hash for Semaphore<System> {
fn hash<H>(&self, state: &mut H)
where
H: hash::Hasher,
{
hash::Hash::hash(&self.0, state);
}
}
impl<System> fmt::Debug for Semaphore<System> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Semaphore").field(&self.0).finish()
}
}
impl<System> Semaphore<System> {
pub const unsafe fn from_id(id: Id) -> Self {
Self(id, PhantomData)
}
pub const fn id(self) -> Id {
self.0
}
}
impl<System: Kernel> Semaphore<System> {
fn semaphore_cb(self) -> Result<&'static SemaphoreCb<System>, BadIdError> {
System::get_semaphore_cb(self.0.get() - 1).ok_or(BadIdError::BadId)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn drain(self) -> Result<(), DrainSemaphoreError> {
let mut lock = utils::lock_cpu::<System>()?;
let semaphore_cb = self.semaphore_cb()?;
semaphore_cb.value.replace(&mut *lock, 0);
Ok(())
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn get(self) -> Result<SemaphoreValue, GetSemaphoreError> {
let lock = utils::lock_cpu::<System>()?;
let semaphore_cb = self.semaphore_cb()?;
Ok(semaphore_cb.value.get(&*lock))
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn signal(self, count: SemaphoreValue) -> Result<(), SignalSemaphoreError> {
let lock = utils::lock_cpu::<System>()?;
let semaphore_cb = self.semaphore_cb()?;
signal(semaphore_cb, lock, count)
}
pub fn signal_one(self) -> Result<(), SignalSemaphoreError> {
self.signal(1)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn wait_one(self) -> Result<(), WaitSemaphoreError> {
let lock = utils::lock_cpu::<System>()?;
state::expect_waitable_context::<System>()?;
let semaphore_cb = self.semaphore_cb()?;
wait_one(semaphore_cb, lock)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn wait_one_timeout(self, timeout: Duration) -> Result<(), WaitSemaphoreTimeoutError> {
let time32 = timeout::time32_from_duration(timeout)?;
let lock = utils::lock_cpu::<System>()?;
state::expect_waitable_context::<System>()?;
let semaphore_cb = self.semaphore_cb()?;
wait_one_timeout(semaphore_cb, lock, time32)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn poll_one(self) -> Result<(), PollSemaphoreError> {
let lock = utils::lock_cpu::<System>()?;
let semaphore_cb = self.semaphore_cb()?;
poll_one(semaphore_cb, lock)
}
}
#[doc(hidden)]
pub struct SemaphoreCb<System: Port> {
pub(super) value: utils::CpuLockCell<System, SemaphoreValue>,
pub(super) max_value: SemaphoreValue,
pub(super) wait_queue: WaitQueue<System>,
}
impl<System: Port> Init for SemaphoreCb<System> {
#[allow(clippy::declare_interior_mutable_const)]
const INIT: Self = Self {
value: Init::INIT,
max_value: Init::INIT,
wait_queue: Init::INIT,
};
}
impl<System: Kernel> fmt::Debug for SemaphoreCb<System> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SemaphoreCb")
.field("self", &(self as *const _))
.field("value", &self.value)
.field("max_value", &self.max_value)
.field("wait_queue", &self.wait_queue)
.finish()
}
}
fn poll_one<System: Kernel>(
semaphore_cb: &'static SemaphoreCb<System>,
mut lock: utils::CpuLockGuard<System>,
) -> Result<(), PollSemaphoreError> {
if poll_core(semaphore_cb.value.write(&mut *lock)) {
Ok(())
} else {
Err(PollSemaphoreError::Timeout)
}
}
fn wait_one<System: Kernel>(
semaphore_cb: &'static SemaphoreCb<System>,
mut lock: utils::CpuLockGuard<System>,
) -> Result<(), WaitSemaphoreError> {
if poll_core(semaphore_cb.value.write(&mut *lock)) {
Ok(())
} else {
semaphore_cb
.wait_queue
.wait(lock.borrow_mut(), WaitPayload::Semaphore)?;
Ok(())
}
}
fn wait_one_timeout<System: Kernel>(
semaphore_cb: &'static SemaphoreCb<System>,
mut lock: utils::CpuLockGuard<System>,
time32: timeout::Time32,
) -> Result<(), WaitSemaphoreTimeoutError> {
if poll_core(semaphore_cb.value.write(&mut *lock)) {
Ok(())
} else {
semaphore_cb
.wait_queue
.wait_timeout(lock.borrow_mut(), WaitPayload::Semaphore, time32)?;
Ok(())
}
}
#[inline]
fn poll_core(value: &mut SemaphoreValue) -> bool {
if *value > 0 {
*value -= 1;
true
} else {
false
}
}
fn signal<System: Kernel>(
semaphore_cb: &'static SemaphoreCb<System>,
mut lock: utils::CpuLockGuard<System>,
mut count: SemaphoreValue,
) -> Result<(), SignalSemaphoreError> {
let value = semaphore_cb.value.get(&*lock);
if semaphore_cb.max_value - value < count {
return Err(SignalSemaphoreError::QueueOverflow);
}
let orig_count = count;
while count > 0 {
if semaphore_cb.wait_queue.wake_up_one(lock.borrow_mut()) {
count -= 1;
} else {
semaphore_cb.value.replace(&mut *lock, value + count);
break;
}
}
if count != orig_count {
task::unlock_cpu_and_check_preemption(lock);
}
Ok(())
}