skrillax_stream/
stream.rs

1use bytes::Bytes;
2use futures::{SinkExt, Stream, StreamExt};
3use skrillax_codec::{SilkroadCodec, SilkroadFrame};
4use skrillax_packet::{
5    AsFrames, FramingError, FromFrames, IncomingPacket, OutgoingPacket, Packet, PacketError,
6    ReframingError, SecurityBytes, SecurityContext, TryFromPacket,
7};
8use skrillax_security::SilkroadEncryption;
9use std::io;
10use std::sync::Arc;
11use thiserror::Error;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
14use tokio::net::TcpStream;
15use tokio_util::codec::{FramedRead, FramedWrite};
16
17/// Errors for possible problems writing packets.
18///
19/// When writing packets to be sent over the wire a few issues can appear,
20/// which are represented by this error.
21#[derive(Debug, Error)]
22pub enum OutStreamError {
23    /// Some I/O related issue occurred. This generally means the underlying
24    /// transport layer was disconnected or otherwise impaired.
25    #[error("Some IO level error occurred")]
26    IoError(#[from] io::Error),
27    /// Something went wrong when trying to create frame(s) for the packet.
28    /// This currently can only happen if an encrypted frame is supposed to
29    /// be built, but no encryption has been configured.
30    #[error("Error occurred when trying to create frames")]
31    Framing(#[from] FramingError),
32}
33
34/// Errors encountered when reading packets.
35///
36/// Unlike [OutStreamError], there are many more possibilities for an error
37/// to occur here, due to accepting mostly untrusted input.
38#[derive(Debug, Error)]
39pub enum InStreamError {
40    /// Something went wrong on the I/O layer.
41    ///
42    /// When the underlying transport layer was disconnected or had other
43    /// issues while trying to read data, this error occurs.
44    #[error("Some IO level error occurred")]
45    IoError(#[from] io::Error),
46    #[error("Error occurred at the packet level")]
47    PacketError(#[from] PacketError),
48    #[error("Error when trying to turn frames into packets")]
49    ReframingError(#[from] ReframingError),
50    /// The end of the stream was reached, but we expected more data.
51    #[error("Reached the end of the stream")]
52    EndOfStream,
53    /// When trying to receive a specific packet or protocol, a different or
54    /// unknown packet was received.
55    #[error("Received unexpected opcode: {0:#06x}")]
56    UnmatchedOpcode(u16),
57}
58
59///
60pub trait InputProtocol {
61    type Proto: Send;
62
63    fn create_from(opcode: u16, data: &[u8]) -> Result<(usize, Self::Proto), InStreamError>;
64}
65
66impl<T: TryFromPacket + Packet + Send> InputProtocol for T {
67    type Proto = T;
68
69    fn create_from(opcode: u16, data: &[u8]) -> Result<(usize, T), InStreamError> {
70        if opcode != T::ID {
71            return Err(InStreamError::UnmatchedOpcode(opcode));
72        }
73
74        Ok(T::try_deserialize(data)?)
75    }
76}
77
78/// Extensions to [TcpStream] to convert it into a silkroad stream, sending
79/// and receiving silkroad packets.
80pub trait SilkroadTcpExt {
81    /// Creates a stream using the existing socket, wrapping it into a stream to
82    /// read and write [IncomingPacket] & [OutgoingPacket].
83    ///
84    /// ```
85    /// # use std::error::Error;
86    /// use skrillax_stream::stream::SilkroadTcpExt;
87    ///
88    /// # async fn test() -> Result<(), Box<dyn Error>> {
89    /// # use tokio::net::TcpStream;
90    /// let stream = TcpStream::connect("127.0.0.1:1337").await?;
91    /// let (reader, writer) = stream.into_silkroad_stream();
92    /// # Ok(())
93    /// # }
94    /// ```
95    fn into_silkroad_stream(
96        self,
97    ) -> (
98        SilkroadStreamRead<OwnedReadHalf>,
99        SilkroadStreamWrite<OwnedWriteHalf>,
100    );
101}
102
103impl SilkroadTcpExt for TcpStream {
104    fn into_silkroad_stream(
105        self,
106    ) -> (
107        SilkroadStreamRead<OwnedReadHalf>,
108        SilkroadStreamWrite<OwnedWriteHalf>,
109    ) {
110        let (read, write) = self.into_split();
111        let reader = FramedRead::new(read, SilkroadCodec);
112        let writer = FramedWrite::new(write, SilkroadCodec);
113
114        let stream_reader = SilkroadStreamRead::new(reader);
115        let stream_writer = SilkroadStreamWrite::new(writer);
116
117        (stream_reader, stream_writer)
118    }
119}
120
121/// The writing side of a Silkroad Online connection.
122///
123/// This is an analog to [OwnedWriteHalf], containing additional state to
124/// facilitate a Silkroad connection, such as encryption.
125pub struct SilkroadStreamWrite<T: AsyncWrite + Unpin> {
126    writer: FramedWrite<T, SilkroadCodec>,
127    encryption: Option<Arc<SilkroadEncryption>>,
128    security_bytes: Option<Arc<SecurityBytes>>,
129}
130
131impl<T: AsyncWrite + Unpin> SilkroadStreamWrite<T> {
132    fn new(writer: FramedWrite<T, SilkroadCodec>) -> Self {
133        Self {
134            writer,
135            encryption: None,
136            security_bytes: None,
137        }
138    }
139
140    #[allow(unused)]
141    fn with_encryption(
142        writer: FramedWrite<T, SilkroadCodec>,
143        encryption: Arc<SilkroadEncryption>,
144        security_bytes: Arc<SecurityBytes>,
145    ) -> Self {
146        Self {
147            writer,
148            encryption: Some(encryption),
149            security_bytes: Some(security_bytes),
150        }
151    }
152
153    pub fn enable_encryption(&mut self, encryption: Arc<SilkroadEncryption>) {
154        self.encryption = Some(encryption);
155    }
156
157    pub fn enable_security_checks(&mut self, security_bytes: Arc<SecurityBytes>) {
158        self.security_bytes = Some(security_bytes);
159    }
160
161    pub fn encryption(&self) -> Option<&SilkroadEncryption> {
162        self.encryption.as_deref()
163    }
164
165    pub fn security_bytes(&self) -> Option<&SecurityBytes> {
166        self.security_bytes.as_deref()
167    }
168
169    pub fn security_context(&self) -> SecurityContext {
170        SecurityContext::new(self.encryption(), self.security_bytes())
171    }
172
173    pub async fn write(&mut self, packet: OutgoingPacket) -> Result<(), OutStreamError> {
174        let frames = packet.as_frames(self.security_context())?;
175        for frame in frames {
176            self.writer.send(frame).await?;
177        }
178        Ok(())
179    }
180
181    pub async fn write_packet<S: Into<OutgoingPacket>>(
182        &mut self,
183        packet: S,
184    ) -> Result<(), OutStreamError> {
185        let outgoing_packet = packet.into();
186        self.write(outgoing_packet).await
187    }
188}
189
190/// The reading side of a Silkroad Online connection.
191///
192/// This is an analog to [OwnedReadHalf], containing additional state to
193/// facilitate a Silkroad connection, such as encryption.
194pub struct SilkroadStreamRead<T: AsyncRead + Unpin> {
195    reader: FramedRead<T, SilkroadCodec>,
196    encryption: Option<Arc<SilkroadEncryption>>,
197    security_bytes: Option<Arc<SecurityBytes>>,
198    unconsumed: Option<(u16, Bytes)>,
199}
200
201impl<T: AsyncRead + Unpin> SilkroadStreamRead<T>
202where
203    FramedRead<T, SilkroadCodec>: Stream<Item = Result<SilkroadFrame, io::Error>>,
204{
205    fn new(reader: FramedRead<T, SilkroadCodec>) -> Self {
206        Self {
207            reader,
208            encryption: None,
209            security_bytes: None,
210            unconsumed: None,
211        }
212    }
213
214    #[allow(unused)]
215    fn with_encryption(
216        reader: FramedRead<T, SilkroadCodec>,
217        encryption: Arc<SilkroadEncryption>,
218        security_bytes: Arc<SecurityBytes>,
219    ) -> Self {
220        Self {
221            reader,
222            encryption: Some(encryption),
223            security_bytes: Some(security_bytes),
224            unconsumed: None,
225        }
226    }
227
228    /// Enables encryption for this stream.
229    ///
230    /// Upon starting a connection, a stream will not be encrypted. Only after
231    /// the handshake finished will the encryption be present. This should
232    /// generally be set implicitly by the handshake protocol, but it is
233    /// possible to manually configure it.
234    ///
235    /// An [Arc] is expected here because it is assumed that the same encryption
236    /// will be set on the write half as well.
237    pub fn enable_encryption(&mut self, encryption: Arc<SilkroadEncryption>) {
238        self.encryption = Some(encryption);
239    }
240
241    /// Enables additional security checks for this stream.
242    ///
243    /// In addition to encryption, there are additional security checks
244    /// available on packets. In particular this is the counter and CRC
245    /// checksum.
246    ///
247    /// An [Arc] is expected here because it is assumed that the same encryption
248    /// will be set on the write half as well.
249    pub fn enable_security_checks(&mut self, security_bytes: Arc<SecurityBytes>) {
250        self.security_bytes = Some(security_bytes);
251    }
252
253    /// Provides the currently set encryption configuration, if present.
254    pub fn encryption(&self) -> Option<&SilkroadEncryption> {
255        self.encryption.as_deref()
256    }
257
258    /// Provides the currently set security data, if present.
259    pub fn security_bytes(&self) -> Option<&SecurityBytes> {
260        self.security_bytes.as_deref()
261    }
262
263    /// Provides the security context present for the reader.
264    ///
265    /// This will always return a new context with the
266    /// [SilkroadStreamRead::encryption] and
267    /// [SilkroadStreamRead::security_bytes] data inside. Essentially, this
268    /// is a convenience wrapper around those functions to provide
269    /// a single struct that can be passed around.
270    pub fn security_context(&self) -> SecurityContext {
271        SecurityContext::new(self.encryption(), self.security_bytes())
272    }
273
274    /// Read next packet and handle re-framing.
275    ///
276    /// [skrillax_codec] deals on single packets (i.e. frames) and some packets
277    /// may span multiple frames. It does not attempt to collect those
278    /// frames where appropriate and instead pushes the problem up the
279    /// abstraction chain. Thus, at the current abstraction level we're
280    /// performing this merging of frames into logical packets. Thus, it is
281    /// possible the resulting [IncomingPacket] is actually a massive packet
282    /// containing multiple operations inside it. At this point we can't
283    /// split that into the individual operations, because we don't know the
284    /// length of those operations.
285    ///
286    /// This should only be necessary if you're not interested in actual packet
287    /// data or work really generically. Otherwise,
288    /// [SilkroadStreamRead::next_packet] should be used instead.
289    pub async fn next(&mut self) -> Result<IncomingPacket, InStreamError> {
290        let mut buffer = Vec::new();
291        let mut remaining = 1usize;
292        while let Some(res) = self.reader.next().await {
293            let frame = res?;
294            buffer.push(frame);
295            remaining -= 1;
296            if remaining == 0 {
297                match IncomingPacket::from_frames(&buffer, self.security_context()) {
298                    Ok(packet) => return Ok(packet),
299                    Err(ReframingError::Incomplete(required)) => {
300                        remaining += required.unwrap_or(1);
301                    },
302                    Err(e) => return Err(InStreamError::ReframingError(e)),
303                }
304            }
305        }
306
307        Err(InStreamError::EndOfStream)
308    }
309
310    /// Tries to serialize the next incoming packet into the given protocol.
311    ///
312    /// This will poll the underlying transport layer to read a new packet
313    /// and will then try to serialize into a matching packet of the given
314    /// protocol. We expect that all packets are part of the given protocol,
315    /// otherwise it will be _discarded_ and [InStreamError::UnmatchedOpcode]
316    /// will be returned.
317    ///
318    /// Since [InputProtocol] is automatically derived for structs that have
319    /// [skrillax_packet::Packet] & [skrillax_serde::Deserialize], you can
320    /// both expect a single packet and a full protocol here.
321    pub async fn next_packet<S: InputProtocol>(&mut self) -> Result<S::Proto, InStreamError> {
322        let (opcode, mut buffer) = match self.unconsumed.take() {
323            Some(inner) => inner,
324            _ => self.next().await?.consume(),
325        };
326
327        let (consumed, p) = S::create_from(opcode, &buffer)?;
328        let _ = buffer.split_to(consumed);
329        if !buffer.is_empty() {
330            self.unconsumed = Some((opcode, buffer));
331        }
332
333        Ok(p)
334    }
335}
336
337#[cfg(test)]
338mod test {
339    use super::*;
340    use skrillax_serde::{ByteSize, Deserialize, Serialize};
341
342    #[derive(Packet, Deserialize, Serialize, ByteSize)]
343    #[packet(opcode = 0x0042)]
344    struct Empty;
345
346    #[tokio::test]
347    pub async fn test_read_packet_from_stream() {
348        let buffer: &[u8] = &[0x00, 0x00, 0x42, 0x00, 0x00, 0x00];
349        let mut reader = SilkroadStreamRead::new(FramedRead::new(buffer, SilkroadCodec));
350        let _ = reader
351            .next_packet::<Empty>()
352            .await
353            .expect("Should read empty packet.");
354    }
355
356    #[tokio::test]
357    pub async fn test_write_packet_to_stream() {
358        let mut buffer: Vec<u8> = Vec::new();
359        let mut writer = SilkroadStreamWrite::new(FramedWrite::new(&mut buffer, SilkroadCodec));
360        writer
361            .write_packet(Empty)
362            .await
363            .expect("Should write empty packet.");
364        drop(writer);
365        let content: &[u8] = &buffer;
366        assert_eq!(&[0x00u8, 0x00, 0x42, 0x00, 0x00, 0x00], content);
367    }
368}