1use crate::{connection::Connection, encryption, Context};
2use anyhow::{anyhow, bail};
3use bytes::{Bytes, BytesMut};
4use std::fmt;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tracing::*;
7use std::sync::Arc;
8
9pub(crate) type Channel = crate::bichannel::Channel<Message, Message>;
10
11#[derive(Debug, Clone)]
12pub struct Message {
13 msg_type: Type,
14 data: Option<RawData>,
15}
16
17#[derive(Debug, Clone)]
18pub enum Type {
19 Control,
20 StatusChange(Status),
21 Error(Arc<str>),
22 General(u16),
23}
24
25#[derive(Debug, Clone)]
26struct RawData {
27 nonce: Option<BytesMut>,
28 msg_type_raw: BytesMut,
29 data: BytesMut,
30}
31
32#[allow(dead_code)]
33#[derive(Default, Debug, Clone, Copy, PartialEq)]
34pub enum Status {
35 #[default]
36 NotConnected,
37 Listening,
38 Connecting,
39 Connected,
40 ReConnecting,
41 Closed,
42 Closing,
43 Error,
44 Timeout,
45}
46
47const I32_SIZE: usize = std::mem::size_of::<i32>();
48
49impl Message {
50 pub(crate) fn new(msg_type_in: i32, data_in: Bytes, encryption_padding: bool) -> Message {
51 let data_len = data_in.len();
54
55 let (mut data, nonce) = if encryption_padding {
56 let (data, nonce) = encryption::new_data_buffer(I32_SIZE + data_len);
57 (data, Some(nonce))
58 } else {
59 (BytesMut::with_capacity(I32_SIZE + data_len), None)
61 };
62
63 data.extend(msg_type_in.to_be_bytes());
65 let msg_type_raw = data.split();
66
67 data.extend(data_in);
69
70 Message {
71 msg_type: msg_type_in.into(),
72 data: Some(RawData {
73 nonce,
74 msg_type_raw,
75 data,
76 }),
77 }
78 }
79
80 pub fn from_status(status: Status) -> Message {
81 Message {
82 msg_type: Type::StatusChange(status),
83 data: None,
84 }
85 }
86
87 pub fn from_error(err: crate::Error) -> Message {
88 Message {
89 msg_type: Type::Error(err.to_string().into()),
90 data: None,
91 }
92 }
93
94 pub fn is_data_message(&self) -> bool {
95 matches!(self.msg_type, Type::General { .. })
96 }
97
98 pub fn data_length(&self) -> usize {
99 if let Some(RawData { data, .. }) = &self.data {
100 data.len()
101 } else {
102 0
103 }
104 }
105
106 pub fn data(&self) -> Option<&[u8]> {
107 if let Some(RawData { data, .. }) = &self.data {
108 Some(data.as_ref())
109 } else {
110 None
111 }
112 }
113
114 pub fn msg_type(&self) -> Type {
115 self.msg_type.clone()
116 }
117
118 #[instrument(level = "trace")]
121 pub(crate) async fn from_stream(connection: &mut Connection) -> crate::Result<Option<Message>> {
122 let data_len = connection.reader.read_u32().await? as usize;
123
124 if data_len < 4 {
125 bail!("Not enough data for a message");
126 }
127
128 let mut data = BytesMut::with_capacity(data_len);
132
133 if 0 == connection.reader.read_buf(&mut data).await? {
134 if data.is_empty() {
139 return Ok(None);
140 } else {
141 bail!("connection reset by peer");
142 }
143 }
144
145 let (mut data, nonce) = encryption::maybe_decrypt(&connection.cipher, data)?;
147
148 let msg_type_raw = data.split_to(I32_SIZE);
150 let msg_type = i32::from_be_bytes(msg_type_raw.as_ref().try_into().unwrap()).into();
151
152 let message = Message {
153 msg_type,
154 data: Some(RawData {
155 nonce,
156 msg_type_raw,
157 data,
158 }),
159 };
160 info!("recv message {}", message);
161
162 Ok(Some(message))
163 }
164
165 #[instrument(level = "trace")]
166 pub(crate) async fn to_stream(
167 connection: &mut Connection,
168 message: Message,
169 ) -> crate::Result<()> {
170 info!("send message {}", message);
171
172 if let Some(data) = message.data {
173 let RawData {
175 nonce,
176 mut msg_type_raw,
177 data,
178 ..
179 } = data;
180
181 msg_type_raw.unsplit(data);
184 let data = msg_type_raw;
185
186 let data = encryption::maybe_encrypt(&connection.cipher, data, nonce)?;
188
189 connection.writer.write_u32(data.len() as u32).await?;
190 connection.writer.write_all(data.as_ref()).await?;
191 } else {
192 bail!("Internal messages can not send.");
193 }
194
195 Ok(())
196 }
197}
198
199pub(crate) async fn message_loop(
200 mut connection: Connection,
201 context: &mut Context,
202) -> crate::Result<()> {
203 loop {
204 tokio::select! {
205 result = Message::from_stream(&mut connection) => {
206 match result {
207 Ok(Some(message)) => {
208 if let Err(_err) = context.get_channel().send(message).await {
210 context.report_error_status(anyhow!("Dropped message")).await;
211 }
212 },
213 Ok(None) => {
214 context.report_error_status(anyhow!("Remote disconnect")).await;
215 context.report_status(Status::Closing).await?;
216 break;
217 },
218 Err(err) => {
219 let new_status = match err.downcast_ref::<std::io::Error>() {
220 Some(io_err) => {
221 if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
222 Status::ReConnecting
223 } else {
224 Status::Closing
225 }
226 },
227 None => {
228 Status::Closing
229 },
230 };
231 context.report_error_status(err).await;
232 context.report_status(new_status).await?;
233 break;
234 },
235 }
236
237 },
238 to_send = context.get_channel().recv() => {
239 match to_send {
240 Some(message) => {
241
242 if let Err(err) = Message::to_stream(&mut connection, message).await {
243 context.report_error_status(err).await;
245 }
246 },
247 None => {
248 context.report_error_status(anyhow!("bad recv")).await;
249 },
250 }
251 },
252 };
253 }
254
255 Ok(())
256}
257
258impl From<i32> for Type {
259 fn from(value: i32) -> Type {
260 match value {
261 0 => Type::Control,
262 -1 => Type::StatusChange(Status::Error), -2 => Type::Error("Unknown Error".into()), _ => {
265 if value > 0 {
266 Type::General(value as u16)
267 } else {
268 panic!("Unexpected mess_type {}", value)
270 }
271 }
272 }
273 }
274}
275
276impl From<Type> for i32 {
277 fn from(value: Type) -> i32 {
278 match value {
279 Type::Control => 0,
280 Type::StatusChange(_) => -1,
281 Type::Error(_) => -2,
282 Type::General(value) => value as i32,
283 }
284 }
285}
286
287impl fmt::Display for Message {
288 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289 match &self.msg_type {
290 Type::StatusChange(status) => write!(f, "{}: {}", self.msg_type, status),
291 Type::Control => write!(f, "{}", self.msg_type),
292 Type::Error(err) => write!(f, "{}: {}", self.msg_type, err),
293 Type::General(_) => {
294 write!(
295 f,
296 "{}: [{}] {:?}",
297 self.msg_type,
298 self.data_length(),
299 self.data
300 )
301 }
302 }
303 }
304}
305
306impl fmt::Display for Type {
307 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
308 write!(f, "{:?}", self)
309 }
310}
311
312impl fmt::Display for Status {
313 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
314 write!(f, "{:?}", self)
315 }
316}