#![deny(missing_docs)]
use crossbeam_channel::{unbounded, Sender};
use crossbeam_utils::thread::scope;
use ipc_channel::ipc::{channel, IpcReceiver, IpcSender};
use serde::{Deserialize, Serialize};
use snafu::{ensure, OptionExt, ResultExt, Snafu};
use std::{
any::Any,
collections::HashMap,
io,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
pub mod ipc_error;
pub mod labor;
mod panic_error;
use ipc_error::IpcErrorWrapper;
pub use panic_error::PanicError;
#[derive(Clone, Debug)]
struct ThreadGuard {
sender: Sender<()>,
}
impl Drop for ThreadGuard {
fn drop(&mut self) {
let _ = self.sender.send(());
}
}
#[derive(Snafu, Debug)]
pub enum Error {
#[snafu(display("Unable to initialize communication channel for {} channels", channels))]
MainChannelInit {
source: io::Error,
channels: usize,
},
#[snafu(display(
"Unable to initialize {} channels (at channel #{}): {}",
channels,
channel_id,
source,
))]
ChannelsInit {
source: io::Error,
channel_id: usize,
channels: usize,
},
#[snafu(display(
"Can't make request, since there is no channel #{} ({} total channels)",
channel_id,
channels
))]
ChannelNotFound {
channel_id: usize,
channels: usize,
},
#[snafu(display(
"Unable to initialize a channel for a response while working on channel #{}: {}",
channel_id,
source
))]
ResponseChannelInit {
source: io::Error,
channel_id: usize,
},
#[snafu(display("Unable to initialize a quit confirmation channel: {}", source))]
QuitChannelInit {
source: io::Error,
},
#[snafu(display("Unable to send a request on a channel #{}: {}", channel_id, source))]
SendingRequest {
source: ipc_channel::Error,
channel_id: u64,
},
#[snafu(display(
"Unable to receiver a response on a channel #{}: {}",
channel_id,
source
))]
ReceivingResponse {
#[snafu(source(from(ipc_channel::ipc::IpcError, From::from)))]
source: IpcErrorWrapper,
channel_id: u64,
},
#[snafu(display("Unable to receive a request on a channel: {}", source))]
ReceivingRequest {
source: crossbeam_channel::RecvError,
},
#[snafu(display("Unable to send a response to client {}: {}", client_id, source))]
SendingResponse {
client_id: u64,
source: ipc_channel::Error,
},
#[snafu(display("Unable to send a request because a system has stopped"))]
StoppedSendingRequest,
#[snafu(display("Unable to receive a response because a system has stopped"))]
StoppedReceivingResponse,
#[snafu(display("Error while receiving a message on a global IPC channel: {}", source))]
RouterReceive {
#[snafu(source(from(ipc_channel::ipc::IpcError, From::from)))]
source: IpcErrorWrapper,
},
#[snafu(display("Unable to send a request to a processor on channel #{}", channel_id))]
RouterSend {
channel_id: u64,
},
}
impl Error {
pub fn is_disconnected(&self) -> bool {
self.ipc_error()
.map(IpcErrorWrapper::is_disconnected)
.unwrap_or(false)
}
pub fn has_stopped(&self) -> bool {
match self {
Error::StoppedSendingRequest | Error::StoppedReceivingResponse => true,
_ => false,
}
}
pub fn ipc_error(&self) -> Option<&IpcErrorWrapper> {
match self {
Error::ReceivingResponse { source, .. } => Some(source),
_ => None,
}
}
}
#[derive(Serialize, Deserialize)]
enum Message<Request, Response> {
Request {
channel_id: u64,
request: Request,
respond_to: u64,
},
Register {
client_id: u64,
sender: IpcSender<Response>,
},
Unregister {
client_id: u64,
},
Quit,
}
#[derive(Serialize, Deserialize)]
enum InternalRequest<Request, Response> {
Normal {
request: Request,
respond_to: u64,
respond_channel: IpcSender<Response>,
},
Quit,
}
pub struct ClientBuilder<Request, Response>
where
Request: Serialize,
Response: Serialize,
{
sender: IpcSender<Message<Request, Response>>,
total_channels: u64,
running: Arc<AtomicBool>,
}
impl<Request, Response> Clone for ClientBuilder<Request, Response>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
running: Arc::clone(&self.running),
total_channels: self.total_channels,
}
}
}
impl<Request, Response> ClientBuilder<Request, Response>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
pub fn build(&self) -> Client<Request, Response> {
Client::new(self.sender.clone(), &self.running, self.total_channels)
}
}
pub struct Client<Request, Response>
where
Request: Serialize,
Response: Serialize,
{
id: u64,
total_channels: u64,
sender: IpcSender<Message<Request, Response>>,
receiver: IpcReceiver<Response>,
running: Arc<AtomicBool>,
}
impl<Request, Response> Drop for Client<Request, Response>
where
Request: Serialize,
Response: Serialize,
{
fn drop(&mut self) {
let _ = self.sender.send(Message::Unregister { client_id: self.id });
}
}
impl<Request, Response> Clone for Client<Request, Response>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
fn clone(&self) -> Self {
Client::new(self.sender.clone(), &self.running, self.total_channels)
}
}
impl<Request, Response> Client<Request, Response>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
fn new(
server_sender: IpcSender<Message<Request, Response>>,
running: &Arc<AtomicBool>,
total_channels: u64,
) -> Self {
let new_id = rand::Rng::gen(&mut rand::thread_rng());
let (sender, receiver) =
channel().expect("Can't initialize a sender-receiver pair; shouldn't fail");
server_sender
.send(Message::Register {
client_id: new_id,
sender: sender.clone(),
})
.expect("Unable to register a client");
Client {
id: new_id,
sender: server_sender,
running: Arc::clone(running),
receiver,
total_channels,
}
}
pub fn total_channels(&self) -> u64 {
self.total_channels
}
#[allow(clippy::redundant_clone)]
pub fn make_request(&self, channel_id: u64, request: Request) -> Result<Response, Error> {
ensure!(self.running.load(Ordering::SeqCst), StoppedSendingRequest);
self.sender
.send(Message::Request {
channel_id,
request,
respond_to: self.id,
})
.context(SendingRequest { channel_id })?;
ensure!(
self.running.load(Ordering::SeqCst),
StoppedReceivingResponse
);
self.receiver
.recv()
.context(ReceivingResponse { channel_id })
}
}
pub struct Processor<Request, Response> {
receiver: crossbeam_channel::Receiver<InternalRequest<Request, Response>>,
}
#[derive(Debug, Clone, Copy)]
pub enum LoaferResult {
ImDone,
CallMeAgain,
}
fn maybe_message<T>(
rcv: &crossbeam_channel::Receiver<T>,
) -> Result<Option<T>, crossbeam_channel::RecvError>
where
for<'de> T: Deserialize<'de> + Serialize,
{
match rcv.try_recv() {
Ok(item) => Ok(Some(item)),
Err(e) => match e {
crossbeam_channel::TryRecvError::Empty => Ok(None),
crossbeam_channel::TryRecvError::Disconnected => Err(crossbeam_channel::RecvError),
},
}
}
impl<Request, Response> Processor<Request, Response>
where
for<'de> Request: Serialize + Deserialize<'de>,
for<'de> Response: Serialize + Deserialize<'de>,
{
pub fn run_loop<P>(&self, mut proletarian: P) -> Result<(), Error>
where
P: labor::Proletarian<Request, Response>,
{
let mut should_block = false;
loop {
let item = if should_block {
self.receiver.recv().context(ReceivingRequest)?
} else if let Some(item) = maybe_message(&self.receiver).context(ReceivingRequest)? {
item
} else {
match proletarian.loaf() {
labor::LoafingResult::ImDone => {
should_block = true;
continue;
}
labor::LoafingResult::TouchMeAgain => {
should_block = false;
continue;
}
}
};
should_block = false;
match item {
InternalRequest::Quit => break Ok(()),
InternalRequest::Normal {
request,
respond_to,
respond_channel,
} => {
let response = proletarian.process_request(request);
if let Err(e) = respond_channel.send(response).context(SendingResponse {
client_id: respond_to,
}) {
log::error!("Unable to send a response: {}", e);
}
}
}
}
}
}
#[must_use = "One must call process requests in order for the communication to run"]
pub struct Processors<Request, Response> {
pub processors: Vec<Processor<Request, Response>>,
pub router: Router<Request, Response>,
handle: ProcessorsHandle<Request, Response>,
}
pub struct ProcessorsHandle<Request, Response> {
sender: IpcSender<Message<Request, Response>>,
running: Arc<AtomicBool>,
}
impl<Request, Response> Clone for ProcessorsHandle<Request, Response>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
fn clone(&self) -> Self {
ProcessorsHandle {
sender: self.sender.clone(),
running: self.running.clone(),
}
}
}
impl<Request, Response> ProcessorsHandle<Request, Response>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
pub fn stop(&self) -> Result<(), Error> {
self.running.store(false, Ordering::SeqCst);
let _ = self.sender.send(Message::Quit);
Ok(())
}
}
#[derive(Snafu, Debug)]
pub enum ParallelRunError {
#[snafu(display("Thread {:?} panicked: {}", thread_name, source))]
ThreadPanic {
thread_name: String,
#[snafu(source(from(Box<dyn Any + Send + 'static>, PanicError::new)))]
source: PanicError,
},
#[snafu(display("Non-joined thread panicked: {}", source))]
UnjoinedThreadPanic {
#[snafu(source(from(Box<dyn Any + Send + 'static>, PanicError::new)))]
source: PanicError,
},
#[snafu(display("Thread {:?} terminated with error: {}", thread_name, source))]
IpcError {
thread_name: String,
source: Error,
},
#[snafu(display(
"Failed to spawn a thread for processing channel #{}: {}",
channel_id,
source
))]
SpawnError {
channel_id: usize,
source: io::Error,
},
#[snafu(display("Failed to spawn a thread for router: {}", source))]
RouterSpawn {
source: io::Error,
},
}
impl<Request, Response> Processors<Request, Response>
where
for<'de> Request: Serialize + Deserialize<'de> + Send,
for<'de> Response: Serialize + Deserialize<'de> + Send,
{
pub fn run_in_parallel<S>(self, socium: S) -> Result<Vec<ParallelRunError>, ParallelRunError>
where
S: labor::Socium<Request, Response> + Sync,
S::Proletarian: labor::Proletarian<Request, Response>,
{
let res = scope(|s| {
let (tx, rx) = unbounded::<()>();
let router_handler = {
let tx = tx.clone();
let router = self.router;
s.builder()
.name("Router".to_string())
.spawn(move |_| {
let _guard = ThreadGuard { sender: tx };
router.route()
})
.context(RouterSpawn)
};
let handlers = self
.processors
.into_iter()
.enumerate()
.map(|(channel_id, processor)| {
let name = format!("Channel #{}", channel_id);
let socium = &socium;
let tx = tx.clone();
s.builder()
.name(name)
.spawn(move |_| {
let _guard = ThreadGuard { sender: tx };
let prolet = socium.construct_proletarian(channel_id);
processor.run_loop(prolet)
})
.context(SpawnError { channel_id })
})
.chain(std::iter::once(router_handler))
.collect::<Result<Vec<_>, _>>()?;
let _ = rx.recv();
let _ = self.handle.stop();
let join_errors: Vec<_> = handlers
.into_iter()
.map(|handler| {
let thread_name = handler
.thread()
.name()
.unwrap_or("[unknown thread]")
.to_string();
let thread_name = &thread_name;
handler
.join()
.context(ThreadPanic { thread_name })?
.context(IpcError { thread_name })
})
.filter_map(|res| match res {
Ok(()) => None,
Err(e) => Some(e),
})
.collect();
Ok(join_errors)
})
.context(UnjoinedThreadPanic)??;
Ok(res)
}
}
pub struct Communication<Request, Response>
where
Request: Serialize,
Response: Serialize,
{
pub client_builder: ClientBuilder<Request, Response>,
pub processors: Processors<Request, Response>,
pub handle: ProcessorsHandle<Request, Response>,
}
pub fn communication<Request, Response>(
channels: usize,
) -> Result<Communication<Request, Response>, Error>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
let mut processors = Vec::with_capacity(channels);
let mut senders = Vec::with_capacity(channels);
let (ipc_sender, ipc_receiver) = ipc_channel::ipc::channel::<Message<Request, Response>>()
.context(MainChannelInit { channels })?;
for _channel_id in 0..channels {
let (sender, receiver) = unbounded();
processors.push(Processor { receiver });
senders.push(sender);
}
let running = Arc::new(AtomicBool::new(true));
let handle = ProcessorsHandle {
sender: ipc_sender.clone(),
running: Arc::clone(&running),
};
let client_builder = ClientBuilder {
sender: ipc_sender,
running,
total_channels: channels as u64,
};
let router = Router {
channels: senders,
ipc_receiver,
};
let processors = Processors {
processors,
handle: handle.clone(),
router,
};
Ok(Communication {
client_builder,
processors,
handle,
})
}
pub struct Router<Request, Response> {
ipc_receiver: IpcReceiver<Message<Request, Response>>,
channels: Vec<Sender<InternalRequest<Request, Response>>>,
}
impl<Request, Response> Router<Request, Response>
where
for<'de> Request: Deserialize<'de> + Serialize,
for<'de> Response: Deserialize<'de> + Serialize,
{
pub fn route(&self) -> Result<(), Error> {
let mut clients = HashMap::<u64, IpcSender<Response>>::new();
loop {
match self.ipc_receiver.recv().context(RouterReceive)? {
Message::Quit => {
for snd in &self.channels {
let _ = snd.send(InternalRequest::Quit);
}
break;
}
Message::Unregister { client_id } => {
if clients.remove(&client_id).is_none() {
log::error!("Client #{} wasn't registered!", client_id);
}
}
Message::Register { client_id, sender } => {
if clients.insert(client_id, sender).is_some() {
log::error!("A client #{} was alreay registered!", client_id);
}
}
Message::Request {
channel_id,
request,
respond_to,
} => {
if let Some(respond_channel) = clients.get(&respond_to) {
if let Some(channel) = self.channels.get(channel_id as usize) {
channel
.send(InternalRequest::Normal {
request,
respond_to,
respond_channel: respond_channel.clone(),
})
.ok()
.context(RouterSend { channel_id })?;
} else {
log::error!(
"Received a request from a client #{} on an unknown channel #{}",
respond_to,
channel_id
);
}
} else {
log::error!("Received a request from an unknown client #{}", respond_to);
}
}
}
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use rand::{distributions::Standard, prelude::*};
#[test]
fn check() {
const CHANNELS: usize = 4;
const MAX_LEN: usize = 1024;
const CLIENT_THREADS: usize = 100;
const MESSAGES_PER_CLIENT: usize = 100;
let Communication {
client_builder,
processors,
handle,
} = communication::<Vec<u8>, _>(CHANNELS).unwrap();
let processors = std::thread::spawn(move || {
processors
.run_in_parallel(|_channel_id| |v: Vec<_>| v.len())
.unwrap()
});
scope(|s| {
for _ in 0..CLIENT_THREADS {
let client_builder = client_builder.clone();
s.spawn(move |_| {
let mut rng = thread_rng();
for _ in 0..MESSAGES_PER_CLIENT {
let channel_id = rng.gen_range(0, CHANNELS as u64);
let length = rng.gen_range(0, MAX_LEN);
let data = rng.sample_iter(Standard).take(length).collect();
let client = client_builder.build();
let response = client.make_request(channel_id, data).unwrap();
assert_eq!(response, length);
}
});
}
})
.unwrap();
handle.stop().unwrap();
processors.join().unwrap();
}
}