ironrdp_pdu/gcc/
network_data.rs

1use std::borrow::Cow;
2use std::{io, str};
3
4use bitflags::bitflags;
5use ironrdp_core::{
6    cast_length, ensure_fixed_part_size, ensure_size, invalid_field_err, read_padding, write_padding, Decode,
7    DecodeResult, Encode, EncodeResult, ReadCursor, WriteCursor,
8};
9use num_integer::Integer as _;
10use thiserror::Error;
11
12const CHANNELS_MAX: usize = 31;
13
14const CLIENT_CHANNEL_OPTIONS_SIZE: usize = 4;
15const CLIENT_CHANNEL_SIZE: usize = ChannelName::SIZE + CLIENT_CHANNEL_OPTIONS_SIZE;
16
17const SERVER_IO_CHANNEL_SIZE: usize = 2;
18const SERVER_CHANNEL_COUNT_SIZE: usize = 2;
19const SERVER_CHANNEL_SIZE: usize = 2;
20
21/// An 8-byte array containing a null-terminated collection of seven ANSI characters
22/// with the purpose of uniquely identifying a channel.
23///
24/// In RDP, an ANSI character is a 8-bit Windows-1252 character set unit. ANSI character set
25/// is using all the code values from 0 to 255, as such any u8 value is a valid ANSI character.
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct ChannelName {
28    inner: Cow<'static, [u8; Self::SIZE]>,
29}
30
31impl ChannelName {
32    pub const SIZE: usize = 8;
33
34    /// Creates a channel name using the provided array, ensuring the last byte is always the null terminator.
35    pub const fn new(mut value: [u8; Self::SIZE]) -> Self {
36        value[Self::SIZE - 1] = 0; // ensure the last byte is always the null terminator
37
38        Self {
39            inner: Cow::Owned(value),
40        }
41    }
42
43    /// Converts an UTF-8 string into a channel name by copying up to 7 bytes.
44    pub fn from_utf8(value: &str) -> Option<Self> {
45        let mut inner = [0; Self::SIZE];
46
47        value
48            .chars()
49            .take(Self::SIZE - 1)
50            .zip(inner.iter_mut())
51            .try_for_each(|(src, dst)| {
52                let c = u8::try_from(src).ok()?;
53                c.is_ascii().then(|| *dst = c)
54            })?;
55
56        Some(Self {
57            inner: Cow::Owned(inner),
58        })
59    }
60
61    /// Converts a static u8 array into a channel name without copy.
62    ///
63    /// # Panics
64    ///
65    /// Panics if input is not null-terminated.
66    pub const fn from_static(value: &'static [u8; 8]) -> Self {
67        // ensure the last byte is always the null terminator
68        if value[Self::SIZE - 1] != 0 {
69            panic!("channel name must be null-terminated")
70        }
71
72        Self {
73            inner: Cow::Borrowed(value),
74        }
75    }
76
77    /// Returns the underlying raw representation of the channel name (an 8-byte array).
78    pub fn as_bytes(&self) -> &[u8; Self::SIZE] {
79        self.inner.as_ref()
80    }
81
82    /// Get a &str if this channel name is a valid ASCII string.
83    pub fn as_str(&self) -> Option<&str> {
84        if self.inner.iter().all(u8::is_ascii) {
85            let terminator_idx = self
86                .inner
87                .iter()
88                .position(|c| *c == 0)
89                .expect("null-terminated ASCII string");
90            Some(str::from_utf8(&self.inner[..terminator_idx]).expect("ASCII characters"))
91        } else {
92            None
93        }
94    }
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct ClientNetworkData {
99    pub channels: Vec<ChannelDef>,
100}
101
102impl ClientNetworkData {
103    const NAME: &'static str = "ClientNetworkData";
104
105    const FIXED_PART_SIZE: usize = 4 /* channelCount */;
106}
107
108impl Encode for ClientNetworkData {
109    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
110        ensure_fixed_part_size!(in: dst);
111
112        dst.write_u32(cast_length!("channelCount", self.channels.len())?);
113
114        for channel in self.channels.iter().take(CHANNELS_MAX) {
115            channel.encode(dst)?;
116        }
117
118        Ok(())
119    }
120
121    fn name(&self) -> &'static str {
122        Self::NAME
123    }
124
125    fn size(&self) -> usize {
126        Self::FIXED_PART_SIZE + self.channels.len() * CLIENT_CHANNEL_SIZE
127    }
128}
129
130impl<'de> Decode<'de> for ClientNetworkData {
131    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
132        ensure_fixed_part_size!(in: src);
133
134        let channel_count = cast_length!("channelCount", src.read_u32())?;
135
136        if channel_count > CHANNELS_MAX {
137            return Err(invalid_field_err!("channelCount", "invalid channel count"));
138        }
139
140        let mut channels = Vec::with_capacity(channel_count);
141        for _ in 0..channel_count {
142            channels.push(ChannelDef::decode(src)?);
143        }
144
145        Ok(Self { channels })
146    }
147}
148
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub struct ServerNetworkData {
151    pub channel_ids: Vec<u16>,
152    pub io_channel: u16,
153}
154
155impl ServerNetworkData {
156    const NAME: &'static str = "ServerNetworkData";
157
158    const FIXED_PART_SIZE: usize = SERVER_IO_CHANNEL_SIZE + SERVER_CHANNEL_COUNT_SIZE;
159
160    fn padding_needed(&self) -> bool {
161        self.channel_ids.len().is_odd()
162    }
163}
164
165impl Encode for ServerNetworkData {
166    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
167        ensure_size!(in: dst, size: self.size());
168
169        dst.write_u16(self.io_channel);
170        dst.write_u16(cast_length!("channelIdLen", self.channel_ids.len())?);
171
172        for channel_id in self.channel_ids.iter() {
173            dst.write_u16(*channel_id);
174        }
175
176        // The size in bytes of the Server Network Data structure MUST be a multiple of 4.
177        // If the channelCount field contains an odd value, then the size of the channelIdArray
178        // (and by implication the entire Server Network Data structure) will not be a multiple of 4.
179        // In this scenario, the Pad field MUST be present and it is used to add an additional
180        // 2 bytes to the size of the Server Network Data structure.
181        if self.padding_needed() {
182            write_padding!(dst, 2);
183        }
184
185        Ok(())
186    }
187
188    fn name(&self) -> &'static str {
189        Self::NAME
190    }
191
192    fn size(&self) -> usize {
193        let padding_size = if self.padding_needed() { 2 } else { 0 };
194
195        Self::FIXED_PART_SIZE + self.channel_ids.len() * SERVER_CHANNEL_SIZE + padding_size
196    }
197}
198
199impl<'de> Decode<'de> for ServerNetworkData {
200    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
201        ensure_fixed_part_size!(in: src);
202
203        let io_channel = src.read_u16();
204        let channel_count = cast_length!("channelCount", src.read_u16())?;
205
206        ensure_size!(in: src, size: channel_count * 2);
207        let mut channel_ids = Vec::with_capacity(channel_count);
208        for _ in 0..channel_count {
209            channel_ids.push(src.read_u16());
210        }
211
212        let result = Self {
213            io_channel,
214            channel_ids,
215        };
216
217        if src.len() >= 2 {
218            read_padding!(src, 2);
219        }
220
221        Ok(result)
222    }
223}
224
225/// Channel Definition Structure (CHANNEL_DEF)
226#[derive(Debug, Clone, PartialEq, Eq)]
227pub struct ChannelDef {
228    pub name: ChannelName,
229    pub options: ChannelOptions,
230}
231
232impl ChannelDef {
233    const NAME: &'static str = "ChannelDef";
234
235    const FIXED_PART_SIZE: usize = CLIENT_CHANNEL_SIZE;
236}
237
238impl Encode for ChannelDef {
239    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
240        ensure_fixed_part_size!(in: dst);
241
242        dst.write_slice(self.name.as_bytes());
243        dst.write_u32(self.options.bits());
244
245        Ok(())
246    }
247
248    fn name(&self) -> &'static str {
249        Self::NAME
250    }
251
252    fn size(&self) -> usize {
253        Self::FIXED_PART_SIZE
254    }
255}
256
257impl<'de> Decode<'de> for ChannelDef {
258    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
259        ensure_fixed_part_size!(in: src);
260
261        let name = src.read_array();
262        let name = ChannelName::new(name);
263
264        let options = ChannelOptions::from_bits(src.read_u32())
265            .ok_or_else(|| invalid_field_err!("options", "invalid channel options"))?;
266
267        Ok(Self { name, options })
268    }
269}
270
271bitflags! {
272    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
273    pub struct ChannelOptions: u32 {
274        const INITIALIZED = 0x8000_0000;
275        const ENCRYPT_RDP = 0x4000_0000;
276        const ENCRYPT_SC = 0x2000_0000;
277        const ENCRYPT_CS = 0x1000_0000;
278        const PRI_HIGH = 0x0800_0000;
279        const PRI_MED = 0x0400_0000;
280        const PRI_LOW = 0x0200_0000;
281        const COMPRESS_RDP = 0x0080_0000;
282        const COMPRESS = 0x0040_0000;
283        const SHOW_PROTOCOL = 0x0020_0000;
284        const REMOTE_CONTROL_PERSISTENT = 0x0010_0000;
285    }
286}
287
288#[derive(Debug, Error)]
289pub enum NetworkDataError {
290    #[error("IO error")]
291    IOError(#[from] io::Error),
292    #[error("UTF-8 error")]
293    Utf8Error(#[from] str::Utf8Error),
294    #[error("invalid channel options field")]
295    InvalidChannelOptions,
296    #[error("invalid channel count field")]
297    InvalidChannelCount,
298}