ironrdp_pdu/rdp/server_license/
server_license_request.rs1pub mod cert;
2
3#[cfg(test)]
4mod tests;
5
6use cert::{CertificateType, ProprietaryCertificate, X509CertificateChain};
7use ironrdp_core::{
8 cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, Decode, DecodeResult, Encode, EncodeResult,
9 ReadCursor, WriteCursor,
10};
11
12use super::{
13 BlobHeader, BlobType, LicenseHeader, PreambleType, ServerLicenseError, BLOB_LENGTH_SIZE, BLOB_TYPE_SIZE,
14 KEY_EXCHANGE_ALGORITHM_RSA, RANDOM_NUMBER_SIZE, UTF16_NULL_TERMINATOR_SIZE, UTF8_NULL_TERMINATOR_SIZE,
15};
16use crate::utils;
17
18const CERT_VERSION_FIELD_SIZE: usize = 4;
19const KEY_EXCHANGE_FIELD_SIZE: usize = 4;
20const SCOPE_ARRAY_SIZE_FIELD_SIZE: usize = 4;
21const PRODUCT_INFO_STATIC_FIELDS_SIZE: usize = 12;
22const CERT_CHAIN_VERSION_MASK: u32 = 0x7FFF_FFFF;
23const CERT_CHAIN_ISSUED_MASK: u32 = 0x8000_0000;
24const MAX_SCOPE_COUNT: u32 = 256;
25const MAX_COMPANY_NAME_LEN: usize = 1024;
26const MAX_PRODUCT_ID_LEN: usize = 1024;
27
28const RSA_EXCHANGE_ALGORITHM: u32 = 1;
29
30#[derive(Debug, PartialEq, Eq)]
34pub struct ServerLicenseRequest {
35 pub license_header: LicenseHeader,
36 pub server_random: Vec<u8>,
37 pub product_info: ProductInfo,
38 pub server_certificate: Option<ServerCertificate>,
39 pub scope_list: Vec<Scope>,
40}
41
42impl ServerLicenseRequest {
43 const NAME: &'static str = "ServerLicenseRequest";
44
45 pub fn get_public_key(&self) -> Result<Option<Vec<u8>>, ServerLicenseError> {
46 self.server_certificate.as_ref().map(|c| c.get_public_key()).transpose()
47 }
48}
49
50impl ServerLicenseRequest {
51 pub fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
52 ensure_size!(in: dst, size: self.size());
53
54 self.license_header.encode(dst)?;
55
56 dst.write_slice(&self.server_random);
57 self.product_info.encode(dst)?;
58
59 BlobHeader::new(BlobType::KEY_EXCHANGE_ALGORITHM, KEY_EXCHANGE_FIELD_SIZE).encode(dst)?;
60 dst.write_u32(KEY_EXCHANGE_ALGORITHM_RSA);
61
62 let cert_size = self.server_certificate.as_ref().map(|v| v.size()).unwrap_or(0);
63 BlobHeader::new(BlobType::CERTIFICATE, cert_size).encode(dst)?;
64
65 if let Some(cert) = &self.server_certificate {
66 cert.encode(dst)?;
67 }
68
69 dst.write_u32(cast_length!("listLen", self.scope_list.len())?);
70
71 for scope in self.scope_list.iter() {
72 scope.encode(dst)?;
73 }
74
75 Ok(())
76 }
77
78 pub fn name(&self) -> &'static str {
79 Self::NAME
80 }
81
82 pub fn size(&self) -> usize {
83 self.license_header.size()
84 + RANDOM_NUMBER_SIZE
85 + self.product_info.size()
86 + BLOB_LENGTH_SIZE * 2 + BLOB_TYPE_SIZE * 2 + KEY_EXCHANGE_FIELD_SIZE
89 + self.server_certificate.as_ref().map(|c| c.size()).unwrap_or(0)
90 + SCOPE_ARRAY_SIZE_FIELD_SIZE
91 + self.scope_list.iter().map(|s| s.size()).sum::<usize>()
92 }
93}
94
95impl ServerLicenseRequest {
96 pub fn decode(license_header: LicenseHeader, src: &mut ReadCursor<'_>) -> DecodeResult<Self> {
97 if license_header.preamble_message_type != PreambleType::LicenseRequest {
98 return Err(invalid_field_err!("preambleMessageType", "unexpected preamble type"));
99 }
100
101 ensure_size!(in: src, size: RANDOM_NUMBER_SIZE);
102 let server_random = src.read_slice(RANDOM_NUMBER_SIZE).into();
103
104 let product_info = ProductInfo::decode(src)?;
105
106 let key_exchange_algorithm_blob = BlobHeader::decode(src)?;
107 if key_exchange_algorithm_blob.blob_type != BlobType::KEY_EXCHANGE_ALGORITHM {
108 return Err(invalid_field_err!("blobType", "invalid blob type"));
109 }
110
111 ensure_size!(in: src, size: 4);
112 let key_exchange_algorithm = src.read_u32();
113 if key_exchange_algorithm != RSA_EXCHANGE_ALGORITHM {
114 return Err(invalid_field_err!("keyAlgo", "invalid key exchange algorithm"));
115 }
116
117 let cert_blob = BlobHeader::decode(src)?;
118 if cert_blob.blob_type != BlobType::CERTIFICATE {
119 return Err(invalid_field_err!("blobType", "invalid blob type"));
120 }
121
122 let server_certificate = if cert_blob.length != 0 {
124 Some(ServerCertificate::decode(src)?)
125 } else {
126 None
127 };
128
129 ensure_size!(in: src, size: 4);
130 let scope_count = src.read_u32();
131 if scope_count > MAX_SCOPE_COUNT {
132 return Err(invalid_field_err!("scopeCount", "invalid scope count"));
133 }
134
135 let mut scope_list = Vec::with_capacity(scope_count as usize);
136
137 for _ in 0..scope_count {
138 scope_list.push(Scope::decode(src)?);
139 }
140
141 Ok(Self {
142 license_header,
143 server_random,
144 product_info,
145 server_certificate,
146 scope_list,
147 })
148 }
149}
150
151#[derive(Debug, PartialEq, Eq)]
152pub struct Scope(pub String);
153
154impl Scope {
155 const NAME: &'static str = "Scope";
156
157 const FIXED_PART_SIZE: usize = BLOB_TYPE_SIZE + BLOB_LENGTH_SIZE;
158}
159
160impl Encode for Scope {
161 fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
162 ensure_size!(in: dst, size: self.size());
163
164 let data_size = self.0.len() + UTF8_NULL_TERMINATOR_SIZE;
165 BlobHeader::new(BlobType::SCOPE, data_size).encode(dst)?;
166 dst.write_slice(self.0.as_bytes());
167 dst.write_u8(0); Ok(())
170 }
171
172 fn name(&self) -> &'static str {
173 Self::NAME
174 }
175
176 fn size(&self) -> usize {
177 Self::FIXED_PART_SIZE + self.0.len() + UTF8_NULL_TERMINATOR_SIZE
178 }
179}
180
181impl<'de> Decode<'de> for Scope {
182 fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
183 let blob_header = BlobHeader::decode(src)?;
184 if blob_header.blob_type != BlobType::SCOPE {
185 return Err(invalid_field_err!("blobType", "invalid blob type"));
186 }
187 if blob_header.length < UTF8_NULL_TERMINATOR_SIZE {
188 return Err(invalid_field_err!("blobLen", "blob too small"));
189 }
190 ensure_size!(in: src, size: blob_header.length);
191 let mut blob_data = src.read_slice(blob_header.length).to_vec();
192 blob_data.resize(blob_data.len() - UTF8_NULL_TERMINATOR_SIZE, 0);
193
194 if let Ok(data) = core::str::from_utf8(&blob_data) {
195 Ok(Self(String::from(data)))
196 } else {
197 Err(invalid_field_err!("scope", "scope is not utf8"))
198 }
199 }
200}
201
202#[derive(Debug, PartialEq, Eq)]
206pub struct ServerCertificate {
207 pub issued_permanently: bool,
208 pub certificate: CertificateType,
209}
210
211impl ServerCertificate {
212 const NAME: &'static str = "ServerCertificate";
213
214 const FIXED_PART_SIZE: usize = CERT_VERSION_FIELD_SIZE;
215
216 pub fn get_public_key(&self) -> Result<Vec<u8>, ServerLicenseError> {
217 use x509_cert::der::Decode as _;
218
219 match &self.certificate {
220 CertificateType::Proprietary(certificate) => {
221 let public_exponent = certificate.public_key.public_exponent.to_le_bytes();
222
223 let rsa_public_key = pkcs1::RsaPublicKey {
224 modulus: pkcs1::UintRef::new(&certificate.public_key.modulus).unwrap(),
225 public_exponent: pkcs1::UintRef::new(&public_exponent).unwrap(),
226 };
227
228 let public_key = pkcs1::der::Encode::to_der(&rsa_public_key).unwrap();
229
230 Ok(public_key)
231 }
232 CertificateType::X509(certificate) => {
233 let cert_der = certificate
234 .certificate_array
235 .last()
236 .ok_or_else(|| ServerLicenseError::InvalidX509CertificatesAmount)?;
237
238 let cert = x509_cert::Certificate::from_der(cert_der).map_err(|source| {
239 ServerLicenseError::InvalidX509Certificate {
240 source,
241 cert_der: cert_der.clone(),
242 }
243 })?;
244
245 let public_key = cert
246 .tbs_certificate
247 .subject_public_key_info
248 .subject_public_key
249 .raw_bytes()
250 .to_owned();
251
252 Ok(public_key)
253 }
254 }
255 }
256}
257
258impl Encode for ServerCertificate {
259 fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
260 ensure_size!(in: dst, size: self.size());
261
262 let cert_version: u32 = match self.certificate {
263 CertificateType::Proprietary(_) => 1,
264 CertificateType::X509(_) => 2,
265 };
266 let mask = if self.issued_permanently {
267 CERT_CHAIN_ISSUED_MASK
268 } else {
269 0
270 };
271
272 dst.write_u32(cert_version | mask);
273
274 match &self.certificate {
275 CertificateType::Proprietary(cert) => cert.encode(dst)?,
276 CertificateType::X509(cert) => cert.encode(dst)?,
277 }
278
279 Ok(())
280 }
281
282 fn name(&self) -> &'static str {
283 Self::NAME
284 }
285
286 fn size(&self) -> usize {
287 let certificate_size = match &self.certificate {
288 CertificateType::Proprietary(cert) => cert.size(),
289 CertificateType::X509(cert) => cert.size(),
290 };
291
292 Self::FIXED_PART_SIZE + certificate_size
293 }
294}
295
296impl<'de> Decode<'de> for ServerCertificate {
297 fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
298 ensure_fixed_part_size!(in: src);
299
300 let cert_version = src.read_u32();
301
302 let issued_permanently = cert_version & CERT_CHAIN_ISSUED_MASK == CERT_CHAIN_ISSUED_MASK;
303
304 let certificate = match cert_version & CERT_CHAIN_VERSION_MASK {
305 1 => CertificateType::Proprietary(ProprietaryCertificate::decode(src)?),
306 2 => CertificateType::X509(X509CertificateChain::decode(src)?),
307 _ => return Err(invalid_field_err!("certVersion", "invalid certificate version")),
308 };
309
310 Ok(Self {
311 issued_permanently,
312 certificate,
313 })
314 }
315}
316
317#[derive(Debug, PartialEq, Eq)]
318pub struct ProductInfo {
319 pub version: u32,
320 pub company_name: String,
321 pub product_id: String,
322}
323
324impl ProductInfo {
325 const NAME: &'static str = "ProductInfo";
326
327 const FIXED_PART_SIZE: usize = PRODUCT_INFO_STATIC_FIELDS_SIZE;
328}
329
330impl Encode for ProductInfo {
331 fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
332 ensure_size!(in: dst, size: self.size());
333
334 dst.write_u32(self.version);
335
336 let mut company_name = utils::to_utf16_bytes(&self.company_name);
337 company_name.resize(company_name.len() + 2, 0);
338
339 dst.write_u32(cast_length!("companyLen", company_name.len())?);
340 dst.write_slice(&company_name);
341
342 let mut product_id = utils::to_utf16_bytes(&self.product_id);
343 product_id.resize(product_id.len() + 2, 0);
344
345 dst.write_u32(cast_length!("produceLen", product_id.len())?);
346 dst.write_slice(&product_id);
347
348 Ok(())
349 }
350
351 fn name(&self) -> &'static str {
352 Self::NAME
353 }
354
355 fn size(&self) -> usize {
356 let company_name_utf_16 = utils::to_utf16_bytes(&self.company_name);
357 let product_id_utf_16 = utils::to_utf16_bytes(&self.product_id);
358
359 Self::FIXED_PART_SIZE
360 + company_name_utf_16.len()
361 + UTF16_NULL_TERMINATOR_SIZE
362 + product_id_utf_16.len()
363 + UTF16_NULL_TERMINATOR_SIZE
364 }
365}
366
367impl<'de> Decode<'de> for ProductInfo {
368 fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
369 ensure_fixed_part_size!(in: src);
370
371 let version = src.read_u32();
372
373 let company_name_len = cast_length!("companyLen", src.read_u32())?;
374 if !(2..=MAX_COMPANY_NAME_LEN).contains(&company_name_len) {
375 return Err(invalid_field_err!("companyLen", "invalid company name length"));
376 }
377
378 ensure_size!(in: src, size: company_name_len);
379 let mut company_name = src.read_slice(company_name_len).to_vec();
380 company_name.resize(company_name_len - 2, 0);
381 let company_name = utils::from_utf16_bytes(company_name.as_slice());
382
383 ensure_size!(in: src, size: 4);
384 let product_id_len = cast_length!("productIdLen", src.read_u32())?;
385 if !(2..=MAX_PRODUCT_ID_LEN).contains(&product_id_len) {
386 return Err(invalid_field_err!("productIdLen", "invalid produce ID length"));
387 }
388
389 ensure_size!(in: src, size: product_id_len);
390 let mut product_id = src.read_slice(product_id_len).to_vec();
391 product_id.resize(product_id_len - 2, 0);
392 let product_id = utils::from_utf16_bytes(product_id.as_slice());
393
394 Ok(Self {
395 version,
396 company_name,
397 product_id,
398 })
399 }
400}