async_protocol/
transport.rs

1use std::collections::VecDeque;
2use std::io::Cursor;
3use std::mem;
4
5use async_trait::async_trait;
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7
8use protocol::{Error, Parcel, Settings};
9
10#[async_trait]
11pub trait Transport {
12    async fn process_data<R: AsyncRead + Send + Unpin>(
13        &mut self,
14        read: &mut R,
15        settings: &Settings,
16    ) -> Result<(), Error>;
17
18    async fn receive_raw_packet(&mut self) -> Result<Option<Vec<u8>>, Error>;
19
20    async fn send_raw_packet<W: AsyncWrite + Send + Unpin>(
21        &mut self,
22        write: &mut W,
23        packet: &[u8],
24        settings: &Settings,
25    ) -> Result<(), Error>;
26}
27
28/// The type that we use to describe packet sizes.
29pub type PacketSize = u32;
30
31/// The current state.
32#[derive(Clone, Debug)]
33enum State {
34    /// We are awaiting packet size bytes.
35    AwaitingSize(Vec<u8>),
36    AwaitingPacket {
37        size: PacketSize,
38        received_data: Vec<u8>,
39    },
40}
41
42/// A simple transport.
43#[derive(Clone, Debug)]
44pub struct Simple {
45    state: State,
46    packets: VecDeque<Vec<u8>>,
47}
48
49impl Simple {
50    pub fn new() -> Self {
51        Simple {
52            state: State::AwaitingSize(Vec::new()),
53            packets: VecDeque::new(),
54        }
55    }
56
57    async fn process_bytes(&mut self, bytes: &[u8], settings: &Settings) -> Result<(), Error> {
58        let mut read = Cursor::new(bytes);
59
60        loop {
61            match self.state.clone() {
62                State::AwaitingSize(mut size_bytes) => {
63                    let remaining_bytes = mem::size_of::<PacketSize>() - size_bytes.len();
64
65                    let mut received_bytes = vec![0; remaining_bytes];
66                    let bytes_read = std::io::Read::read(&mut read, &mut received_bytes)?;
67                    received_bytes.drain(bytes_read..);
68
69                    assert_eq!(received_bytes.len(), bytes_read);
70
71                    size_bytes.extend(received_bytes.into_iter());
72
73                    if size_bytes.len() == mem::size_of::<PacketSize>() {
74                        let mut size_buffer = Cursor::new(size_bytes);
75
76                        let size = PacketSize::read(&mut size_buffer, settings).unwrap();
77
78                        // We are now ready to receive packet data.
79                        self.state = State::AwaitingPacket {
80                            size,
81                            received_data: Vec::new(),
82                        }
83                    } else {
84                        // Still waiting to receive the whole packet.
85                        self.state = State::AwaitingSize(size_bytes);
86                        break;
87                    }
88                }
89                State::AwaitingPacket {
90                    size,
91                    mut received_data,
92                } => {
93                    let remaining_bytes = (size as usize) - received_data.len();
94                    assert!(remaining_bytes > 0);
95
96                    let mut received_bytes = vec![0; remaining_bytes];
97                    let bytes_read = read.read(&mut received_bytes).await?;
98                    received_bytes.drain(bytes_read..);
99
100                    assert_eq!(received_bytes.len(), bytes_read);
101
102                    received_data.extend(received_bytes.into_iter());
103
104                    assert!(received_data.len() <= (size as usize));
105
106                    if (size as usize) == received_data.len() {
107                        self.packets.push_back(received_data);
108
109                        // Start reading the next packet.
110                        self.state = State::AwaitingSize(Vec::new());
111                    } else {
112                        // Keep reading the current packet.
113                        self.state = State::AwaitingPacket {
114                            size,
115                            received_data,
116                        };
117                        break;
118                    }
119                }
120            }
121        }
122
123        Ok(())
124    }
125}
126
127const BUFFER_SIZE: usize = 10000;
128
129#[async_trait]
130impl Transport for Simple {
131    async fn process_data<R: AsyncRead + Send + Unpin>(
132        &mut self,
133        read: &mut R,
134        settings: &Settings,
135    ) -> Result<(), Error> {
136        // Load the data into a temporary buffer before we process it.
137        loop {
138            let mut buffer = [0u8; BUFFER_SIZE];
139            let bytes_read = read.read(&mut buffer).await.unwrap();
140            let buffer = &buffer[0..bytes_read];
141
142            if bytes_read == 0 {
143                break;
144            } else {
145                self.process_bytes(buffer, settings).await?;
146
147                // We didn't fill the whole buffer so stop now.
148                if bytes_read != BUFFER_SIZE {
149                    break;
150                }
151            }
152        }
153
154        Ok(())
155    }
156
157    async fn send_raw_packet<W: AsyncWrite + Send + Unpin>(
158        &mut self,
159        write: &mut W,
160        packet: &[u8],
161        settings: &Settings,
162    ) -> Result<(), Error> {
163        let mut w = Cursor::new(Vec::<u8>::new());
164        // Prefix the packet size.
165        (packet.len() as PacketSize).write(&mut w, settings)?;
166        // Write the packet data.
167        w.write_all(&packet).await?;
168
169        write.write(&w.into_inner()).await?;
170
171        Ok(())
172    }
173
174    async fn receive_raw_packet(&mut self) -> Result<Option<Vec<u8>>, Error> {
175        Ok(self.packets.pop_front())
176    }
177}