open_protocol_client/
network.rs1use std::collections::VecDeque;
2use std::io;
3use bytes::{Buf, BytesMut};
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::TcpStream;
6use open_protocol::{Header, Message};
7use open_protocol::decode::{self, Decoder, Decode};
8use crate::client::{ConnectionError, Event};
9
10pub struct Network {
11 pub socket: TcpStream,
12 pub read_buf: BytesMut,
13}
14
15impl Network {
16 pub fn new(socket: TcpStream) -> Self {
17 Self {
18 socket,
19 read_buf: BytesMut::with_capacity(10 * 1024),
20 }
21 }
22
23 async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
24 let mut total_read = 0;
25 loop {
26 let read = self.socket.read_buf(&mut self.read_buf).await?;
27
28 if 0 == read {
29 return if self.read_buf.is_empty() {
30 Err(io::Error::new(
31 io::ErrorKind::ConnectionAborted,
32 "connection closed by peer",
33 ))
34 } else {
35 Err(io::Error::new(
36 io::ErrorKind::ConnectionReset,
37 "connection reset by peer",
38 ))
39 };
40 }
41
42 total_read += read;
43 if total_read >= required {
44 return Ok(total_read);
45 }
46 }
47 }
48
49 pub async fn read(&mut self, events: &mut VecDeque<Event>) -> io::Result<()> {
50 loop {
51 let required = match read_message(&mut self.read_buf) {
52 Ok(message) => {
53 events.push_back(Event::Incoming(message));
54 return Ok(());
55 },
56 Err(decode::Error::InsufficientBytes { have, need }) => need - have,
57 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
58 };
59
60 self.read_bytes(required).await?;
61 }
62 }
63
64 pub async fn flush(&mut self, write_buf: &mut BytesMut) -> Result<(), ConnectionError> {
65 if write_buf.is_empty() {
66 return Ok(());
67 }
68
69 self.socket.write_all(&write_buf[..]).await?;
70 write_buf.clear();
71 Ok(())
72 }
73}
74
75fn read_message(stream: &mut BytesMut) -> decode::Result<Message> {
76 if stream.len() < 20 {
77 return Err(decode::Error::InsufficientBytes { have: stream.len(), need: 20 });
78 }
79
80 let mut decoder = Decoder::new(&stream[..]);
81 let header = Header::decode(&mut decoder)?;
82
83 if stream.len() < (header.length as usize) {
84 return Err(decode::Error::InsufficientBytes { have: stream.len(), need: header.length as usize });
85 }
86
87 let message = Message::decode_payload(header.mid, header.revision_number(), &mut decoder)?;
88 decoder.expect_char(0x0 as char)?;
89 stream.advance((header.length + 1) as usize);
90 Ok(message)
91}