retty_io/
broadcast.rs

1//! Thread safe communication broadcast channel implementing `Evented`
2use super::{Evented, Poll, PollOpt, Ready, Registration, SetReadiness, Token};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::{
6    io,
7    sync::{Arc, Mutex},
8};
9
10/// Create a pair of the [`Sender`] and the [`Receiver`].
11///
12/// The [`Receiver`] implements the [`Evented`] so that it can be registered
13/// with the [`Poll`], while the [`Sender`] doesn't.
14pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
15    let (tx, rx) = crossbeam::channel::unbounded();
16
17    let waker = Arc::new(Mutex::new(HashMap::new()));
18
19    (
20        Sender {
21            waker: waker.clone(),
22            tx,
23        },
24        Receiver {
25            waker,
26            rx,
27            id: 0,
28            next_id: Arc::new(AtomicUsize::new(1)),
29        },
30    )
31}
32
33/// A wrapper of the [`crossbeam::channel::Receiver`].
34///
35/// It implements the [`Evented`] so that it can be registered with the [`Poll`].
36/// It ignores the [`Ready`] and always cause readable events.
37pub struct Receiver<T> {
38    waker: Arc<Mutex<HashMap<usize, (Registration, SetReadiness)>>>,
39    rx: crossbeam::channel::Receiver<T>,
40    id: usize,
41    next_id: Arc<AtomicUsize>,
42}
43
44impl<T> Clone for Receiver<T> {
45    fn clone(&self) -> Self {
46        let next_id = self.next_id.clone();
47        let id = next_id.fetch_add(1, Ordering::Relaxed);
48        Self {
49            waker: self.waker.clone(),
50            rx: self.rx.clone(),
51            id,
52            next_id,
53        }
54    }
55}
56
57impl<T> Receiver<T> {
58    /// Try to receive a value. It works just like [`crossbeam::channel::Receiver::try_recv`].
59    pub fn try_recv(&self) -> Result<T, crossbeam::channel::TryRecvError> {
60        self.rx.try_recv()
61    }
62}
63
64impl<T> Evented for Receiver<T> {
65    fn register(
66        &self,
67        poll: &Poll,
68        token: Token,
69        interest: Ready,
70        opts: PollOpt,
71    ) -> io::Result<()> {
72        let (registration, set_readiness) = Registration::new2();
73        poll.register(&registration, token, interest, opts)?;
74
75        let mut waker_map = self.waker.lock().unwrap();
76        if let std::collections::hash_map::Entry::Vacant(e) = waker_map.entry(self.id) {
77            e.insert((registration, set_readiness));
78            Ok(())
79        } else {
80            Err(io::Error::new(
81                io::ErrorKind::Other,
82                "receiver already registered",
83            ))
84        }
85    }
86
87    fn reregister(
88        &self,
89        poll: &Poll,
90        token: Token,
91        interest: Ready,
92        opts: PollOpt,
93    ) -> io::Result<()> {
94        let waker_map = self.waker.lock().unwrap();
95        if let Some((registration, _set_readiness)) = waker_map.get(&self.id) {
96            poll.reregister(registration, token, interest, opts)
97        } else {
98            Err(io::Error::new(
99                io::ErrorKind::Other,
100                "receiver not registered",
101            ))
102        }
103    }
104
105    fn deregister(&self, poll: &Poll) -> io::Result<()> {
106        let waker_map = self.waker.lock().unwrap();
107        if let Some((registration, _set_readiness)) = waker_map.get(&self.id) {
108            poll.deregister(registration)
109        } else {
110            Err(io::Error::new(
111                io::ErrorKind::Other,
112                "receiver not registered",
113            ))
114        }
115    }
116}
117
118/// A wrapper of the [`crossbeam::channel::Sender`].
119pub struct Sender<T> {
120    waker: Arc<Mutex<HashMap<usize, (Registration, SetReadiness)>>>,
121    tx: crossbeam::channel::Sender<T>,
122}
123
124impl<T> Sender<T> {
125    /// Try to send a value. It works just like [`crossbeam::channel::Sender::send`].
126    /// After sending it, it's waking up the [`Poll`].
127    ///
128    /// Note that it does not return any I/O error even if it occurs
129    /// when waking up the [`Poll`].
130    pub fn send(&self, t: T) -> Result<(), crossbeam::channel::SendError<T>> {
131        self.tx.send(t)?;
132
133        let mut waker_map = self.waker.lock().unwrap();
134        for (_registration, set_readiness) in waker_map.values_mut() {
135            let _ = set_readiness.set_readiness(Ready::readable());
136        }
137
138        Ok(())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::Events;
146    use crossbeam::sync::WaitGroup;
147
148    #[test]
149    fn test_channel() -> Result<(), Box<dyn std::error::Error>> {
150        let (tx, rx) = channel();
151
152        let handler = std::thread::spawn(move || {
153            std::thread::sleep(std::time::Duration::from_millis(1000));
154
155            let _ = tx.send("Hello world!");
156        });
157
158        let wg = WaitGroup::new();
159        for i in 0..2 {
160            let rx = rx.clone();
161            let wg = wg.clone();
162            std::thread::spawn(move || {
163                const CHANNEL: Token = Token(0);
164
165                let poll = Poll::new()?;
166                let mut events = Events::with_capacity(2);
167                poll.register(&rx, CHANNEL, Ready::readable(), PollOpt::edge())?;
168
169                poll.poll(&mut events, None)?;
170                for event in events.iter() {
171                    match event.token() {
172                        CHANNEL => {
173                            println!("receive CHANNEL {}", i);
174                            let _ = rx.try_recv();
175                            drop(wg);
176                            return Ok(());
177                        }
178                        _ => unreachable!(),
179                    }
180                }
181
182                Ok::<(), std::io::Error>(())
183            });
184        }
185
186        wg.wait();
187        let _ = handler.join();
188
189        Ok(())
190    }
191}