ipc_communication/
lib.rs

1//! IPC communication helper crate.
2
3#![deny(missing_docs)]
4
5use crossbeam_channel::{unbounded, Sender};
6use crossbeam_utils::thread::scope;
7use ipc_channel::ipc::{channel, IpcReceiver, IpcSender};
8use serde::{Deserialize, Serialize};
9use snafu::{ensure, OptionExt, ResultExt, Snafu};
10use std::{
11    any::Any,
12    collections::HashMap,
13    io,
14    sync::{
15        atomic::{AtomicBool, Ordering},
16        Arc,
17    },
18};
19
20pub mod ipc_error;
21pub mod labor;
22mod panic_error;
23
24use ipc_error::IpcErrorWrapper;
25pub use panic_error::PanicError;
26
27#[derive(Clone, Debug)]
28struct ThreadGuard {
29    sender: Sender<()>,
30}
31
32impl Drop for ThreadGuard {
33    fn drop(&mut self) {
34        let _ = self.sender.send(());
35    }
36}
37
38/// IPC communication error.
39#[derive(Snafu, Debug)]
40pub enum Error {
41    /// Unable to initialize communication channel.
42    #[snafu(display("Unable to initialize communication channel for {} channels", channels))]
43    MainChannelInit {
44        /// Source IPC error.
45        source: io::Error,
46
47        /// Channels amount.
48        channels: usize,
49    },
50
51    /// Unable to initialize channels.
52    #[snafu(display(
53        "Unable to initialize {} channels (at channel #{}): {}",
54        channels,
55        channel_id,
56        source,
57    ))]
58    ChannelsInit {
59        /// Source I/O error.
60        source: io::Error,
61
62        /// Channel ID.
63        channel_id: usize,
64
65        /// Channels amount.
66        channels: usize,
67    },
68
69    /// Channel not found while making a request.
70    #[snafu(display(
71        "Can't make request, since there is no channel #{} ({} total channels)",
72        channel_id,
73        channels
74    ))]
75    ChannelNotFound {
76        /// Channel ID.
77        channel_id: usize,
78
79        /// Channels amount.
80        channels: usize,
81    },
82
83    /// Unable to initialize a channel for a response.
84    #[snafu(display(
85        "Unable to initialize a channel for a response while working on channel #{}: {}",
86        channel_id,
87        source
88    ))]
89    ResponseChannelInit {
90        /// Source I/O error.
91        source: io::Error,
92
93        /// Channel ID.
94        channel_id: usize,
95    },
96
97    /// Unable to initialize a quit confirmation channel.
98    #[snafu(display("Unable to initialize a quit confirmation channel: {}", source))]
99    QuitChannelInit {
100        /// Source I/O error,
101        source: io::Error,
102    },
103
104    /// Unable to send a request on a channel.
105    #[snafu(display("Unable to send a request on a channel #{}: {}", channel_id, source))]
106    SendingRequest {
107        /// Source IPC error.
108        source: ipc_channel::Error,
109
110        /// Channel ID.
111        channel_id: u64,
112    },
113
114    /// Unable to receiver a response on a channel.
115    #[snafu(display(
116        "Unable to receiver a response on a channel #{}: {}",
117        channel_id,
118        source
119    ))]
120    ReceivingResponse {
121        /// Source IPC error.
122        #[snafu(source(from(ipc_channel::ipc::IpcError, From::from)))]
123        source: IpcErrorWrapper,
124
125        /// Channel ID.
126        channel_id: u64,
127    },
128
129    /// Unable to receive a request on a channel.
130    #[snafu(display("Unable to receive a request on a channel: {}", source))]
131    ReceivingRequest {
132        /// Source IPC error.
133        source: crossbeam_channel::RecvError,
134    },
135
136    /// Unable to send a response on a channel.
137    #[snafu(display("Unable to send a response to client {}: {}", client_id, source))]
138    SendingResponse {
139        /// Client's ID.
140        client_id: u64,
141
142        /// IPC error.
143        source: ipc_channel::Error,
144    },
145
146    /// Unable to send a request because a system has stopped.
147    #[snafu(display("Unable to send a request because a system has stopped"))]
148    StoppedSendingRequest,
149
150    /// Unable to receive a response because a system has stopped.
151    #[snafu(display("Unable to receive a response because a system has stopped"))]
152    StoppedReceivingResponse,
153
154    /// Error while receiving a message on a global IPC channel.
155    #[snafu(display("Error while receiving a message on a global IPC channel: {}", source))]
156    RouterReceive {
157        /// Source IPC error.
158        #[snafu(source(from(ipc_channel::ipc::IpcError, From::from)))]
159        source: IpcErrorWrapper,
160    },
161
162    /// Unable to send a request to a processor.
163    #[snafu(display("Unable to send a request to a processor on channel #{}", channel_id))]
164    RouterSend {
165        /// Channel id.
166        channel_id: u64,
167    },
168}
169
170impl Error {
171    /// Checks if the other end has terminated.
172    pub fn is_disconnected(&self) -> bool {
173        self.ipc_error()
174            .map(IpcErrorWrapper::is_disconnected)
175            .unwrap_or(false)
176    }
177
178    /// Checks if the Error happened because the system was stopped.
179    pub fn has_stopped(&self) -> bool {
180        match self {
181            Error::StoppedSendingRequest | Error::StoppedReceivingResponse => true,
182            _ => false,
183        }
184    }
185
186    /// Returns the underlying `ipc-channel` error, if any.
187    pub fn ipc_error(&self) -> Option<&IpcErrorWrapper> {
188        match self {
189            // Error::SendingRequest { source, .. }
190            Error::ReceivingResponse { source, .. } => Some(source),
191            // | Error::SendingResponse { source, .. } => Some(source),
192            _ => None,
193        }
194    }
195}
196#[derive(Serialize, Deserialize)]
197enum Message<Request, Response> {
198    Request {
199        channel_id: u64,
200        request: Request,
201        respond_to: u64,
202    },
203    Register {
204        client_id: u64,
205        sender: IpcSender<Response>,
206    },
207    Unregister {
208        client_id: u64,
209    },
210    Quit,
211}
212
213#[derive(Serialize, Deserialize)]
214enum InternalRequest<Request, Response> {
215    Normal {
216        request: Request,
217        respond_to: u64,
218        respond_channel: IpcSender<Response>,
219    },
220    Quit,
221}
222
223/// An immutable clients builder.
224pub struct ClientBuilder<Request, Response>
225where
226    Request: Serialize,
227    Response: Serialize,
228{
229    sender: IpcSender<Message<Request, Response>>,
230    total_channels: u64,
231    running: Arc<AtomicBool>,
232}
233
234impl<Request, Response> Clone for ClientBuilder<Request, Response>
235where
236    for<'de> Request: Deserialize<'de> + Serialize,
237    for<'de> Response: Deserialize<'de> + Serialize,
238{
239    fn clone(&self) -> Self {
240        Self {
241            sender: self.sender.clone(),
242            running: Arc::clone(&self.running),
243            total_channels: self.total_channels,
244        }
245    }
246}
247
248impl<Request, Response> ClientBuilder<Request, Response>
249where
250    for<'de> Request: Deserialize<'de> + Serialize,
251    for<'de> Response: Deserialize<'de> + Serialize,
252{
253    /// Builds a client.
254    pub fn build(&self) -> Client<Request, Response> {
255        Client::new(self.sender.clone(), &self.running, self.total_channels)
256    }
257}
258
259/// A "client" capable of sending requests to processors.
260pub struct Client<Request, Response>
261where
262    Request: Serialize,
263    Response: Serialize,
264{
265    id: u64,
266    total_channels: u64,
267    sender: IpcSender<Message<Request, Response>>,
268    receiver: IpcReceiver<Response>,
269    running: Arc<AtomicBool>,
270}
271
272impl<Request, Response> Drop for Client<Request, Response>
273where
274    Request: Serialize,
275    Response: Serialize,
276{
277    fn drop(&mut self) {
278        let _ = self.sender.send(Message::Unregister { client_id: self.id });
279    }
280}
281
282impl<Request, Response> Clone for Client<Request, Response>
283where
284    for<'de> Request: Deserialize<'de> + Serialize,
285    for<'de> Response: Deserialize<'de> + Serialize,
286{
287    fn clone(&self) -> Self {
288        Client::new(self.sender.clone(), &self.running, self.total_channels)
289    }
290}
291
292impl<Request, Response> Client<Request, Response>
293where
294    for<'de> Request: Deserialize<'de> + Serialize,
295    for<'de> Response: Deserialize<'de> + Serialize,
296{
297    fn new(
298        server_sender: IpcSender<Message<Request, Response>>,
299        running: &Arc<AtomicBool>,
300        total_channels: u64,
301    ) -> Self {
302        let new_id = rand::Rng::gen(&mut rand::thread_rng());
303        let (sender, receiver) =
304            channel().expect("Can't initialize a sender-receiver pair; shouldn't fail");
305        server_sender
306            .send(Message::Register {
307                client_id: new_id,
308                sender: sender.clone(),
309            })
310            .expect("Unable to register a client");
311        Client {
312            id: new_id,
313            sender: server_sender,
314            running: Arc::clone(running),
315            receiver,
316            total_channels,
317        }
318    }
319
320    /// Returns the amount of available channels.
321    pub fn total_channels(&self) -> u64 {
322        self.total_channels
323    }
324
325    /// Sends a request to a given channel id and waits for a response.
326    #[allow(clippy::redundant_clone)]
327    pub fn make_request(&self, channel_id: u64, request: Request) -> Result<Response, Error> {
328        ensure!(self.running.load(Ordering::SeqCst), StoppedSendingRequest);
329        self.sender
330            .send(Message::Request {
331                channel_id,
332                request,
333                respond_to: self.id,
334            })
335            .context(SendingRequest { channel_id })?;
336        ensure!(
337            self.running.load(Ordering::SeqCst),
338            StoppedReceivingResponse
339        );
340        self.receiver
341            .recv()
342            .context(ReceivingResponse { channel_id })
343    }
344}
345
346/// Requests processor.
347pub struct Processor<Request, Response> {
348    receiver: crossbeam_channel::Receiver<InternalRequest<Request, Response>>,
349}
350
351/// A result returned by a "loafer".
352#[derive(Debug, Clone, Copy)]
353pub enum LoaferResult {
354    /// The caller can block on receiving data, since the loafer has done all it needed.
355    ImDone,
356
357    /// A hint to call the loafer again.
358    CallMeAgain,
359}
360
361fn maybe_message<T>(
362    rcv: &crossbeam_channel::Receiver<T>,
363) -> Result<Option<T>, crossbeam_channel::RecvError>
364where
365    for<'de> T: Deserialize<'de> + Serialize,
366{
367    match rcv.try_recv() {
368        Ok(item) => Ok(Some(item)),
369        Err(e) => match e {
370            crossbeam_channel::TryRecvError::Empty => Ok(None),
371            crossbeam_channel::TryRecvError::Disconnected => Err(crossbeam_channel::RecvError),
372        },
373    }
374}
375
376impl<Request, Response> Processor<Request, Response>
377where
378    for<'de> Request: Serialize + Deserialize<'de>,
379    for<'de> Response: Serialize + Deserialize<'de>,
380{
381    /// Runs infinitely, processing incoming request using a given closure and sending generated
382    /// responses back to the clients.
383    ///
384    /// The `loafer` is called every time there are no more messages in the queue.
385    pub fn run_loop<P>(&self, mut proletarian: P) -> Result<(), Error>
386    where
387        P: labor::Proletarian<Request, Response>,
388    {
389        let mut should_block = false;
390        loop {
391            let item = if should_block {
392                self.receiver.recv().context(ReceivingRequest)?
393            } else if let Some(item) = maybe_message(&self.receiver).context(ReceivingRequest)? {
394                item
395            } else {
396                match proletarian.loaf() {
397                    labor::LoafingResult::ImDone => {
398                        should_block = true;
399                        continue;
400                    }
401                    labor::LoafingResult::TouchMeAgain => {
402                        should_block = false;
403                        continue;
404                    }
405                }
406            };
407            should_block = false;
408            match item {
409                InternalRequest::Quit => break Ok(()),
410                InternalRequest::Normal {
411                    request,
412                    respond_to,
413                    respond_channel,
414                } => {
415                    let response = proletarian.process_request(request);
416                    if let Err(e) = respond_channel.send(response).context(SendingResponse {
417                        client_id: respond_to,
418                    }) {
419                        // Do not stop execution when sending a response fails.
420                        log::error!("Unable to send a response: {}", e);
421                    }
422                }
423            }
424        }
425    }
426}
427
428/// Request processors.
429#[must_use = "One must call process requests in order for the communication to run"]
430pub struct Processors<Request, Response> {
431    /// The underlying processors.
432    pub processors: Vec<Processor<Request, Response>>,
433
434    /// Requests router.
435    pub router: Router<Request, Response>,
436
437    /// Processors handle.
438    handle: ProcessorsHandle<Request, Response>,
439}
440
441/// Processors handler.
442pub struct ProcessorsHandle<Request, Response> {
443    sender: IpcSender<Message<Request, Response>>,
444    running: Arc<AtomicBool>,
445}
446
447impl<Request, Response> Clone for ProcessorsHandle<Request, Response>
448where
449    for<'de> Request: Deserialize<'de> + Serialize,
450    for<'de> Response: Deserialize<'de> + Serialize,
451{
452    fn clone(&self) -> Self {
453        ProcessorsHandle {
454            sender: self.sender.clone(),
455            running: self.running.clone(),
456        }
457    }
458}
459
460impl<Request, Response> ProcessorsHandle<Request, Response>
461where
462    for<'de> Request: Deserialize<'de> + Serialize,
463    for<'de> Response: Deserialize<'de> + Serialize,
464{
465    /// Sends a stop signal to all the running processors and waits for them to receive the signal.
466    pub fn stop(&self) -> Result<(), Error> {
467        self.running.store(false, Ordering::SeqCst);
468        let _ = self.sender.send(Message::Quit);
469        Ok(())
470    }
471}
472
473/// An error that happened during a parallel execution of processors.
474#[derive(Snafu, Debug)]
475pub enum ParallelRunError {
476    /// A thread has panicked.
477    #[snafu(display("Thread {:?} panicked: {}", thread_name, source))]
478    ThreadPanic {
479        /// Thread name.
480        thread_name: String,
481
482        /// Panic message.
483        #[snafu(source(from(Box<dyn Any + Send + 'static>, PanicError::new)))]
484        source: PanicError,
485    },
486
487    /// An unknown thread has panicked. Shouldn't happen.
488    #[snafu(display("Non-joined thread panicked: {}", source))]
489    UnjoinedThreadPanic {
490        /// Panic message.
491        #[snafu(source(from(Box<dyn Any + Send + 'static>, PanicError::new)))]
492        source: PanicError,
493    },
494
495    /// IPC communication error.
496    #[snafu(display("Thread {:?} terminated with error: {}", thread_name, source))]
497    IpcError {
498        /// Thread name.
499        thread_name: String,
500
501        /// IPC error.
502        source: Error,
503    },
504
505    /// Spawn of a processor failed.
506    #[snafu(display(
507        "Failed to spawn a thread for processing channel #{}: {}",
508        channel_id,
509        source
510    ))]
511    SpawnError {
512        /// Channel ID.
513        channel_id: usize,
514
515        /// Spawn I/O error.
516        source: io::Error,
517    },
518
519    /// Can't spawn a router thread.
520    #[snafu(display("Failed to spawn a thread for router: {}", source))]
521    RouterSpawn {
522        /// Spawn I/O error.
523        source: io::Error,
524    },
525}
526
527impl<Request, Response> Processors<Request, Response>
528where
529    for<'de> Request: Serialize + Deserialize<'de> + Send,
530    for<'de> Response: Serialize + Deserialize<'de> + Send,
531{
532    /// Runs all the underlying responses in separate thread each using given socium.
533    pub fn run_in_parallel<S>(self, socium: S) -> Result<Vec<ParallelRunError>, ParallelRunError>
534    where
535        S: labor::Socium<Request, Response> + Sync,
536        S::Proletarian: labor::Proletarian<Request, Response>,
537    {
538        let res = scope(|s| {
539            let (tx, rx) = unbounded::<()>();
540            // let router_handler = self.router.route().context(RouterError)?;
541            let router_handler = {
542                let tx = tx.clone();
543                let router = self.router;
544                s.builder()
545                    .name("Router".to_string())
546                    .spawn(move |_| {
547                        let _guard = ThreadGuard { sender: tx };
548                        router.route()
549                    })
550                    .context(RouterSpawn)
551            };
552            let handlers = self
553                .processors
554                .into_iter()
555                .enumerate()
556                .map(|(channel_id, processor)| {
557                    let name = format!("Channel #{}", channel_id);
558                    let socium = &socium;
559                    let tx = tx.clone();
560                    s.builder()
561                        .name(name)
562                        .spawn(move |_| {
563                            let _guard = ThreadGuard { sender: tx };
564                            let prolet = socium.construct_proletarian(channel_id);
565                            processor.run_loop(prolet)
566                        })
567                        .context(SpawnError { channel_id })
568                })
569                .chain(std::iter::once(router_handler))
570                .collect::<Result<Vec<_>, _>>()?;
571
572            // Wait for the first channel to end and then join 'em all!
573            let _ = rx.recv();
574            let _ = self.handle.stop();
575
576            let join_errors: Vec<_> = handlers
577                .into_iter()
578                .map(|handler| {
579                    let thread_name = handler
580                        .thread()
581                        .name()
582                        .unwrap_or("[unknown thread]")
583                        .to_string();
584                    let thread_name = &thread_name;
585                    handler
586                        .join()
587                        .context(ThreadPanic { thread_name })?
588                        .context(IpcError { thread_name })
589                })
590                .filter_map(|res| match res {
591                    Ok(()) => None,
592                    Err(e) => Some(e),
593                })
594                .collect();
595            Ok(join_errors)
596        })
597        .context(UnjoinedThreadPanic)??;
598        Ok(res)
599    }
600}
601
602/// A helper structure that contains communication objects.
603pub struct Communication<Request, Response>
604where
605    Request: Serialize,
606    Response: Serialize,
607{
608    /// A clonable client builder.
609    pub client_builder: ClientBuilder<Request, Response>,
610
611    /// A processors container.
612    pub processors: Processors<Request, Response>,
613
614    /// A processors control handle.
615    pub handle: ProcessorsHandle<Request, Response>,
616}
617
618/// Sets up communication channels for a given request-response pairs.
619///
620/// `Client` is clonable and can be sent across multiple threads or even processes.
621///
622/// `Processor`s, on the other hand, are not clonable and can not be accessed from
623/// multiple threads simultaneously, but still can be sent across threads and processes.
624pub fn communication<Request, Response>(
625    channels: usize,
626) -> Result<Communication<Request, Response>, Error>
627where
628    for<'de> Request: Deserialize<'de> + Serialize,
629    for<'de> Response: Deserialize<'de> + Serialize,
630{
631    let mut processors = Vec::with_capacity(channels);
632    let mut senders = Vec::with_capacity(channels);
633
634    let (ipc_sender, ipc_receiver) = ipc_channel::ipc::channel::<Message<Request, Response>>()
635        .context(MainChannelInit { channels })?;
636
637    for _channel_id in 0..channels {
638        let (sender, receiver) = unbounded();
639        processors.push(Processor { receiver });
640        senders.push(sender);
641    }
642
643    let running = Arc::new(AtomicBool::new(true));
644    let handle = ProcessorsHandle {
645        sender: ipc_sender.clone(),
646        running: Arc::clone(&running),
647    };
648    let client_builder = ClientBuilder {
649        sender: ipc_sender,
650        running,
651        total_channels: channels as u64,
652    };
653    let router = Router {
654        channels: senders,
655        ipc_receiver,
656    };
657    let processors = Processors {
658        processors,
659        handle: handle.clone(),
660        router,
661    };
662    Ok(Communication {
663        client_builder,
664        processors,
665        handle,
666    })
667}
668
669/// Routs requests from IPC to internal processors.
670pub struct Router<Request, Response> {
671    ipc_receiver: IpcReceiver<Message<Request, Response>>,
672    channels: Vec<Sender<InternalRequest<Request, Response>>>,
673}
674
675impl<Request, Response> Router<Request, Response>
676where
677    for<'de> Request: Deserialize<'de> + Serialize,
678    for<'de> Response: Deserialize<'de> + Serialize,
679{
680    /// Starts routing.
681    pub fn route(&self) -> Result<(), Error> {
682        let mut clients = HashMap::<u64, IpcSender<Response>>::new();
683
684        loop {
685            match self.ipc_receiver.recv().context(RouterReceive)? {
686                Message::Quit => {
687                    for snd in &self.channels {
688                        let _ = snd.send(InternalRequest::Quit);
689                    }
690                    break;
691                }
692                Message::Unregister { client_id } => {
693                    if clients.remove(&client_id).is_none() {
694                        log::error!("Client #{} wasn't registered!", client_id);
695                    }
696                }
697                Message::Register { client_id, sender } => {
698                    if clients.insert(client_id, sender).is_some() {
699                        log::error!("A client #{} was alreay registered!", client_id);
700                    }
701                }
702                Message::Request {
703                    channel_id,
704                    request,
705                    respond_to,
706                } => {
707                    if let Some(respond_channel) = clients.get(&respond_to) {
708                        if let Some(channel) = self.channels.get(channel_id as usize) {
709                            channel
710                                .send(InternalRequest::Normal {
711                                    request,
712                                    respond_to,
713                                    respond_channel: respond_channel.clone(),
714                                })
715                                .ok()
716                                .context(RouterSend { channel_id })?;
717                        } else {
718                            log::error!(
719                                "Received a request from a client #{} on an unknown channel #{}",
720                                respond_to,
721                                channel_id
722                            );
723                        }
724                    } else {
725                        log::error!("Received a request from an unknown client #{}", respond_to);
726                    }
727                }
728            }
729        }
730        Ok(())
731    }
732}
733
734#[cfg(test)]
735mod test {
736    use super::*;
737    use rand::{distributions::Standard, prelude::*};
738
739    #[test]
740    fn check() {
741        const CHANNELS: usize = 4;
742        const MAX_LEN: usize = 1024;
743        const CLIENT_THREADS: usize = 100;
744        const MESSAGES_PER_CLIENT: usize = 100;
745
746        let Communication {
747            client_builder,
748            processors,
749            handle,
750        } = communication::<Vec<u8>, _>(CHANNELS).unwrap();
751
752        let processors = std::thread::spawn(move || {
753            processors
754                .run_in_parallel(|_channel_id| |v: Vec<_>| v.len())
755                .unwrap()
756        });
757        scope(|s| {
758            for _ in 0..CLIENT_THREADS {
759                let client_builder = client_builder.clone();
760                s.spawn(move |_| {
761                    let mut rng = thread_rng();
762                    for _ in 0..MESSAGES_PER_CLIENT {
763                        let channel_id = rng.gen_range(0, CHANNELS as u64);
764                        let length = rng.gen_range(0, MAX_LEN);
765                        let data = rng.sample_iter(Standard).take(length).collect();
766
767                        let client = client_builder.build();
768                        let response = client.make_request(channel_id, data).unwrap();
769                        assert_eq!(response, length);
770                    }
771                });
772            }
773        })
774        .unwrap();
775        handle.stop().unwrap();
776        processors.join().unwrap();
777    }
778}