1use std::fmt::{Debug, Display, Formatter};
2use std::io::Read;
3use std::mem::size_of;
4use std::net::{AddrParseError, IpAddr, SocketAddr, TcpStream};
5use std::str::{FromStr, Utf8Error};
6use std::time::Duration;
7use std::{io, thread};
8
9pub const DEFAULT_ADDR: &str = "127.0.0.1";
11pub const DEFAULT_PORT: u16 = 13579;
13
14const CONNECT_WAIT_TIME: u64 = 250; const BUFFER_SIZE: usize = 4096;
16const LEN_FIELD_SIZE: usize = size_of::<u32>();
17const WIRE_PROTOCOL_VERSION: u8 = 1;
18
19#[repr(u8)]
22enum MsgPayloadVal {
23 Message = 1,
24 Values = 2,
25}
26
27impl MsgPayloadVal {
28 #[inline]
29 fn from_buffer(buffer: &mut ByteBuffer) -> Result<MsgPayloadVal, Error> {
30 buffer.read_u8()?.try_into()
31 }
32}
33
34impl TryFrom<u8> for MsgPayloadVal {
35 type Error = Error;
36
37 fn try_from(value: u8) -> Result<Self, Self::Error> {
38 match value {
39 1 => Ok(MsgPayloadVal::Message),
40 2 => Ok(MsgPayloadVal::Values),
41 _ => Err(Error::CorruptMsg),
42 }
43 }
44}
45
46struct ByteBuffer {
49 buffer: Vec<u8>,
50 idx: usize,
51}
52
53impl ByteBuffer {
54 #[inline]
55 fn new(capacity: usize) -> Self {
56 Self::from_vec(Vec::with_capacity(capacity))
57 }
58
59 #[inline]
60 fn from_vec(buffer: Vec<u8>) -> Self {
61 Self { buffer, idx: 0 }
62 }
63
64 fn read_from_stream(&mut self, stream: &mut TcpStream, size: usize) -> io::Result<()> {
65 self.buffer.resize(size, 0);
66 stream.read_exact(&mut self.buffer)?;
67 self.idx = 0;
69 Ok(())
70 }
71
72 fn as_slice(&mut self, len: usize) -> Result<&[u8], Error> {
73 if self.idx + len <= self.buffer.len() {
74 self.idx += len;
75 Ok(&self.buffer[(self.idx - len)..self.idx])
76 } else {
77 Err(Error::CorruptMsg)
78 }
79 }
80
81 fn read_u8(&mut self) -> Result<u8, Error> {
82 Ok(u8::from_be_bytes(
83 self.as_slice(size_of::<u8>())?.try_into().unwrap(),
84 ))
85 }
86
87 fn read_u64(&mut self) -> Result<u64, Error> {
88 Ok(u64::from_be_bytes(
89 self.as_slice(size_of::<u64>())?.try_into().unwrap(),
90 ))
91 }
92
93 fn read_u32(&mut self) -> Result<u32, Error> {
94 Ok(u32::from_be_bytes(
95 self.as_slice(size_of::<u32>())?.try_into().unwrap(),
96 ))
97 }
98
99 fn read_str(&mut self) -> Result<String, Error> {
100 let len = self.read_u32()?;
101
102 match std::str::from_utf8(self.as_slice(len as usize)?) {
103 Ok(str) => Ok(str.to_string()),
104 Err(err) => Err(Error::BadUtf8(err)),
105 }
106 }
107}
108
109#[derive(Clone, Debug, Eq, PartialEq)]
114pub enum MsgPayload {
115 Message(String),
117 Values(Vec<(String, String)>),
119}
120
121impl MsgPayload {
122 fn from_buffer(buffer: &mut ByteBuffer) -> Result<Self, Error> {
123 match MsgPayloadVal::from_buffer(buffer)? {
124 MsgPayloadVal::Message => {
125 let s = buffer.read_str()?;
126 Ok(MsgPayload::Message(s))
127 }
128 MsgPayloadVal::Values => {
129 let len = buffer.read_u32()?;
130 let mut values = Vec::with_capacity(len as usize);
133
134 for _ in 0..len {
135 let name = buffer.read_str()?;
136 let val = buffer.read_str()?;
137 values.push((name, val));
138 }
139
140 Ok(MsgPayload::Values(values))
141 }
142 }
143 }
144}
145
146#[derive(Clone, Debug, Eq, PartialEq)]
151pub struct Message {
152 pub time: u64,
154 pub thread_id: String,
156 pub filename: String,
158 pub line: u32,
160 pub payload: MsgPayload,
162}
163
164impl Message {
165 fn from_buffer(buffer: &mut ByteBuffer) -> Result<Message, Error> {
166 let time = buffer.read_u64()?;
167 let thread_id = buffer.read_str()?;
168 let filename = buffer.read_str()?;
169 let line = buffer.read_u32()?;
170 let payload = MsgPayload::from_buffer(buffer)?;
171
172 Ok(Self {
173 time,
174 thread_id,
175 filename,
176 line,
177 payload,
178 })
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use std::thread;
185
186 use crate::{ByteBuffer, LEN_FIELD_SIZE};
187
188 #[test]
189 fn deserialize_msg() {
190 let filename = file!();
191 let line: u32 = line!();
192 let message = "message".to_string();
193
194 let raw_msg =
195 rdbg::Message::new(filename, line, rdbg::MsgPayload::Message(message.clone()));
196
197 let expected_msg = crate::Message {
198 time: 42,
199 thread_id: format!("{:?}", thread::current().id()),
200 filename: filename.to_string(),
201 line,
202 payload: crate::MsgPayload::Message(message),
203 };
204 let mut buffer = ByteBuffer::from_vec(raw_msg.as_slice()[LEN_FIELD_SIZE..].to_vec());
205 let mut actual_msg = crate::Message::from_buffer(&mut buffer).expect("Corrupt message");
206
207 actual_msg.time = expected_msg.time;
209 assert_eq!(expected_msg, actual_msg);
210 }
211
212 #[test]
213 fn deserialize_vals() {
214 let filename = file!();
215 let line: u32 = line!();
216 let values = vec![("name1", "val1".to_string()), ("name2", "val2".to_string())];
217
218 let raw_msg = rdbg::Message::new(filename, line, rdbg::MsgPayload::Values(values.clone()));
219
220 let expected_msg = crate::Message {
221 time: 42,
222 thread_id: format!("{:?}", thread::current().id()),
223 filename: filename.to_string(),
224 line,
225 payload: crate::MsgPayload::Values(
226 values
227 .into_iter()
228 .map(|(k, v)| (k.to_string(), v))
229 .collect(),
230 ),
231 };
232 let mut buffer = ByteBuffer::from_vec(raw_msg.as_slice()[LEN_FIELD_SIZE..].to_vec());
233 let mut actual_msg = crate::Message::from_buffer(&mut buffer).expect("Corrupt message");
234
235 actual_msg.time = expected_msg.time;
237 assert_eq!(expected_msg, actual_msg);
238 }
239}
240
241pub enum Error {
245 BadVersion,
247 BadUtf8(Utf8Error),
249 CorruptMsg,
251}
252
253impl Debug for Error {
254 #[inline]
255 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 <Self as Display>::fmt(self, f)
257 }
258}
259
260impl Display for Error {
261 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
262 match self {
263 Error::BadVersion => f.write_str("This library only supports protocol version 1"),
264 Error::BadUtf8(err) => std::fmt::Display::fmt(err, f),
265 Error::CorruptMsg => f.write_str("The message payload was corrupted"),
266 }
267 }
268}
269
270impl std::error::Error for Error {}
271
272pub enum Event {
276 Connected(SocketAddr),
278 Disconnected(SocketAddr),
280 Message(Message),
282}
283
284pub struct MsgIterator {
295 addr: SocketAddr,
296 stream: Option<TcpStream>,
297 buffer: ByteBuffer,
298}
299
300impl MsgIterator {
301 #[inline]
303 pub fn new(ip: &str, port: u16) -> Result<Self, AddrParseError> {
304 Ok(Self {
305 addr: SocketAddr::new(IpAddr::from_str(ip)?, port),
306 stream: None,
307 buffer: ByteBuffer::new(BUFFER_SIZE),
308 })
309 }
310}
311
312impl Default for MsgIterator {
313 #[inline]
314 fn default() -> Self {
315 Self::new(DEFAULT_ADDR, DEFAULT_PORT).unwrap()
316 }
317}
318
319impl Iterator for MsgIterator {
320 type Item = Result<Event, Error>;
321
322 fn next(&mut self) -> Option<Self::Item> {
323 match &mut self.stream {
324 Some(stream) => match self.buffer.read_from_stream(stream, LEN_FIELD_SIZE) {
325 Ok(_) => {
326 let len = self.buffer.read_u32().unwrap();
328
329 match self
330 .buffer
331 .read_from_stream(stream, len as usize - LEN_FIELD_SIZE)
332 {
333 Ok(_) => match Message::from_buffer(&mut self.buffer) {
334 Ok(msg) => Some(Ok(Event::Message(msg))),
335 Err(err) => {
336 self.stream = None;
337 Some(Err(err))
338 }
339 },
340 Err(_) => {
341 self.stream = None;
342 Some(Ok(Event::Disconnected(self.addr)))
343 }
344 }
345 }
346 Err(_) => {
347 self.stream = None;
348 Some(Ok(Event::Disconnected(self.addr)))
349 }
350 },
351 None => loop {
352 if let Ok(mut stream) = TcpStream::connect(self.addr) {
353 match self.buffer.read_from_stream(&mut stream, size_of::<u8>()) {
354 Ok(_) if self.buffer.read_u8().unwrap() == WIRE_PROTOCOL_VERSION => {
356 self.stream = Some(stream);
357 return Some(Ok(Event::Connected(self.addr)));
358 }
359 Ok(_) => return Some(Err(Error::BadVersion)),
360 Err(_) => {
361 }
363 }
364 }
365
366 thread::sleep(Duration::from_millis(CONNECT_WAIT_TIME));
367 },
368 }
369 }
370}