1use std::{
48 io::{Error, ErrorKind},
49 net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
50};
51
52use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
53
54#[allow(async_fn_in_trait)]
56pub trait ByteWrite {
57 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error>;
59}
60
61#[allow(async_fn_in_trait)]
63pub trait ByteRead: Sized {
64 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error>;
66}
67
68impl ByteWrite for () {
69 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, _: &mut W) -> Result<(), Error> {
70 Ok(())
71 }
72}
73
74impl ByteRead for () {
75 async fn read<R: AsyncRead + Unpin + ?Sized>(_: &mut R) -> Result<Self, Error> {
76 Ok(())
77 }
78}
79
80impl ByteWrite for bool {
81 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
82 writer.write_u8(*self as u8).await
83 }
84}
85
86impl ByteRead for bool {
87 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
88 Ok(reader.read_u8().await? != 0)
89 }
90}
91
92impl ByteWrite for u8 {
93 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
94 writer.write_u8(*self).await
95 }
96}
97
98impl ByteRead for u8 {
99 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
100 reader.read_u8().await
101 }
102}
103
104impl ByteWrite for u16 {
105 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
106 writer.write_u16(*self).await
107 }
108}
109
110impl ByteRead for u16 {
111 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
112 reader.read_u16().await
113 }
114}
115
116impl ByteWrite for u32 {
117 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
118 writer.write_u32(*self).await
119 }
120}
121
122impl ByteRead for u32 {
123 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
124 reader.read_u32().await
125 }
126}
127
128impl ByteWrite for u64 {
129 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
130 writer.write_u64(*self).await
131 }
132}
133
134impl ByteRead for u64 {
135 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
136 reader.read_u64().await
137 }
138}
139
140impl ByteWrite for i64 {
141 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
142 writer.write_i64(*self).await
143 }
144}
145
146impl ByteRead for i64 {
147 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
148 reader.read_i64().await
149 }
150}
151
152impl ByteWrite for char {
153 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
154 let mut buf = [0u8; 4];
155 let s = self.encode_utf8(&mut buf);
156 writer.write_all(s.as_bytes()).await
157 }
158}
159
160impl ByteRead for char {
161 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
162 let mut buf = [0u8; 4];
163 let mut byte_count = 0;
164 loop {
165 reader.read_exact(&mut buf[byte_count..(byte_count + 1)]).await?;
166 byte_count += 1;
167 if let Ok(s) = std::str::from_utf8(&buf[0..byte_count]) {
168 return Ok(s.chars().next().unwrap());
169 }
170
171 if byte_count == 4 {
172 return Err(Error::new(ErrorKind::InvalidData, "char is not valid UTF-8"));
173 }
174 }
175 }
176}
177
178impl ByteWrite for Ipv4Addr {
179 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
180 writer.write_all(&self.octets()).await
181 }
182}
183
184impl ByteRead for Ipv4Addr {
185 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
186 let mut octets = [0u8; 4];
187 reader.read_exact(&mut octets).await?;
188 Ok(octets.into())
189 }
190}
191
192impl ByteWrite for Ipv6Addr {
193 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
194 writer.write_all(&self.octets()).await
195 }
196}
197
198impl ByteRead for Ipv6Addr {
199 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
200 let mut octets = [0u8; 16];
201 reader.read_exact(&mut octets).await?;
202
203 Ok(octets.into())
204 }
205}
206
207impl ByteWrite for SocketAddrV4 {
208 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
209 self.ip().write(writer).await?;
210 writer.write_u16(self.port()).await
211 }
212}
213
214impl ByteRead for SocketAddrV4 {
215 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
216 let mut octets = [0u8; 4];
217 reader.read_exact(&mut octets).await?;
218 let port = reader.read_u16().await?;
219
220 Ok(SocketAddrV4::new(octets.into(), port))
221 }
222}
223
224impl ByteWrite for SocketAddrV6 {
225 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
226 self.ip().write(writer).await?;
227 writer.write_u16(self.port()).await?;
228 writer.write_u32(self.flowinfo()).await?;
229 writer.write_u32(self.scope_id()).await
230 }
231}
232
233impl ByteRead for SocketAddrV6 {
234 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
235 let mut octets = [0u8; 16];
236 reader.read_exact(&mut octets).await?;
237 let port = reader.read_u16().await?;
238 let flowinfo = reader.read_u32().await?;
239 let scope_id = reader.read_u32().await?;
240
241 Ok(SocketAddrV6::new(octets.into(), port, flowinfo, scope_id))
242 }
243}
244
245impl ByteWrite for SocketAddr {
246 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
247 match self {
248 SocketAddr::V4(v4) => {
249 writer.write_u8(4).await?;
250 v4.write(writer).await
251 }
252 SocketAddr::V6(v6) => {
253 writer.write_u8(6).await?;
254 v6.write(writer).await
255 }
256 }
257 }
258}
259
260impl ByteRead for SocketAddr {
261 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
262 let addr_type = reader.read_u8().await?;
263 match addr_type {
264 4 => Ok(SocketAddr::V4(SocketAddrV4::read(reader).await?)),
265 6 => Ok(SocketAddr::V6(SocketAddrV6::read(reader).await?)),
266 v => Err(Error::new(ErrorKind::InvalidData, format!("Invalid socket address type, {v}"))),
267 }
268 }
269}
270
271impl<T: ByteWrite> ByteWrite for Option<T> {
272 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
273 match self {
274 Some(value) => {
275 writer.write_u8(1).await?;
276 value.write(writer).await
277 }
278 None => writer.write_u8(0).await,
279 }
280 }
281}
282
283impl<T: ByteRead> ByteRead for Option<T> {
284 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
285 let has_value = reader.read_u8().await?;
286 match has_value {
287 0 => Ok(None),
288 _ => Ok(Some(T::read(reader).await?)),
289 }
290 }
291}
292
293impl<T: ByteWrite, E: ByteWrite> ByteWrite for Result<T, E> {
294 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
295 match self {
296 Ok(v) => {
297 writer.write_u8(1).await?;
298 v.write(writer).await
299 }
300 Err(e) => {
301 writer.write_u8(0).await?;
302 e.write(writer).await
303 }
304 }
305 }
306}
307
308impl<T: ByteRead, E: ByteRead> ByteRead for Result<T, E> {
309 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
310 match reader.read_u8().await? {
311 0 => Ok(Err(E::read(reader).await?)),
312 _ => Ok(Ok(T::read(reader).await?)),
313 }
314 }
315}
316
317impl ByteWrite for Error {
318 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
319 let kind_id = match self.kind() {
320 ErrorKind::NotFound => 1,
321 ErrorKind::PermissionDenied => 2,
322 ErrorKind::ConnectionRefused => 3,
323 ErrorKind::ConnectionReset => 4,
324 ErrorKind::ConnectionAborted => 5,
325 ErrorKind::NotConnected => 6,
326 ErrorKind::AddrInUse => 7,
327 ErrorKind::AddrNotAvailable => 8,
328 ErrorKind::BrokenPipe => 9,
329 ErrorKind::AlreadyExists => 10,
330 ErrorKind::WouldBlock => 11,
331 ErrorKind::InvalidInput => 12,
332 ErrorKind::InvalidData => 13,
333 ErrorKind::TimedOut => 14,
334 ErrorKind::WriteZero => 15,
335 ErrorKind::Interrupted => 16,
336 ErrorKind::Unsupported => 17,
337 ErrorKind::UnexpectedEof => 18,
338 ErrorKind::OutOfMemory => 19,
339 ErrorKind::Other => 20,
340 _ => 0,
341 };
342
343 writer.write_u8(kind_id).await?;
344 self.to_string().write(writer).await
345 }
346}
347
348impl ByteRead for Error {
349 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
350 let kind_id = reader.read_u8().await?;
351
352 let error_kind = match kind_id {
353 1 => ErrorKind::NotFound,
354 2 => ErrorKind::PermissionDenied,
355 3 => ErrorKind::ConnectionRefused,
356 4 => ErrorKind::ConnectionReset,
357 5 => ErrorKind::ConnectionAborted,
358 6 => ErrorKind::NotConnected,
359 7 => ErrorKind::AddrInUse,
360 8 => ErrorKind::AddrNotAvailable,
361 9 => ErrorKind::BrokenPipe,
362 10 => ErrorKind::AlreadyExists,
363 11 => ErrorKind::WouldBlock,
364 12 => ErrorKind::InvalidInput,
365 13 => ErrorKind::InvalidData,
366 14 => ErrorKind::TimedOut,
367 15 => ErrorKind::WriteZero,
368 16 => ErrorKind::Interrupted,
369 17 => ErrorKind::Unsupported,
370 18 => ErrorKind::UnexpectedEof,
371 19 => ErrorKind::OutOfMemory,
372 _ => ErrorKind::Other,
373 };
374
375 let message = String::read(reader).await?;
376
377 Ok(Error::new(error_kind, message))
378 }
379}
380
381impl ByteWrite for str {
382 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
383 let bytes = self.as_bytes();
384 let len = bytes.len();
385 if len > u16::MAX as usize {
386 return Err(Error::new(ErrorKind::InvalidData, "String is too long (>= 64KB)"));
387 }
388
389 let len = len as u16;
390 writer.write_u16(len).await?;
391 writer.write_all(bytes).await
392 }
393}
394
395impl ByteWrite for String {
396 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
397 self.as_str().write(writer).await
398 }
399}
400
401impl ByteRead for String {
402 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
403 let len = reader.read_u16().await? as usize;
404
405 let mut s = String::with_capacity(len);
406 unsafe {
407 let v = s.as_mut_vec();
409 v.set_len(len);
410 reader.read_exact(&mut v[0..len]).await?;
411 if std::str::from_utf8(v).is_err() {
412 return Err(Error::new(ErrorKind::InvalidData, "String is not valid UTF-8"));
413 }
414 }
415
416 Ok(s)
417 }
418}
419
420pub struct SmallWriteString<'a>(pub &'a str);
423
424impl<'a> ByteWrite for SmallWriteString<'a> {
425 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
426 let bytes = self.0.as_bytes();
427 let len = bytes.len();
428 if len > u8::MAX as usize {
429 return Err(Error::new(ErrorKind::InvalidData, "Small string is too long (>= 256B)"));
430 }
431
432 let len = len as u8;
433 writer.write_u8(len).await?;
434 writer.write_all(bytes).await
435 }
436}
437
438pub struct SmallReadString(pub String);
441
442impl ByteRead for SmallReadString {
443 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
444 let len = reader.read_u8().await? as usize;
445
446 let mut s = String::with_capacity(len);
447 unsafe {
448 let v = s.as_mut_vec();
450 v.set_len(len);
451 reader.read_exact(&mut v[0..len]).await?;
452 if std::str::from_utf8(v).is_err() {
453 return Err(Error::new(ErrorKind::InvalidData, "Small string is not valid UTF-8"));
454 }
455 }
456
457 Ok(SmallReadString(s))
458 }
459}
460
461impl<T: ByteWrite> ByteWrite for &[T] {
462 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
463 let len = self.len();
464 if len > u16::MAX as usize {
465 return Err(Error::new(ErrorKind::InvalidData, "List is too long (>= 64K)"));
466 }
467
468 let len = len as u16;
469 writer.write_u16(len).await?;
470 for ele in self.iter() {
471 ele.write(writer).await?;
472 }
473
474 Ok(())
475 }
476}
477
478impl<T: ByteRead> ByteRead for Vec<T> {
479 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
480 let len = reader.read_u16().await? as usize;
481
482 let mut v = Vec::with_capacity(len);
483 for _ in 0..len {
484 v.push(T::read(reader).await?);
485 }
486
487 Ok(v)
488 }
489}
490
491pub struct SmallWriteList<'a, T>(pub &'a [T]);
494
495impl<'a, T: ByteWrite> ByteWrite for SmallWriteList<'a, T> {
496 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
497 let len = self.0.len();
498 if len > u8::MAX as usize {
499 return Err(Error::new(ErrorKind::InvalidData, "Small list is too long (>= 256)"));
500 }
501
502 let len = len as u8;
503 writer.write_u8(len).await?;
504 for ele in self.0.iter() {
505 ele.write(writer).await?;
506 }
507
508 Ok(())
509 }
510}
511pub struct SmallReadList<T>(pub Vec<T>);
514
515impl<T: ByteRead> ByteRead for SmallReadList<T> {
516 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
517 let len = reader.read_u8().await? as usize;
518
519 let mut v = Vec::with_capacity(len);
520 for _ in 0..len {
521 v.push(T::read(reader).await?);
522 }
523
524 Ok(SmallReadList(v))
525 }
526}
527
528impl<T: ByteWrite> ByteWrite for &T {
529 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
530 (*self).write(writer).await
531 }
532}
533
534impl<T0: ByteWrite, T1: ByteWrite> ByteWrite for (T0, T1) {
535 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
536 self.0.write(writer).await?;
537 self.1.write(writer).await
538 }
539}
540
541impl<T0: ByteRead, T1: ByteRead> ByteRead for (T0, T1) {
542 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
543 Ok((T0::read(reader).await?, T1::read(reader).await?))
544 }
545}
546
547impl<T0: ByteWrite, T1: ByteWrite, T2: ByteWrite> ByteWrite for (T0, T1, T2) {
548 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
549 self.0.write(writer).await?;
550 self.1.write(writer).await?;
551 self.2.write(writer).await
552 }
553}
554
555impl<T0: ByteRead, T1: ByteRead, T2: ByteRead> ByteRead for (T0, T1, T2) {
556 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
557 Ok((T0::read(reader).await?, T1::read(reader).await?, T2::read(reader).await?))
558 }
559}
560
561impl<T0: ByteWrite, T1: ByteWrite, T2: ByteWrite, T3: ByteWrite> ByteWrite for (T0, T1, T2, T3) {
562 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
563 self.0.write(writer).await?;
564 self.1.write(writer).await?;
565 self.2.write(writer).await?;
566 self.3.write(writer).await
567 }
568}
569
570impl<T0: ByteRead, T1: ByteRead, T2: ByteRead, T3: ByteRead> ByteRead for (T0, T1, T2, T3) {
571 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
572 Ok((
573 T0::read(reader).await?,
574 T1::read(reader).await?,
575 T2::read(reader).await?,
576 T3::read(reader).await?,
577 ))
578 }
579}
580
581impl<T0: ByteWrite, T1: ByteWrite, T2: ByteWrite, T3: ByteWrite, T4: ByteWrite> ByteWrite for (T0, T1, T2, T3, T4) {
582 async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
583 self.0.write(writer).await?;
584 self.1.write(writer).await?;
585 self.2.write(writer).await?;
586 self.3.write(writer).await?;
587 self.4.write(writer).await
588 }
589}
590
591impl<T0: ByteRead, T1: ByteRead, T2: ByteRead, T3: ByteRead, T4: ByteRead> ByteRead for (T0, T1, T2, T3, T4) {
592 async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
593 Ok((
594 T0::read(reader).await?,
595 T1::read(reader).await?,
596 T2::read(reader).await?,
597 T3::read(reader).await?,
598 T4::read(reader).await?,
599 ))
600 }
601}