golang_ipc_rs/
message.rs

1use crate::{connection::Connection, encryption, Context};
2use anyhow::{anyhow, bail};
3use bytes::{Bytes, BytesMut};
4use std::fmt;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tracing::*;
7use std::sync::Arc;
8
9pub(crate) type Channel = crate::bichannel::Channel<Message, Message>;
10
11#[derive(Debug, Clone)]
12pub struct Message {
13    msg_type: Type,
14    data: Option<RawData>,
15}
16
17#[derive(Debug, Clone)]
18pub enum Type {
19    Control,
20    StatusChange(Status),
21    Error(Arc<str>),
22    General(u16),
23}
24
25#[derive(Debug, Clone)]
26struct RawData {
27    nonce: Option<BytesMut>,
28    msg_type_raw: BytesMut,
29    data: BytesMut,
30}
31
32#[allow(dead_code)]
33#[derive(Default, Debug, Clone, Copy, PartialEq)]
34pub enum Status {
35    #[default]
36    NotConnected,
37    Listening,
38    Connecting,
39    Connected,
40    ReConnecting,
41    Closed,
42    Closing,
43    Error,
44    Timeout,
45}
46
47const I32_SIZE: usize = std::mem::size_of::<i32>();
48
49impl Message {
50    pub(crate) fn new(msg_type_in: i32, data_in: Bytes, encryption_padding: bool) -> Message {
51        // Create a BytesMut that can hold all the data
52        //  This accounts for encryption overhead and the msg_type
53        let data_len = data_in.len();
54
55        let (mut data, nonce) = if encryption_padding {
56            let (data, nonce) = encryption::new_data_buffer(I32_SIZE + data_len);
57            (data, Some(nonce))
58        } else {
59            // Reserve for just the message
60            (BytesMut::with_capacity(I32_SIZE + data_len), None)
61        };
62
63        // Write the message type, and split
64        data.extend(msg_type_in.to_be_bytes());
65        let msg_type_raw = data.split();
66
67        // Write the data
68        data.extend(data_in);
69
70        Message {
71            msg_type: msg_type_in.into(),
72            data: Some(RawData {
73                nonce,
74                msg_type_raw,
75                data,
76            }),
77        }
78    }
79
80    pub fn from_status(status: Status) -> Message {
81        Message {
82            msg_type: Type::StatusChange(status),
83            data: None,
84        }
85    }
86
87    pub fn from_error(err: crate::Error) -> Message {
88        Message {
89            msg_type: Type::Error(err.to_string().into()),
90            data: None,
91        }
92    }
93
94    pub fn is_data_message(&self) -> bool {
95        matches!(self.msg_type, Type::General { .. })
96    }
97
98    pub fn data_length(&self) -> usize {
99        if let Some(RawData { data, .. }) = &self.data {
100            data.len()
101        } else {
102            0
103        }
104    }
105
106    pub fn data(&self) -> Option<&[u8]> {
107        if let Some(RawData { data, .. }) = &self.data {
108            Some(data.as_ref())
109        } else {
110            None
111        }
112    }
113
114    pub fn msg_type(&self) -> Type {
115        self.msg_type.clone()
116    }
117
118    // pub fn stauts(&self) -> Type
119
120    #[instrument(level = "trace")]
121    pub(crate) async fn from_stream(connection: &mut Connection) -> crate::Result<Option<Message>> {
122        let data_len = connection.reader.read_u32().await? as usize;
123
124        if data_len < 4 {
125            bail!("Not enough data for a message");
126        }
127
128        // let msg_type = connection.reader.read_i32().await?;
129
130        // Allocate a buffer large enough for the data.
131        let mut data = BytesMut::with_capacity(data_len);
132
133        if 0 == connection.reader.read_buf(&mut data).await? {
134            // The remote closed the connection. For this to be a clean
135            // shutdown, there should be no data in the read buffer. If
136            // there is, this means that the peer closed the socket while
137            // sending a message.
138            if data.is_empty() {
139                return Ok(None);
140            } else {
141                bail!("connection reset by peer");
142            }
143        }
144
145        // Decrypt the data
146        let (mut data, nonce) = encryption::maybe_decrypt(&connection.cipher, data)?;
147
148        // Split and decode the message type
149        let msg_type_raw = data.split_to(I32_SIZE);
150        let msg_type = i32::from_be_bytes(msg_type_raw.as_ref().try_into().unwrap()).into();
151
152        let message = Message {
153            msg_type,
154            data: Some(RawData {
155                nonce,
156                msg_type_raw,
157                data,
158            }),
159        };
160        info!("recv message {}", message);
161
162        Ok(Some(message))
163    }
164
165    #[instrument(level = "trace")]
166    pub(crate) async fn to_stream(
167        connection: &mut Connection,
168        message: Message,
169    ) -> crate::Result<()> {
170        info!("send message {}", message);
171
172        if let Some(data) = message.data {
173            // Extract raw data objects
174            let RawData {
175                nonce,
176                mut msg_type_raw,
177                data,
178                ..
179            } = data;
180
181            // Unsplit the buffer
182            //  Note, this looks backwards but otherwise the data is in the wrong order
183            msg_type_raw.unsplit(data);
184            let data = msg_type_raw;
185
186            // Encrypt the data
187            let data = encryption::maybe_encrypt(&connection.cipher, data, nonce)?;
188
189            connection.writer.write_u32(data.len() as u32).await?;
190            connection.writer.write_all(data.as_ref()).await?;
191        } else {
192            bail!("Internal messages can not send.");
193        }
194
195        Ok(())
196    }
197}
198
199pub(crate) async fn message_loop(
200    mut connection: Connection,
201    context: &mut Context,
202) -> crate::Result<()> {
203    loop {
204        tokio::select! {
205            result = Message::from_stream(&mut connection) => {
206                match result {
207                    Ok(Some(message)) => {
208                        // We received a message over the socket, so send it to the consumer
209                        if let Err(_err) = context.get_channel().send(message).await {
210                            context.report_error_status(anyhow!("Dropped message")).await;
211                        }
212                    },
213                    Ok(None) => {
214                        context.report_error_status(anyhow!("Remote disconnect")).await;
215                        context.report_status(Status::Closing).await?;
216                        break;
217                    },
218                    Err(err) =>  {
219                        let new_status = match err.downcast_ref::<std::io::Error>() {
220                            Some(io_err) => {
221                                if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
222                                    Status::ReConnecting
223                                } else {
224                                    Status::Closing
225                                }
226                            },
227                            None => {
228                                Status::Closing
229                            },
230                        };
231                        context.report_error_status(err).await;
232                        context.report_status(new_status).await?;
233                        break;
234                    },
235                }
236
237            },
238            to_send = context.get_channel().recv() => {
239                match to_send {
240                    Some(message) => {
241
242                        if let Err(err) = Message::to_stream(&mut connection, message).await {
243                            // TODO: The message is dropped in this case.  Can we put it back into `context.get_channel()`?
244                            context.report_error_status(err).await;
245                        }
246                    },
247                    None => {
248                        context.report_error_status(anyhow!("bad recv")).await;
249                    },
250                }
251            },
252        };
253    }
254
255    Ok(())
256}
257
258impl From<i32> for Type {
259    fn from(value: i32) -> Type {
260        match value {
261            0 => Type::Control,
262            -1 => Type::StatusChange(Status::Error), // not expected
263            -2 => Type::Error("Unknown Error".into()), // not expected
264            _ => {
265                if value > 0 {
266                    Type::General(value as u16)
267                } else {
268                    // TODO: how does Go version handle this
269                    panic!("Unexpected mess_type {}", value)
270                }
271            }
272        }
273    }
274}
275
276impl From<Type> for i32 {
277    fn from(value: Type) -> i32 {
278        match value {
279            Type::Control => 0,
280            Type::StatusChange(_) => -1,
281            Type::Error(_) => -2,
282            Type::General(value) => value as i32,
283        }
284    }
285}
286
287impl fmt::Display for Message {
288    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289        match &self.msg_type {
290            Type::StatusChange(status) => write!(f, "{}: {}", self.msg_type, status),
291            Type::Control => write!(f, "{}", self.msg_type),
292            Type::Error(err) => write!(f, "{}: {}", self.msg_type, err),
293            Type::General(_) => {
294                write!(
295                    f,
296                    "{}: [{}] {:?}",
297                    self.msg_type,
298                    self.data_length(),
299                    self.data
300                )
301            }
302        }
303    }
304}
305
306impl fmt::Display for Type {
307    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
308        write!(f, "{:?}", self)
309    }
310}
311
312impl fmt::Display for Status {
313    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
314        write!(f, "{:?}", self)
315    }
316}