x11rb_async/rust_connection/
shared_state.rs

1//! The state of the connection that is shared with the reading future
2
3use event_listener::Event;
4use futures_lite::future;
5use std::convert::Infallible;
6use std::io;
7use std::mem;
8use std::sync::{
9    atomic::{AtomicBool, Ordering},
10    Arc, Mutex as StdMutex, MutexGuard as StdMutexGuard,
11};
12use x11rb::errors::ConnectionError;
13use x11rb_protocol::connection::Connection as ProtoConnection;
14use x11rb_protocol::packet_reader::PacketReader as ProtoPacketReader;
15use x11rb_protocol::RawFdContainer;
16
17use super::Stream;
18
19/// State shared between the `RustConnection` and the future polling for new packets.
20#[derive(Debug)]
21pub(super) struct SharedState<S> {
22    /// The underlying connection manager.
23    ///
24    /// This is never held across an `.await` point, so it's fine to use a standard library mutex.
25    inner: StdMutex<ProtoConnection>,
26
27    /// The stream for communicating with the X11 server.
28    pub(super) stream: S,
29
30    /// Listener for when new data is available on the stream.
31    new_input: Event,
32
33    /// Flag that indicates that the future for drive() was dropped and we no longer read input.
34    driver_dropped: AtomicBool,
35}
36
37impl<S: Stream> SharedState<S> {
38    pub(super) fn new(stream: S) -> Self {
39        Self {
40            inner: Default::default(),
41            stream,
42            new_input: Event::new(),
43            driver_dropped: AtomicBool::new(false),
44        }
45    }
46
47    /// Lock the inner connection and return a mutex guard for it.
48    pub(super) fn lock_connection(&self) -> StdMutexGuard<'_, ProtoConnection> {
49        self.inner.lock().unwrap()
50    }
51
52    /// Wait for an incoming packet.
53    ///
54    /// The given function get_reply should check whether the needed package was already received
55    /// and put into the inner connection. It should return `None` if nothing is present yet and
56    /// new incoming X11 packets should be awaited.
57    pub(super) async fn wait_for_incoming<R, F>(&self, mut get_reply: F) -> Result<R, io::Error>
58    where
59        F: FnMut(&mut ProtoConnection) -> Option<R>,
60    {
61        loop {
62            // See if we can find the reply in the connection.
63            if let Some(reply) = get_reply(&mut self.lock_connection()) {
64                return Ok(reply);
65            }
66
67            // Register a listener for the reply.
68            let listener = self.new_input.listen();
69
70            // Maybe a packet was delivered while we were registering the listener.
71            if let Some(reply) = get_reply(&mut self.lock_connection()) {
72                return Ok(reply);
73            }
74
75            // Maybe the future from drive() was dropped?
76            // We only check this down here and not before the listener since this is unlikely
77            if self.driver_dropped.load(Ordering::SeqCst) {
78                return Err(io::Error::new(
79                    io::ErrorKind::Other,
80                    "Driving future was dropped",
81                ));
82            }
83
84            // Wait for the next packet.
85            listener.await;
86        }
87    }
88
89    /// Read incoming packets from the stream and put them into the inner connection.
90    pub(super) async fn drive(
91        &self,
92        _break_on_drop: BreakOnDrop<S>,
93    ) -> Result<Infallible, ConnectionError> {
94        let mut packet_reader = PacketReader {
95            read_buffer: vec![0; 4096].into_boxed_slice(),
96            inner: ProtoPacketReader::new(),
97        };
98        let mut fds = vec![];
99        let mut packets = vec![];
100
101        loop {
102            for _ in 0..50 {
103                // Try to read packets from the stream.
104                packet_reader.try_read_packets(&self.stream, &mut packets, &mut fds)?;
105                let packet_count = packets.len();
106
107                // Now, actually enqueue the packets.
108                {
109                    let mut inner = self.inner.lock().unwrap();
110                    inner.enqueue_fds(mem::take(&mut fds));
111                    packets
112                        .drain(..)
113                        .for_each(|packet| inner.enqueue_packet(packet));
114                }
115
116                if packet_count > 0 {
117                    // Notify any listeners that there is new data.
118                    let _num_notified = self.new_input.notify_additional(usize::MAX);
119                } else {
120                    // Wait for more data.
121                    self.stream.readable().await?;
122                }
123            }
124
125            // In the case of a large influx of packets, don't starve other tasks.
126            future::yield_now().await;
127        }
128    }
129}
130
131#[derive(Debug)]
132struct PacketReader {
133    /// The read buffer to store incoming bytes in.
134    read_buffer: Box<[u8]>,
135
136    /// The inner reader that breaks these bytes into packets.
137    inner: ProtoPacketReader,
138}
139
140impl PacketReader {
141    /// Try to read packets from the stream.
142    fn try_read_packets(
143        &mut self,
144        stream: &impl Stream,
145        out_packets: &mut Vec<Vec<u8>>,
146        fd_storage: &mut Vec<RawFdContainer>,
147    ) -> io::Result<()> {
148        let original_length = out_packets.len();
149        loop {
150            // If the necessary packet size is larger than our buffer, just fill straight
151            // into the buffer.
152            if self.inner.remaining_capacity() >= self.read_buffer.len() {
153                tracing::trace!(
154                    "Trying to read large packet with {} bytes remaining",
155                    self.inner.remaining_capacity()
156                );
157                match stream.read(self.inner.buffer(), fd_storage) {
158                    Ok(0) => {
159                        tracing::error!("Large read returned zero");
160                        return Err(io::Error::new(
161                            io::ErrorKind::UnexpectedEof,
162                            "The X11 server closed the connection",
163                        ));
164                    }
165                    Ok(n) => {
166                        tracing::trace!("Read {} bytes directly into large packet", n);
167                        if let Some(packet) = self.inner.advance(n) {
168                            out_packets.push(packet);
169                        }
170                    }
171                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
172                    Err(e) => return Err(e),
173                }
174            } else {
175                // read into our buffer
176                let nread = match stream.read(&mut self.read_buffer, fd_storage) {
177                    Ok(0) => {
178                        tracing::error!("Buffer read returned zero");
179                        return Err(io::Error::new(
180                            io::ErrorKind::UnexpectedEof,
181                            "The X11 server closed the connection",
182                        ));
183                    }
184                    Ok(n) => n,
185                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
186                    Err(e) => return Err(e),
187                };
188                tracing::trace!("Read {} bytes into read buffer", nread);
189
190                // begin reading that data into packets
191                let mut src = &self.read_buffer[..nread];
192                while !src.is_empty() {
193                    let dest = self.inner.buffer();
194                    let amt_to_read = std::cmp::min(src.len(), dest.len());
195
196                    dest[..amt_to_read].copy_from_slice(&src[..amt_to_read]);
197
198                    // reborrow src
199                    src = &src[amt_to_read..];
200
201                    // advance by the given amount
202                    if let Some(packet) = self.inner.advance(amt_to_read) {
203                        out_packets.push(packet);
204                    }
205                }
206            }
207        }
208        tracing::trace!(
209            "Read {} complete packet(s)",
210            out_packets.len() - original_length
211        );
212
213        Ok(())
214    }
215}
216
217#[derive(Debug)]
218pub(super) struct BreakOnDrop<S>(pub(super) Arc<SharedState<S>>);
219
220impl<S> Drop for BreakOnDrop<S> {
221    fn drop(&mut self) {
222        // Mark the connection as broken
223        self.0.driver_dropped.store(true, Ordering::SeqCst);
224
225        // Wake up everyone that might be waiting
226        let _num_notified = self.0.new_input.notify_additional(usize::MAX);
227    }
228}