async_rustbus/conn/
recv.rs

1use std::io::{ErrorKind, IoSliceMut};
2use std::mem;
3use std::net::Shutdown;
4use std::num::NonZeroU32;
5
6use crate::rustbus_core;
7use rustbus_core::message_builder::{DynamicHeader, MarshalledMessage, MarshalledMessageBody};
8use rustbus_core::wire::marshal::traits::SignatureBuffer;
9use rustbus_core::wire::util::align_offset;
10use rustbus_core::wire::{unmarshal, UnixFd};
11use unmarshal::traits::Unmarshal;
12use unmarshal::HEADER_LEN;
13
14use crate::utils::{align_num, parse_u32};
15
16use super::{AncillaryData, GenStream, SocketAncillary, DBUS_MAX_FD_MESSAGE};
17
18pub enum InState {
19    Header(Vec<u8>),
20    DynHdr(unmarshal::Header, Vec<u8>),
21    Finishing(unmarshal::Header, DynamicHeader, Vec<u8>),
22}
23
24impl Default for InState {
25    fn default() -> Self {
26        InState::Header(Vec::new())
27    }
28}
29
30impl InState {
31    fn into_buf(self) -> Vec<u8> {
32        let mut ret = match self {
33            InState::Header(b) | InState::DynHdr(_, b) | InState::Finishing(_, _, b) => b,
34        };
35        ret.clear();
36        ret
37    }
38    fn into_hdr(self) -> Self {
39        let buf = self.into_buf();
40        InState::Header(buf)
41    }
42    fn get_mut_buf(&mut self) -> &mut Vec<u8> {
43        match self {
44            InState::Header(b) | InState::DynHdr(_, b) | InState::Finishing(_, _, b) => b,
45        }
46    }
47    fn bytes_needed_for_next(&self) -> usize {
48        match self {
49            InState::Header(b) => HEADER_LEN + 4 - b.len(),
50            InState::DynHdr(hdr, b) => {
51                if b.len() < 16 {
52                    16 - b.len()
53                } else {
54                    let array_len = parse_u32(&b[12..16], hdr.byteorder) as usize;
55                    align_num(HEADER_LEN + 4 + array_len, 8) - b.len()
56                }
57            }
58            InState::Finishing(hdr, _, b) => hdr.body_len as usize - b.len(),
59        }
60    }
61}
62
63pub(crate) struct RecvState {
64    pub(super) in_state: InState,
65    pub(super) in_fds: Vec<UnixFd>,
66    pub(super) remaining: Vec<u8>,
67    pub(super) rem_loc: usize,
68    pub(super) with_fd: bool,
69}
70
71fn extend_max(vec: &mut Vec<u8>, buf: &[u8], loc: &mut usize, target: usize) -> bool {
72    if vec.len() >= target {
73        return true;
74    }
75    let buf = &buf[*loc..];
76    let needed = target - vec.len();
77    if needed > buf.len() {
78        vec.extend_from_slice(buf);
79        *loc += buf.len();
80        false
81    } else {
82        vec.extend_from_slice(&buf[..needed]);
83        *loc += needed;
84        true
85    }
86}
87impl RecvState {
88    fn try_get_msg(
89        &mut self,
90        stream: &GenStream,
91    ) -> std::io::Result<Option<(unmarshal::Header, DynamicHeader, Vec<u8>)>> {
92        let mut try_block = || {
93            match &mut self.in_state {
94                InState::Header(hdr_buf) => {
95                    use unmarshal::unmarshal_header;
96                    if !extend_max(hdr_buf, &self.remaining, &mut self.rem_loc, HEADER_LEN) {
97                        return Ok(None);
98                    }
99
100                    let (_, hdr) = unmarshal_header(&hdr_buf[..], 0).map_err(|_e| {
101                        eprintln!("{:?} ({:?}", _e, hdr_buf);
102                        std::io::Error::new(ErrorKind::Other, "Bad header!")
103                    })?;
104                    self.in_state = InState::DynHdr(hdr, mem::take(hdr_buf));
105                    self.try_get_msg(stream)
106                }
107                InState::DynHdr(hdr, dyn_buf) => {
108                    if !extend_max(dyn_buf, &self.remaining, &mut self.rem_loc, HEADER_LEN + 4) {
109                        return Ok(None);
110                    }
111
112                    // copy bytes for header
113                    let array_len =
114                        parse_u32(&dyn_buf[HEADER_LEN..HEADER_LEN + 4], hdr.byteorder) as usize;
115                    let total_hdr_len = align_num(HEADER_LEN + 4 + array_len, 8);
116                    if !extend_max(dyn_buf, &self.remaining, &mut self.rem_loc, total_hdr_len) {
117                        return Ok(None);
118                    }
119                    let mut ctx = unmarshal::UnmarshalContext {
120                        byteorder: hdr.byteorder,
121                        offset: HEADER_LEN,
122                        buf: &dyn_buf[..],
123                        fds: &[],
124                    };
125                    let (used, mut dynhdr) = DynamicHeader::unmarshal(&mut ctx).map_err(|e| {
126                        std::io::Error::new(ErrorKind::Other, format!("Bad header!: {:?}", e))
127                    })?;
128                    drop(ctx);
129                    let serial = NonZeroU32::new(hdr.serial)
130                        .ok_or_else(|| std::io::Error::new(ErrorKind::Other, "Serial was zero!"))?;
131                    dynhdr.serial = Some(serial);
132
133                    // DBus Spec says body is aligned to 8 bytes.
134                    align_offset(8, &dyn_buf[..], HEADER_LEN + used)
135                        .map_err(|_| std::io::Error::new(ErrorKind::Other, "Data in offset!"))?;
136
137                    // Validate dynhdr
138                    if dynhdr.num_fds.unwrap_or(0) > 0 && !self.with_fd {
139                        return Err(std::io::Error::new(ErrorKind::Other, "Bad header!"));
140                    }
141                    dyn_buf.clear();
142                    dyn_buf.reserve(hdr.body_len as usize);
143                    self.in_state = InState::Finishing(*hdr, dynhdr, mem::take(dyn_buf));
144                    self.try_get_msg(stream)
145                }
146                InState::Finishing(hdr, dynhdr, body_buf) => {
147                    if !extend_max(
148                        body_buf,
149                        &self.remaining,
150                        &mut self.rem_loc,
151                        hdr.body_len as usize,
152                    ) {
153                        return Ok(None);
154                    }
155                    let hdr = *hdr;
156                    let dynhdr = mem::take(dynhdr);
157                    let body = mem::take(body_buf);
158                    self.in_state = InState::default();
159                    Ok(Some((hdr, dynhdr, body)))
160                }
161            }
162        };
163        match try_block() {
164            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None), // TODO: is this reachable?
165            Err(e) => {
166                self.in_fds.clear();
167                self.in_state = mem::take(&mut self.in_state).into_hdr();
168                // Parsing errors mean that we need to close the stream
169                stream.shutdown(Shutdown::Both).ok();
170                Err(e)
171            }
172            els => els,
173        }
174    }
175
176    pub(crate) fn get_next_message(
177        &mut self,
178        stream: &GenStream,
179    ) -> std::io::Result<MarshalledMessage> {
180        let res = self.try_get_msg(stream);
181        if let Some((hdr, dynhdr, body)) = res? {
182            let msg = mm_from_raw(hdr, dynhdr, body, Vec::new());
183            match msg.body.validate() {
184                Ok(_) => return Ok(msg),
185                Err(e) => {
186                    stream.shutdown(Shutdown::Both).ok();
187                    return Err(std::io::Error::new(
188                        ErrorKind::Other,
189                        format!("Bad message body!: {:?}", e),
190                    ));
191                }
192            }
193        }
194        let mut anc_buf = [0; 256];
195        loop {
196            debug_assert_eq!(self.remaining.len(), self.rem_loc);
197            debug_assert!(self.remaining.capacity() >= 4096);
198            self.remaining.clear();
199            self.rem_loc = 0;
200            let needed = self.in_state.bytes_needed_for_next();
201            // Read the stream directly into the in_state buffer
202            debug_assert!(needed > 0);
203            let vec = self.in_state.get_mut_buf();
204            // SAFETY: uninit_buf is never read from
205            let uninit_buf = unsafe { vec_uninit_slice(vec, Some(needed)) };
206            let uninit_len = uninit_buf.len();
207
208            debug_assert!(self.remaining.is_empty());
209            let mut rem: Vec<u8> = mem::take(&mut self.remaining);
210            let mut bufs = [IoSliceMut::new(uninit_buf), IoSliceMut::new(&mut [])];
211            let (bufs, mut anc) = if self.with_fd {
212                (&mut bufs[..1], SocketAncillary::new(&mut anc_buf[..]))
213            } else {
214                // SAFETY: rem_buf is never read from
215                let rem_buf = unsafe { vec_uninit_slice(&mut rem, None) };
216                bufs[1] = IoSliceMut::new(rem_buf);
217                (&mut bufs[..], SocketAncillary::new(&mut []))
218            };
219            let res = stream.recv_vectored_with_ancillary(bufs, &mut anc);
220            let gotten = match &res {
221                Ok(0) | Err(_) => {
222                    self.remaining = rem;
223                    res?; // return if err otherwise return Hungup in case of Ok(0)
224                    return Err(std::io::Error::new(
225                        ErrorKind::BrokenPipe,
226                        "DBus daemon hung up!",
227                    ));
228                }
229                Ok(i) => *i,
230            };
231            unsafe {
232                if gotten > uninit_len {
233                    vec.set_len(vec.len() + uninit_len);
234                    rem.set_len(gotten - uninit_len);
235                } else {
236                    vec.set_len(vec.len() + gotten);
237                }
238            }
239            self.remaining = rem;
240            if self.with_fd {
241                let anc_fds_iter =
242                    anc.messages()
243                        .flat_map(|res| match res.expect("Anc Data should be valid.") {
244                            AncillaryData::ScmRights(rights) => rights.map(UnixFd::new),
245                        });
246                self.in_fds.extend(anc_fds_iter);
247                if self.in_fds.len() > DBUS_MAX_FD_MESSAGE {
248                    // We received too many fds
249                    self.in_state = mem::take(&mut self.in_state).into_hdr();
250                    self.in_fds.clear();
251                    //TODO: Find better error
252                    return Err(std::io::Error::new(
253                        ErrorKind::Other,
254                        "Too many unix fds received!",
255                    ));
256                }
257            }
258            let res = self.try_get_msg(stream);
259            if let Some((hdr, dynhdr, body)) = res? {
260                if self.in_fds.len() != dynhdr.num_fds.unwrap_or(0) as usize {
261                    self.in_fds.clear();
262                    return Err(std::io::Error::new(
263                        ErrorKind::Other,
264                        "Unepexted number of fds received!",
265                    ));
266                }
267                let msg = mm_from_raw(hdr, dynhdr, body, mem::take(&mut self.in_fds));
268                match msg.body.validate() {
269                    Ok(_) => return Ok(msg),
270                    Err(e) => {
271                        stream.shutdown(Shutdown::Both).ok();
272                        return Err(std::io::Error::new(
273                            ErrorKind::Other,
274                            format!("Bad message body!: {:?}", e),
275                        ));
276                    }
277                }
278            }
279        }
280    }
281    #[allow(dead_code)]
282    pub fn pos_next_msg(&self) -> bool {
283        let needed = self.in_state.bytes_needed_for_next();
284        self.remaining.len() >= needed
285    }
286}
287
288/// Get a slice to the uninitialized portion of an `Vec<u8>`
289///
290/// `wanted` determines how long the slice should be. If it is `None` then the
291/// slice will point to the remaining capacity.
292// SAFETY: The slice returned by this function must never be read from
293unsafe fn vec_uninit_slice(vec: &mut Vec<u8>, wanted: Option<usize>) -> &mut [u8] {
294    let target = match wanted {
295        Some(wanted) => {
296            vec.reserve(wanted);
297            wanted
298        }
299        None => vec.capacity() - vec.len(),
300    };
301    let rem_buf = vec.as_mut_ptr().add(vec.len());
302    std::slice::from_raw_parts_mut(rem_buf, target)
303}
304fn mm_from_raw(
305    hdr: unmarshal::Header,
306    dynhdr: DynamicHeader,
307    body: Vec<u8>,
308    fds: Vec<UnixFd>,
309) -> MarshalledMessage {
310    let sig = dynhdr.signature.as_deref().unwrap_or("");
311    let sig = SignatureBuffer::from_string(sig.to_string());
312    MarshalledMessage {
313        typ: hdr.typ,
314        flags: hdr.flags,
315        body: MarshalledMessageBody::from_parts(body, fds, sig, hdr.byteorder),
316        dynheader: dynhdr,
317    }
318}