1use std::collections::{HashSet, VecDeque};
7use std::io::{IoSlice, IoSliceMut};
8use std::mem;
9use std::net::Shutdown;
10use std::num::NonZeroU32;
11use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
12use std::os::unix::net::UnixStream as StdUnixStream;
13
14use tokio::io::*;
15use tokio::io::{AsyncRead, AsyncWrite};
16
17use std::io::ErrorKind;
18use std::path::Path;
19use tokio::net::TcpStream;
20use tokio::net::ToSocketAddrs;
21use tokio::net::UnixStream;
22
23use super::rustbus_core;
24
25use rustbus_core::message_builder::MarshalledMessage;
26
27mod ancillary;
28
29mod addr;
30pub use addr::{get_session_bus_addr, get_system_bus_addr, DBusAddr, DBUS_SESS_ENV, DBUS_SYS_PATH};
31mod recv;
32use recv::InState;
33pub(crate) use recv::RecvState;
34
35mod sender;
36pub(crate) use sender::SendState;
37
38use ancillary::{
39 recv_vectored_with_ancillary, send_vectored_with_ancillary, AncillaryData, SocketAncillary,
40};
41
42const DBUS_LINE_END_STR: &str = "\r\n";
43const DBUS_LINE_END: &[u8] = DBUS_LINE_END_STR.as_bytes();
44const DBUS_MAX_FD_MESSAGE: usize = 32;
45
46pub(crate) struct GenStream {
50 fd: RawFd,
51}
52
53impl AsRawFd for GenStream {
54 fn as_raw_fd(&self) -> RawFd {
55 self.fd
56 }
57}
58impl FromRawFd for GenStream {
59 unsafe fn from_raw_fd(fd: RawFd) -> Self {
60 Self { fd }
61 }
62}
63
64impl GenStream {
65 fn recv_vectored_with_ancillary(
66 &self,
67 bufs: &mut [IoSliceMut<'_>],
68 ancillary: &mut SocketAncillary<'_>,
69 ) -> std::io::Result<usize> {
70 recv_vectored_with_ancillary(self.as_raw_fd(), bufs, ancillary)
71 }
72 fn send_vectored_with_ancillary(
73 &self,
74 bufs: &[IoSlice<'_>],
75 ancillary: &mut SocketAncillary<'_>,
76 ) -> std::io::Result<usize> {
77 send_vectored_with_ancillary(self.as_raw_fd(), bufs, ancillary)
78 }
79 fn shutdown(&self, how: Shutdown) -> std::io::Result<()> {
80 let how = match how {
81 Shutdown::Read => libc::SHUT_RD,
82 Shutdown::Write => libc::SHUT_WR,
83 Shutdown::Both => libc::SHUT_RDWR,
84 };
85 unsafe {
86 if libc::shutdown(self.as_raw_fd(), how) == -1 {
87 Err(std::io::Error::last_os_error())
88 } else {
89 Ok(())
90 }
91 }
92 }
93}
94impl Drop for GenStream {
95 fn drop(&mut self) {
96 unsafe {
97 libc::close(self.fd);
98 }
99 }
100}
101pub struct Conn {
109 pub(super) stream: GenStream,
110 pub(super) recv_state: RecvState,
111 pub(super) send_state: SendState,
112 serial: u32,
113}
114fn fd_or_os_err(fd: i32) -> std::io::Result<i32> {
115 if fd == -1 {
116 Err(std::io::Error::last_os_error())
117 } else {
118 Ok(fd)
119 }
120}
121trait IntoRawFd {
124 fn into_raw_fd(self) -> std::io::Result<RawFd>;
125}
126impl<T: AsRawFd> IntoRawFd for T {
127 fn into_raw_fd(self) -> std::io::Result<RawFd> {
128 let fd = self.as_raw_fd();
129 unsafe { fd_or_os_err(libc::dup(fd)) }
130 }
131}
132impl Conn {
133 async fn conn_handshake<T>(mut stream: T, with_fd: bool) -> std::io::Result<Self>
134 where
135 T: AsyncRead + AsyncWrite + Unpin + IntoRawFd,
136 {
137 do_auth(&mut stream).await?;
138 if with_fd && !negotiate_unix_fds(&mut stream).await? {
139 return Err(std::io::Error::new(
140 ErrorKind::ConnectionAborted,
141 "Failed to negotiate Unix FDs!",
142 ));
143 }
144 stream.write_all(b"BEGIN\r\n").await?;
145 let stream = unsafe {
148 let fd = stream.into_raw_fd()?;
149 GenStream::from_raw_fd(fd)
150 };
151 Ok(Self {
152 recv_state: RecvState {
153 in_state: InState::Header(Vec::new()),
154 in_fds: Vec::new(),
155 with_fd,
156 remaining: Vec::with_capacity(4096),
157 rem_loc: 0,
158 },
159 send_state: SendState {
160 with_fd,
161 idx: 0,
162 queue: VecDeque::new(),
163 },
164 stream,
165 serial: 0,
166 })
167 }
168 pub async fn connect_to_addr<P: AsRef<Path>, S: ToSocketAddrs, B: AsRef<[u8]>>(
169 addr: &DBusAddr<P, S, B>,
170 with_fd: bool,
171 ) -> std::io::Result<Self> {
172 match addr {
173 DBusAddr::Path(p) => Self::conn_handshake(UnixStream::connect(p).await?, with_fd).await,
174 DBusAddr::Tcp(s) => {
175 if with_fd {
176 Err(std::io::Error::new(
177 ErrorKind::InvalidInput,
178 "Cannot use Fds over TCP.",
179 ))
180 } else {
181 Self::conn_handshake(TcpStream::connect(s).await?, with_fd).await
182 }
183 }
184 #[cfg(target_os = "linux")]
185 DBusAddr::Abstract(buf) => unsafe {
186 let buf = buf.as_ref();
187 let mut addr: libc::sockaddr_un = mem::zeroed();
188 addr.sun_family = libc::AF_UNIX as u16;
189 #[cfg(not(target_arch = "arm"))]
191 let c_buf = &*(buf as *const [u8] as *const [i8]);
192
193 #[cfg(target_arch = "arm")]
195 let c_buf = &buf[..];
196 addr.sun_path
197 .get_mut(1..1 + buf.len())
198 .ok_or_else(|| {
199 std::io::Error::new(
200 ErrorKind::InvalidData,
201 "Abstract unix socket address was too long!",
202 )
203 })?
204 .copy_from_slice(c_buf);
205 let fd = fd_or_os_err(libc::socket(libc::AF_UNIX, libc::SOCK_STREAM, 0))?;
207 if let Err(e) = fd_or_os_err(libc::connect(
208 fd,
209 &addr as *const libc::sockaddr_un as *const libc::sockaddr,
210 (mem::size_of_val(&addr) - (108 - buf.len() - 1)) as u32,
211 )) {
212 libc::close(fd);
213 return Err(e);
214 }
215 let stream = StdUnixStream::from_raw_fd(fd);
216 let stream = UnixStream::from_std(stream)?;
217 Self::conn_handshake(stream, with_fd).await
218 },
219 }
220 }
221 async fn connect_to_path_byteorder<P: AsRef<Path>>(
222 p: P,
223 with_fd: bool,
224 ) -> std::io::Result<Self> {
225 let addr = DBusAddr::unix_path(p);
226 Self::connect_to_addr(&addr, with_fd).await
227 }
228 pub async fn connect_to_path<P: AsRef<Path>>(p: P, with_fd: bool) -> std::io::Result<Self> {
229 Self::connect_to_path_byteorder(p, with_fd).await
230 }
231 pub fn get_next_message(&mut self) -> std::io::Result<MarshalledMessage> {
232 self.recv_state.get_next_message(&self.stream)
233 }
234 pub fn finish_sending_next(&mut self) -> std::io::Result<u64> {
235 self.send_state.finish_sending_next(&self.stream)
236 }
237 pub fn write_next_message(
238 &mut self,
239 msg: &MarshalledMessage,
240 ) -> std::io::Result<(Option<u64>, Option<u32>)> {
241 self.serial += 1;
242 let mut idx;
243 loop {
244 self.serial += 1;
245 idx = self.serial;
246 if idx != 0 {
247 break;
248 }
249 }
250 self.send_state
251 .write_next_message(&self.stream, msg, NonZeroU32::new(idx).unwrap())
252 .map(|b| (b, Some(idx)))
253 }
254}
255impl AsRawFd for Conn {
256 fn as_raw_fd(&self) -> RawFd {
257 self.stream.as_raw_fd()
258 }
259}
260
261fn find_line_ending(buf: &[u8]) -> Option<usize> {
262 buf.windows(2).position(|w| w == DBUS_LINE_END)
263}
264
265async fn starts_with<T: AsyncRead + AsyncWrite + Unpin>(
270 buf: &[u8],
271 stream: &mut T,
272) -> std::io::Result<Option<Vec<u8>>> {
273 debug_assert!(buf.len() <= 510);
274 let mut pos = 0;
275 let mut read_buf = [0; 512];
276 loop {
277 match find_line_ending(&read_buf[..pos]) {
278 Some(loc) => {
279 if buf.len() > loc {
280 return Ok(None);
281 }
282 return if &read_buf[..buf.len()] == buf {
283 Ok(Some(read_buf[buf.len()..loc].to_owned()))
284 } else {
285 Ok(None)
286 };
287 }
288 None => {
289 if pos == 512 {
290 return Ok(None);
292 }
293 pos += stream.read(&mut read_buf[pos..]).await?;
294 }
295 }
296 }
297}
298async fn find_auth_mechs<T: AsyncRead + AsyncWrite + Unpin>(
299 stream: &mut T,
300) -> std::io::Result<HashSet<String>> {
301 stream.write_all(b"AUTH\r\n").await?;
302 let ret = starts_with(b"REJECTED", stream).await?;
303 match ret {
304 Some(s) if s.is_empty() => Ok(HashSet::new()),
305 Some(s) => {
306 let s = std::str::from_utf8(&s[..]).map_err(|_| {
307 std::io::Error::new(
308 ErrorKind::PermissionDenied,
309 "Invalid AUTH response from remote!",
310 )
311 })?;
312
313 Ok(s.split(' ').map(|s| s.to_owned()).collect())
314 }
315 None => Ok(HashSet::new()), }
317}
318async fn await_ok<T: AsyncRead + AsyncWrite + Unpin>(stream: &mut T) -> std::io::Result<()> {
319 match starts_with(b"OK", stream).await? {
320 Some(_) => Ok(()),
321 None => Err(std::io::Error::new(
322 ErrorKind::PermissionDenied,
323 "External authentication failed with remote!",
324 )),
325 }
326}
327async fn do_external_auth<T: AsyncRead + AsyncWrite + Unpin>(
328 stream: &mut T,
329) -> std::io::Result<()> {
330 let mut to_write = Vec::from(&b"AUTH EXTERNAL "[..]);
331 let mut pid = unsafe { libc::geteuid() };
332 let mut order = 1;
333 loop {
334 let next = order * 10;
335 if pid / next == 0 {
336 break;
337 }
338 order = next;
339 }
340 while order > 0 {
341 to_write.push(b'3');
342 let digit = pid / order;
343 to_write.push(0x30 + digit as u8);
344 pid -= digit * order;
345 order /= 10;
346 }
347 to_write.extend_from_slice(DBUS_LINE_END);
348 stream.write_all(&to_write).await?;
349 await_ok(stream).await
350}
351async fn do_anon_auth<T: AsyncRead + AsyncWrite + Unpin>(stream: &mut T) -> std::io::Result<()> {
352 stream.write_all(b"AUTH ANONYMOUS\r\n").await?;
353 await_ok(stream).await
354}
355async fn do_auth<T: AsyncRead + AsyncWrite + Unpin>(stream: &mut T) -> std::io::Result<()> {
356 stream.write_all(b"\0").await?;
357 let auth_mechs = find_auth_mechs(stream).await?;
358 let mut err = None;
359 if auth_mechs.contains("EXTERNAL") {
360 match do_external_auth(stream).await {
361 Ok(_) => return Ok(()),
362 Err(e) => err = Some(e),
363 }
364 }
365 if auth_mechs.contains("ANONYMOUS") {
366 match do_anon_auth(stream).await {
367 Ok(_) => return Ok(()),
368 Err(e) => err = Some(e),
369 }
370 }
371 match err {
372 Some(err) => Err(err),
373 None => Err(std::io::Error::new(
374 ErrorKind::PermissionDenied,
375 "Remote doesn't support our auth methods!",
376 )),
377 }
378}
379async fn negotiate_unix_fds<T: AsyncRead + AsyncWrite + Unpin>(
380 stream: &mut T,
381) -> std::io::Result<bool> {
382 stream.write_all(b"NEGOTIATE_UNIX_FD\r\n").await?;
383 starts_with(b"AGREE_UNIX_FD", stream)
384 .await
385 .map(|o| o.is_some())
386}