Skip to main content

rusty_tpkt/
service.rs

1use std::{collections::VecDeque, net::SocketAddr};
2
3use bytes::{Buf, BytesMut};
4use tokio::{
5    io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf, split},
6    net::{TcpListener, TcpStream},
7};
8
9use crate::{
10    ProtocolInformation, TpktConnection, TpktError, TpktReader, TpktWriter,
11    parser::{TpktParser, TpktParserResult},
12    serialiser::TpktSerialiser,
13};
14
15/// Keeps track of tpkt connection information
16#[derive(Clone, Debug)]
17pub struct TcpTpktProtocolInformation {
18    pub remote_address: SocketAddr,
19}
20
21impl ProtocolInformation for TcpTpktProtocolInformation {}
22
23/// A TPKT server implemented over a TCP connection.
24pub struct TcpTpktServer {
25    listener: TcpListener,
26}
27
28impl TcpTpktServer {
29    /// Start listening on the provided TCP port.
30    pub async fn listen(address: SocketAddr) -> Result<Self, TpktError> {
31        Ok(Self { listener: TcpListener::bind(address).await? })
32    }
33
34    /// Accept an incoming connection. This may be called multiple times.
35    pub async fn accept<'a>(&self) -> Result<TcpTpktConnection, TpktError> {
36        let (stream, remote_host) = self.listener.accept().await?;
37        let (reader, writer) = split(stream);
38        Ok(TcpTpktConnection::new(TcpTpktReader::new(reader), TcpTpktWriter::new(writer), Box::new(TcpTpktProtocolInformation { remote_address: remote_host })))
39    }
40}
41
42/// An established TPKT connection.
43pub struct TcpTpktConnection {
44    reader: TcpTpktReader,
45    writer: TcpTpktWriter,
46    protocol_information_list: Vec<Box<dyn ProtocolInformation>>,
47}
48
49impl TcpTpktConnection {
50    /// Initiates a client TPKT connection.
51    pub async fn connect<'a>(address: SocketAddr) -> Result<TcpTpktConnection, TpktError> {
52        let stream = TcpStream::connect(address).await?;
53        let (reader, writer) = split(stream);
54        return Ok(TcpTpktConnection::new(TcpTpktReader::new(reader), TcpTpktWriter::new(writer), Box::new(TcpTpktProtocolInformation { remote_address: address })));
55    }
56
57    fn new(reader: TcpTpktReader, writer: TcpTpktWriter, protocol_information: Box<dyn ProtocolInformation>) -> Self {
58        TcpTpktConnection { reader, writer, protocol_information_list: vec![protocol_information] }
59    }
60}
61
62impl TpktConnection for TcpTpktConnection {
63    fn get_protocol_infomation_list(&self) -> &Vec<Box<dyn crate::ProtocolInformation>> {
64        &self.protocol_information_list
65    }
66
67    async fn split(self) -> Result<(impl TpktReader, impl TpktWriter), TpktError> {
68        Ok((self.reader, self.writer))
69    }
70}
71
72/// The read half of a TPKT connection.
73pub struct TcpTpktReader {
74    parser: TpktParser,
75    receive_buffer: BytesMut,
76    reader: ReadHalf<TcpStream>,
77}
78
79impl TcpTpktReader {
80    fn new(reader: ReadHalf<TcpStream>) -> Self {
81        Self { reader, parser: TpktParser::new(), receive_buffer: BytesMut::new() }
82    }
83}
84
85impl TpktReader for TcpTpktReader {
86    async fn recv(&mut self) -> Result<Option<Vec<u8>>, TpktError> {
87        loop {
88            let buffer = &mut self.receive_buffer;
89            match self.parser.parse(buffer) {
90                Ok(TpktParserResult::Data(x)) => return Ok(Some(x)),
91                Ok(TpktParserResult::InProgress) => (),
92                Err(x) => return Err(x),
93            };
94            if self.reader.read_buf(buffer).await? == 0 {
95                return Ok(None);
96            };
97        }
98    }
99}
100
101/// The write half of a TPKT connection.
102pub struct TcpTpktWriter {
103    write_buffer: BytesMut,
104    serialiser: TpktSerialiser,
105    writer: WriteHalf<TcpStream>,
106}
107
108impl TcpTpktWriter {
109    fn new(writer: WriteHalf<TcpStream>) -> Self {
110        Self { serialiser: TpktSerialiser::new(), writer, write_buffer: BytesMut::new() }
111    }
112}
113
114impl TpktWriter for TcpTpktWriter {
115    async fn send(&mut self, input: &mut VecDeque<Vec<u8>>) -> Result<(), TpktError> {
116        while let Some(packet) = input.pop_front() {
117            self.write_buffer.extend(self.serialiser.serialise(&packet)?);
118        }
119
120        while self.write_buffer.has_remaining() {
121            self.writer.write_buf(&mut self.write_buffer).await?;
122        }
123        Ok(())
124    }
125}