#![no_std]
#![doc = include_str!("../README.md")]
use core::cell::RefCell;
use core::future::poll_fn;
use core::task::Waker;
use embassy_sync::blocking_mutex::Mutex;
use embassy_sync::blocking_mutex::raw::RawMutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RequestDroppedError;
pub struct RpcService<M, Req, Resp>
where
M: RawMutex,
{
state: Mutex<M, RefCell<State<Req, Resp>>>,
}
struct State<Req, Resp> {
client_busy: bool,
client_abandoned: bool,
waiting_client_slot_waker: Option<Waker>,
waiting_client_response_waker: Option<Waker>,
waiting_server_waker: Option<Waker>,
queued_request: Option<Req>,
queued_response: Option<Result<Resp, RequestDroppedError>>,
}
impl<Req, Resp> State<Req, Resp> {
const fn new() -> Self {
Self {
client_busy: false,
client_abandoned: false,
waiting_client_slot_waker: None,
waiting_client_response_waker: None,
waiting_server_waker: None,
queued_request: None,
queued_response: None,
}
}
}
struct InFlightGuard<'a, M, Req, Resp>
where
M: RawMutex,
{
service: &'a RpcService<M, Req, Resp>,
disarm: bool,
}
impl<'a, M, Req, Resp> InFlightGuard<'a, M, Req, Resp>
where
M: RawMutex,
{
fn new(service: &'a RpcService<M, Req, Resp>) -> Self {
Self {
service,
disarm: false,
}
}
fn defuse(mut self) {
self.disarm = true;
core::mem::forget(self);
}
}
impl<'a, M, Req, Resp> Drop for InFlightGuard<'a, M, Req, Resp>
where
M: RawMutex,
{
fn drop(&mut self) {
if self.disarm {
return;
}
self.service.state.lock(|state| {
let mut s = state.borrow_mut();
if let Some(req) = s.queued_request.take() {
drop(req);
s.client_abandoned = false;
s.client_busy = false;
if let Some(w) = s.waiting_client_slot_waker.take() {
w.wake();
}
return;
}
if let Some(resp) = s.queued_response.take() {
drop(resp);
s.client_abandoned = false;
s.client_busy = false;
if let Some(w) = s.waiting_client_slot_waker.take() {
w.wake();
}
return;
}
s.client_abandoned = true;
});
}
}
impl<M, Req, Resp> RpcService<M, Req, Resp>
where
M: RawMutex,
{
pub const fn new() -> Self {
Self {
state: Mutex::new(RefCell::new(State::new())),
}
}
pub async fn request(&self, req: Req) -> Result<Resp, RequestDroppedError> {
self.acquire_client_slot().await;
self.state.lock(|state| {
let mut state = state.borrow_mut();
state.queued_request = Some(req);
if let Some(waker) = state.waiting_server_waker.take() {
waker.wake();
}
});
let in_flight = InFlightGuard::new(self);
let result = poll_fn(|cx| {
self.state.lock(|state| {
let mut state = state.borrow_mut();
if let Some(resp) = state.queued_response.take() {
state.client_busy = false;
if let Some(waker) = state.waiting_client_slot_waker.take() {
waker.wake();
}
return core::task::Poll::Ready(resp);
}
state.waiting_client_response_waker = Some(cx.waker().clone());
core::task::Poll::Pending
})
})
.await;
in_flight.defuse();
result
}
pub async fn serve(&self) -> (Req, ServedRequest<'_, M, Req, Resp>) {
let req = poll_fn(|cx| {
self.state.lock(|state| {
let mut state = state.borrow_mut();
if let Some(req) = state.queued_request.take() {
return core::task::Poll::Ready(req);
}
state.waiting_server_waker = Some(cx.waker().clone());
core::task::Poll::Pending
})
})
.await;
let served = ServedRequest {
state: &self.state,
completed: false,
};
(req, served)
}
async fn acquire_client_slot(&self) {
poll_fn(|cx| {
self.state.lock(|state| {
let mut state = state.borrow_mut();
if !state.client_busy {
state.client_busy = true;
return core::task::Poll::Ready(());
}
state.waiting_client_slot_waker = Some(cx.waker().clone());
core::task::Poll::Pending
})
})
.await;
}
}
impl<M, Req, Resp> Default for RpcService<M, Req, Resp>
where
M: RawMutex,
{
fn default() -> Self {
Self::new()
}
}
pub struct ServedRequest<'a, M, Req, Resp>
where
M: RawMutex,
{
state: &'a Mutex<M, RefCell<State<Req, Resp>>>,
completed: bool,
}
impl<'a, M, Req, Resp> ServedRequest<'a, M, Req, Resp>
where
M: RawMutex,
{
pub fn respond(mut self, resp: Resp) {
self.state.lock(|state| {
let mut state = state.borrow_mut();
if state.client_abandoned {
state.client_abandoned = false;
state.client_busy = false;
if let Some(waker) = state.waiting_client_slot_waker.take() {
waker.wake();
}
} else {
state.queued_response = Some(Ok(resp));
if let Some(waker) = state.waiting_client_response_waker.take() {
waker.wake();
}
}
});
self.completed = true;
}
}
impl<'a, M, Req, Resp> Drop for ServedRequest<'a, M, Req, Resp>
where
M: RawMutex,
{
fn drop(&mut self) {
if !self.completed {
self.state.lock(|state| {
let mut state = state.borrow_mut();
if state.client_abandoned {
state.client_abandoned = false;
state.client_busy = false;
if let Some(waker) = state.waiting_client_slot_waker.take() {
waker.wake();
}
} else {
state.queued_response = Some(Err(RequestDroppedError));
if let Some(waker) = state.waiting_client_response_waker.take() {
waker.wake();
}
}
});
}
}
}