use std::cell::UnsafeCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::task::{waker_ref, ArcWake};
use super::CallTag;
use crate::call::Call;
use crate::cq::{CompletionQueue, WorkQueue};
use crate::error::{Error, Result};
use crate::grpc_sys::{self, grpc_call_error};
type SpawnHandle = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
pub(crate) struct Kicker {
call: Call,
}
impl Kicker {
pub fn from_call(call: Call) -> Kicker {
Kicker { call }
}
pub fn kick(&self, tag: Box<CallTag>) -> Result<()> {
let _ref = self.call.cq.borrow()?;
unsafe {
let ptr = Box::into_raw(tag);
let status = grpc_sys::grpcwrap_call_kick_completion_queue(self.call.call, ptr as _);
if status == grpc_call_error::GRPC_CALL_OK {
Ok(())
} else {
Err(Error::CallFailure(status))
}
}
}
}
unsafe impl Sync for Kicker {}
impl Clone for Kicker {
fn clone(&self) -> Kicker {
let call = unsafe {
grpc_sys::grpc_call_ref(self.call.call);
self.call.call
};
let cq = self.call.cq.clone();
Kicker {
call: Call { call, cq },
}
}
}
const NOTIFIED: u8 = 1;
const IDLE: u8 = 2;
const POLLING: u8 = 3;
const COMPLETED: u8 = 4;
pub struct SpawnTask {
handle: UnsafeCell<Option<SpawnHandle>>,
state: AtomicU8,
kicker: Kicker,
queue: Arc<WorkQueue>,
}
unsafe impl Sync for SpawnTask {}
impl SpawnTask {
fn new(s: SpawnHandle, kicker: Kicker, queue: Arc<WorkQueue>) -> SpawnTask {
SpawnTask {
handle: UnsafeCell::new(Some(s)),
state: AtomicU8::new(IDLE),
kicker,
queue,
}
}
fn mark_notified(&self) -> bool {
loop {
match self.state.compare_exchange_weak(
IDLE,
NOTIFIED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(POLLING) => match self.state.compare_exchange_weak(
POLLING,
NOTIFIED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Err(IDLE) | Err(POLLING) => continue,
_ => return false,
},
Err(IDLE) => continue,
_ => return false,
}
}
}
}
pub fn resolve(task: Arc<SpawnTask>, success: bool) {
assert!(success);
poll(task, true);
}
impl ArcWake for SpawnTask {
fn wake_by_ref(task: &Arc<Self>) {
if !task.mark_notified() {
return;
}
if let Some(UnfinishedWork(w)) = task.queue.push_work(UnfinishedWork(task.clone())) {
match task.kicker.kick(Box::new(CallTag::Spawn(w))) {
Err(Error::QueueShutdown) => (),
Err(e) => panic!("unexpected error when canceling call: {:?}", e),
_ => (),
}
}
}
}
pub struct UnfinishedWork(Arc<SpawnTask>);
impl UnfinishedWork {
pub fn finish(self) {
resolve(self.0, true);
}
}
fn poll(task: Arc<SpawnTask>, woken: bool) {
let mut init_state = if woken { NOTIFIED } else { IDLE };
loop {
match task
.state
.compare_exchange(init_state, POLLING, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {}
Err(COMPLETED) => return,
Err(s) => panic!("unexpected state {}", s),
}
let waker = waker_ref(&task);
let mut cx = Context::from_waker(&waker);
match unsafe { &mut *task.handle.get() }
.as_mut()
.unwrap()
.as_mut()
.poll(&mut cx)
{
Poll::Ready(()) => {
task.state.store(COMPLETED, Ordering::Release);
unsafe { &mut *task.handle.get() }.take();
}
_ => {
match task.state.compare_exchange(
POLLING,
IDLE,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return,
Err(NOTIFIED) => {
init_state = NOTIFIED;
}
Err(s) => panic!("unexpected state {}", s),
}
}
}
}
}
pub(crate) struct Executor<'a> {
cq: &'a CompletionQueue,
}
impl<'a> Executor<'a> {
pub fn new(cq: &CompletionQueue) -> Executor<'_> {
Executor { cq }
}
pub fn cq(&self) -> &CompletionQueue {
self.cq
}
pub fn spawn<F>(&self, f: F, kicker: Kicker)
where
F: Future<Output = ()> + Send + 'static,
{
let s = Box::pin(f);
let notify = Arc::new(SpawnTask::new(s, kicker, self.cq.worker.clone()));
poll(notify, false)
}
}