ironrdp_pdu/rdp/server_license/
server_license_request.rs

1pub mod cert;
2
3#[cfg(test)]
4mod tests;
5
6use cert::{CertificateType, ProprietaryCertificate, X509CertificateChain};
7use ironrdp_core::{
8    cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, Decode, DecodeResult, Encode, EncodeResult,
9    ReadCursor, WriteCursor,
10};
11
12use super::{
13    BlobHeader, BlobType, LicenseHeader, PreambleType, ServerLicenseError, BLOB_LENGTH_SIZE, BLOB_TYPE_SIZE,
14    KEY_EXCHANGE_ALGORITHM_RSA, RANDOM_NUMBER_SIZE, UTF16_NULL_TERMINATOR_SIZE, UTF8_NULL_TERMINATOR_SIZE,
15};
16use crate::utils;
17
18const CERT_VERSION_FIELD_SIZE: usize = 4;
19const KEY_EXCHANGE_FIELD_SIZE: usize = 4;
20const SCOPE_ARRAY_SIZE_FIELD_SIZE: usize = 4;
21const PRODUCT_INFO_STATIC_FIELDS_SIZE: usize = 12;
22const CERT_CHAIN_VERSION_MASK: u32 = 0x7FFF_FFFF;
23const CERT_CHAIN_ISSUED_MASK: u32 = 0x8000_0000;
24const MAX_SCOPE_COUNT: u32 = 256;
25const MAX_COMPANY_NAME_LEN: usize = 1024;
26const MAX_PRODUCT_ID_LEN: usize = 1024;
27
28const RSA_EXCHANGE_ALGORITHM: u32 = 1;
29
30/// [2.2.2.1] Server License Request (SERVER_LICENSE_REQUEST)
31///
32/// [2.2.2.1]: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpele/e17772e9-9642-4bb6-a2bc-82875dd6da7c
33#[derive(Debug, PartialEq, Eq)]
34pub struct ServerLicenseRequest {
35    pub license_header: LicenseHeader,
36    pub server_random: Vec<u8>,
37    pub product_info: ProductInfo,
38    pub server_certificate: Option<ServerCertificate>,
39    pub scope_list: Vec<Scope>,
40}
41
42impl ServerLicenseRequest {
43    const NAME: &'static str = "ServerLicenseRequest";
44
45    pub fn get_public_key(&self) -> Result<Option<Vec<u8>>, ServerLicenseError> {
46        self.server_certificate.as_ref().map(|c| c.get_public_key()).transpose()
47    }
48}
49
50impl ServerLicenseRequest {
51    pub fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
52        ensure_size!(in: dst, size: self.size());
53
54        self.license_header.encode(dst)?;
55
56        dst.write_slice(&self.server_random);
57        self.product_info.encode(dst)?;
58
59        BlobHeader::new(BlobType::KEY_EXCHANGE_ALGORITHM, KEY_EXCHANGE_FIELD_SIZE).encode(dst)?;
60        dst.write_u32(KEY_EXCHANGE_ALGORITHM_RSA);
61
62        let cert_size = self.server_certificate.as_ref().map(|v| v.size()).unwrap_or(0);
63        BlobHeader::new(BlobType::CERTIFICATE, cert_size).encode(dst)?;
64
65        if let Some(cert) = &self.server_certificate {
66            cert.encode(dst)?;
67        }
68
69        dst.write_u32(cast_length!("listLen", self.scope_list.len())?);
70
71        for scope in self.scope_list.iter() {
72            scope.encode(dst)?;
73        }
74
75        Ok(())
76    }
77
78    pub fn name(&self) -> &'static str {
79        Self::NAME
80    }
81
82    pub fn size(&self) -> usize {
83        self.license_header.size()
84            + RANDOM_NUMBER_SIZE
85            + self.product_info.size()
86            + BLOB_LENGTH_SIZE * 2 // KeyExchangeBlob + CertificateBlob
87            + BLOB_TYPE_SIZE * 2 // KeyExchangeBlob + CertificateBlob
88            + KEY_EXCHANGE_FIELD_SIZE
89            + self.server_certificate.as_ref().map(|c| c.size()).unwrap_or(0)
90            + SCOPE_ARRAY_SIZE_FIELD_SIZE
91            + self.scope_list.iter().map(|s| s.size()).sum::<usize>()
92    }
93}
94
95impl ServerLicenseRequest {
96    pub fn decode(license_header: LicenseHeader, src: &mut ReadCursor<'_>) -> DecodeResult<Self> {
97        if license_header.preamble_message_type != PreambleType::LicenseRequest {
98            return Err(invalid_field_err!("preambleMessageType", "unexpected preamble type"));
99        }
100
101        ensure_size!(in: src, size: RANDOM_NUMBER_SIZE);
102        let server_random = src.read_slice(RANDOM_NUMBER_SIZE).into();
103
104        let product_info = ProductInfo::decode(src)?;
105
106        let key_exchange_algorithm_blob = BlobHeader::decode(src)?;
107        if key_exchange_algorithm_blob.blob_type != BlobType::KEY_EXCHANGE_ALGORITHM {
108            return Err(invalid_field_err!("blobType", "invalid blob type"));
109        }
110
111        ensure_size!(in: src, size: 4);
112        let key_exchange_algorithm = src.read_u32();
113        if key_exchange_algorithm != RSA_EXCHANGE_ALGORITHM {
114            return Err(invalid_field_err!("keyAlgo", "invalid key exchange algorithm"));
115        }
116
117        let cert_blob = BlobHeader::decode(src)?;
118        if cert_blob.blob_type != BlobType::CERTIFICATE {
119            return Err(invalid_field_err!("blobType", "invalid blob type"));
120        }
121
122        // The terminal server can choose not to send the certificate by setting the wblobLen field in the Licensing Binary BLOB structure to 0
123        let server_certificate = if cert_blob.length != 0 {
124            Some(ServerCertificate::decode(src)?)
125        } else {
126            None
127        };
128
129        ensure_size!(in: src, size: 4);
130        let scope_count = src.read_u32();
131        if scope_count > MAX_SCOPE_COUNT {
132            return Err(invalid_field_err!("scopeCount", "invalid scope count"));
133        }
134
135        let mut scope_list = Vec::with_capacity(scope_count as usize);
136
137        for _ in 0..scope_count {
138            scope_list.push(Scope::decode(src)?);
139        }
140
141        Ok(Self {
142            license_header,
143            server_random,
144            product_info,
145            server_certificate,
146            scope_list,
147        })
148    }
149}
150
151#[derive(Debug, PartialEq, Eq)]
152pub struct Scope(pub String);
153
154impl Scope {
155    const NAME: &'static str = "Scope";
156
157    const FIXED_PART_SIZE: usize = BLOB_TYPE_SIZE + BLOB_LENGTH_SIZE;
158}
159
160impl Encode for Scope {
161    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
162        ensure_size!(in: dst, size: self.size());
163
164        let data_size = self.0.len() + UTF8_NULL_TERMINATOR_SIZE;
165        BlobHeader::new(BlobType::SCOPE, data_size).encode(dst)?;
166        dst.write_slice(self.0.as_bytes());
167        dst.write_u8(0); // null terminator
168
169        Ok(())
170    }
171
172    fn name(&self) -> &'static str {
173        Self::NAME
174    }
175
176    fn size(&self) -> usize {
177        Self::FIXED_PART_SIZE + self.0.len() + UTF8_NULL_TERMINATOR_SIZE
178    }
179}
180
181impl<'de> Decode<'de> for Scope {
182    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
183        let blob_header = BlobHeader::decode(src)?;
184        if blob_header.blob_type != BlobType::SCOPE {
185            return Err(invalid_field_err!("blobType", "invalid blob type"));
186        }
187        if blob_header.length < UTF8_NULL_TERMINATOR_SIZE {
188            return Err(invalid_field_err!("blobLen", "blob too small"));
189        }
190        ensure_size!(in: src, size: blob_header.length);
191        let mut blob_data = src.read_slice(blob_header.length).to_vec();
192        blob_data.resize(blob_data.len() - UTF8_NULL_TERMINATOR_SIZE, 0);
193
194        if let Ok(data) = core::str::from_utf8(&blob_data) {
195            Ok(Self(String::from(data)))
196        } else {
197            Err(invalid_field_err!("scope", "scope is not utf8"))
198        }
199    }
200}
201
202/// [2.2.1.4.3.1] Server Certificate (SERVER_CERTIFICATE)
203///
204/// [2.2.1.4.3.1]: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/54e72cc6-3422-404c-a6b4-2486db125342
205#[derive(Debug, PartialEq, Eq)]
206pub struct ServerCertificate {
207    pub issued_permanently: bool,
208    pub certificate: CertificateType,
209}
210
211impl ServerCertificate {
212    const NAME: &'static str = "ServerCertificate";
213
214    const FIXED_PART_SIZE: usize = CERT_VERSION_FIELD_SIZE;
215
216    pub fn get_public_key(&self) -> Result<Vec<u8>, ServerLicenseError> {
217        use x509_cert::der::Decode as _;
218
219        match &self.certificate {
220            CertificateType::Proprietary(certificate) => {
221                let public_exponent = certificate.public_key.public_exponent.to_le_bytes();
222
223                let rsa_public_key = pkcs1::RsaPublicKey {
224                    modulus: pkcs1::UintRef::new(&certificate.public_key.modulus).unwrap(),
225                    public_exponent: pkcs1::UintRef::new(&public_exponent).unwrap(),
226                };
227
228                let public_key = pkcs1::der::Encode::to_der(&rsa_public_key).unwrap();
229
230                Ok(public_key)
231            }
232            CertificateType::X509(certificate) => {
233                let cert_der = certificate
234                    .certificate_array
235                    .last()
236                    .ok_or_else(|| ServerLicenseError::InvalidX509CertificatesAmount)?;
237
238                let cert = x509_cert::Certificate::from_der(cert_der).map_err(|source| {
239                    ServerLicenseError::InvalidX509Certificate {
240                        source,
241                        cert_der: cert_der.clone(),
242                    }
243                })?;
244
245                let public_key = cert
246                    .tbs_certificate
247                    .subject_public_key_info
248                    .subject_public_key
249                    .raw_bytes()
250                    .to_owned();
251
252                Ok(public_key)
253            }
254        }
255    }
256}
257
258impl Encode for ServerCertificate {
259    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
260        ensure_size!(in: dst, size: self.size());
261
262        let cert_version: u32 = match self.certificate {
263            CertificateType::Proprietary(_) => 1,
264            CertificateType::X509(_) => 2,
265        };
266        let mask = if self.issued_permanently {
267            CERT_CHAIN_ISSUED_MASK
268        } else {
269            0
270        };
271
272        dst.write_u32(cert_version | mask);
273
274        match &self.certificate {
275            CertificateType::Proprietary(cert) => cert.encode(dst)?,
276            CertificateType::X509(cert) => cert.encode(dst)?,
277        }
278
279        Ok(())
280    }
281
282    fn name(&self) -> &'static str {
283        Self::NAME
284    }
285
286    fn size(&self) -> usize {
287        let certificate_size = match &self.certificate {
288            CertificateType::Proprietary(cert) => cert.size(),
289            CertificateType::X509(cert) => cert.size(),
290        };
291
292        Self::FIXED_PART_SIZE + certificate_size
293    }
294}
295
296impl<'de> Decode<'de> for ServerCertificate {
297    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
298        ensure_fixed_part_size!(in: src);
299
300        let cert_version = src.read_u32();
301
302        let issued_permanently = cert_version & CERT_CHAIN_ISSUED_MASK == CERT_CHAIN_ISSUED_MASK;
303
304        let certificate = match cert_version & CERT_CHAIN_VERSION_MASK {
305            1 => CertificateType::Proprietary(ProprietaryCertificate::decode(src)?),
306            2 => CertificateType::X509(X509CertificateChain::decode(src)?),
307            _ => return Err(invalid_field_err!("certVersion", "invalid certificate version")),
308        };
309
310        Ok(Self {
311            issued_permanently,
312            certificate,
313        })
314    }
315}
316
317#[derive(Debug, PartialEq, Eq)]
318pub struct ProductInfo {
319    pub version: u32,
320    pub company_name: String,
321    pub product_id: String,
322}
323
324impl ProductInfo {
325    const NAME: &'static str = "ProductInfo";
326
327    const FIXED_PART_SIZE: usize = PRODUCT_INFO_STATIC_FIELDS_SIZE;
328}
329
330impl Encode for ProductInfo {
331    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
332        ensure_size!(in: dst, size: self.size());
333
334        dst.write_u32(self.version);
335
336        let mut company_name = utils::to_utf16_bytes(&self.company_name);
337        company_name.resize(company_name.len() + 2, 0);
338
339        dst.write_u32(cast_length!("companyLen", company_name.len())?);
340        dst.write_slice(&company_name);
341
342        let mut product_id = utils::to_utf16_bytes(&self.product_id);
343        product_id.resize(product_id.len() + 2, 0);
344
345        dst.write_u32(cast_length!("produceLen", product_id.len())?);
346        dst.write_slice(&product_id);
347
348        Ok(())
349    }
350
351    fn name(&self) -> &'static str {
352        Self::NAME
353    }
354
355    fn size(&self) -> usize {
356        let company_name_utf_16 = utils::to_utf16_bytes(&self.company_name);
357        let product_id_utf_16 = utils::to_utf16_bytes(&self.product_id);
358
359        Self::FIXED_PART_SIZE
360            + company_name_utf_16.len()
361            + UTF16_NULL_TERMINATOR_SIZE
362            + product_id_utf_16.len()
363            + UTF16_NULL_TERMINATOR_SIZE
364    }
365}
366
367impl<'de> Decode<'de> for ProductInfo {
368    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
369        ensure_fixed_part_size!(in: src);
370
371        let version = src.read_u32();
372
373        let company_name_len = cast_length!("companyLen", src.read_u32())?;
374        if !(2..=MAX_COMPANY_NAME_LEN).contains(&company_name_len) {
375            return Err(invalid_field_err!("companyLen", "invalid company name length"));
376        }
377
378        ensure_size!(in: src, size: company_name_len);
379        let mut company_name = src.read_slice(company_name_len).to_vec();
380        company_name.resize(company_name_len - 2, 0);
381        let company_name = utils::from_utf16_bytes(company_name.as_slice());
382
383        ensure_size!(in: src, size: 4);
384        let product_id_len = cast_length!("productIdLen", src.read_u32())?;
385        if !(2..=MAX_PRODUCT_ID_LEN).contains(&product_id_len) {
386            return Err(invalid_field_err!("productIdLen", "invalid produce ID length"));
387        }
388
389        ensure_size!(in: src, size: product_id_len);
390        let mut product_id = src.read_slice(product_id_len).to_vec();
391        product_id.resize(product_id_len - 2, 0);
392        let product_id = utils::from_utf16_bytes(product_id.as_slice());
393
394        Ok(Self {
395            version,
396            company_name,
397            product_id,
398        })
399    }
400}