privsep/
imsg.rs

1//! Internal message handling between privilege-separated processes.
2
3use crate::net::{AncillaryData, Fd, SocketAncillary, UnixStream, UnixStreamExt};
4use bytes::{BufMut, BytesMut};
5use derive_more::Into;
6use nix::unistd::{close, getpid};
7use parking_lot::Mutex;
8use serde::{de::DeserializeOwned, Serialize};
9use std::{
10    convert::TryFrom,
11    io::{self, Result},
12    mem,
13    os::unix::io::{AsRawFd, IntoRawFd, RawFd},
14    slice,
15    sync::atomic::{AtomicBool, Ordering},
16};
17use zerocopy::{AsBytes, FromBytes};
18
19/// `imsg` handler.
20#[derive(Debug, Into)]
21pub struct Handler {
22    /// Async half of a UNIX socketpair.
23    socket: UnixStream,
24    /// Set after the stream was shut down.
25    shutdown: AtomicBool,
26    /// Read buffer.
27    read_buffer: Mutex<BytesMut>,
28}
29
30impl From<UnixStream> for Handler {
31    fn from(socket: UnixStream) -> Self {
32        Self {
33            socket,
34            shutdown: Default::default(),
35            read_buffer: Mutex::new(BytesMut::with_capacity(Self::BUFFER_LENGTH)),
36        }
37    }
38}
39
40impl Handler {
41    pub const BUFFER_LENGTH: usize = 0xffff;
42
43    /// Create new handler pair.
44    pub fn pair() -> Result<(Self, Self)> {
45        UnixStream::pair().map(|(a, b)| (a.into(), b.into()))
46    }
47
48    pub fn socketpair() -> Result<(Fd, Fd)> {
49        let (a, b) = Self::pair()?;
50        let fd_a = Fd::from(a.as_raw_fd());
51        let fd_b = Fd::from(b.as_raw_fd());
52        mem::forget(a);
53        mem::forget(b);
54        Ok((fd_a, fd_b))
55    }
56
57    /// Create half of a handler pair from a file descriptor.
58    pub fn from_raw_fd<T: IntoRawFd>(fd: T) -> Result<Handler> {
59        let fd = fd.into_raw_fd();
60        unsafe { UnixStream::from_raw_fd(fd).map(Into::into) }
61    }
62
63    /// Send message to remote end.
64    pub async fn send_message<T: Serialize>(
65        &self,
66        message: Message,
67        fd: Option<&Fd>,
68        data: &T,
69    ) -> Result<()> {
70        if message.id < Message::RESERVED {
71            return Err(io::Error::new(io::ErrorKind::Other, "Reserved message ID"));
72        }
73        self.send_message_internal(message, fd, data).await
74    }
75
76    /// Send message to the remote end.
77    pub(crate) async fn send_message_internal<T: Serialize>(
78        &self,
79        mut message: Message,
80        fd: Option<&Fd>,
81        data: &T,
82    ) -> Result<()> {
83        if self.shutdown.load(Ordering::SeqCst) {
84            return Err(io::Error::new(
85                io::ErrorKind::NotConnected,
86                "Handler is closed",
87            ));
88        }
89        let data = bincode::serialize(data)
90            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
91        message.pid = getpid().as_raw();
92        message.length = u16::try_from(data.len() + message.length as usize)
93            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
94        let message_length = message.length as usize;
95        let iovs = [
96            io::IoSlice::new(message.as_bytes()),
97            io::IoSlice::new(&data),
98        ];
99        let bufs = if data.is_empty() {
100            &iovs[..1]
101        } else {
102            &iovs[..]
103        };
104
105        let mut ancillary_buffer = [0; 128];
106        let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
107        if let Some(fd) = fd {
108            if !ancillary.add_fds(&[fd.as_raw_fd()]) {
109                return Err(io::Error::new(io::ErrorKind::Other, "failed to add fd"));
110            }
111        }
112
113        let length = self
114            .socket
115            .send_vectored_with_ancillary(bufs, &mut ancillary)
116            .await?;
117
118        if length != message_length {
119            return Err(io::Error::new(io::ErrorKind::WriteZero, "short message"));
120        }
121
122        Ok(())
123    }
124
125    /// Receive message from the remote end.
126    pub async fn recv_message<T: DeserializeOwned>(
127        &self,
128    ) -> Result<Option<(Message, Option<Fd>, T)>> {
129        if self.shutdown.load(Ordering::SeqCst) {
130            return Err(io::Error::new(
131                io::ErrorKind::NotConnected,
132                "Handler is closed",
133            ));
134        }
135
136        let mut fd_result = None;
137        let mut message = Message::default();
138        let mut message_length: usize;
139
140        let received_buf = loop {
141            let mut buf = self.read_buffer.lock();
142
143            if buf.len() >= Message::HEADER_LENGTH {
144                message
145                    .as_bytes_mut()
146                    .copy_from_slice(&buf[..Message::HEADER_LENGTH]);
147                message_length = message.length as usize;
148
149                // We have a complete message, break out of the loop.
150                if buf.len() >= message_length {
151                    break buf.split_to(message_length);
152                }
153            }
154
155            let mut ancillary_buffer = [0u8; 128];
156            let mut ancillary = SocketAncillary::new(&mut ancillary_buffer[..]);
157
158            buf.reserve(Self::BUFFER_LENGTH);
159            let slice = unsafe {
160                slice::from_raw_parts_mut(buf.chunk_mut().as_mut_ptr(), Self::BUFFER_LENGTH)
161            };
162            let bufs = &mut [io::IoSliceMut::new(slice)][..];
163
164            // Read more data.  This is also our yield point in the loop.
165            let length = self
166                .socket
167                .recv_vectored_with_ancillary(bufs, &mut ancillary)
168                .await?;
169            if length == 0 {
170                return Ok(None);
171            }
172            unsafe { buf.advance_mut(length) };
173
174            for ancillary_result in ancillary.messages().flatten() {
175                #[allow(irrefutable_let_patterns)]
176                if let AncillaryData::ScmRights(scm_rights) = ancillary_result {
177                    for fd in scm_rights {
178                        let fd = Fd::from(fd);
179
180                        // We only return one fd per message and auto-
181                        // close all the remaining ones once the `Fd`
182                        // is dropped.
183                        if fd_result.is_none() {
184                            fd_result = Some(fd);
185                        }
186                    }
187                }
188            }
189        };
190
191        let result = if message_length > Message::HEADER_LENGTH {
192            bincode::deserialize(&received_buf[Message::HEADER_LENGTH..message_length])
193        } else {
194            bincode::deserialize(&[])
195        }
196        .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
197
198        Ok(Some((message, fd_result, result)))
199    }
200
201    /// Forcefully close the imsg handler without dropping it.
202    pub fn shutdown(&self) {
203        let fd = self.as_raw_fd();
204        let _ = close(fd);
205        self.shutdown.store(true, Ordering::SeqCst);
206    }
207}
208
209impl AsRawFd for Handler {
210    fn as_raw_fd(&self) -> RawFd {
211        self.socket.as_raw_fd()
212    }
213}
214
215/// Internal message header.
216#[derive(Debug, AsBytes, FromBytes, Default)]
217#[repr(C)]
218pub struct Message {
219    /// Request type.
220    pub id: u32,
221    /// Total message length (header + payload).
222    pub length: u16,
223    /// Optional flags.
224    pub flags: u16,
225    /// Optional peer ID.
226    pub peer_id: u32,
227    /// Local PID.
228    pub pid: libc::pid_t,
229}
230
231impl Message {
232    /// Reserved IDs 0-10
233    pub const RESERVED: u32 = 10;
234
235    /// Message header length.
236    pub const HEADER_LENGTH: usize = mem::size_of::<Self>();
237
238    /// Create new message header.
239    pub fn new<T: Into<u32>>(id: T) -> Self {
240        let length = Self::HEADER_LENGTH as u16;
241        Message {
242            id: id.into(),
243            pid: getpid().as_raw(),
244            length,
245            ..Default::default()
246        }
247    }
248
249    pub fn min() -> Self {
250        Self::RESERVED.into()
251    }
252
253    pub fn connect(peer_id: usize) -> Self {
254        Self {
255            peer_id: peer_id as u32,
256            ..Self::new(1u32)
257        }
258    }
259}
260
261impl<T: Into<u32>> From<T> for Message {
262    fn from(id: T) -> Self {
263        Message::new(id)
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    #[test]
270    fn test_empty_data() {
271        let data = bincode::serialize(&()).unwrap();
272        assert!(data.is_empty());
273    }
274}