1use crate::error::CodecError;
2use bytes::{Buf, BufMut, Bytes, BytesMut};
3use flate2::{read::GzDecoder, write::GzEncoder, Compression};
4use ppaass_crypto::crypto::{decrypt_with_aes, encrypt_with_aes, RsaCryptoFetcher};
5use ppaass_crypto::error::CryptoError;
6use ppaass_protocol::message::{PpaassPacket, PpaassPacketPayloadEncryption};
7use std::io::{Read, Write};
8use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
9use tracing::error;
10struct PpaassPacketEncoder<'a, T>
11where
12 T: RsaCryptoFetcher,
13{
14 rsa_crypto_fetcher: &'a T,
15 length_delimited_codec: LengthDelimitedCodec,
16}
17
18impl<'a, T> PpaassPacketEncoder<'a, T>
19where
20 T: RsaCryptoFetcher,
21{
22 pub fn new(rsa_crypto_fetcher: &'a T) -> Self {
23 Self {
24 rsa_crypto_fetcher,
25 length_delimited_codec: LengthDelimitedCodec::new(),
26 }
27 }
28}
29
30impl<'a, T> Encoder<PpaassPacket> for PpaassPacketEncoder<'a, T>
32where
33 T: RsaCryptoFetcher,
34{
35 type Error = CodecError;
36
37 fn encode(
38 &mut self,
39 original_packet: PpaassPacket,
40 dst: &mut BytesMut,
41 ) -> Result<(), Self::Error> {
42 let rsa_crypto = self
43 .rsa_crypto_fetcher
44 .fetch(original_packet.user_token())?
45 .ok_or(CryptoError::Other(format!(
46 "Crypto not exist for user: {}",
47 original_packet.user_token()
48 )))?;
49
50 let (encrypted_payload_bytes, encrypted_encryption) = match original_packet.encryption() {
51 PpaassPacketPayloadEncryption::Plain => (
52 original_packet.payload().to_vec(),
53 PpaassPacketPayloadEncryption::Plain,
54 ),
55 PpaassPacketPayloadEncryption::Aes(ref original_aes_token) => {
56 let rsa_encrypted_aes_token = Bytes::from(rsa_crypto.encrypt(original_aes_token)?);
57 let mut original_payload_bytes: BytesMut =
58 BytesMut::from(original_packet.payload());
59 let aes_encrypted_payload_bytes =
60 encrypt_with_aes(original_aes_token, &mut original_payload_bytes)?;
61 (
62 aes_encrypted_payload_bytes.to_vec(),
63 PpaassPacketPayloadEncryption::Aes(rsa_encrypted_aes_token),
64 )
65 }
66 };
67
68 let packet_to_send = PpaassPacket::new(
69 original_packet.packet_id().to_owned(),
70 original_packet.user_token().to_owned(),
71 encrypted_encryption,
72 encrypted_payload_bytes.into(),
73 );
74 let packet_bytes_to_send: Bytes = packet_to_send.try_into()?;
75 let gz_encoder_buf = BytesMut::new();
76 let mut gzip_encoder = GzEncoder::new(gz_encoder_buf.writer(), Compression::fast());
77 gzip_encoder.write_all(&packet_bytes_to_send)?;
78 let packet_bytes_to_send = gzip_encoder.finish()?.into_inner().freeze();
79 self.length_delimited_codec
80 .encode(packet_bytes_to_send, dst)?;
81 Ok(())
82 }
83}
84
85struct PpaassPacketDecoder<'a, T>
86where
87 T: RsaCryptoFetcher,
88{
89 rsa_crypto_fetcher: &'a T,
90 length_delimited_codec: LengthDelimitedCodec,
91}
92
93impl<'a, T> PpaassPacketDecoder<'a, T>
94where
95 T: RsaCryptoFetcher,
96{
97 pub fn new(rsa_crypto_fetcher: &'a T) -> Self {
98 Self {
99 rsa_crypto_fetcher,
100 length_delimited_codec: LengthDelimitedCodec::new(),
101 }
102 }
103}
104
105impl<'a, T> Decoder for PpaassPacketDecoder<'a, T>
107where
108 T: RsaCryptoFetcher,
109{
110 type Item = PpaassPacket;
111 type Error = CodecError;
112
113 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
114 let length_decode_result = self.length_delimited_codec.decode(src)?;
115 let decompressed_packet: PpaassPacket = match length_decode_result {
116 None => return Ok(None),
117 Some(packet_bytes) => {
118 let mut gzip_decoder = GzDecoder::new(packet_bytes.reader());
119 let mut decompressed_packet_bytes = Vec::new();
120 if let Err(e) = gzip_decoder.read_to_end(&mut decompressed_packet_bytes) {
121 error!("Fail to decompress incoming message bytes because of error: {e:?}");
122 return Err(CodecError::StdIo(e));
123 };
124 let decompressed_packet_bytes = Bytes::from_iter(decompressed_packet_bytes);
125 decompressed_packet_bytes.try_into()?
126 }
127 };
128 let decrypted_packed = match decompressed_packet.encryption() {
129 PpaassPacketPayloadEncryption::Plain => decompressed_packet,
130 PpaassPacketPayloadEncryption::Aes(rsa_encrypted_aes_token) => {
131 let rsa_crypto = self
132 .rsa_crypto_fetcher
133 .fetch(decompressed_packet.user_token())?
134 .ok_or(CryptoError::Other(format!(
135 "Crypto not exist for user: {}",
136 decompressed_packet.user_token()
137 )))?;
138 let decrypted_aes_token = Bytes::from(rsa_crypto.decrypt(rsa_encrypted_aes_token)?);
139 let mut decrypted_payload_bytes =
140 BytesMut::from_iter(decompressed_packet.payload());
141 let decrypted_payload =
142 decrypt_with_aes(&decrypted_aes_token, &mut decrypted_payload_bytes)?.freeze();
143 PpaassPacket::new(
144 decompressed_packet.packet_id().to_owned(),
145 decompressed_packet.user_token().to_owned(),
146 PpaassPacketPayloadEncryption::Aes(decrypted_aes_token),
147 decrypted_payload,
148 )
149 }
150 };
151 Ok(Some(decrypted_packed))
152 }
153}
154
155pub struct PpaassPacketCodec<'a, T>
156where
157 T: RsaCryptoFetcher,
158{
159 encoder: PpaassPacketEncoder<'a, T>,
160 decoder: PpaassPacketDecoder<'a, T>,
161}
162
163impl<'a, T> PpaassPacketCodec<'a, T>
164where
165 T: RsaCryptoFetcher,
166{
167 pub fn new(rsa_crypto_fetcher: &'a T) -> Self {
168 Self {
169 encoder: PpaassPacketEncoder::new(&rsa_crypto_fetcher),
170 decoder: PpaassPacketDecoder::new(&rsa_crypto_fetcher),
171 }
172 }
173}
174
175impl<'a, T> Encoder<PpaassPacket> for PpaassPacketCodec<'a, T>
176where
177 T: RsaCryptoFetcher,
178{
179 type Error = CodecError;
180 fn encode(&mut self, item: PpaassPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
181 self.encoder.encode(item, dst)
182 }
183}
184
185impl<'a, T> Decoder for PpaassPacketCodec<'a, T>
186where
187 T: RsaCryptoFetcher,
188{
189 type Item = PpaassPacket;
190 type Error = CodecError;
191 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
192 self.decoder.decode(src)
193 }
194}