rdbg_client/
lib.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::io::Read;
3use std::mem::size_of;
4use std::net::{AddrParseError, IpAddr, SocketAddr, TcpStream};
5use std::str::{FromStr, Utf8Error};
6use std::time::Duration;
7use std::{io, thread};
8
9/// Default IP to connect to on the debugged program
10pub const DEFAULT_ADDR: &str = "127.0.0.1";
11/// Default port to connect to on the debugged program
12pub const DEFAULT_PORT: u16 = 13579;
13
14const CONNECT_WAIT_TIME: u64 = 250; // Milliseconds
15const BUFFER_SIZE: usize = 4096;
16const LEN_FIELD_SIZE: usize = size_of::<u32>();
17const WIRE_PROTOCOL_VERSION: u8 = 1;
18
19// *** MsgPayloadVal ***
20
21#[repr(u8)]
22enum MsgPayloadVal {
23    Message = 1,
24    Values = 2,
25}
26
27impl MsgPayloadVal {
28    #[inline]
29    fn from_buffer(buffer: &mut ByteBuffer) -> Result<MsgPayloadVal, Error> {
30        buffer.read_u8()?.try_into()
31    }
32}
33
34impl TryFrom<u8> for MsgPayloadVal {
35    type Error = Error;
36
37    fn try_from(value: u8) -> Result<Self, Self::Error> {
38        match value {
39            1 => Ok(MsgPayloadVal::Message),
40            2 => Ok(MsgPayloadVal::Values),
41            _ => Err(Error::CorruptMsg),
42        }
43    }
44}
45
46// *** ByteBuffer ***
47
48struct ByteBuffer {
49    buffer: Vec<u8>,
50    idx: usize,
51}
52
53impl ByteBuffer {
54    #[inline]
55    fn new(capacity: usize) -> Self {
56        Self::from_vec(Vec::with_capacity(capacity))
57    }
58
59    #[inline]
60    fn from_vec(buffer: Vec<u8>) -> Self {
61        Self { buffer, idx: 0 }
62    }
63
64    fn read_from_stream(&mut self, stream: &mut TcpStream, size: usize) -> io::Result<()> {
65        self.buffer.resize(size, 0);
66        stream.read_exact(&mut self.buffer)?;
67        // We start over every time we read
68        self.idx = 0;
69        Ok(())
70    }
71
72    fn as_slice(&mut self, len: usize) -> Result<&[u8], Error> {
73        if self.idx + len <= self.buffer.len() {
74            self.idx += len;
75            Ok(&self.buffer[(self.idx - len)..self.idx])
76        } else {
77            Err(Error::CorruptMsg)
78        }
79    }
80
81    fn read_u8(&mut self) -> Result<u8, Error> {
82        Ok(u8::from_be_bytes(
83            self.as_slice(size_of::<u8>())?.try_into().unwrap(),
84        ))
85    }
86
87    fn read_u64(&mut self) -> Result<u64, Error> {
88        Ok(u64::from_be_bytes(
89            self.as_slice(size_of::<u64>())?.try_into().unwrap(),
90        ))
91    }
92
93    fn read_u32(&mut self) -> Result<u32, Error> {
94        Ok(u32::from_be_bytes(
95            self.as_slice(size_of::<u32>())?.try_into().unwrap(),
96        ))
97    }
98
99    fn read_str(&mut self) -> Result<String, Error> {
100        let len = self.read_u32()?;
101
102        match std::str::from_utf8(self.as_slice(len as usize)?) {
103            Ok(str) => Ok(str.to_string()),
104            Err(err) => Err(Error::BadUtf8(err)),
105        }
106    }
107}
108
109// *** MsgPayload ***
110
111/// The payload as sent by the remote program - this can either be a string message or a list
112/// of expressions and their values
113#[derive(Clone, Debug, Eq, PartialEq)]
114pub enum MsgPayload {
115    /// A formatted string
116    Message(String),
117    /// A list of name/value pairs from expressions
118    Values(Vec<(String, String)>),
119}
120
121impl MsgPayload {
122    fn from_buffer(buffer: &mut ByteBuffer) -> Result<Self, Error> {
123        match MsgPayloadVal::from_buffer(buffer)? {
124            MsgPayloadVal::Message => {
125                let s = buffer.read_str()?;
126                Ok(MsgPayload::Message(s))
127            }
128            MsgPayloadVal::Values => {
129                let len = buffer.read_u32()?;
130                // TODO: Do we need to protect against VERY large values here? We will still check
131                // bounds but not before a LOT of memory could be allocated
132                let mut values = Vec::with_capacity(len as usize);
133
134                for _ in 0..len {
135                    let name = buffer.read_str()?;
136                    let val = buffer.read_str()?;
137                    values.push((name, val));
138                }
139
140                Ok(MsgPayload::Values(values))
141            }
142        }
143    }
144}
145
146// *** Message ***
147
148/// The primary structure. Represents all the fields of debug information as received from the
149/// debugged program
150#[derive(Clone, Debug, Eq, PartialEq)]
151pub struct Message {
152    /// Milliseconds since epoch at the exact moment the debug message was triggered in the remote program
153    pub time: u64,
154    /// The thread ID that invoked the message in the remote program
155    pub thread_id: String,
156    /// The filename that invoked the message in the remote program
157    pub filename: String,
158    /// The line number at which the message was invoked in the remote program
159    pub line: u32,
160    /// The message OR expression values sent from the remote program
161    pub payload: MsgPayload,
162}
163
164impl Message {
165    fn from_buffer(buffer: &mut ByteBuffer) -> Result<Message, Error> {
166        let time = buffer.read_u64()?;
167        let thread_id = buffer.read_str()?;
168        let filename = buffer.read_str()?;
169        let line = buffer.read_u32()?;
170        let payload = MsgPayload::from_buffer(buffer)?;
171
172        Ok(Self {
173            time,
174            thread_id,
175            filename,
176            line,
177            payload,
178        })
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use std::thread;
185
186    use crate::{ByteBuffer, LEN_FIELD_SIZE};
187
188    #[test]
189    fn deserialize_msg() {
190        let filename = file!();
191        let line: u32 = line!();
192        let message = "message".to_string();
193
194        let raw_msg =
195            rdbg::Message::new(filename, line, rdbg::MsgPayload::Message(message.clone()));
196
197        let expected_msg = crate::Message {
198            time: 42,
199            thread_id: format!("{:?}", thread::current().id()),
200            filename: filename.to_string(),
201            line,
202            payload: crate::MsgPayload::Message(message),
203        };
204        let mut buffer = ByteBuffer::from_vec(raw_msg.as_slice()[LEN_FIELD_SIZE..].to_vec());
205        let mut actual_msg = crate::Message::from_buffer(&mut buffer).expect("Corrupt message");
206
207        // Cheat on time since we have no way to know exact time
208        actual_msg.time = expected_msg.time;
209        assert_eq!(expected_msg, actual_msg);
210    }
211
212    #[test]
213    fn deserialize_vals() {
214        let filename = file!();
215        let line: u32 = line!();
216        let values = vec![("name1", "val1".to_string()), ("name2", "val2".to_string())];
217
218        let raw_msg = rdbg::Message::new(filename, line, rdbg::MsgPayload::Values(values.clone()));
219
220        let expected_msg = crate::Message {
221            time: 42,
222            thread_id: format!("{:?}", thread::current().id()),
223            filename: filename.to_string(),
224            line,
225            payload: crate::MsgPayload::Values(
226                values
227                    .into_iter()
228                    .map(|(k, v)| (k.to_string(), v))
229                    .collect(),
230            ),
231        };
232        let mut buffer = ByteBuffer::from_vec(raw_msg.as_slice()[LEN_FIELD_SIZE..].to_vec());
233        let mut actual_msg = crate::Message::from_buffer(&mut buffer).expect("Corrupt message");
234
235        // Cheat on time since we have no way to know exact time
236        actual_msg.time = expected_msg.time;
237        assert_eq!(expected_msg, actual_msg);
238    }
239}
240
241// *** Error ***
242
243/// Errors that can occur based on data received from the debugged program
244pub enum Error {
245    /// The remote debugged program is using a different version of rdbg that is incompatible
246    BadVersion,
247    /// A string in the [Message] was not valid UTF8
248    BadUtf8(Utf8Error),
249    /// The binary payload of the [Message] was corrupted and could not be decoded
250    CorruptMsg,
251}
252
253impl Debug for Error {
254    #[inline]
255    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        <Self as Display>::fmt(self, f)
257    }
258}
259
260impl Display for Error {
261    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
262        match self {
263            Error::BadVersion => f.write_str("This library only supports protocol version 1"),
264            Error::BadUtf8(err) => std::fmt::Display::fmt(err, f),
265            Error::CorruptMsg => f.write_str("The message payload was corrupted"),
266        }
267    }
268}
269
270impl std::error::Error for Error {}
271
272// *** Event ***
273
274/// This represents various events that occur during iteration and are returned by [MsgIterator]
275pub enum Event {
276    /// Returned when attached to debugged program
277    Connected(SocketAddr),
278    /// Returned when loses connection to debugged program
279    Disconnected(SocketAddr),
280    /// Returned when a new message from the debugged program arrives
281    Message(Message),
282}
283
284// *** MsgIterator ***
285
286/// An iterator that returns [Event]s based on a connection to the debugged program. The primary
287/// objective is to receive [Message]s
288///
289/// This iterator never completes (so [Option] is never `None`). If a disconnect occurs, it will
290/// simply wait for a new connection and then continue returning messages.
291///
292/// This iterator returns a [Result] with either the next [Event] or an [Error]. Errors are not
293/// fatal and the user and handle (or not handle) as they see fit.
294pub struct MsgIterator {
295    addr: SocketAddr,
296    stream: Option<TcpStream>,
297    buffer: ByteBuffer,
298}
299
300impl MsgIterator {
301    /// Create a new message iterator to a custom destination IP and port
302    #[inline]
303    pub fn new(ip: &str, port: u16) -> Result<Self, AddrParseError> {
304        Ok(Self {
305            addr: SocketAddr::new(IpAddr::from_str(ip)?, port),
306            stream: None,
307            buffer: ByteBuffer::new(BUFFER_SIZE),
308        })
309    }
310}
311
312impl Default for MsgIterator {
313    #[inline]
314    fn default() -> Self {
315        Self::new(DEFAULT_ADDR, DEFAULT_PORT).unwrap()
316    }
317}
318
319impl Iterator for MsgIterator {
320    type Item = Result<Event, Error>;
321
322    fn next(&mut self) -> Option<Self::Item> {
323        match &mut self.stream {
324            Some(stream) => match self.buffer.read_from_stream(stream, LEN_FIELD_SIZE) {
325                Ok(_) => {
326                    // We know this is long enough - guaranteed by read above
327                    let len = self.buffer.read_u32().unwrap();
328
329                    match self
330                        .buffer
331                        .read_from_stream(stream, len as usize - LEN_FIELD_SIZE)
332                    {
333                        Ok(_) => match Message::from_buffer(&mut self.buffer) {
334                            Ok(msg) => Some(Ok(Event::Message(msg))),
335                            Err(err) => {
336                                self.stream = None;
337                                Some(Err(err))
338                            }
339                        },
340                        Err(_) => {
341                            self.stream = None;
342                            Some(Ok(Event::Disconnected(self.addr)))
343                        }
344                    }
345                }
346                Err(_) => {
347                    self.stream = None;
348                    Some(Ok(Event::Disconnected(self.addr)))
349                }
350            },
351            None => loop {
352                if let Ok(mut stream) = TcpStream::connect(self.addr) {
353                    match self.buffer.read_from_stream(&mut stream, size_of::<u8>()) {
354                        // We know this is long enough - guaranteed by read above
355                        Ok(_) if self.buffer.read_u8().unwrap() == WIRE_PROTOCOL_VERSION => {
356                            self.stream = Some(stream);
357                            return Some(Ok(Event::Connected(self.addr)));
358                        }
359                        Ok(_) => return Some(Err(Error::BadVersion)),
360                        Err(_) => {
361                            // No op
362                        }
363                    }
364                }
365
366                thread::sleep(Duration::from_millis(CONNECT_WAIT_TIME));
367            },
368        }
369    }
370}