use log::debug;
#[allow(unused_imports)]
use std::{
marker::PhantomData,
mem,
sync::{mpsc, Arc},
thread,
};
use crate::{Error, Result};
pub enum Tx<Q, R = ()> {
N(mpsc::Sender<(Q, Option<mpsc::Sender<R>>)>),
S(mpsc::SyncSender<(Q, Option<mpsc::Sender<R>>)>),
}
impl<Q, R> Clone for Tx<Q, R> {
fn clone(&self) -> Self {
match self {
Tx::N(tx) => Tx::N(tx.clone()),
Tx::S(tx) => Tx::S(tx.clone()),
}
}
}
impl<Q, R> Tx<Q, R> {
pub fn post(&self, msg: Q) -> Result<()> {
match self {
Tx::N(tx) => err_at!(IPCFail, tx.send((msg, None)))?,
Tx::S(tx) => err_at!(IPCFail, tx.send((msg, None)))?,
};
Ok(())
}
pub fn request(&self, request: Q) -> Result<R> {
let (stx, srx) = mpsc::channel();
match self {
Tx::N(tx) => err_at!(IPCFail, tx.send((request, Some(stx))))?,
Tx::S(tx) => err_at!(IPCFail, tx.send((request, Some(stx))))?,
}
Ok(err_at!(IPCFail, srx.recv())?)
}
}
pub type Rx<Q, R = ()> = mpsc::Receiver<(Q, Option<mpsc::Sender<R>>)>;
pub struct Thread<Q, R = (), T = ()> {
name: String,
inner: Option<Inner<Q, R, T>>,
}
struct Inner<Q, R, T> {
handle: thread::JoinHandle<T>,
_req: PhantomData<Q>,
_res: PhantomData<R>,
}
impl<Q, R, T> Inner<Q, R, T> {
fn join(self) -> Result<T> {
match self.handle.join() {
Ok(val) => Ok(val),
Err(err) => err_at!(ThreadFail, msg: "fail {:?}", err),
}
}
}
impl<Q, R, T> Drop for Thread<Q, R, T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
inner.join().ok();
}
debug!(target: "thread", "dropped thread `{}`", self.name);
}
}
impl<Q, R, T> Thread<Q, R, T> {
pub fn new<F, N>(name: &str, main_loop: F) -> (Thread<Q, R, T>, Tx<Q, R>)
where
F: 'static + FnOnce(Rx<Q, R>) -> N + Send,
N: 'static + Send + FnOnce() -> T,
T: 'static + Send,
{
let (tx, rx) = mpsc::channel();
let handle = thread::spawn(main_loop(rx));
debug!(target: "thread", "{} spawned in async mode", name);
let th = Thread {
name: name.to_string(),
inner: Some(Inner {
handle,
_req: PhantomData,
_res: PhantomData,
}),
};
(th, Tx::N(tx))
}
pub fn new_sync<F, N>(
name: &str,
channel_size: usize,
main_loop: F,
) -> (Thread<Q, R, T>, Tx<Q, R>)
where
F: 'static + FnOnce(Rx<Q, R>) -> N + Send,
N: 'static + Send + FnOnce() -> T,
T: 'static + Send,
{
let (tx, rx) = mpsc::sync_channel(channel_size);
let handle = thread::spawn(main_loop(rx));
debug!(target: "thread", "{} spawned in sync mode", name);
let th = Thread {
name: name.to_string(),
inner: Some(Inner {
handle,
_req: PhantomData,
_res: PhantomData,
}),
};
(th, Tx::S(tx))
}
pub fn join(mut self) -> Result<T> {
self.inner.take().unwrap().join()
}
}