ironrdp_pdu/rdp/server_license/server_license_request/
cert.rs

1use ironrdp_core::{
2    cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, read_padding, write_padding, Decode,
3    DecodeResult, Encode, EncodeResult, ReadCursor, WriteCursor,
4};
5
6use super::{BlobHeader, BlobType, KEY_EXCHANGE_ALGORITHM_RSA};
7
8pub const SIGNATURE_ALGORITHM_RSA: u32 = 1;
9pub const PROP_CERT_NO_BLOBS_SIZE: usize = 8;
10pub const PROP_CERT_BLOBS_HEADERS_SIZE: usize = 8;
11pub const X509_CERT_LENGTH_FIELD_SIZE: usize = 4;
12pub const X509_CERT_COUNT: usize = 4;
13pub const RSA_KEY_PADDING_LENGTH: u32 = 8;
14pub const RSA_SENTINEL: u32 = 0x3141_5352;
15pub const RSA_KEY_SIZE_WITHOUT_MODULUS: usize = 20;
16
17const MIN_CERTIFICATE_AMOUNT: usize = 2;
18const MAX_CERTIFICATE_AMOUNT: usize = 200;
19const MAX_CERTIFICATE_LEN: usize = 4096;
20
21#[derive(Debug, PartialEq, Eq)]
22pub enum CertificateType {
23    Proprietary(ProprietaryCertificate),
24    X509(X509CertificateChain),
25}
26
27/// [2.2.1.4.2] X.509 Certificate Chain (X509 _CERTIFICATE_CHAIN)
28///
29/// [2.2.1.4.2]: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpele/bf2cc9cc-2b01-442e-a288-6ddfa3b80d59
30#[derive(Debug, PartialEq, Eq)]
31pub struct X509CertificateChain {
32    pub certificate_array: Vec<Vec<u8>>,
33}
34
35impl X509CertificateChain {
36    const NAME: &'static str = "X509CertificateChain";
37}
38
39impl Encode for X509CertificateChain {
40    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
41        ensure_size!(in: dst, size: self.size());
42
43        dst.write_u32(cast_length!("certArrayLen", self.certificate_array.len())?);
44
45        for certificate in &self.certificate_array {
46            dst.write_u32(cast_length!("certLen", certificate.len())?);
47            dst.write_slice(certificate);
48        }
49
50        let padding_len = 8 + 4 * self.certificate_array.len(); // MSDN: A byte array of the length 8 + 4*NumCertBlobs
51        write_padding!(dst, padding_len);
52
53        Ok(())
54    }
55
56    fn name(&self) -> &'static str {
57        Self::NAME
58    }
59
60    fn size(&self) -> usize {
61        let certificates_length: usize = self
62            .certificate_array
63            .iter()
64            .map(|certificate| certificate.len() + X509_CERT_LENGTH_FIELD_SIZE)
65            .sum();
66        let padding: usize = 8 + 4 * self.certificate_array.len();
67        X509_CERT_COUNT + certificates_length + padding
68    }
69}
70
71impl<'de> Decode<'de> for X509CertificateChain {
72    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
73        ensure_size!(in: src, size: 4);
74        let certificate_count = cast_length!("certArrayLen", src.read_u32())?;
75        if !(MIN_CERTIFICATE_AMOUNT..MAX_CERTIFICATE_AMOUNT).contains(&certificate_count) {
76            return Err(invalid_field_err!("certArrayLen", "invalid x509 certificate amount"));
77        }
78
79        let certificate_array: Vec<_> = (0..certificate_count)
80            .map(|_| {
81                ensure_size!(in: src, size: 4);
82                let certificate_len = cast_length!("certLen", src.read_u32())?;
83                if certificate_len > MAX_CERTIFICATE_LEN {
84                    return Err(invalid_field_err!("certLen", "invalid x509 certificate length"));
85                }
86
87                ensure_size!(in: src, size: certificate_len);
88                let certificate = src.read_slice(certificate_len).into();
89
90                Ok(certificate)
91            })
92            .collect::<Result<_, _>>()?;
93
94        let padding = 8 + 4 * certificate_count; // MSDN: A byte array of the length 8 + 4*NumCertBlobs
95        ensure_size!(in: src, size: padding);
96        read_padding!(src, padding);
97
98        Ok(Self { certificate_array })
99    }
100}
101
102/// [2.2.1.4.3.1.1] Server Proprietary Certificate (PROPRIETARYSERVERCERTIFICATE)
103///
104/// [2.2.1.4.3.1.1]: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/a37d449a-73ac-4f00-9b9d-56cefc954634
105#[derive(Debug, PartialEq, Eq)]
106pub struct ProprietaryCertificate {
107    pub public_key: RsaPublicKey,
108    pub signature: Vec<u8>,
109}
110
111impl ProprietaryCertificate {
112    const NAME: &'static str = "ProprietaryCertificate";
113
114    const FIXED_PART_SIZE: usize = PROP_CERT_BLOBS_HEADERS_SIZE + PROP_CERT_NO_BLOBS_SIZE;
115}
116
117impl Encode for ProprietaryCertificate {
118    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
119        ensure_size!(in: dst, size: self.size());
120
121        dst.write_u32(SIGNATURE_ALGORITHM_RSA);
122        dst.write_u32(KEY_EXCHANGE_ALGORITHM_RSA);
123
124        BlobHeader::new(BlobType::RSA_KEY, self.public_key.size()).encode(dst)?;
125        self.public_key.encode(dst)?;
126
127        BlobHeader::new(BlobType::RSA_SIGNATURE, self.signature.len()).encode(dst)?;
128        dst.write_slice(&self.signature);
129
130        Ok(())
131    }
132
133    fn name(&self) -> &'static str {
134        Self::NAME
135    }
136
137    fn size(&self) -> usize {
138        Self::FIXED_PART_SIZE + self.public_key.size() + self.signature.len()
139    }
140}
141
142impl<'de> Decode<'de> for ProprietaryCertificate {
143    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
144        ensure_size!(in: src, size: PROP_CERT_NO_BLOBS_SIZE);
145
146        let signature_algorithm_id = src.read_u32();
147        if signature_algorithm_id != SIGNATURE_ALGORITHM_RSA {
148            return Err(invalid_field_err!("sigAlgId", "invalid signature algorithm ID"));
149        }
150
151        let key_algorithm_id = src.read_u32();
152        if key_algorithm_id != KEY_EXCHANGE_ALGORITHM_RSA {
153            return Err(invalid_field_err!("keyAlgId", "invalid key algorithm ID"));
154        }
155
156        let key_blob_header = BlobHeader::decode(src)?;
157        if key_blob_header.blob_type != BlobType::RSA_KEY {
158            return Err(invalid_field_err!("blobType", "invalid blob type"));
159        }
160        let public_key = RsaPublicKey::decode(src)?;
161
162        let sig_blob_header = BlobHeader::decode(src)?;
163        if sig_blob_header.blob_type != BlobType::RSA_SIGNATURE {
164            return Err(invalid_field_err!("blobType", "invalid blob type"));
165        }
166        ensure_size!(in: src, size: sig_blob_header.length);
167        let signature = src.read_slice(sig_blob_header.length).into();
168
169        Ok(Self { public_key, signature })
170    }
171}
172
173#[derive(PartialEq, Eq, Debug, Clone)]
174pub struct RsaPublicKey {
175    pub public_exponent: u32,
176    pub modulus: Vec<u8>,
177}
178
179impl RsaPublicKey {
180    const NAME: &'static str = "RsaPublicKey";
181
182    const FIXED_PART_SIZE: usize = RSA_KEY_SIZE_WITHOUT_MODULUS;
183}
184
185impl Encode for RsaPublicKey {
186    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
187        ensure_size!(in: dst, size: self.size());
188
189        let keylen = cast_length!("modulusLen", self.modulus.len())?;
190        let bitlen = (keylen - RSA_KEY_PADDING_LENGTH) * 8;
191        let datalen = bitlen / 8 - 1;
192
193        dst.write_u32(RSA_SENTINEL); // magic
194        dst.write_u32(keylen);
195        dst.write_u32(bitlen);
196        dst.write_u32(datalen);
197        dst.write_u32(self.public_exponent);
198        dst.write_slice(&self.modulus);
199
200        Ok(())
201    }
202
203    fn name(&self) -> &'static str {
204        Self::NAME
205    }
206
207    fn size(&self) -> usize {
208        Self::FIXED_PART_SIZE + self.modulus.len()
209    }
210}
211
212impl<'de> Decode<'de> for RsaPublicKey {
213    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
214        ensure_fixed_part_size!(in: src);
215
216        let magic = src.read_u32();
217        if magic != RSA_SENTINEL {
218            return Err(invalid_field_err!("magic", "invalid RSA public key magic"));
219        }
220
221        let keylen = cast_length!("keyLen", src.read_u32())?;
222
223        let bitlen: usize = cast_length!("bitlen", src.read_u32())?;
224        if keylen != (bitlen / 8) + 8 {
225            return Err(invalid_field_err!("bitlen", "invalid RSA public key length"));
226        }
227
228        if bitlen < 8 {
229            return Err(invalid_field_err!("bitlen", "invalid RSA public key length"));
230        }
231
232        let datalen: usize = cast_length!("dataLen", src.read_u32())?;
233        if datalen != (bitlen / 8) - 1 {
234            return Err(invalid_field_err!("dataLen", "invalid RSA public key data length"));
235        }
236
237        let public_exponent = src.read_u32();
238
239        ensure_size!(in: src, size: keylen);
240        let modulus = src.read_slice(keylen).into();
241
242        Ok(Self {
243            public_exponent,
244            modulus,
245        })
246    }
247}