ppaass_codec/
codec.rs

1use crate::error::CodecError;
2use bytes::{Buf, BufMut, Bytes, BytesMut};
3use flate2::{read::GzDecoder, write::GzEncoder, Compression};
4use ppaass_crypto::crypto::{decrypt_with_aes, encrypt_with_aes, RsaCryptoFetcher};
5use ppaass_crypto::error::CryptoError;
6use ppaass_protocol::message::{PpaassPacket, PpaassPacketPayloadEncryption};
7use std::io::{Read, Write};
8use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
9use tracing::error;
10struct PpaassPacketEncoder<'a, T>
11where
12    T: RsaCryptoFetcher,
13{
14    rsa_crypto_fetcher: &'a T,
15    length_delimited_codec: LengthDelimitedCodec,
16}
17
18impl<'a, T> PpaassPacketEncoder<'a, T>
19where
20    T: RsaCryptoFetcher,
21{
22    pub fn new(rsa_crypto_fetcher: &'a T) -> Self {
23        Self {
24            rsa_crypto_fetcher,
25            length_delimited_codec: LengthDelimitedCodec::new(),
26        }
27    }
28}
29
30/// Encode the ppaass message to bytes buffer
31impl<'a, T> Encoder<PpaassPacket> for PpaassPacketEncoder<'a, T>
32where
33    T: RsaCryptoFetcher,
34{
35    type Error = CodecError;
36
37    fn encode(
38        &mut self,
39        original_packet: PpaassPacket,
40        dst: &mut BytesMut,
41    ) -> Result<(), Self::Error> {
42        let rsa_crypto = self
43            .rsa_crypto_fetcher
44            .fetch(original_packet.user_token())?
45            .ok_or(CryptoError::Other(format!(
46                "Crypto not exist for user: {}",
47                original_packet.user_token()
48            )))?;
49
50        let (encrypted_payload_bytes, encrypted_encryption) = match original_packet.encryption() {
51            PpaassPacketPayloadEncryption::Plain => (
52                original_packet.payload().to_vec(),
53                PpaassPacketPayloadEncryption::Plain,
54            ),
55            PpaassPacketPayloadEncryption::Aes(ref original_aes_token) => {
56                let rsa_encrypted_aes_token = Bytes::from(rsa_crypto.encrypt(original_aes_token)?);
57                let mut original_payload_bytes: BytesMut =
58                    BytesMut::from(original_packet.payload());
59                let aes_encrypted_payload_bytes =
60                    encrypt_with_aes(original_aes_token, &mut original_payload_bytes)?;
61                (
62                    aes_encrypted_payload_bytes.to_vec(),
63                    PpaassPacketPayloadEncryption::Aes(rsa_encrypted_aes_token),
64                )
65            }
66        };
67
68        let packet_to_send = PpaassPacket::new(
69            original_packet.packet_id().to_owned(),
70            original_packet.user_token().to_owned(),
71            encrypted_encryption,
72            encrypted_payload_bytes.into(),
73        );
74        let packet_bytes_to_send: Bytes = packet_to_send.try_into()?;
75        let gz_encoder_buf = BytesMut::new();
76        let mut gzip_encoder = GzEncoder::new(gz_encoder_buf.writer(), Compression::fast());
77        gzip_encoder.write_all(&packet_bytes_to_send)?;
78        let packet_bytes_to_send = gzip_encoder.finish()?.into_inner().freeze();
79        self.length_delimited_codec
80            .encode(packet_bytes_to_send, dst)?;
81        Ok(())
82    }
83}
84
85struct PpaassPacketDecoder<'a, T>
86where
87    T: RsaCryptoFetcher,
88{
89    rsa_crypto_fetcher: &'a T,
90    length_delimited_codec: LengthDelimitedCodec,
91}
92
93impl<'a, T> PpaassPacketDecoder<'a, T>
94where
95    T: RsaCryptoFetcher,
96{
97    pub fn new(rsa_crypto_fetcher: &'a T) -> Self {
98        Self {
99            rsa_crypto_fetcher,
100            length_delimited_codec: LengthDelimitedCodec::new(),
101        }
102    }
103}
104
105/// Decode the input bytes buffer to ppaass message
106impl<'a, T> Decoder for PpaassPacketDecoder<'a, T>
107where
108    T: RsaCryptoFetcher,
109{
110    type Item = PpaassPacket;
111    type Error = CodecError;
112
113    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
114        let length_decode_result = self.length_delimited_codec.decode(src)?;
115        let decompressed_packet: PpaassPacket = match length_decode_result {
116            None => return Ok(None),
117            Some(packet_bytes) => {
118                let mut gzip_decoder = GzDecoder::new(packet_bytes.reader());
119                let mut decompressed_packet_bytes = Vec::new();
120                if let Err(e) = gzip_decoder.read_to_end(&mut decompressed_packet_bytes) {
121                    error!("Fail to decompress incoming message bytes because of error: {e:?}");
122                    return Err(CodecError::StdIo(e));
123                };
124                let decompressed_packet_bytes = Bytes::from_iter(decompressed_packet_bytes);
125                decompressed_packet_bytes.try_into()?
126            }
127        };
128        let decrypted_packed = match decompressed_packet.encryption() {
129            PpaassPacketPayloadEncryption::Plain => decompressed_packet,
130            PpaassPacketPayloadEncryption::Aes(rsa_encrypted_aes_token) => {
131                let rsa_crypto = self
132                    .rsa_crypto_fetcher
133                    .fetch(decompressed_packet.user_token())?
134                    .ok_or(CryptoError::Other(format!(
135                        "Crypto not exist for user: {}",
136                        decompressed_packet.user_token()
137                    )))?;
138                let decrypted_aes_token = Bytes::from(rsa_crypto.decrypt(rsa_encrypted_aes_token)?);
139                let mut decrypted_payload_bytes =
140                    BytesMut::from_iter(decompressed_packet.payload());
141                let decrypted_payload =
142                    decrypt_with_aes(&decrypted_aes_token, &mut decrypted_payload_bytes)?.freeze();
143                PpaassPacket::new(
144                    decompressed_packet.packet_id().to_owned(),
145                    decompressed_packet.user_token().to_owned(),
146                    PpaassPacketPayloadEncryption::Aes(decrypted_aes_token),
147                    decrypted_payload,
148                )
149            }
150        };
151        Ok(Some(decrypted_packed))
152    }
153}
154
155pub struct PpaassPacketCodec<'a, T>
156where
157    T: RsaCryptoFetcher,
158{
159    encoder: PpaassPacketEncoder<'a, T>,
160    decoder: PpaassPacketDecoder<'a, T>,
161}
162
163impl<'a, T> PpaassPacketCodec<'a, T>
164where
165    T: RsaCryptoFetcher,
166{
167    pub fn new(rsa_crypto_fetcher: &'a T) -> Self {
168        Self {
169            encoder: PpaassPacketEncoder::new(&rsa_crypto_fetcher),
170            decoder: PpaassPacketDecoder::new(&rsa_crypto_fetcher),
171        }
172    }
173}
174
175impl<'a, T> Encoder<PpaassPacket> for PpaassPacketCodec<'a, T>
176where
177    T: RsaCryptoFetcher,
178{
179    type Error = CodecError;
180    fn encode(&mut self, item: PpaassPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
181        self.encoder.encode(item, dst)
182    }
183}
184
185impl<'a, T> Decoder for PpaassPacketCodec<'a, T>
186where
187    T: RsaCryptoFetcher,
188{
189    type Item = PpaassPacket;
190    type Error = CodecError;
191    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
192        self.decoder.decode(src)
193    }
194}