1use std::{collections::VecDeque, net::SocketAddr};
2
3use bytes::{Buf, BytesMut};
4use tokio::{
5 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf, split},
6 net::{TcpListener, TcpStream},
7};
8
9use crate::{
10 ProtocolInformation, TpktConnection, TpktError, TpktReader, TpktWriter,
11 parser::{TpktParser, TpktParserResult},
12 serialiser::TpktSerialiser,
13};
14
15#[derive(Clone, Debug)]
17pub struct TcpTpktProtocolInformation {
18 pub remote_address: SocketAddr,
19}
20
21impl ProtocolInformation for TcpTpktProtocolInformation {}
22
23pub struct TcpTpktServer {
25 listener: TcpListener,
26}
27
28impl TcpTpktServer {
29 pub async fn listen(address: SocketAddr) -> Result<Self, TpktError> {
31 Ok(Self { listener: TcpListener::bind(address).await? })
32 }
33
34 pub async fn accept<'a>(&self) -> Result<TcpTpktConnection, TpktError> {
36 let (stream, remote_host) = self.listener.accept().await?;
37 let (reader, writer) = split(stream);
38 Ok(TcpTpktConnection::new(TcpTpktReader::new(reader), TcpTpktWriter::new(writer), Box::new(TcpTpktProtocolInformation { remote_address: remote_host })))
39 }
40}
41
42pub struct TcpTpktConnection {
44 reader: TcpTpktReader,
45 writer: TcpTpktWriter,
46 protocol_information_list: Vec<Box<dyn ProtocolInformation>>,
47}
48
49impl TcpTpktConnection {
50 pub async fn connect<'a>(address: SocketAddr) -> Result<TcpTpktConnection, TpktError> {
52 let stream = TcpStream::connect(address).await?;
53 let (reader, writer) = split(stream);
54 return Ok(TcpTpktConnection::new(TcpTpktReader::new(reader), TcpTpktWriter::new(writer), Box::new(TcpTpktProtocolInformation { remote_address: address })));
55 }
56
57 fn new(reader: TcpTpktReader, writer: TcpTpktWriter, protocol_information: Box<dyn ProtocolInformation>) -> Self {
58 TcpTpktConnection { reader, writer, protocol_information_list: vec![protocol_information] }
59 }
60}
61
62impl TpktConnection for TcpTpktConnection {
63 fn get_protocol_infomation_list(&self) -> &Vec<Box<dyn crate::ProtocolInformation>> {
64 &self.protocol_information_list
65 }
66
67 async fn split(self) -> Result<(impl TpktReader, impl TpktWriter), TpktError> {
68 Ok((self.reader, self.writer))
69 }
70}
71
72pub struct TcpTpktReader {
74 parser: TpktParser,
75 receive_buffer: BytesMut,
76 reader: ReadHalf<TcpStream>,
77}
78
79impl TcpTpktReader {
80 fn new(reader: ReadHalf<TcpStream>) -> Self {
81 Self { reader, parser: TpktParser::new(), receive_buffer: BytesMut::new() }
82 }
83}
84
85impl TpktReader for TcpTpktReader {
86 async fn recv(&mut self) -> Result<Option<Vec<u8>>, TpktError> {
87 loop {
88 match self.parser.parse(&mut self.receive_buffer) {
89 Ok(TpktParserResult::Data(x)) => return Ok(Some(x)),
90 Ok(TpktParserResult::InProgress) => (),
91 Err(x) => return Err(x),
92 };
93 if self.reader.read_buf(&mut self.receive_buffer).await? == 0 {
94 return Ok(None);
95 };
96 }
97 }
98}
99
100pub struct TcpTpktWriter {
102 write_buffer: BytesMut,
103 serialiser: TpktSerialiser,
104 writer: WriteHalf<TcpStream>,
105}
106
107impl TcpTpktWriter {
108 fn new(writer: WriteHalf<TcpStream>) -> Self {
109 Self { serialiser: TpktSerialiser::new(), writer, write_buffer: BytesMut::new() }
110 }
111}
112
113impl TpktWriter for TcpTpktWriter {
114 async fn send(&mut self, input: &mut VecDeque<Vec<u8>>) -> Result<(), TpktError> {
115 while let Some(packet) = input.pop_front() {
116 self.write_buffer.extend(self.serialiser.serialise(&packet)?);
117 }
118
119 while self.write_buffer.has_remaining() {
120 self.writer.write_buf(&mut self.write_buffer).await?;
121 }
122 Ok(())
123 }
124}