ironrdp_pdu/gcc/
security_data.rs

1use std::io;
2
3use bitflags::bitflags;
4use ironrdp_core::{
5    cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, Decode, DecodeResult, Encode, EncodeResult,
6    ReadCursor, WriteCursor,
7};
8use num_derive::{FromPrimitive, ToPrimitive};
9use num_traits::{FromPrimitive as _, ToPrimitive as _};
10use thiserror::Error;
11
12const CLIENT_ENCRYPTION_METHODS_SIZE: usize = 4;
13const CLIENT_EXT_ENCRYPTION_METHODS_SIZE: usize = 4;
14
15const SERVER_ENCRYPTION_METHOD_SIZE: usize = 4;
16const SERVER_ENCRYPTION_LEVEL_SIZE: usize = 4;
17const SERVER_RANDOM_LEN_SIZE: usize = 4;
18const SERVER_CERT_LEN_SIZE: usize = 4;
19const SERVER_RANDOM_LEN: usize = 0x20;
20const MAX_SERVER_CERT_LEN: usize = 1024;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct ClientSecurityData {
24    pub encryption_methods: EncryptionMethod,
25    pub ext_encryption_methods: u32,
26}
27
28impl ClientSecurityData {
29    const NAME: &'static str = "ClientSecurityData";
30
31    const FIXED_PART_SIZE: usize = CLIENT_ENCRYPTION_METHODS_SIZE + CLIENT_EXT_ENCRYPTION_METHODS_SIZE;
32
33    pub fn no_security() -> Self {
34        Self {
35            encryption_methods: EncryptionMethod::empty(),
36            ext_encryption_methods: 0,
37        }
38    }
39}
40
41impl Encode for ClientSecurityData {
42    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
43        ensure_size!(in: dst, size: self.size());
44
45        dst.write_u32(self.encryption_methods.bits());
46        dst.write_u32(self.ext_encryption_methods);
47
48        Ok(())
49    }
50
51    fn name(&self) -> &'static str {
52        Self::NAME
53    }
54
55    fn size(&self) -> usize {
56        Self::FIXED_PART_SIZE
57    }
58}
59
60impl<'de> Decode<'de> for ClientSecurityData {
61    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
62        ensure_fixed_part_size!(in: src);
63
64        let encryption_methods = EncryptionMethod::from_bits(src.read_u32())
65            .ok_or_else(|| invalid_field_err!("encryptionMethods", "invalid encryption methods"))?;
66        let ext_encryption_methods = src.read_u32();
67
68        Ok(Self {
69            encryption_methods,
70            ext_encryption_methods,
71        })
72    }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct ServerSecurityData {
77    pub encryption_method: EncryptionMethod,
78    pub encryption_level: EncryptionLevel,
79    pub server_random: Option<[u8; SERVER_RANDOM_LEN]>,
80    pub server_cert: Vec<u8>,
81}
82
83impl ServerSecurityData {
84    const NAME: &'static str = "ServerSecurityData";
85
86    const FIXED_PART_SIZE: usize = SERVER_ENCRYPTION_METHOD_SIZE + SERVER_ENCRYPTION_LEVEL_SIZE;
87
88    pub fn no_security() -> Self {
89        Self {
90            encryption_method: EncryptionMethod::empty(),
91            encryption_level: EncryptionLevel::None,
92            server_random: None,
93            server_cert: Vec::new(),
94        }
95    }
96}
97
98impl Encode for ServerSecurityData {
99    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
100        ensure_size!(in: dst, size: self.size());
101
102        dst.write_u32(self.encryption_method.bits());
103        dst.write_u32(self.encryption_level.to_u32().unwrap());
104
105        if self.encryption_method.is_empty() && self.encryption_level == EncryptionLevel::None {
106            if self.server_random.is_some() || !self.server_cert.is_empty() {
107                Err(invalid_field_err!("serverRandom", "An encryption method and encryption level is none, but the server random or certificate is not empty"))
108            } else {
109                Ok(())
110            }
111        } else {
112            let server_random_len = match self.server_random {
113                Some(ref server_random) => server_random.len(),
114                None => 0,
115            };
116            dst.write_u32(cast_length!("serverRandomLen", server_random_len)?);
117            dst.write_u32(cast_length!("serverCertLen", self.server_cert.len())?);
118
119            if let Some(ref server_random) = self.server_random {
120                dst.write_slice(server_random.as_ref());
121            }
122            dst.write_slice(self.server_cert.as_ref());
123
124            Ok(())
125        }
126    }
127
128    fn name(&self) -> &'static str {
129        Self::NAME
130    }
131
132    fn size(&self) -> usize {
133        let mut size = Self::FIXED_PART_SIZE;
134
135        if let Some(ref server_random) = self.server_random {
136            size += SERVER_RANDOM_LEN_SIZE + server_random.len();
137        }
138        if !self.server_cert.is_empty() {
139            size += SERVER_CERT_LEN_SIZE + self.server_cert.len();
140        }
141
142        size
143    }
144}
145
146impl<'de> Decode<'de> for ServerSecurityData {
147    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
148        ensure_fixed_part_size!(in: src);
149
150        let encryption_method = EncryptionMethod::from_bits(src.read_u32())
151            .ok_or_else(|| invalid_field_err!("encryptionMethod", "invalid encryption method"))?;
152        let encryption_level = EncryptionLevel::from_u32(src.read_u32())
153            .ok_or_else(|| invalid_field_err!("encryptionLevel", "invalid encryption level"))?;
154
155        let (server_random, server_cert) = if encryption_method.is_empty() && encryption_level == EncryptionLevel::None
156        {
157            (None, Vec::new())
158        } else {
159            ensure_size!(in: src, size: 4 + 4);
160
161            let server_random_len: usize = cast_length!("serverRandomLen", src.read_u32())?;
162            if server_random_len != SERVER_RANDOM_LEN {
163                return Err(invalid_field_err!("serverRandomLen", "Invalid server random length"));
164            }
165
166            let server_cert_len = cast_length!("serverCertLen", src.read_u32())?;
167
168            if server_cert_len > MAX_SERVER_CERT_LEN {
169                return Err(invalid_field_err!("serverCetLen", "Invalid server certificate length"));
170            }
171
172            ensure_size!(in: src, size: SERVER_RANDOM_LEN);
173            let server_random = src.read_array();
174
175            ensure_size!(in: src, size: server_cert_len);
176            let server_cert = src.read_slice(server_cert_len);
177
178            (Some(server_random), server_cert.into())
179        };
180
181        Ok(Self {
182            encryption_method,
183            encryption_level,
184            server_random,
185            server_cert,
186        })
187    }
188}
189
190bitflags! {
191    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
192    pub struct EncryptionMethod: u32 {
193        const BIT_40 = 0x0000_0001;
194        const BIT_128 = 0x0000_0002;
195        const BIT_56 = 0x0000_0008;
196        const FIPS = 0x0000_0010;
197    }
198}
199
200#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
201pub enum EncryptionLevel {
202    None = 0,
203    Low = 1,
204    ClientCompatible = 2,
205    High = 3,
206    Fips = 4,
207}
208
209#[derive(Debug, Error)]
210pub enum SecurityDataError {
211    #[error("IO error")]
212    IOError(#[from] io::Error),
213    #[error("invalid encryption methods field")]
214    InvalidEncryptionMethod,
215    #[error("invalid encryption level field")]
216    InvalidEncryptionLevel,
217    #[error("invalid server random length field: {0}")]
218    InvalidServerRandomLen(u32),
219    #[error("invalid input: {0}")]
220    InvalidInput(String),
221    #[error("invalid server certificate length: {0}")]
222    InvalidServerCertificateLen(u32),
223}