1#![deny(missing_docs)]
4
5use crossbeam_channel::{unbounded, Sender};
6use crossbeam_utils::thread::scope;
7use ipc_channel::ipc::{channel, IpcReceiver, IpcSender};
8use serde::{Deserialize, Serialize};
9use snafu::{ensure, OptionExt, ResultExt, Snafu};
10use std::{
11 any::Any,
12 collections::HashMap,
13 io,
14 sync::{
15 atomic::{AtomicBool, Ordering},
16 Arc,
17 },
18};
19
20pub mod ipc_error;
21pub mod labor;
22mod panic_error;
23
24use ipc_error::IpcErrorWrapper;
25pub use panic_error::PanicError;
26
27#[derive(Clone, Debug)]
28struct ThreadGuard {
29 sender: Sender<()>,
30}
31
32impl Drop for ThreadGuard {
33 fn drop(&mut self) {
34 let _ = self.sender.send(());
35 }
36}
37
38#[derive(Snafu, Debug)]
40pub enum Error {
41 #[snafu(display("Unable to initialize communication channel for {} channels", channels))]
43 MainChannelInit {
44 source: io::Error,
46
47 channels: usize,
49 },
50
51 #[snafu(display(
53 "Unable to initialize {} channels (at channel #{}): {}",
54 channels,
55 channel_id,
56 source,
57 ))]
58 ChannelsInit {
59 source: io::Error,
61
62 channel_id: usize,
64
65 channels: usize,
67 },
68
69 #[snafu(display(
71 "Can't make request, since there is no channel #{} ({} total channels)",
72 channel_id,
73 channels
74 ))]
75 ChannelNotFound {
76 channel_id: usize,
78
79 channels: usize,
81 },
82
83 #[snafu(display(
85 "Unable to initialize a channel for a response while working on channel #{}: {}",
86 channel_id,
87 source
88 ))]
89 ResponseChannelInit {
90 source: io::Error,
92
93 channel_id: usize,
95 },
96
97 #[snafu(display("Unable to initialize a quit confirmation channel: {}", source))]
99 QuitChannelInit {
100 source: io::Error,
102 },
103
104 #[snafu(display("Unable to send a request on a channel #{}: {}", channel_id, source))]
106 SendingRequest {
107 source: ipc_channel::Error,
109
110 channel_id: u64,
112 },
113
114 #[snafu(display(
116 "Unable to receiver a response on a channel #{}: {}",
117 channel_id,
118 source
119 ))]
120 ReceivingResponse {
121 #[snafu(source(from(ipc_channel::ipc::IpcError, From::from)))]
123 source: IpcErrorWrapper,
124
125 channel_id: u64,
127 },
128
129 #[snafu(display("Unable to receive a request on a channel: {}", source))]
131 ReceivingRequest {
132 source: crossbeam_channel::RecvError,
134 },
135
136 #[snafu(display("Unable to send a response to client {}: {}", client_id, source))]
138 SendingResponse {
139 client_id: u64,
141
142 source: ipc_channel::Error,
144 },
145
146 #[snafu(display("Unable to send a request because a system has stopped"))]
148 StoppedSendingRequest,
149
150 #[snafu(display("Unable to receive a response because a system has stopped"))]
152 StoppedReceivingResponse,
153
154 #[snafu(display("Error while receiving a message on a global IPC channel: {}", source))]
156 RouterReceive {
157 #[snafu(source(from(ipc_channel::ipc::IpcError, From::from)))]
159 source: IpcErrorWrapper,
160 },
161
162 #[snafu(display("Unable to send a request to a processor on channel #{}", channel_id))]
164 RouterSend {
165 channel_id: u64,
167 },
168}
169
170impl Error {
171 pub fn is_disconnected(&self) -> bool {
173 self.ipc_error()
174 .map(IpcErrorWrapper::is_disconnected)
175 .unwrap_or(false)
176 }
177
178 pub fn has_stopped(&self) -> bool {
180 match self {
181 Error::StoppedSendingRequest | Error::StoppedReceivingResponse => true,
182 _ => false,
183 }
184 }
185
186 pub fn ipc_error(&self) -> Option<&IpcErrorWrapper> {
188 match self {
189 Error::ReceivingResponse { source, .. } => Some(source),
191 _ => None,
193 }
194 }
195}
196#[derive(Serialize, Deserialize)]
197enum Message<Request, Response> {
198 Request {
199 channel_id: u64,
200 request: Request,
201 respond_to: u64,
202 },
203 Register {
204 client_id: u64,
205 sender: IpcSender<Response>,
206 },
207 Unregister {
208 client_id: u64,
209 },
210 Quit,
211}
212
213#[derive(Serialize, Deserialize)]
214enum InternalRequest<Request, Response> {
215 Normal {
216 request: Request,
217 respond_to: u64,
218 respond_channel: IpcSender<Response>,
219 },
220 Quit,
221}
222
223pub struct ClientBuilder<Request, Response>
225where
226 Request: Serialize,
227 Response: Serialize,
228{
229 sender: IpcSender<Message<Request, Response>>,
230 total_channels: u64,
231 running: Arc<AtomicBool>,
232}
233
234impl<Request, Response> Clone for ClientBuilder<Request, Response>
235where
236 for<'de> Request: Deserialize<'de> + Serialize,
237 for<'de> Response: Deserialize<'de> + Serialize,
238{
239 fn clone(&self) -> Self {
240 Self {
241 sender: self.sender.clone(),
242 running: Arc::clone(&self.running),
243 total_channels: self.total_channels,
244 }
245 }
246}
247
248impl<Request, Response> ClientBuilder<Request, Response>
249where
250 for<'de> Request: Deserialize<'de> + Serialize,
251 for<'de> Response: Deserialize<'de> + Serialize,
252{
253 pub fn build(&self) -> Client<Request, Response> {
255 Client::new(self.sender.clone(), &self.running, self.total_channels)
256 }
257}
258
259pub struct Client<Request, Response>
261where
262 Request: Serialize,
263 Response: Serialize,
264{
265 id: u64,
266 total_channels: u64,
267 sender: IpcSender<Message<Request, Response>>,
268 receiver: IpcReceiver<Response>,
269 running: Arc<AtomicBool>,
270}
271
272impl<Request, Response> Drop for Client<Request, Response>
273where
274 Request: Serialize,
275 Response: Serialize,
276{
277 fn drop(&mut self) {
278 let _ = self.sender.send(Message::Unregister { client_id: self.id });
279 }
280}
281
282impl<Request, Response> Clone for Client<Request, Response>
283where
284 for<'de> Request: Deserialize<'de> + Serialize,
285 for<'de> Response: Deserialize<'de> + Serialize,
286{
287 fn clone(&self) -> Self {
288 Client::new(self.sender.clone(), &self.running, self.total_channels)
289 }
290}
291
292impl<Request, Response> Client<Request, Response>
293where
294 for<'de> Request: Deserialize<'de> + Serialize,
295 for<'de> Response: Deserialize<'de> + Serialize,
296{
297 fn new(
298 server_sender: IpcSender<Message<Request, Response>>,
299 running: &Arc<AtomicBool>,
300 total_channels: u64,
301 ) -> Self {
302 let new_id = rand::Rng::gen(&mut rand::thread_rng());
303 let (sender, receiver) =
304 channel().expect("Can't initialize a sender-receiver pair; shouldn't fail");
305 server_sender
306 .send(Message::Register {
307 client_id: new_id,
308 sender: sender.clone(),
309 })
310 .expect("Unable to register a client");
311 Client {
312 id: new_id,
313 sender: server_sender,
314 running: Arc::clone(running),
315 receiver,
316 total_channels,
317 }
318 }
319
320 pub fn total_channels(&self) -> u64 {
322 self.total_channels
323 }
324
325 #[allow(clippy::redundant_clone)]
327 pub fn make_request(&self, channel_id: u64, request: Request) -> Result<Response, Error> {
328 ensure!(self.running.load(Ordering::SeqCst), StoppedSendingRequest);
329 self.sender
330 .send(Message::Request {
331 channel_id,
332 request,
333 respond_to: self.id,
334 })
335 .context(SendingRequest { channel_id })?;
336 ensure!(
337 self.running.load(Ordering::SeqCst),
338 StoppedReceivingResponse
339 );
340 self.receiver
341 .recv()
342 .context(ReceivingResponse { channel_id })
343 }
344}
345
346pub struct Processor<Request, Response> {
348 receiver: crossbeam_channel::Receiver<InternalRequest<Request, Response>>,
349}
350
351#[derive(Debug, Clone, Copy)]
353pub enum LoaferResult {
354 ImDone,
356
357 CallMeAgain,
359}
360
361fn maybe_message<T>(
362 rcv: &crossbeam_channel::Receiver<T>,
363) -> Result<Option<T>, crossbeam_channel::RecvError>
364where
365 for<'de> T: Deserialize<'de> + Serialize,
366{
367 match rcv.try_recv() {
368 Ok(item) => Ok(Some(item)),
369 Err(e) => match e {
370 crossbeam_channel::TryRecvError::Empty => Ok(None),
371 crossbeam_channel::TryRecvError::Disconnected => Err(crossbeam_channel::RecvError),
372 },
373 }
374}
375
376impl<Request, Response> Processor<Request, Response>
377where
378 for<'de> Request: Serialize + Deserialize<'de>,
379 for<'de> Response: Serialize + Deserialize<'de>,
380{
381 pub fn run_loop<P>(&self, mut proletarian: P) -> Result<(), Error>
386 where
387 P: labor::Proletarian<Request, Response>,
388 {
389 let mut should_block = false;
390 loop {
391 let item = if should_block {
392 self.receiver.recv().context(ReceivingRequest)?
393 } else if let Some(item) = maybe_message(&self.receiver).context(ReceivingRequest)? {
394 item
395 } else {
396 match proletarian.loaf() {
397 labor::LoafingResult::ImDone => {
398 should_block = true;
399 continue;
400 }
401 labor::LoafingResult::TouchMeAgain => {
402 should_block = false;
403 continue;
404 }
405 }
406 };
407 should_block = false;
408 match item {
409 InternalRequest::Quit => break Ok(()),
410 InternalRequest::Normal {
411 request,
412 respond_to,
413 respond_channel,
414 } => {
415 let response = proletarian.process_request(request);
416 if let Err(e) = respond_channel.send(response).context(SendingResponse {
417 client_id: respond_to,
418 }) {
419 log::error!("Unable to send a response: {}", e);
421 }
422 }
423 }
424 }
425 }
426}
427
428#[must_use = "One must call process requests in order for the communication to run"]
430pub struct Processors<Request, Response> {
431 pub processors: Vec<Processor<Request, Response>>,
433
434 pub router: Router<Request, Response>,
436
437 handle: ProcessorsHandle<Request, Response>,
439}
440
441pub struct ProcessorsHandle<Request, Response> {
443 sender: IpcSender<Message<Request, Response>>,
444 running: Arc<AtomicBool>,
445}
446
447impl<Request, Response> Clone for ProcessorsHandle<Request, Response>
448where
449 for<'de> Request: Deserialize<'de> + Serialize,
450 for<'de> Response: Deserialize<'de> + Serialize,
451{
452 fn clone(&self) -> Self {
453 ProcessorsHandle {
454 sender: self.sender.clone(),
455 running: self.running.clone(),
456 }
457 }
458}
459
460impl<Request, Response> ProcessorsHandle<Request, Response>
461where
462 for<'de> Request: Deserialize<'de> + Serialize,
463 for<'de> Response: Deserialize<'de> + Serialize,
464{
465 pub fn stop(&self) -> Result<(), Error> {
467 self.running.store(false, Ordering::SeqCst);
468 let _ = self.sender.send(Message::Quit);
469 Ok(())
470 }
471}
472
473#[derive(Snafu, Debug)]
475pub enum ParallelRunError {
476 #[snafu(display("Thread {:?} panicked: {}", thread_name, source))]
478 ThreadPanic {
479 thread_name: String,
481
482 #[snafu(source(from(Box<dyn Any + Send + 'static>, PanicError::new)))]
484 source: PanicError,
485 },
486
487 #[snafu(display("Non-joined thread panicked: {}", source))]
489 UnjoinedThreadPanic {
490 #[snafu(source(from(Box<dyn Any + Send + 'static>, PanicError::new)))]
492 source: PanicError,
493 },
494
495 #[snafu(display("Thread {:?} terminated with error: {}", thread_name, source))]
497 IpcError {
498 thread_name: String,
500
501 source: Error,
503 },
504
505 #[snafu(display(
507 "Failed to spawn a thread for processing channel #{}: {}",
508 channel_id,
509 source
510 ))]
511 SpawnError {
512 channel_id: usize,
514
515 source: io::Error,
517 },
518
519 #[snafu(display("Failed to spawn a thread for router: {}", source))]
521 RouterSpawn {
522 source: io::Error,
524 },
525}
526
527impl<Request, Response> Processors<Request, Response>
528where
529 for<'de> Request: Serialize + Deserialize<'de> + Send,
530 for<'de> Response: Serialize + Deserialize<'de> + Send,
531{
532 pub fn run_in_parallel<S>(self, socium: S) -> Result<Vec<ParallelRunError>, ParallelRunError>
534 where
535 S: labor::Socium<Request, Response> + Sync,
536 S::Proletarian: labor::Proletarian<Request, Response>,
537 {
538 let res = scope(|s| {
539 let (tx, rx) = unbounded::<()>();
540 let router_handler = {
542 let tx = tx.clone();
543 let router = self.router;
544 s.builder()
545 .name("Router".to_string())
546 .spawn(move |_| {
547 let _guard = ThreadGuard { sender: tx };
548 router.route()
549 })
550 .context(RouterSpawn)
551 };
552 let handlers = self
553 .processors
554 .into_iter()
555 .enumerate()
556 .map(|(channel_id, processor)| {
557 let name = format!("Channel #{}", channel_id);
558 let socium = &socium;
559 let tx = tx.clone();
560 s.builder()
561 .name(name)
562 .spawn(move |_| {
563 let _guard = ThreadGuard { sender: tx };
564 let prolet = socium.construct_proletarian(channel_id);
565 processor.run_loop(prolet)
566 })
567 .context(SpawnError { channel_id })
568 })
569 .chain(std::iter::once(router_handler))
570 .collect::<Result<Vec<_>, _>>()?;
571
572 let _ = rx.recv();
574 let _ = self.handle.stop();
575
576 let join_errors: Vec<_> = handlers
577 .into_iter()
578 .map(|handler| {
579 let thread_name = handler
580 .thread()
581 .name()
582 .unwrap_or("[unknown thread]")
583 .to_string();
584 let thread_name = &thread_name;
585 handler
586 .join()
587 .context(ThreadPanic { thread_name })?
588 .context(IpcError { thread_name })
589 })
590 .filter_map(|res| match res {
591 Ok(()) => None,
592 Err(e) => Some(e),
593 })
594 .collect();
595 Ok(join_errors)
596 })
597 .context(UnjoinedThreadPanic)??;
598 Ok(res)
599 }
600}
601
602pub struct Communication<Request, Response>
604where
605 Request: Serialize,
606 Response: Serialize,
607{
608 pub client_builder: ClientBuilder<Request, Response>,
610
611 pub processors: Processors<Request, Response>,
613
614 pub handle: ProcessorsHandle<Request, Response>,
616}
617
618pub fn communication<Request, Response>(
625 channels: usize,
626) -> Result<Communication<Request, Response>, Error>
627where
628 for<'de> Request: Deserialize<'de> + Serialize,
629 for<'de> Response: Deserialize<'de> + Serialize,
630{
631 let mut processors = Vec::with_capacity(channels);
632 let mut senders = Vec::with_capacity(channels);
633
634 let (ipc_sender, ipc_receiver) = ipc_channel::ipc::channel::<Message<Request, Response>>()
635 .context(MainChannelInit { channels })?;
636
637 for _channel_id in 0..channels {
638 let (sender, receiver) = unbounded();
639 processors.push(Processor { receiver });
640 senders.push(sender);
641 }
642
643 let running = Arc::new(AtomicBool::new(true));
644 let handle = ProcessorsHandle {
645 sender: ipc_sender.clone(),
646 running: Arc::clone(&running),
647 };
648 let client_builder = ClientBuilder {
649 sender: ipc_sender,
650 running,
651 total_channels: channels as u64,
652 };
653 let router = Router {
654 channels: senders,
655 ipc_receiver,
656 };
657 let processors = Processors {
658 processors,
659 handle: handle.clone(),
660 router,
661 };
662 Ok(Communication {
663 client_builder,
664 processors,
665 handle,
666 })
667}
668
669pub struct Router<Request, Response> {
671 ipc_receiver: IpcReceiver<Message<Request, Response>>,
672 channels: Vec<Sender<InternalRequest<Request, Response>>>,
673}
674
675impl<Request, Response> Router<Request, Response>
676where
677 for<'de> Request: Deserialize<'de> + Serialize,
678 for<'de> Response: Deserialize<'de> + Serialize,
679{
680 pub fn route(&self) -> Result<(), Error> {
682 let mut clients = HashMap::<u64, IpcSender<Response>>::new();
683
684 loop {
685 match self.ipc_receiver.recv().context(RouterReceive)? {
686 Message::Quit => {
687 for snd in &self.channels {
688 let _ = snd.send(InternalRequest::Quit);
689 }
690 break;
691 }
692 Message::Unregister { client_id } => {
693 if clients.remove(&client_id).is_none() {
694 log::error!("Client #{} wasn't registered!", client_id);
695 }
696 }
697 Message::Register { client_id, sender } => {
698 if clients.insert(client_id, sender).is_some() {
699 log::error!("A client #{} was alreay registered!", client_id);
700 }
701 }
702 Message::Request {
703 channel_id,
704 request,
705 respond_to,
706 } => {
707 if let Some(respond_channel) = clients.get(&respond_to) {
708 if let Some(channel) = self.channels.get(channel_id as usize) {
709 channel
710 .send(InternalRequest::Normal {
711 request,
712 respond_to,
713 respond_channel: respond_channel.clone(),
714 })
715 .ok()
716 .context(RouterSend { channel_id })?;
717 } else {
718 log::error!(
719 "Received a request from a client #{} on an unknown channel #{}",
720 respond_to,
721 channel_id
722 );
723 }
724 } else {
725 log::error!("Received a request from an unknown client #{}", respond_to);
726 }
727 }
728 }
729 }
730 Ok(())
731 }
732}
733
734#[cfg(test)]
735mod test {
736 use super::*;
737 use rand::{distributions::Standard, prelude::*};
738
739 #[test]
740 fn check() {
741 const CHANNELS: usize = 4;
742 const MAX_LEN: usize = 1024;
743 const CLIENT_THREADS: usize = 100;
744 const MESSAGES_PER_CLIENT: usize = 100;
745
746 let Communication {
747 client_builder,
748 processors,
749 handle,
750 } = communication::<Vec<u8>, _>(CHANNELS).unwrap();
751
752 let processors = std::thread::spawn(move || {
753 processors
754 .run_in_parallel(|_channel_id| |v: Vec<_>| v.len())
755 .unwrap()
756 });
757 scope(|s| {
758 for _ in 0..CLIENT_THREADS {
759 let client_builder = client_builder.clone();
760 s.spawn(move |_| {
761 let mut rng = thread_rng();
762 for _ in 0..MESSAGES_PER_CLIENT {
763 let channel_id = rng.gen_range(0, CHANNELS as u64);
764 let length = rng.gen_range(0, MAX_LEN);
765 let data = rng.sample_iter(Standard).take(length).collect();
766
767 let client = client_builder.build();
768 let response = client.make_request(channel_id, data).unwrap();
769 assert_eq!(response, length);
770 }
771 });
772 }
773 })
774 .unwrap();
775 handle.stop().unwrap();
776 processors.join().unwrap();
777 }
778}