1use std::io::{ErrorKind, IoSliceMut};
2use std::mem;
3use std::net::Shutdown;
4use std::num::NonZeroU32;
5
6use crate::rustbus_core;
7use rustbus_core::message_builder::{DynamicHeader, MarshalledMessage, MarshalledMessageBody};
8use rustbus_core::wire::marshal::traits::SignatureBuffer;
9use rustbus_core::wire::util::align_offset;
10use rustbus_core::wire::{unmarshal, UnixFd};
11use unmarshal::traits::Unmarshal;
12use unmarshal::HEADER_LEN;
13
14use crate::utils::{align_num, parse_u32};
15
16use super::{AncillaryData, GenStream, SocketAncillary, DBUS_MAX_FD_MESSAGE};
17
18pub enum InState {
19 Header(Vec<u8>),
20 DynHdr(unmarshal::Header, Vec<u8>),
21 Finishing(unmarshal::Header, DynamicHeader, Vec<u8>),
22}
23
24impl Default for InState {
25 fn default() -> Self {
26 InState::Header(Vec::new())
27 }
28}
29
30impl InState {
31 fn into_buf(self) -> Vec<u8> {
32 let mut ret = match self {
33 InState::Header(b) | InState::DynHdr(_, b) | InState::Finishing(_, _, b) => b,
34 };
35 ret.clear();
36 ret
37 }
38 fn into_hdr(self) -> Self {
39 let buf = self.into_buf();
40 InState::Header(buf)
41 }
42 fn get_mut_buf(&mut self) -> &mut Vec<u8> {
43 match self {
44 InState::Header(b) | InState::DynHdr(_, b) | InState::Finishing(_, _, b) => b,
45 }
46 }
47 fn bytes_needed_for_next(&self) -> usize {
48 match self {
49 InState::Header(b) => HEADER_LEN + 4 - b.len(),
50 InState::DynHdr(hdr, b) => {
51 if b.len() < 16 {
52 16 - b.len()
53 } else {
54 let array_len = parse_u32(&b[12..16], hdr.byteorder) as usize;
55 align_num(HEADER_LEN + 4 + array_len, 8) - b.len()
56 }
57 }
58 InState::Finishing(hdr, _, b) => hdr.body_len as usize - b.len(),
59 }
60 }
61}
62
63pub(crate) struct RecvState {
64 pub(super) in_state: InState,
65 pub(super) in_fds: Vec<UnixFd>,
66 pub(super) remaining: Vec<u8>,
67 pub(super) rem_loc: usize,
68 pub(super) with_fd: bool,
69}
70
71fn extend_max(vec: &mut Vec<u8>, buf: &[u8], loc: &mut usize, target: usize) -> bool {
72 if vec.len() >= target {
73 return true;
74 }
75 let buf = &buf[*loc..];
76 let needed = target - vec.len();
77 if needed > buf.len() {
78 vec.extend_from_slice(buf);
79 *loc += buf.len();
80 false
81 } else {
82 vec.extend_from_slice(&buf[..needed]);
83 *loc += needed;
84 true
85 }
86}
87impl RecvState {
88 fn try_get_msg(
89 &mut self,
90 stream: &GenStream,
91 ) -> std::io::Result<Option<(unmarshal::Header, DynamicHeader, Vec<u8>)>> {
92 let mut try_block = || {
93 match &mut self.in_state {
94 InState::Header(hdr_buf) => {
95 use unmarshal::unmarshal_header;
96 if !extend_max(hdr_buf, &self.remaining, &mut self.rem_loc, HEADER_LEN) {
97 return Ok(None);
98 }
99
100 let (_, hdr) = unmarshal_header(&hdr_buf[..], 0).map_err(|_e| {
101 eprintln!("{:?} ({:?}", _e, hdr_buf);
102 std::io::Error::new(ErrorKind::Other, "Bad header!")
103 })?;
104 self.in_state = InState::DynHdr(hdr, mem::take(hdr_buf));
105 self.try_get_msg(stream)
106 }
107 InState::DynHdr(hdr, dyn_buf) => {
108 if !extend_max(dyn_buf, &self.remaining, &mut self.rem_loc, HEADER_LEN + 4) {
109 return Ok(None);
110 }
111
112 let array_len =
114 parse_u32(&dyn_buf[HEADER_LEN..HEADER_LEN + 4], hdr.byteorder) as usize;
115 let total_hdr_len = align_num(HEADER_LEN + 4 + array_len, 8);
116 if !extend_max(dyn_buf, &self.remaining, &mut self.rem_loc, total_hdr_len) {
117 return Ok(None);
118 }
119 let mut ctx = unmarshal::UnmarshalContext {
120 byteorder: hdr.byteorder,
121 offset: HEADER_LEN,
122 buf: &dyn_buf[..],
123 fds: &[],
124 };
125 let (used, mut dynhdr) = DynamicHeader::unmarshal(&mut ctx).map_err(|e| {
126 std::io::Error::new(ErrorKind::Other, format!("Bad header!: {:?}", e))
127 })?;
128 drop(ctx);
129 let serial = NonZeroU32::new(hdr.serial)
130 .ok_or_else(|| std::io::Error::new(ErrorKind::Other, "Serial was zero!"))?;
131 dynhdr.serial = Some(serial);
132
133 align_offset(8, &dyn_buf[..], HEADER_LEN + used)
135 .map_err(|_| std::io::Error::new(ErrorKind::Other, "Data in offset!"))?;
136
137 if dynhdr.num_fds.unwrap_or(0) > 0 && !self.with_fd {
139 return Err(std::io::Error::new(ErrorKind::Other, "Bad header!"));
140 }
141 dyn_buf.clear();
142 dyn_buf.reserve(hdr.body_len as usize);
143 self.in_state = InState::Finishing(*hdr, dynhdr, mem::take(dyn_buf));
144 self.try_get_msg(stream)
145 }
146 InState::Finishing(hdr, dynhdr, body_buf) => {
147 if !extend_max(
148 body_buf,
149 &self.remaining,
150 &mut self.rem_loc,
151 hdr.body_len as usize,
152 ) {
153 return Ok(None);
154 }
155 let hdr = *hdr;
156 let dynhdr = mem::take(dynhdr);
157 let body = mem::take(body_buf);
158 self.in_state = InState::default();
159 Ok(Some((hdr, dynhdr, body)))
160 }
161 }
162 };
163 match try_block() {
164 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None), Err(e) => {
166 self.in_fds.clear();
167 self.in_state = mem::take(&mut self.in_state).into_hdr();
168 stream.shutdown(Shutdown::Both).ok();
170 Err(e)
171 }
172 els => els,
173 }
174 }
175
176 pub(crate) fn get_next_message(
177 &mut self,
178 stream: &GenStream,
179 ) -> std::io::Result<MarshalledMessage> {
180 let res = self.try_get_msg(stream);
181 if let Some((hdr, dynhdr, body)) = res? {
182 let msg = mm_from_raw(hdr, dynhdr, body, Vec::new());
183 match msg.body.validate() {
184 Ok(_) => return Ok(msg),
185 Err(e) => {
186 stream.shutdown(Shutdown::Both).ok();
187 return Err(std::io::Error::new(
188 ErrorKind::Other,
189 format!("Bad message body!: {:?}", e),
190 ));
191 }
192 }
193 }
194 let mut anc_buf = [0; 256];
195 loop {
196 debug_assert_eq!(self.remaining.len(), self.rem_loc);
197 debug_assert!(self.remaining.capacity() >= 4096);
198 self.remaining.clear();
199 self.rem_loc = 0;
200 let needed = self.in_state.bytes_needed_for_next();
201 debug_assert!(needed > 0);
203 let vec = self.in_state.get_mut_buf();
204 let uninit_buf = unsafe { vec_uninit_slice(vec, Some(needed)) };
206 let uninit_len = uninit_buf.len();
207
208 debug_assert!(self.remaining.is_empty());
209 let mut rem: Vec<u8> = mem::take(&mut self.remaining);
210 let mut bufs = [IoSliceMut::new(uninit_buf), IoSliceMut::new(&mut [])];
211 let (bufs, mut anc) = if self.with_fd {
212 (&mut bufs[..1], SocketAncillary::new(&mut anc_buf[..]))
213 } else {
214 let rem_buf = unsafe { vec_uninit_slice(&mut rem, None) };
216 bufs[1] = IoSliceMut::new(rem_buf);
217 (&mut bufs[..], SocketAncillary::new(&mut []))
218 };
219 let res = stream.recv_vectored_with_ancillary(bufs, &mut anc);
220 let gotten = match &res {
221 Ok(0) | Err(_) => {
222 self.remaining = rem;
223 res?; return Err(std::io::Error::new(
225 ErrorKind::BrokenPipe,
226 "DBus daemon hung up!",
227 ));
228 }
229 Ok(i) => *i,
230 };
231 unsafe {
232 if gotten > uninit_len {
233 vec.set_len(vec.len() + uninit_len);
234 rem.set_len(gotten - uninit_len);
235 } else {
236 vec.set_len(vec.len() + gotten);
237 }
238 }
239 self.remaining = rem;
240 if self.with_fd {
241 let anc_fds_iter =
242 anc.messages()
243 .flat_map(|res| match res.expect("Anc Data should be valid.") {
244 AncillaryData::ScmRights(rights) => rights.map(UnixFd::new),
245 });
246 self.in_fds.extend(anc_fds_iter);
247 if self.in_fds.len() > DBUS_MAX_FD_MESSAGE {
248 self.in_state = mem::take(&mut self.in_state).into_hdr();
250 self.in_fds.clear();
251 return Err(std::io::Error::new(
253 ErrorKind::Other,
254 "Too many unix fds received!",
255 ));
256 }
257 }
258 let res = self.try_get_msg(stream);
259 if let Some((hdr, dynhdr, body)) = res? {
260 if self.in_fds.len() != dynhdr.num_fds.unwrap_or(0) as usize {
261 self.in_fds.clear();
262 return Err(std::io::Error::new(
263 ErrorKind::Other,
264 "Unepexted number of fds received!",
265 ));
266 }
267 let msg = mm_from_raw(hdr, dynhdr, body, mem::take(&mut self.in_fds));
268 match msg.body.validate() {
269 Ok(_) => return Ok(msg),
270 Err(e) => {
271 stream.shutdown(Shutdown::Both).ok();
272 return Err(std::io::Error::new(
273 ErrorKind::Other,
274 format!("Bad message body!: {:?}", e),
275 ));
276 }
277 }
278 }
279 }
280 }
281 #[allow(dead_code)]
282 pub fn pos_next_msg(&self) -> bool {
283 let needed = self.in_state.bytes_needed_for_next();
284 self.remaining.len() >= needed
285 }
286}
287
288unsafe fn vec_uninit_slice(vec: &mut Vec<u8>, wanted: Option<usize>) -> &mut [u8] {
294 let target = match wanted {
295 Some(wanted) => {
296 vec.reserve(wanted);
297 wanted
298 }
299 None => vec.capacity() - vec.len(),
300 };
301 let rem_buf = vec.as_mut_ptr().add(vec.len());
302 std::slice::from_raw_parts_mut(rem_buf, target)
303}
304fn mm_from_raw(
305 hdr: unmarshal::Header,
306 dynhdr: DynamicHeader,
307 body: Vec<u8>,
308 fds: Vec<UnixFd>,
309) -> MarshalledMessage {
310 let sig = dynhdr.signature.as_deref().unwrap_or("");
311 let sig = SignatureBuffer::from_string(sig.to_string());
312 MarshalledMessage {
313 typ: hdr.typ,
314 flags: hdr.flags,
315 body: MarshalledMessageBody::from_parts(body, fds, sig, hdr.byteorder),
316 dynheader: dynhdr,
317 }
318}