use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;
use crate::async_rt::{ChannelEnd, ChannelReceiver, ChannelSender, Subscriber, TaskHandle};
use crate::error::IonError;
use crate::value::Value;
#[derive(Debug)]
struct TaskSlot {
inner: Mutex<SlotInner>,
cv: Condvar,
}
#[derive(Debug)]
struct SlotInner {
state: SlotState,
subs: Vec<Subscriber>,
}
#[derive(Debug)]
enum SlotState {
Running,
Finished(Result<Value, IonError>),
}
#[derive(Debug)]
pub struct StdTaskHandle {
slot: Arc<TaskSlot>,
join_handle: Mutex<Option<thread::JoinHandle<()>>>,
cancel_flag: Arc<AtomicBool>,
}
impl TaskHandle for StdTaskHandle {
fn join(&self) -> Result<Value, IonError> {
let handle_opt = self.join_handle.lock().unwrap().take();
if let Some(h) = handle_opt {
let mut inner = self.slot.inner.lock().unwrap();
while matches!(inner.state, SlotState::Running) {
inner = self.slot.cv.wait(inner).unwrap();
}
drop(inner);
let _ = h.join();
} else {
let mut inner = self.slot.inner.lock().unwrap();
while matches!(inner.state, SlotState::Running) {
inner = self.slot.cv.wait(inner).unwrap();
}
}
let inner = self.slot.inner.lock().unwrap();
match &inner.state {
SlotState::Finished(r) => r.clone(),
SlotState::Running => Err(IonError::runtime(
"task completion signalled but state still Running".to_string(),
0,
0,
)),
}
}
fn join_timeout(&self, timeout: Duration) -> Option<Result<Value, IonError>> {
let deadline = std::time::Instant::now() + timeout;
let mut inner = self.slot.inner.lock().unwrap();
while matches!(inner.state, SlotState::Running) {
let now = std::time::Instant::now();
if now >= deadline {
return None;
}
let (g, res) = self.slot.cv.wait_timeout(inner, deadline - now).unwrap();
inner = g;
if res.timed_out() && matches!(inner.state, SlotState::Running) {
return None;
}
}
drop(inner);
if let Some(h) = self.join_handle.lock().unwrap().take() {
let _ = h.join();
}
let inner = self.slot.inner.lock().unwrap();
match &inner.state {
SlotState::Finished(r) => Some(r.clone()),
SlotState::Running => None,
}
}
fn is_finished(&self) -> bool {
let inner = self.slot.inner.lock().unwrap();
matches!(inner.state, SlotState::Finished(_))
}
fn cancel(&self) {
self.cancel_flag.store(true, Ordering::Relaxed);
}
fn is_cancelled(&self) -> bool {
self.cancel_flag.load(Ordering::Relaxed)
}
fn subscribe(&self, sub: Subscriber) {
let mut inner = self.slot.inner.lock().unwrap();
if matches!(inner.state, SlotState::Finished(_)) {
drop(inner);
notify_subscriber(sub);
} else {
inner.subs.push(sub);
}
}
}
fn notify_subscriber(sub: Subscriber) {
let (mtx, cv) = &*sub.rendezvous;
let mut guard = mtx.lock().unwrap();
if guard.is_none() {
*guard = Some(sub.my_index);
cv.notify_one();
}
}
pub fn spawn_task_with_cancel<F>(cancel_flag: Arc<AtomicBool>, f: F) -> Arc<dyn TaskHandle>
where
F: FnOnce(Arc<AtomicBool>) -> Result<Value, IonError> + Send + 'static,
{
let slot = Arc::new(TaskSlot {
inner: Mutex::new(SlotInner {
state: SlotState::Running,
subs: Vec::new(),
}),
cv: Condvar::new(),
});
let worker_slot = slot.clone();
let worker_cancel = cancel_flag.clone();
let join_handle = thread::spawn(move || {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(worker_cancel)));
let result = match result {
Ok(r) => r,
Err(_) => Err(IonError::runtime("task panicked".to_string(), 0, 0)),
};
let subs_to_wake = {
let mut inner = worker_slot.inner.lock().unwrap();
inner.state = SlotState::Finished(result);
std::mem::take(&mut inner.subs)
};
worker_slot.cv.notify_all();
for sub in subs_to_wake {
notify_subscriber(sub);
}
});
Arc::new(StdTaskHandle {
slot,
join_handle: Mutex::new(Some(join_handle)),
cancel_flag,
})
}
use crossbeam_channel::{bounded, Receiver, Sender};
#[derive(Debug)]
pub struct StdChannelSender {
inner: Mutex<Option<Sender<Value>>>,
}
impl ChannelSender for StdChannelSender {
fn send(&self, val: Value) -> Result<(), IonError> {
let guard = self.inner.lock().unwrap();
match guard.as_ref() {
Some(sender) => sender
.send(val)
.map_err(|e| IonError::runtime(format!("channel send failed: {}", e), 0, 0)),
None => Err(IonError::runtime("channel is closed".to_string(), 0, 0)),
}
}
fn close(&self) {
let mut guard = self.inner.lock().unwrap();
*guard = None;
}
}
#[derive(Debug)]
pub struct StdChannelReceiver {
inner: Receiver<Value>,
}
impl ChannelReceiver for StdChannelReceiver {
fn recv(&self) -> Option<Value> {
self.inner.recv().ok()
}
fn try_recv(&self) -> Option<Value> {
self.inner.try_recv().ok()
}
fn recv_timeout(&self, timeout: Duration) -> Option<Value> {
self.inner.recv_timeout(timeout).ok()
}
}
pub fn create_channel(buffer: usize) -> (Value, Value) {
let (tx, rx) = bounded::<Value>(buffer.max(1));
(
Value::Channel(ChannelEnd::Sender(Arc::new(StdChannelSender {
inner: Mutex::new(Some(tx)),
}))),
Value::Channel(ChannelEnd::Receiver(Arc::new(StdChannelReceiver {
inner: rx,
}))),
)
}