ironrdp_pdu/rdp/session_info/
logon_extended.rs

1use bitflags::bitflags;
2use ironrdp_core::{
3    cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, read_padding, Decode, DecodeResult, Encode,
4    EncodeResult, ReadCursor, WriteCursor,
5};
6use num_derive::{FromPrimitive, ToPrimitive};
7use num_traits::{FromPrimitive as _, ToPrimitive as _};
8
9const LOGON_EX_LENGTH_FIELD_SIZE: usize = 2;
10const LOGON_EX_FLAGS_FIELD_SIZE: usize = 4;
11const LOGON_EX_PADDING_SIZE: usize = 570;
12const LOGON_EX_PADDING_BUFFER: [u8; LOGON_EX_PADDING_SIZE] = [0; LOGON_EX_PADDING_SIZE];
13
14const LOGON_INFO_FIELD_DATA_SIZE: usize = 4;
15const AUTO_RECONNECT_VERSION_1: u32 = 0x0000_0001;
16const AUTO_RECONNECT_PACKET_SIZE: usize = 28;
17const AUTO_RECONNECT_RANDOM_BITS_SIZE: usize = 16;
18const LOGON_ERRORS_INFO_SIZE: usize = 8;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct LogonInfoExtended {
22    pub present_fields_flags: LogonExFlags,
23    pub auto_reconnect: Option<ServerAutoReconnect>,
24    pub errors_info: Option<LogonErrorsInfo>,
25}
26
27impl LogonInfoExtended {
28    const NAME: &'static str = "LogonInfoExtended";
29
30    const FIXED_PART_SIZE: usize = LOGON_EX_LENGTH_FIELD_SIZE + LOGON_EX_FLAGS_FIELD_SIZE;
31
32    fn get_internal_size(&self) -> usize {
33        let reconnect_size = self.auto_reconnect.as_ref().map(|r| r.size()).unwrap_or(0);
34
35        let errors_size = self.errors_info.as_ref().map(|r| r.size()).unwrap_or(0);
36
37        Self::FIXED_PART_SIZE + reconnect_size + errors_size
38    }
39}
40
41impl Encode for LogonInfoExtended {
42    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
43        ensure_size!(in: dst, size: self.size());
44
45        dst.write_u16(cast_length!("internalSize", self.get_internal_size())?);
46        dst.write_u32(self.present_fields_flags.bits());
47
48        if let Some(ref reconnect) = self.auto_reconnect {
49            reconnect.encode(dst)?;
50        }
51        if let Some(ref errors) = self.errors_info {
52            errors.encode(dst)?;
53        }
54
55        dst.write_slice(LOGON_EX_PADDING_BUFFER.as_ref());
56
57        Ok(())
58    }
59
60    fn name(&self) -> &'static str {
61        Self::NAME
62    }
63
64    fn size(&self) -> usize {
65        self.get_internal_size() + LOGON_EX_PADDING_SIZE
66    }
67}
68
69impl<'de> Decode<'de> for LogonInfoExtended {
70    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
71        ensure_fixed_part_size!(in: src);
72
73        let _self_length = src.read_u16();
74        let present_fields_flags = LogonExFlags::from_bits_truncate(src.read_u32());
75
76        let auto_reconnect = if present_fields_flags.contains(LogonExFlags::AUTO_RECONNECT_COOKIE) {
77            Some(ServerAutoReconnect::decode(src)?)
78        } else {
79            None
80        };
81
82        let errors_info = if present_fields_flags.contains(LogonExFlags::LOGON_ERRORS) {
83            Some(LogonErrorsInfo::decode(src)?)
84        } else {
85            None
86        };
87
88        ensure_size!(in: src, size: LOGON_EX_PADDING_SIZE);
89        read_padding!(src, LOGON_EX_PADDING_SIZE);
90
91        Ok(Self {
92            present_fields_flags,
93            auto_reconnect,
94            errors_info,
95        })
96    }
97}
98
99#[derive(Debug, Clone, PartialEq, Eq)]
100pub struct ServerAutoReconnect {
101    pub logon_id: u32,
102    pub random_bits: [u8; AUTO_RECONNECT_RANDOM_BITS_SIZE],
103}
104
105impl ServerAutoReconnect {
106    const NAME: &'static str = "ServerAutoReconnect";
107
108    const FIXED_PART_SIZE: usize = AUTO_RECONNECT_PACKET_SIZE + LOGON_INFO_FIELD_DATA_SIZE;
109}
110
111impl Encode for ServerAutoReconnect {
112    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
113        ensure_fixed_part_size!(in: dst);
114
115        dst.write_u32(AUTO_RECONNECT_PACKET_SIZE as u32);
116        dst.write_u32(AUTO_RECONNECT_PACKET_SIZE as u32);
117        dst.write_u32(AUTO_RECONNECT_VERSION_1);
118        dst.write_u32(self.logon_id);
119        dst.write_slice(self.random_bits.as_ref());
120
121        Ok(())
122    }
123
124    fn name(&self) -> &'static str {
125        Self::NAME
126    }
127
128    fn size(&self) -> usize {
129        Self::FIXED_PART_SIZE
130    }
131}
132
133impl<'de> Decode<'de> for ServerAutoReconnect {
134    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
135        ensure_fixed_part_size!(in: src);
136
137        let _data_length = src.read_u32();
138        let packet_length = src.read_u32();
139        if packet_length != AUTO_RECONNECT_PACKET_SIZE as u32 {
140            return Err(invalid_field_err!("packetLen", "invalid auto-reconnect packet size"));
141        }
142
143        let version = src.read_u32();
144        if version != AUTO_RECONNECT_VERSION_1 {
145            return Err(invalid_field_err!("version", "invalid auto-reconnect version"));
146        }
147
148        let logon_id = src.read_u32();
149        let random_bits = src.read_array();
150
151        Ok(Self { logon_id, random_bits })
152    }
153}
154
155/// TS_LOGON_ERRORS_INFO
156///
157/// [Doc](https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/845eb789-6edf-453a-8b0e-c976823d1f72)
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub struct LogonErrorsInfo {
160    pub error_type: LogonErrorNotificationType,
161    pub error_data: LogonErrorNotificationData,
162}
163
164impl LogonErrorsInfo {
165    const NAME: &'static str = "LogonErrorsInfo";
166
167    const FIXED_PART_SIZE: usize = LOGON_ERRORS_INFO_SIZE + LOGON_INFO_FIELD_DATA_SIZE;
168}
169
170impl Encode for LogonErrorsInfo {
171    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
172        ensure_fixed_part_size!(in: dst);
173
174        dst.write_u32(LOGON_ERRORS_INFO_SIZE as u32);
175        dst.write_u32(self.error_type.to_u32().unwrap());
176        dst.write_u32(self.error_data.to_u32());
177
178        Ok(())
179    }
180
181    fn name(&self) -> &'static str {
182        Self::NAME
183    }
184
185    fn size(&self) -> usize {
186        Self::FIXED_PART_SIZE
187    }
188}
189
190impl<'de> Decode<'de> for LogonErrorsInfo {
191    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
192        ensure_fixed_part_size!(in: src);
193
194        let _data_length = src.read_u32();
195        let error_type = LogonErrorNotificationType::from_u32(src.read_u32())
196            .ok_or_else(|| invalid_field_err!("errorType", "invalid logon error type"))?;
197
198        let error_notification_data = src.read_u32();
199        let error_data = LogonErrorNotificationDataErrorCode::from_u32(error_notification_data)
200            .map(LogonErrorNotificationData::ErrorCode)
201            .unwrap_or(LogonErrorNotificationData::SessionId(error_notification_data));
202
203        Ok(Self { error_type, error_data })
204    }
205}
206
207bitflags! {
208    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
209    pub struct LogonExFlags: u32 {
210        const AUTO_RECONNECT_COOKIE = 0x0000_0001;
211        const LOGON_ERRORS = 0x0000_0002;
212    }
213}
214
215#[repr(u32)]
216#[derive(Debug, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
217pub enum LogonErrorNotificationType {
218    SessionBusyOptions = 0xFFFF_FFF8,
219    DisconnectRefused = 0xFFFF_FFF9,
220    NoPermission = 0xFFFF_FFFA,
221    BumpOptions = 0xFFFF_FFFB,
222    ReconnectOptions = 0xFFFF_FFFC,
223    SessionTerminate = 0xFFFF_FFFD,
224    SessionContinue = 0xFFFF_FFFE,
225    AccessDenied = 0xFFFF_FFFF,
226}
227
228#[repr(u32)]
229#[derive(Debug, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
230pub enum LogonErrorNotificationDataErrorCode {
231    FailedBadPassword = 0x0000_0000,
232    FailedUpdatePassword = 0x0000_0001,
233    FailedOther = 0x0000_0002,
234    Warning = 0x0000_0003,
235}
236
237#[derive(Debug, Clone, PartialEq, Eq)]
238pub enum LogonErrorNotificationData {
239    ErrorCode(LogonErrorNotificationDataErrorCode),
240    SessionId(u32),
241}
242
243impl LogonErrorNotificationData {
244    pub fn to_u32(&self) -> u32 {
245        match self {
246            LogonErrorNotificationData::ErrorCode(code) => code.to_u32().unwrap(),
247            LogonErrorNotificationData::SessionId(id) => *id,
248        }
249    }
250}