async_rustbus/conn/
mod.rs

1//! Low level non-blocking implementation of the DBus connection and some helper functions.
2//!
3//! These are used to create `RpcConn`, the primary async connection of this crate.
4//! Most end-user of this library will never need to touch this module.
5
6use std::collections::{HashSet, VecDeque};
7use std::io::{IoSlice, IoSliceMut};
8use std::mem;
9use std::net::Shutdown;
10use std::num::NonZeroU32;
11use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
12use std::os::unix::net::UnixStream as StdUnixStream;
13
14use tokio::io::*;
15use tokio::io::{AsyncRead, AsyncWrite};
16
17use std::io::ErrorKind;
18use std::path::Path;
19use tokio::net::TcpStream;
20use tokio::net::ToSocketAddrs;
21use tokio::net::UnixStream;
22
23use super::rustbus_core;
24
25use rustbus_core::message_builder::MarshalledMessage;
26
27mod ancillary;
28
29mod addr;
30pub use addr::{get_session_bus_addr, get_system_bus_addr, DBusAddr, DBUS_SESS_ENV, DBUS_SYS_PATH};
31mod recv;
32use recv::InState;
33pub(crate) use recv::RecvState;
34
35mod sender;
36pub(crate) use sender::SendState;
37
38use ancillary::{
39    recv_vectored_with_ancillary, send_vectored_with_ancillary, AncillaryData, SocketAncillary,
40};
41
42const DBUS_LINE_END_STR: &str = "\r\n";
43const DBUS_LINE_END: &[u8] = DBUS_LINE_END_STR.as_bytes();
44const DBUS_MAX_FD_MESSAGE: usize = 32;
45
46/// GenStream is a generic stream that can be used to read and write messages to/from the DBus socket.
47/// It allows for shutdown of the socket and send/recv of vectored data with ancillary data (usually FDs).
48/// Its drop method will close the socket
49pub(crate) struct GenStream {
50    fd: RawFd,
51}
52
53impl AsRawFd for GenStream {
54    fn as_raw_fd(&self) -> RawFd {
55        self.fd
56    }
57}
58impl FromRawFd for GenStream {
59    unsafe fn from_raw_fd(fd: RawFd) -> Self {
60        Self { fd }
61    }
62}
63
64impl GenStream {
65    fn recv_vectored_with_ancillary(
66        &self,
67        bufs: &mut [IoSliceMut<'_>],
68        ancillary: &mut SocketAncillary<'_>,
69    ) -> std::io::Result<usize> {
70        recv_vectored_with_ancillary(self.as_raw_fd(), bufs, ancillary)
71    }
72    fn send_vectored_with_ancillary(
73        &self,
74        bufs: &[IoSlice<'_>],
75        ancillary: &mut SocketAncillary<'_>,
76    ) -> std::io::Result<usize> {
77        send_vectored_with_ancillary(self.as_raw_fd(), bufs, ancillary)
78    }
79    fn shutdown(&self, how: Shutdown) -> std::io::Result<()> {
80        let how = match how {
81            Shutdown::Read => libc::SHUT_RD,
82            Shutdown::Write => libc::SHUT_WR,
83            Shutdown::Both => libc::SHUT_RDWR,
84        };
85        unsafe {
86            if libc::shutdown(self.as_raw_fd(), how) == -1 {
87                Err(std::io::Error::last_os_error())
88            } else {
89                Ok(())
90            }
91        }
92    }
93}
94impl Drop for GenStream {
95    fn drop(&mut self) {
96        unsafe {
97            libc::close(self.fd);
98        }
99    }
100}
101/// A synchronous non-blocking connection to DBus session.
102///
103/// Most people will want to use `RpcConn`. This is a low-level
104/// struct used by `RpcConn` to read and write messages to/from the DBus
105/// socket. It does minimal processing of data and provides no Async interfaces.
106/// # Notes
107/// * If you are interested in synchronous interface for DBus, the `rustbus` is a better solution.
108pub struct Conn {
109    pub(super) stream: GenStream,
110    pub(super) recv_state: RecvState,
111    pub(super) send_state: SendState,
112    serial: u32,
113}
114fn fd_or_os_err(fd: i32) -> std::io::Result<i32> {
115    if fd == -1 {
116        Err(std::io::Error::last_os_error())
117    } else {
118        Ok(fd)
119    }
120}
121// TODO: if https://github.com/async-rs/async-std/pull/961
122// is completed then perhaps this trait can be elimnated
123trait IntoRawFd {
124    fn into_raw_fd(self) -> std::io::Result<RawFd>;
125}
126impl<T: AsRawFd> IntoRawFd for T {
127    fn into_raw_fd(self) -> std::io::Result<RawFd> {
128        let fd = self.as_raw_fd();
129        unsafe { fd_or_os_err(libc::dup(fd)) }
130    }
131}
132impl Conn {
133    async fn conn_handshake<T>(mut stream: T, with_fd: bool) -> std::io::Result<Self>
134    where
135        T: AsyncRead + AsyncWrite + Unpin + IntoRawFd,
136    {
137        do_auth(&mut stream).await?;
138        if with_fd && !negotiate_unix_fds(&mut stream).await? {
139            return Err(std::io::Error::new(
140                ErrorKind::ConnectionAborted,
141                "Failed to negotiate Unix FDs!",
142            ));
143        }
144        stream.write_all(b"BEGIN\r\n").await?;
145        // SAFETY: into_raw_fd() gets an "owned" fd
146        // that can be taken by the StdUnixStream.
147        let stream = unsafe {
148            let fd = stream.into_raw_fd()?;
149            GenStream::from_raw_fd(fd)
150        };
151        Ok(Self {
152            recv_state: RecvState {
153                in_state: InState::Header(Vec::new()),
154                in_fds: Vec::new(),
155                with_fd,
156                remaining: Vec::with_capacity(4096),
157                rem_loc: 0,
158            },
159            send_state: SendState {
160                with_fd,
161                idx: 0,
162                queue: VecDeque::new(),
163            },
164            stream,
165            serial: 0,
166        })
167    }
168    pub async fn connect_to_addr<P: AsRef<Path>, S: ToSocketAddrs, B: AsRef<[u8]>>(
169        addr: &DBusAddr<P, S, B>,
170        with_fd: bool,
171    ) -> std::io::Result<Self> {
172        match addr {
173            DBusAddr::Path(p) => Self::conn_handshake(UnixStream::connect(p).await?, with_fd).await,
174            DBusAddr::Tcp(s) => {
175                if with_fd {
176                    Err(std::io::Error::new(
177                        ErrorKind::InvalidInput,
178                        "Cannot use Fds over TCP.",
179                    ))
180                } else {
181                    Self::conn_handshake(TcpStream::connect(s).await?, with_fd).await
182                }
183            }
184            #[cfg(target_os = "linux")]
185            DBusAddr::Abstract(buf) => unsafe {
186                let buf = buf.as_ref();
187                let mut addr: libc::sockaddr_un = mem::zeroed();
188                addr.sun_family = libc::AF_UNIX as u16;
189                // SAFETY: &[u8] has identical memory layout and size to &[i8]
190                #[cfg(not(target_arch = "arm"))]
191                let c_buf = &*(buf as *const [u8] as *const [i8]);
192
193                // for some reason ARM uses &[u8] instead of &[i8]
194                #[cfg(target_arch = "arm")]
195                let c_buf = &buf[..];
196                addr.sun_path
197                    .get_mut(1..1 + buf.len())
198                    .ok_or_else(|| {
199                        std::io::Error::new(
200                            ErrorKind::InvalidData,
201                            "Abstract unix socket address was too long!",
202                        )
203                    })?
204                    .copy_from_slice(c_buf);
205                //SAFETY: errors are apporiately handled
206                let fd = fd_or_os_err(libc::socket(libc::AF_UNIX, libc::SOCK_STREAM, 0))?;
207                if let Err(e) = fd_or_os_err(libc::connect(
208                    fd,
209                    &addr as *const libc::sockaddr_un as *const libc::sockaddr,
210                    (mem::size_of_val(&addr) - (108 - buf.len() - 1)) as u32,
211                )) {
212                    libc::close(fd);
213                    return Err(e);
214                }
215                let stream = StdUnixStream::from_raw_fd(fd);
216                let stream = UnixStream::from_std(stream)?;
217                Self::conn_handshake(stream, with_fd).await
218            },
219        }
220    }
221    async fn connect_to_path_byteorder<P: AsRef<Path>>(
222        p: P,
223        with_fd: bool,
224    ) -> std::io::Result<Self> {
225        let addr = DBusAddr::unix_path(p);
226        Self::connect_to_addr(&addr, with_fd).await
227    }
228    pub async fn connect_to_path<P: AsRef<Path>>(p: P, with_fd: bool) -> std::io::Result<Self> {
229        Self::connect_to_path_byteorder(p, with_fd).await
230    }
231    pub fn get_next_message(&mut self) -> std::io::Result<MarshalledMessage> {
232        self.recv_state.get_next_message(&self.stream)
233    }
234    pub fn finish_sending_next(&mut self) -> std::io::Result<u64> {
235        self.send_state.finish_sending_next(&self.stream)
236    }
237    pub fn write_next_message(
238        &mut self,
239        msg: &MarshalledMessage,
240    ) -> std::io::Result<(Option<u64>, Option<u32>)> {
241        self.serial += 1;
242        let mut idx;
243        loop {
244            self.serial += 1;
245            idx = self.serial;
246            if idx != 0 {
247                break;
248            }
249        }
250        self.send_state
251            .write_next_message(&self.stream, msg, NonZeroU32::new(idx).unwrap())
252            .map(|b| (b, Some(idx)))
253    }
254}
255impl AsRawFd for Conn {
256    fn as_raw_fd(&self) -> RawFd {
257        self.stream.as_raw_fd()
258    }
259}
260
261fn find_line_ending(buf: &[u8]) -> Option<usize> {
262    buf.windows(2).position(|w| w == DBUS_LINE_END)
263}
264
265/// start_with will read from the stream until it finds a line ending or reads 512 bytes.
266/// If it finds a line that starts with buf, it will return the rest of the line, other wise
267/// it will return `Ok(None)`.
268/// Note: starts_with assmes the writer will not send any data after the line ending.
269async fn starts_with<T: AsyncRead + AsyncWrite + Unpin>(
270    buf: &[u8],
271    stream: &mut T,
272) -> std::io::Result<Option<Vec<u8>>> {
273    debug_assert!(buf.len() <= 510);
274    let mut pos = 0;
275    let mut read_buf = [0; 512];
276    loop {
277        match find_line_ending(&read_buf[..pos]) {
278            Some(loc) => {
279                if buf.len() > loc {
280                    return Ok(None);
281                }
282                return if &read_buf[..buf.len()] == buf {
283                    Ok(Some(read_buf[buf.len()..loc].to_owned()))
284                } else {
285                    Ok(None)
286                };
287            }
288            None => {
289                if pos == 512 {
290                    // Line was too long.
291                    return Ok(None);
292                }
293                pos += stream.read(&mut read_buf[pos..]).await?;
294            }
295        }
296    }
297}
298async fn find_auth_mechs<T: AsyncRead + AsyncWrite + Unpin>(
299    stream: &mut T,
300) -> std::io::Result<HashSet<String>> {
301    stream.write_all(b"AUTH\r\n").await?;
302    let ret = starts_with(b"REJECTED", stream).await?;
303    match ret {
304        Some(s) if s.is_empty() => Ok(HashSet::new()),
305        Some(s) => {
306            let s = std::str::from_utf8(&s[..]).map_err(|_| {
307                std::io::Error::new(
308                    ErrorKind::PermissionDenied,
309                    "Invalid AUTH response from remote!",
310                )
311            })?;
312
313            Ok(s.split(' ').map(|s| s.to_owned()).collect())
314        }
315        None => Ok(HashSet::new()), // TODO: Should this be an error?
316    }
317}
318async fn await_ok<T: AsyncRead + AsyncWrite + Unpin>(stream: &mut T) -> std::io::Result<()> {
319    match starts_with(b"OK", stream).await? {
320        Some(_) => Ok(()),
321        None => Err(std::io::Error::new(
322            ErrorKind::PermissionDenied,
323            "External authentication failed with remote!",
324        )),
325    }
326}
327async fn do_external_auth<T: AsyncRead + AsyncWrite + Unpin>(
328    stream: &mut T,
329) -> std::io::Result<()> {
330    let mut to_write = Vec::from(&b"AUTH EXTERNAL "[..]);
331    let mut pid = unsafe { libc::geteuid() };
332    let mut order = 1;
333    loop {
334        let next = order * 10;
335        if pid / next == 0 {
336            break;
337        }
338        order = next;
339    }
340    while order > 0 {
341        to_write.push(b'3');
342        let digit = pid / order;
343        to_write.push(0x30 + digit as u8);
344        pid -= digit * order;
345        order /= 10;
346    }
347    to_write.extend_from_slice(DBUS_LINE_END);
348    stream.write_all(&to_write).await?;
349    await_ok(stream).await
350}
351async fn do_anon_auth<T: AsyncRead + AsyncWrite + Unpin>(stream: &mut T) -> std::io::Result<()> {
352    stream.write_all(b"AUTH ANONYMOUS\r\n").await?;
353    await_ok(stream).await
354}
355async fn do_auth<T: AsyncRead + AsyncWrite + Unpin>(stream: &mut T) -> std::io::Result<()> {
356    stream.write_all(b"\0").await?;
357    let auth_mechs = find_auth_mechs(stream).await?;
358    let mut err = None;
359    if auth_mechs.contains("EXTERNAL") {
360        match do_external_auth(stream).await {
361            Ok(_) => return Ok(()),
362            Err(e) => err = Some(e),
363        }
364    }
365    if auth_mechs.contains("ANONYMOUS") {
366        match do_anon_auth(stream).await {
367            Ok(_) => return Ok(()),
368            Err(e) => err = Some(e),
369        }
370    }
371    match err {
372        Some(err) => Err(err),
373        None => Err(std::io::Error::new(
374            ErrorKind::PermissionDenied,
375            "Remote doesn't support our auth methods!",
376        )),
377    }
378}
379async fn negotiate_unix_fds<T: AsyncRead + AsyncWrite + Unpin>(
380    stream: &mut T,
381) -> std::io::Result<bool> {
382    stream.write_all(b"NEGOTIATE_UNIX_FD\r\n").await?;
383    starts_with(b"AGREE_UNIX_FD", stream)
384        .await
385        .map(|o| o.is_some())
386}