ironrdp_pdu/gcc/
security_data.rs1use std::io;
2
3use bitflags::bitflags;
4use ironrdp_core::{
5 cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, Decode, DecodeResult, Encode, EncodeResult,
6 ReadCursor, WriteCursor,
7};
8use num_derive::{FromPrimitive, ToPrimitive};
9use num_traits::{FromPrimitive as _, ToPrimitive as _};
10use thiserror::Error;
11
12const CLIENT_ENCRYPTION_METHODS_SIZE: usize = 4;
13const CLIENT_EXT_ENCRYPTION_METHODS_SIZE: usize = 4;
14
15const SERVER_ENCRYPTION_METHOD_SIZE: usize = 4;
16const SERVER_ENCRYPTION_LEVEL_SIZE: usize = 4;
17const SERVER_RANDOM_LEN_SIZE: usize = 4;
18const SERVER_CERT_LEN_SIZE: usize = 4;
19const SERVER_RANDOM_LEN: usize = 0x20;
20const MAX_SERVER_CERT_LEN: usize = 1024;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct ClientSecurityData {
24 pub encryption_methods: EncryptionMethod,
25 pub ext_encryption_methods: u32,
26}
27
28impl ClientSecurityData {
29 const NAME: &'static str = "ClientSecurityData";
30
31 const FIXED_PART_SIZE: usize = CLIENT_ENCRYPTION_METHODS_SIZE + CLIENT_EXT_ENCRYPTION_METHODS_SIZE;
32
33 pub fn no_security() -> Self {
34 Self {
35 encryption_methods: EncryptionMethod::empty(),
36 ext_encryption_methods: 0,
37 }
38 }
39}
40
41impl Encode for ClientSecurityData {
42 fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
43 ensure_size!(in: dst, size: self.size());
44
45 dst.write_u32(self.encryption_methods.bits());
46 dst.write_u32(self.ext_encryption_methods);
47
48 Ok(())
49 }
50
51 fn name(&self) -> &'static str {
52 Self::NAME
53 }
54
55 fn size(&self) -> usize {
56 Self::FIXED_PART_SIZE
57 }
58}
59
60impl<'de> Decode<'de> for ClientSecurityData {
61 fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
62 ensure_fixed_part_size!(in: src);
63
64 let encryption_methods = EncryptionMethod::from_bits(src.read_u32())
65 .ok_or_else(|| invalid_field_err!("encryptionMethods", "invalid encryption methods"))?;
66 let ext_encryption_methods = src.read_u32();
67
68 Ok(Self {
69 encryption_methods,
70 ext_encryption_methods,
71 })
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct ServerSecurityData {
77 pub encryption_method: EncryptionMethod,
78 pub encryption_level: EncryptionLevel,
79 pub server_random: Option<[u8; SERVER_RANDOM_LEN]>,
80 pub server_cert: Vec<u8>,
81}
82
83impl ServerSecurityData {
84 const NAME: &'static str = "ServerSecurityData";
85
86 const FIXED_PART_SIZE: usize = SERVER_ENCRYPTION_METHOD_SIZE + SERVER_ENCRYPTION_LEVEL_SIZE;
87
88 pub fn no_security() -> Self {
89 Self {
90 encryption_method: EncryptionMethod::empty(),
91 encryption_level: EncryptionLevel::None,
92 server_random: None,
93 server_cert: Vec::new(),
94 }
95 }
96}
97
98impl Encode for ServerSecurityData {
99 fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
100 ensure_size!(in: dst, size: self.size());
101
102 dst.write_u32(self.encryption_method.bits());
103 dst.write_u32(self.encryption_level.to_u32().unwrap());
104
105 if self.encryption_method.is_empty() && self.encryption_level == EncryptionLevel::None {
106 if self.server_random.is_some() || !self.server_cert.is_empty() {
107 Err(invalid_field_err!("serverRandom", "An encryption method and encryption level is none, but the server random or certificate is not empty"))
108 } else {
109 Ok(())
110 }
111 } else {
112 let server_random_len = match self.server_random {
113 Some(ref server_random) => server_random.len(),
114 None => 0,
115 };
116 dst.write_u32(cast_length!("serverRandomLen", server_random_len)?);
117 dst.write_u32(cast_length!("serverCertLen", self.server_cert.len())?);
118
119 if let Some(ref server_random) = self.server_random {
120 dst.write_slice(server_random.as_ref());
121 }
122 dst.write_slice(self.server_cert.as_ref());
123
124 Ok(())
125 }
126 }
127
128 fn name(&self) -> &'static str {
129 Self::NAME
130 }
131
132 fn size(&self) -> usize {
133 let mut size = Self::FIXED_PART_SIZE;
134
135 if let Some(ref server_random) = self.server_random {
136 size += SERVER_RANDOM_LEN_SIZE + server_random.len();
137 }
138 if !self.server_cert.is_empty() {
139 size += SERVER_CERT_LEN_SIZE + self.server_cert.len();
140 }
141
142 size
143 }
144}
145
146impl<'de> Decode<'de> for ServerSecurityData {
147 fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
148 ensure_fixed_part_size!(in: src);
149
150 let encryption_method = EncryptionMethod::from_bits(src.read_u32())
151 .ok_or_else(|| invalid_field_err!("encryptionMethod", "invalid encryption method"))?;
152 let encryption_level = EncryptionLevel::from_u32(src.read_u32())
153 .ok_or_else(|| invalid_field_err!("encryptionLevel", "invalid encryption level"))?;
154
155 let (server_random, server_cert) = if encryption_method.is_empty() && encryption_level == EncryptionLevel::None
156 {
157 (None, Vec::new())
158 } else {
159 ensure_size!(in: src, size: 4 + 4);
160
161 let server_random_len: usize = cast_length!("serverRandomLen", src.read_u32())?;
162 if server_random_len != SERVER_RANDOM_LEN {
163 return Err(invalid_field_err!("serverRandomLen", "Invalid server random length"));
164 }
165
166 let server_cert_len = cast_length!("serverCertLen", src.read_u32())?;
167
168 if server_cert_len > MAX_SERVER_CERT_LEN {
169 return Err(invalid_field_err!("serverCetLen", "Invalid server certificate length"));
170 }
171
172 ensure_size!(in: src, size: SERVER_RANDOM_LEN);
173 let server_random = src.read_array();
174
175 ensure_size!(in: src, size: server_cert_len);
176 let server_cert = src.read_slice(server_cert_len);
177
178 (Some(server_random), server_cert.into())
179 };
180
181 Ok(Self {
182 encryption_method,
183 encryption_level,
184 server_random,
185 server_cert,
186 })
187 }
188}
189
190bitflags! {
191 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
192 pub struct EncryptionMethod: u32 {
193 const BIT_40 = 0x0000_0001;
194 const BIT_128 = 0x0000_0002;
195 const BIT_56 = 0x0000_0008;
196 const FIPS = 0x0000_0010;
197 }
198}
199
200#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
201pub enum EncryptionLevel {
202 None = 0,
203 Low = 1,
204 ClientCompatible = 2,
205 High = 3,
206 Fips = 4,
207}
208
209#[derive(Debug, Error)]
210pub enum SecurityDataError {
211 #[error("IO error")]
212 IOError(#[from] io::Error),
213 #[error("invalid encryption methods field")]
214 InvalidEncryptionMethod,
215 #[error("invalid encryption level field")]
216 InvalidEncryptionLevel,
217 #[error("invalid server random length field: {0}")]
218 InvalidServerRandomLen(u32),
219 #[error("invalid input: {0}")]
220 InvalidInput(String),
221 #[error("invalid server certificate length: {0}")]
222 InvalidServerCertificateLen(u32),
223}