ironrdp_pdu/basic_output/
fast_path.rs

1#[cfg(test)]
2mod tests;
3
4use bit_field::BitField as _;
5use bitflags::bitflags;
6use ironrdp_core::{
7    decode_cursor, ensure_fixed_part_size, ensure_size, invalid_field_err, Decode, DecodeError, DecodeResult, Encode,
8    EncodeResult, InvalidFieldErr as _, ReadCursor, WriteCursor,
9};
10use num_derive::{FromPrimitive, ToPrimitive};
11use num_traits::{FromPrimitive as _, ToPrimitive as _};
12
13use super::bitmap::BitmapUpdateData;
14use super::pointer::PointerUpdateData;
15use super::surface_commands::{SurfaceCommand, SURFACE_COMMAND_HEADER_SIZE};
16use crate::per;
17use crate::rdp::client_info::CompressionType;
18use crate::rdp::headers::{CompressionFlags, SHARE_DATA_HEADER_COMPRESSION_MASK};
19
20/// Implements the Fast-Path RDP message header PDU.
21/// TS_FP_UPDATE_PDU
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct FastPathHeader {
24    pub flags: EncryptionFlags,
25    pub data_length: usize,
26    forced_long_length: bool,
27}
28
29impl FastPathHeader {
30    const NAME: &'static str = "TS_FP_UPDATE_PDU header";
31    const FIXED_PART_SIZE: usize = 1 /* EncryptionFlags */;
32
33    pub fn new(flags: EncryptionFlags, data_length: usize) -> Self {
34        Self {
35            flags,
36            data_length,
37            forced_long_length: false,
38        }
39    }
40
41    fn minimal_size(&self) -> usize {
42        // it may then be +2 if > 0x7f
43        let len = self.data_length + Self::FIXED_PART_SIZE + 1;
44
45        Self::FIXED_PART_SIZE + per::sizeof_length(len as u16)
46    }
47}
48
49impl Encode for FastPathHeader {
50    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
51        ensure_size!(in: dst, size: self.size());
52
53        let mut header = 0u8;
54        header.set_bits(0..2, 0); // fast-path action
55        header.set_bits(6..8, self.flags.bits());
56        dst.write_u8(header);
57
58        let length = self.data_length + self.size();
59        if length > u16::MAX as usize {
60            return Err(invalid_field_err!("length", "fastpath PDU length is too big"));
61        }
62
63        if self.forced_long_length {
64            // Preserve same layout for header as received
65            per::write_long_length(dst, length as u16);
66        } else {
67            per::write_length(dst, length as u16);
68        }
69
70        Ok(())
71    }
72
73    fn name(&self) -> &'static str {
74        Self::NAME
75    }
76
77    fn size(&self) -> usize {
78        if self.forced_long_length {
79            Self::FIXED_PART_SIZE + per::U16_SIZE
80        } else {
81            self.minimal_size()
82        }
83    }
84}
85
86impl<'de> Decode<'de> for FastPathHeader {
87    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
88        ensure_fixed_part_size!(in: src);
89
90        let header = src.read_u8();
91        let flags = EncryptionFlags::from_bits_truncate(header.get_bits(6..8));
92
93        let (length, sizeof_length) = per::read_length(src).map_err(|e| {
94            DecodeError::invalid_field("", "length", "Invalid encoded fast path PDU length").with_source(e)
95        })?;
96        if (length as usize) < sizeof_length + Self::FIXED_PART_SIZE {
97            return Err(invalid_field_err!(
98                "length",
99                "received fastpath PDU length is smaller than header size"
100            ));
101        }
102        let data_length = length as usize - sizeof_length - Self::FIXED_PART_SIZE;
103        // Detect case, when received packet has non-optimal packet length packing
104        let forced_long_length = per::sizeof_length(length) != sizeof_length;
105
106        Ok(FastPathHeader {
107            flags,
108            data_length,
109            forced_long_length,
110        })
111    }
112}
113
114/// TS_FP_UPDATE
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub struct FastPathUpdatePdu<'a> {
117    pub fragmentation: Fragmentation,
118    pub update_code: UpdateCode,
119    pub compression_flags: Option<CompressionFlags>,
120    // NOTE: always Some when compression flags is Some
121    pub compression_type: Option<CompressionType>,
122    pub data: &'a [u8],
123}
124
125impl FastPathUpdatePdu<'_> {
126    const NAME: &'static str = "TS_FP_UPDATE";
127    const FIXED_PART_SIZE: usize = 1 /* header */;
128}
129
130impl Encode for FastPathUpdatePdu<'_> {
131    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
132        ensure_size!(in: dst, size: self.size());
133
134        if self.data.len() > u16::MAX as usize {
135            return Err(invalid_field_err!("data", "fastpath PDU data is too big"));
136        }
137
138        let mut header = 0u8;
139        header.set_bits(0..4, self.update_code.to_u8().unwrap());
140        header.set_bits(4..6, self.fragmentation.to_u8().unwrap());
141
142        dst.write_u8(header);
143
144        if self.compression_flags.is_some() {
145            header.set_bits(6..8, Compression::COMPRESSION_USED.bits());
146            let compression_flags_with_type = self.compression_flags.map(|f| f.bits()).unwrap_or(0)
147                | self.compression_type.and_then(|f| f.to_u8()).unwrap_or(0);
148            dst.write_u8(compression_flags_with_type);
149        }
150
151        dst.write_u16(self.data.len() as u16);
152        dst.write_slice(self.data);
153
154        Ok(())
155    }
156
157    fn name(&self) -> &'static str {
158        Self::NAME
159    }
160
161    fn size(&self) -> usize {
162        let compression_flags_size = if self.compression_flags.is_some() { 1 } else { 0 };
163
164        Self::FIXED_PART_SIZE + compression_flags_size + 2 /* len */ + self.data.len()
165    }
166}
167
168impl<'de> Decode<'de> for FastPathUpdatePdu<'de> {
169    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
170        ensure_fixed_part_size!(in: src);
171
172        let header = src.read_u8();
173
174        let update_code = header.get_bits(0..4);
175        let update_code = UpdateCode::from_u8(update_code)
176            .ok_or_else(|| invalid_field_err!("updateHeader", "Invalid update code"))?;
177
178        let fragmentation = header.get_bits(4..6);
179        let fragmentation = Fragmentation::from_u8(fragmentation)
180            .ok_or_else(|| invalid_field_err!("updateHeader", "Invalid fragmentation"))?;
181
182        let compression = Compression::from_bits_truncate(header.get_bits(6..8));
183
184        let (compression_flags, compression_type) = if compression.contains(Compression::COMPRESSION_USED) {
185            let expected_size = 1 /* flags_with_type */ + 2 /* len */;
186            ensure_size!(in: src, size: expected_size);
187
188            let compression_flags_with_type = src.read_u8();
189            let compression_flags =
190                CompressionFlags::from_bits_truncate(compression_flags_with_type & !SHARE_DATA_HEADER_COMPRESSION_MASK);
191            let compression_type =
192                CompressionType::from_u8(compression_flags_with_type & SHARE_DATA_HEADER_COMPRESSION_MASK)
193                    .ok_or_else(|| invalid_field_err!("compressionFlags", "invalid compression type"))?;
194
195            (Some(compression_flags), Some(compression_type))
196        } else {
197            let expected_size = 2 /* len */;
198            ensure_size!(in: src, size: expected_size);
199
200            (None, None)
201        };
202
203        let data_length = src.read_u16() as usize;
204        ensure_size!(in: src, size: data_length);
205        let data = src.read_slice(data_length);
206
207        Ok(Self {
208            fragmentation,
209            update_code,
210            compression_flags,
211            compression_type,
212            data,
213        })
214    }
215}
216
217/// TS_FP_UPDATE data
218#[derive(Debug, Clone, PartialEq, Eq)]
219pub enum FastPathUpdate<'a> {
220    SurfaceCommands(Vec<SurfaceCommand<'a>>),
221    Bitmap(BitmapUpdateData<'a>),
222    Pointer(PointerUpdateData<'a>),
223}
224
225impl<'a> FastPathUpdate<'a> {
226    const NAME: &'static str = "TS_FP_UPDATE data";
227
228    pub fn decode_with_code(src: &'a [u8], code: UpdateCode) -> DecodeResult<Self> {
229        let mut cursor = ReadCursor::<'a>::new(src);
230        Self::decode_cursor_with_code(&mut cursor, code)
231    }
232
233    pub fn decode_cursor_with_code(src: &mut ReadCursor<'a>, code: UpdateCode) -> DecodeResult<Self> {
234        match code {
235            UpdateCode::SurfaceCommands => {
236                let mut commands = Vec::with_capacity(1);
237                while src.len() >= SURFACE_COMMAND_HEADER_SIZE {
238                    commands.push(decode_cursor::<SurfaceCommand<'_>>(src)?);
239                }
240
241                Ok(Self::SurfaceCommands(commands))
242            }
243            UpdateCode::Bitmap => Ok(Self::Bitmap(decode_cursor(src)?)),
244            UpdateCode::HiddenPointer => Ok(Self::Pointer(PointerUpdateData::SetHidden)),
245            UpdateCode::DefaultPointer => Ok(Self::Pointer(PointerUpdateData::SetDefault)),
246            UpdateCode::PositionPointer => Ok(Self::Pointer(PointerUpdateData::SetPosition(decode_cursor(src)?))),
247            UpdateCode::ColorPointer => {
248                let color = decode_cursor(src)?;
249                Ok(Self::Pointer(PointerUpdateData::Color(color)))
250            }
251            UpdateCode::CachedPointer => Ok(Self::Pointer(PointerUpdateData::Cached(decode_cursor(src)?))),
252            UpdateCode::NewPointer => Ok(Self::Pointer(PointerUpdateData::New(decode_cursor(src)?))),
253            UpdateCode::LargePointer => Ok(Self::Pointer(PointerUpdateData::Large(decode_cursor(src)?))),
254            _ => Err(invalid_field_err!("updateCode", "Invalid fast path update code")),
255        }
256    }
257
258    pub fn as_short_name(&self) -> &str {
259        match self {
260            Self::SurfaceCommands(_) => "Surface Commands",
261            Self::Bitmap(_) => "Bitmap",
262            Self::Pointer(_) => "Pointer",
263        }
264    }
265}
266
267impl Encode for FastPathUpdate<'_> {
268    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
269        ensure_size!(in: dst, size: self.size());
270
271        match self {
272            Self::SurfaceCommands(commands) => {
273                for command in commands {
274                    command.encode(dst)?;
275                }
276            }
277            Self::Bitmap(bitmap) => {
278                bitmap.encode(dst)?;
279            }
280            Self::Pointer(pointer) => match pointer {
281                PointerUpdateData::SetHidden => {}
282                PointerUpdateData::SetDefault => {}
283                PointerUpdateData::SetPosition(inner) => inner.encode(dst)?,
284                PointerUpdateData::Color(inner) => inner.encode(dst)?,
285                PointerUpdateData::Cached(inner) => inner.encode(dst)?,
286                PointerUpdateData::New(inner) => inner.encode(dst)?,
287                PointerUpdateData::Large(inner) => inner.encode(dst)?,
288            },
289        }
290
291        Ok(())
292    }
293
294    fn name(&self) -> &'static str {
295        Self::NAME
296    }
297
298    fn size(&self) -> usize {
299        match self {
300            Self::SurfaceCommands(commands) => commands.iter().map(|c| c.size()).sum::<usize>(),
301            Self::Bitmap(bitmap) => bitmap.size(),
302            Self::Pointer(pointer) => match pointer {
303                PointerUpdateData::SetHidden => 0,
304                PointerUpdateData::SetDefault => 0,
305                PointerUpdateData::SetPosition(inner) => inner.size(),
306                PointerUpdateData::Color(inner) => inner.size(),
307                PointerUpdateData::Cached(inner) => inner.size(),
308                PointerUpdateData::New(inner) => inner.size(),
309                PointerUpdateData::Large(inner) => inner.size(),
310            },
311        }
312    }
313}
314
315#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
316pub enum UpdateCode {
317    Orders = 0x0,
318    Bitmap = 0x1,
319    Palette = 0x2,
320    Synchronize = 0x3,
321    SurfaceCommands = 0x4,
322    HiddenPointer = 0x5,
323    DefaultPointer = 0x6,
324    PositionPointer = 0x8,
325    ColorPointer = 0x9,
326    CachedPointer = 0xa,
327    NewPointer = 0xb,
328    LargePointer = 0xc,
329}
330
331impl From<&FastPathUpdate<'_>> for UpdateCode {
332    fn from(update: &FastPathUpdate<'_>) -> Self {
333        match update {
334            FastPathUpdate::SurfaceCommands(_) => Self::SurfaceCommands,
335            FastPathUpdate::Bitmap(_) => Self::Bitmap,
336            FastPathUpdate::Pointer(action) => match action {
337                PointerUpdateData::SetHidden => Self::HiddenPointer,
338                PointerUpdateData::SetDefault => Self::DefaultPointer,
339                PointerUpdateData::SetPosition(_) => Self::PositionPointer,
340                PointerUpdateData::Color(_) => Self::ColorPointer,
341                PointerUpdateData::Cached(_) => Self::CachedPointer,
342                PointerUpdateData::New(_) => Self::NewPointer,
343                PointerUpdateData::Large(_) => Self::LargePointer,
344            },
345        }
346    }
347}
348
349#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
350pub enum Fragmentation {
351    Single = 0x0,
352    Last = 0x1,
353    First = 0x2,
354    Next = 0x3,
355}
356
357bitflags! {
358    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
359    pub struct EncryptionFlags: u8 {
360        const SECURE_CHECKSUM = 0x1;
361        const ENCRYPTED = 0x2;
362    }
363}
364
365bitflags! {
366    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
367    pub struct Compression: u8 {
368        const COMPRESSION_USED = 0x2;
369    }
370}