#![cfg_attr(not(test), no_std)]
use core::fmt::{self, Debug};
use core::sync::atomic::Ordering;
#[cfg(loom)]
use loom::{
cell::UnsafeCell,
sync::atomic::{AtomicBool, AtomicU8, AtomicUsize},
};
#[cfg(not(loom))]
use core::{
cell::UnsafeCell,
sync::atomic::{AtomicBool, AtomicU8, AtomicUsize},
};
#[derive(Clone, Copy)]
pub struct Error;
impl Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("The interchange is busy, this operation could not be performed")
}
}
#[repr(u8)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum State {
Idle = 0,
BuildingRequest = 1,
Requested = 2,
BuildingResponse = 3,
Responded = 4,
Canceled = 12,
}
impl PartialEq<u8> for State {
#[inline]
fn eq(&self, other: &u8) -> bool {
*self as u8 == *other
}
}
impl From<u8> for State {
fn from(byte: u8) -> Self {
match byte {
1 => State::BuildingRequest,
2 => State::Requested,
3 => State::BuildingResponse,
4 => State::Responded,
12 => State::Canceled,
_ => State::Idle,
}
}
}
#[repr(u8)]
enum Message<Rq, Rp> {
None,
Request(Rq),
Response(Rp),
}
impl<Rq, Rp> Message<Rq, Rp> {
fn is_request_state(&self) -> bool {
matches!(self, Self::Request(_))
}
fn is_response_state(&self) -> bool {
matches!(self, Self::Response(_))
}
fn take_rq(&mut self) -> Rq {
let this = core::mem::replace(self, Message::None);
match this {
Message::Request(r) => r,
_ => unreachable!(),
}
}
fn rq_ref(&self) -> &Rq {
match *self {
Self::Request(ref request) => request,
_ => unreachable!(),
}
}
fn rq_mut(&mut self) -> &mut Rq {
match *self {
Self::Request(ref mut request) => request,
_ => unreachable!(),
}
}
fn take_rp(&mut self) -> Rp {
let this = core::mem::replace(self, Message::None);
match this {
Message::Response(r) => r,
_ => unreachable!(),
}
}
fn rp_ref(&self) -> &Rp {
match *self {
Self::Response(ref response) => response,
_ => unreachable!(),
}
}
fn rp_mut(&mut self) -> &mut Rp {
match *self {
Self::Response(ref mut response) => response,
_ => unreachable!(),
}
}
fn from_rq(rq: Rq) -> Self {
Self::Request(rq)
}
fn from_rp(rp: Rp) -> Self {
Self::Response(rp)
}
}
pub struct Channel<Rq, Rp> {
data: UnsafeCell<Message<Rq, Rp>>,
state: AtomicU8,
requester_claimed: AtomicBool,
responder_claimed: AtomicBool,
}
impl<Rq, Rp> Channel<Rq, Rp> {
#[cfg(not(loom))]
pub const fn new() -> Self {
Self {
data: UnsafeCell::new(Message::None),
state: AtomicU8::new(0),
requester_claimed: AtomicBool::new(false),
responder_claimed: AtomicBool::new(false),
}
}
#[cfg(loom)]
pub fn new() -> Self {
Self {
data: UnsafeCell::new(Message::None),
state: AtomicU8::new(0),
requester_claimed: AtomicBool::new(false),
responder_claimed: AtomicBool::new(false),
}
}
pub fn requester(&self) -> Option<Requester<'_, Rq, Rp>> {
if self
.requester_claimed
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
Some(Requester { channel: self })
} else {
None
}
}
pub fn responder(&self) -> Option<Responder<'_, Rq, Rp>> {
if self
.responder_claimed
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
Some(Responder { channel: self })
} else {
None
}
}
pub fn split(&self) -> Option<(Requester<'_, Rq, Rp>, Responder<'_, Rq, Rp>)> {
Some((self.requester()?, self.responder()?))
}
fn transition(&self, from: State, to: State) -> bool {
self.state
.compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
}
}
impl<Rq, Rp> Default for Channel<Rq, Rp> {
fn default() -> Self {
Self::new()
}
}
pub struct Requester<'i, Rq, Rp> {
channel: &'i Channel<Rq, Rp>,
}
impl<Rq, Rp> Drop for Requester<'_, Rq, Rp> {
fn drop(&mut self) {
self.channel
.requester_claimed
.store(false, Ordering::Release);
}
}
impl<'i, Rq, Rp> Requester<'i, Rq, Rp> {
pub fn channel(&self) -> &'i Channel<Rq, Rp> {
self.channel
}
#[cfg(not(loom))]
unsafe fn data(&self) -> &Message<Rq, Rp> {
&mut *self.channel.data.get()
}
#[cfg(not(loom))]
unsafe fn data_mut(&mut self) -> &mut Message<Rq, Rp> {
&mut *self.channel.data.get()
}
#[cfg(not(loom))]
unsafe fn with_data<R>(&self, f: impl FnOnce(&Message<Rq, Rp>) -> R) -> R {
f(&*self.channel.data.get())
}
#[cfg(not(loom))]
unsafe fn with_data_mut<R>(&mut self, f: impl FnOnce(&mut Message<Rq, Rp>) -> R) -> R {
f(&mut *self.channel.data.get())
}
#[cfg(loom)]
unsafe fn with_data<R>(&self, f: impl FnOnce(&Message<Rq, Rp>) -> R) -> R {
self.channel.data.with(|i| f(&*i))
}
#[cfg(loom)]
unsafe fn with_data_mut<R>(&mut self, f: impl FnOnce(&mut Message<Rq, Rp>) -> R) -> R {
self.channel.data.with_mut(|i| f(&mut *i))
}
#[inline]
pub fn state(&self) -> State {
State::from(self.channel.state.load(Ordering::Acquire))
}
pub fn request(&mut self, request: Rq) -> Result<(), Error> {
if State::Idle == self.channel.state.load(Ordering::Acquire) {
unsafe {
self.with_data_mut(|i| *i = Message::from_rq(request));
}
self.channel
.state
.store(State::Requested as u8, Ordering::Release);
Ok(())
} else {
Err(Error)
}
}
pub fn cancel(&mut self) -> Result<Option<Rq>, Error> {
if self
.channel
.transition(State::BuildingResponse, State::Canceled)
{
return Ok(None);
}
if self.channel.transition(State::Requested, State::Idle) {
return Ok(Some(unsafe { self.with_data_mut(|i| i.take_rq()) }));
}
Err(Error)
}
#[cfg(not(loom))]
pub fn response(&self) -> Result<&Rp, Error> {
if self.channel.transition(State::Responded, State::Responded) {
Ok(unsafe { self.data().rp_ref() })
} else {
Err(Error)
}
}
pub fn with_response<R>(&self, f: impl FnOnce(&Rp) -> R) -> Result<R, Error> {
if self.channel.transition(State::Responded, State::Responded) {
Ok(unsafe { self.with_data(|i| f(i.rp_ref())) })
} else {
Err(Error)
}
}
pub fn take_response(&mut self) -> Option<Rp> {
if self.channel.transition(State::Responded, State::Idle) {
Some(unsafe { self.with_data_mut(|i| i.take_rp()) })
} else {
None
}
}
}
impl<Rq, Rp> Requester<'_, Rq, Rp>
where
Rq: Default,
{
pub fn with_request_mut<R>(&mut self, f: impl FnOnce(&mut Rq) -> R) -> Result<R, Error> {
if self.channel.transition(State::Idle, State::BuildingRequest)
|| self
.channel
.transition(State::BuildingRequest, State::BuildingRequest)
{
let res = unsafe {
self.with_data_mut(|i| {
if !i.is_request_state() {
*i = Message::from_rq(Rq::default());
}
f(i.rq_mut())
})
};
Ok(res)
} else {
Err(Error)
}
}
#[cfg(not(loom))]
pub fn request_mut(&mut self) -> Result<&mut Rq, Error> {
if self.channel.transition(State::Idle, State::BuildingRequest)
|| self
.channel
.transition(State::BuildingRequest, State::BuildingRequest)
{
unsafe {
self.with_data_mut(|i| {
if !i.is_request_state() {
*i = Message::from_rq(Rq::default());
}
})
}
Ok(unsafe { self.data_mut().rq_mut() })
} else {
Err(Error)
}
}
pub fn send_request(&mut self) -> Result<(), Error> {
if State::BuildingRequest == self.channel.state.load(Ordering::Acquire)
&& self
.channel
.transition(State::BuildingRequest, State::Requested)
{
Ok(())
} else {
Err(Error)
}
}
}
pub struct Responder<'i, Rq, Rp> {
channel: &'i Channel<Rq, Rp>,
}
impl<Rq, Rp> Drop for Responder<'_, Rq, Rp> {
fn drop(&mut self) {
self.channel
.responder_claimed
.store(false, Ordering::Release);
}
}
impl<'i, Rq, Rp> Responder<'i, Rq, Rp> {
pub fn channel(&self) -> &'i Channel<Rq, Rp> {
self.channel
}
#[cfg(not(loom))]
unsafe fn data(&self) -> &Message<Rq, Rp> {
&mut *self.channel.data.get()
}
#[cfg(not(loom))]
unsafe fn data_mut(&mut self) -> &mut Message<Rq, Rp> {
&mut *self.channel.data.get()
}
#[cfg(not(loom))]
unsafe fn with_data<R>(&self, f: impl FnOnce(&Message<Rq, Rp>) -> R) -> R {
f(&*self.channel.data.get())
}
#[cfg(not(loom))]
unsafe fn with_data_mut<R>(&mut self, f: impl FnOnce(&mut Message<Rq, Rp>) -> R) -> R {
f(&mut *self.channel.data.get())
}
#[cfg(loom)]
unsafe fn with_data<R>(&self, f: impl FnOnce(&Message<Rq, Rp>) -> R) -> R {
self.channel.data.with(|i| f(&*i))
}
#[cfg(loom)]
unsafe fn with_data_mut<R>(&mut self, f: impl FnOnce(&mut Message<Rq, Rp>) -> R) -> R {
self.channel.data.with_mut(|i| f(&mut *i))
}
#[inline]
pub fn state(&self) -> State {
State::from(self.channel.state.load(Ordering::Acquire))
}
pub fn with_request<R>(&self, f: impl FnOnce(&Rq) -> R) -> Result<R, Error> {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
{
Ok(unsafe { self.with_data(|i| f(i.rq_ref())) })
} else {
Err(Error)
}
}
#[cfg(not(loom))]
pub fn request(&self) -> Result<&Rq, Error> {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
{
Ok(unsafe { self.data().rq_ref() })
} else {
Err(Error)
}
}
pub fn take_request(&mut self) -> Option<Rq> {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
{
Some(unsafe { self.with_data_mut(|i| i.take_rq()) })
} else {
None
}
}
pub fn is_canceled(&self) -> bool {
self.channel.state.load(Ordering::SeqCst) == State::Canceled as u8
}
pub fn acknowledge_cancel(&self) -> Result<(), Error> {
if self.channel.transition(State::Canceled, State::Idle) {
Ok(())
} else {
Err(Error)
}
}
pub fn respond(&mut self, response: Rp) -> Result<(), Error> {
if State::BuildingResponse == self.channel.state.load(Ordering::Acquire) {
unsafe {
self.with_data_mut(|i| *i = Message::from_rp(response));
}
if self
.channel
.transition(State::BuildingResponse, State::Responded)
{
Ok(())
} else {
Err(Error)
}
} else {
Err(Error)
}
}
}
impl<Rq, Rp> Responder<'_, Rq, Rp>
where
Rp: Default,
{
pub fn with_response_mut<R>(&mut self, f: impl FnOnce(&mut Rp) -> R) -> Result<R, Error> {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
|| self
.channel
.transition(State::BuildingResponse, State::BuildingResponse)
{
let res = unsafe {
self.with_data_mut(|i| {
if !i.is_response_state() {
*i = Message::from_rp(Rp::default());
}
f(i.rp_mut())
})
};
Ok(res)
} else {
Err(Error)
}
}
#[cfg(not(loom))]
pub fn response_mut(&mut self) -> Result<&mut Rp, Error> {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
|| self
.channel
.transition(State::BuildingResponse, State::BuildingResponse)
{
unsafe {
self.with_data_mut(|i| {
if !i.is_response_state() {
*i = Message::from_rp(Rp::default());
}
})
}
Ok(unsafe { self.data_mut().rp_mut() })
} else {
Err(Error)
}
}
pub fn send_response(&mut self) -> Result<(), Error> {
if State::BuildingResponse == self.channel.state.load(Ordering::Acquire)
&& self
.channel
.transition(State::BuildingResponse, State::Responded)
{
Ok(())
} else {
Err(Error)
}
}
}
unsafe impl<Rq, Rp> Sync for Channel<Rq, Rp>
where
Rq: Send + Sync,
Rp: Send + Sync,
{
}
pub struct Interchange<Rq, Rp, const N: usize> {
channels: [Channel<Rq, Rp>; N],
last_claimed: AtomicUsize,
}
impl<Rq, Rp, const N: usize> Interchange<Rq, Rp, N> {
#[cfg(not(loom))]
pub const fn new() -> Self {
Self {
channels: [const { Channel::new() }; N],
last_claimed: AtomicUsize::new(0),
}
}
#[cfg(loom)]
pub fn new() -> Self {
Self {
channels: core::array::from_fn(|_| Channel::new()),
last_claimed: AtomicUsize::new(0),
}
}
pub fn claim(&self) -> Option<(Requester<Rq, Rp>, Responder<Rq, Rp>)> {
self.as_interchange_ref().claim()
}
pub const fn as_interchange_ref(&self) -> InterchangeRef<'_, Rq, Rp> {
InterchangeRef {
channels: &self.channels,
last_claimed: &self.last_claimed,
}
}
}
pub struct InterchangeRef<'alloc, Rq, Rp> {
channels: &'alloc [Channel<Rq, Rp>],
last_claimed: &'alloc AtomicUsize,
}
impl<'alloc, Rq, Rp> InterchangeRef<'alloc, Rq, Rp> {
pub fn claim(&self) -> Option<(Requester<'alloc, Rq, Rp>, Responder<'alloc, Rq, Rp>)> {
let index = self.last_claimed.fetch_add(1, Ordering::Relaxed);
let n = self.channels.len();
for i in (index % n)..n {
let tmp = self.channels[i].split();
if tmp.is_some() {
return tmp;
}
}
for i in 0..(index % n) {
let tmp = self.channels[i].split();
if tmp.is_some() {
return tmp;
}
}
None
}
}
impl<Rq, Rp> Clone for InterchangeRef<'_, Rq, Rp> {
fn clone(&self) -> Self {
*self
}
}
impl<Rq, Rp> Copy for InterchangeRef<'_, Rq, Rp> {}
impl<Rq, Rp, const N: usize> Default for Interchange<Rq, Rp, N> {
fn default() -> Self {
Self::new()
}
}
const _ASSERT_COMPILE_FAILS: () = {};
#[cfg(all(not(loom), test))]
mod tests {
use super::*;
#[derive(Clone, Debug, PartialEq)]
pub enum Request {
This(u8, u32),
}
#[derive(Clone, Debug, PartialEq)]
pub enum Response {
Here(u8, u8, u8),
There(i16),
}
impl Default for Response {
fn default() -> Self {
Response::There(1)
}
}
impl Default for Request {
fn default() -> Self {
Request::This(0, 0)
}
}
#[test]
fn interchange() {
static INTERCHANGE: Interchange<Request, Response, 1> = Interchange::new();
let (mut rq, mut rp) = INTERCHANGE.claim().unwrap();
assert_eq!(rq.state(), State::Idle);
let request = Request::This(1, 2);
assert!(rq.request(request).is_ok());
let request = rp.take_request().unwrap();
println!("rp got request: {request:?}");
let response = Response::There(-1);
assert!(!rp.is_canceled());
assert!(rp.respond(response).is_ok());
let response = rq.take_response().unwrap();
println!("rq got response: {response:?}");
assert!(rq.request(request).is_ok());
let request = rq.cancel().unwrap().unwrap();
println!("responder could cancel: {request:?}");
assert!(rp.take_request().is_none());
assert_eq!(State::Idle, rq.state());
assert!(rq.request(request).is_ok());
let request = rp.take_request().unwrap();
println!(
"responder could cancel: {:?}",
&rq.cancel().unwrap().is_none()
);
assert_eq!(request, Request::This(1, 2));
assert!(rp.is_canceled());
assert!(rp.respond(response).is_err());
assert!(rp.acknowledge_cancel().is_ok());
assert_eq!(State::Idle, rq.state());
rq.with_request_mut(|r| *r = Request::This(1, 2)).unwrap();
assert!(rq.send_request().is_ok());
let request = rp.take_request().unwrap();
assert_eq!(request, Request::This(1, 2));
println!("rp got request: {request:?}");
rp.with_response_mut(|r| *r = Response::Here(3, 2, 1))
.unwrap();
assert!(rp.send_response().is_ok());
let response = rq.take_response().unwrap();
assert_eq!(response, Response::Here(3, 2, 1));
}
#[test]
fn interchange_ref() {
static INTERCHANGE_INNER: Interchange<Request, Response, 1> = Interchange::new();
static INTERCHANGE: InterchangeRef<'static, Request, Response> =
INTERCHANGE_INNER.as_interchange_ref();
let (mut rq, mut rp) = INTERCHANGE.claim().unwrap();
assert_eq!(rq.state(), State::Idle);
let request = Request::This(1, 2);
assert!(rq.request(request).is_ok());
let request = rp.take_request().unwrap();
println!("rp got request: {request:?}");
let response = Response::There(-1);
assert!(!rp.is_canceled());
assert!(rp.respond(response).is_ok());
let response = rq.take_response().unwrap();
println!("rq got response: {response:?}");
assert!(rq.request(request).is_ok());
let request = rq.cancel().unwrap().unwrap();
println!("responder could cancel: {request:?}");
assert!(rp.take_request().is_none());
assert_eq!(State::Idle, rq.state());
assert!(rq.request(request).is_ok());
let request = rp.take_request().unwrap();
println!(
"responder could cancel: {:?}",
&rq.cancel().unwrap().is_none()
);
assert_eq!(request, Request::This(1, 2));
assert!(rp.is_canceled());
assert!(rp.respond(response).is_err());
assert!(rp.acknowledge_cancel().is_ok());
assert_eq!(State::Idle, rq.state());
rq.with_request_mut(|r| *r = Request::This(1, 2)).unwrap();
assert!(rq.send_request().is_ok());
let request = rp.take_request().unwrap();
assert_eq!(request, Request::This(1, 2));
println!("rp got request: {request:?}");
rp.with_response_mut(|r| *r = Response::Here(3, 2, 1))
.unwrap();
assert!(rp.send_response().is_ok());
let response = rq.take_response().unwrap();
assert_eq!(response, Response::Here(3, 2, 1));
}
#[allow(unconditional_recursion, clippy::extra_unused_type_parameters, unused)]
fn assert_send<T: Send>() {
assert_send::<Channel<String, u32>>();
assert_send::<Responder<'static, String, u32>>();
assert_send::<Requester<'static, String, u32>>();
assert_send::<Channel<&'static mut String, u32>>();
assert_send::<Responder<'static, &'static mut String, u32>>();
assert_send::<Requester<'static, &'static mut String, u32>>();
}
#[allow(unconditional_recursion, clippy::extra_unused_type_parameters, unused)]
fn assert_sync<T: Sync>() {
assert_sync::<Channel<String, u32>>();
assert_sync::<Channel<String, u32>>();
assert_sync::<Responder<'static, String, u32>>();
assert_sync::<Requester<'static, String, u32>>();
assert_sync::<Channel<&'static mut String, u32>>();
assert_sync::<Responder<'static, &'static mut String, u32>>();
assert_sync::<Requester<'static, &'static mut String, u32>>();
}
}