ironrdp_pdu/basic_output/fast_path/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use bit_field::BitField as _;
5use bitflags::bitflags;
6use ironrdp_core::{
7    cast_length, decode_cursor, ensure_fixed_part_size, ensure_size, invalid_field_err, Decode, DecodeError,
8    DecodeResult, Encode, EncodeResult, InvalidFieldErr as _, ReadCursor, WriteCursor,
9};
10use num_derive::FromPrimitive;
11use num_traits::FromPrimitive as _;
12
13use super::bitmap::BitmapUpdateData;
14use super::pointer::PointerUpdateData;
15use super::surface_commands::{SurfaceCommand, SURFACE_COMMAND_HEADER_SIZE};
16use crate::per;
17use crate::rdp::client_info::CompressionType;
18use crate::rdp::headers::{CompressionFlags, SHARE_DATA_HEADER_COMPRESSION_MASK};
19
20/// Implements the Fast-Path RDP message header PDU.
21/// TS_FP_UPDATE_PDU
22#[expect(
23    clippy::partial_pub_fields,
24    reason = "this structure is used in the match expression in the integration tests"
25)]
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct FastPathHeader {
28    pub flags: EncryptionFlags,
29    pub data_length: usize,
30    forced_long_length: bool,
31}
32
33impl FastPathHeader {
34    const NAME: &'static str = "TS_FP_UPDATE_PDU header";
35    const FIXED_PART_SIZE: usize = 1 /* EncryptionFlags */;
36
37    pub fn new(flags: EncryptionFlags, data_length: usize) -> Self {
38        Self {
39            flags,
40            data_length,
41            forced_long_length: false,
42        }
43    }
44
45    fn minimal_size(&self) -> usize {
46        // it may then be +2 if > 0x7f
47        let len = self.data_length + Self::FIXED_PART_SIZE + 1;
48
49        Self::FIXED_PART_SIZE + per::sizeof_length(len)
50    }
51}
52
53impl Encode for FastPathHeader {
54    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
55        ensure_size!(in: dst, size: self.size());
56
57        let mut header = 0u8;
58        header.set_bits(0..2, 0); // fast-path action
59        header.set_bits(6..8, self.flags.bits());
60        dst.write_u8(header);
61
62        let length = self.data_length + self.size();
63        let length = cast_length!("length", length)?;
64
65        if self.forced_long_length {
66            // Preserve same layout for header as received
67            per::write_long_length(dst, length);
68        } else {
69            per::write_length(dst, length);
70        }
71
72        Ok(())
73    }
74
75    fn name(&self) -> &'static str {
76        Self::NAME
77    }
78
79    fn size(&self) -> usize {
80        if self.forced_long_length {
81            Self::FIXED_PART_SIZE + per::U16_SIZE
82        } else {
83            self.minimal_size()
84        }
85    }
86}
87
88impl<'de> Decode<'de> for FastPathHeader {
89    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
90        ensure_fixed_part_size!(in: src);
91
92        let header = src.read_u8();
93        let flags = EncryptionFlags::from_bits_truncate(header.get_bits(6..8));
94
95        let (length, sizeof_length) = per::read_length(src).map_err(|e| {
96            DecodeError::invalid_field("", "length", "Invalid encoded fast path PDU length").with_source(e)
97        })?;
98        let length = usize::from(length);
99        if length < sizeof_length + Self::FIXED_PART_SIZE {
100            return Err(invalid_field_err!(
101                "length",
102                "received fastpath PDU length is smaller than header size"
103            ));
104        }
105        let data_length = length - sizeof_length - Self::FIXED_PART_SIZE;
106        // Detect case, when received packet has non-optimal packet length packing.
107        let forced_long_length = per::sizeof_length(length) != sizeof_length;
108
109        Ok(FastPathHeader {
110            flags,
111            data_length,
112            forced_long_length,
113        })
114    }
115}
116
117/// TS_FP_UPDATE
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct FastPathUpdatePdu<'a> {
120    pub fragmentation: Fragmentation,
121    pub update_code: UpdateCode,
122    pub compression_flags: Option<CompressionFlags>,
123    // NOTE: always Some when compression flags is Some
124    pub compression_type: Option<CompressionType>,
125    pub data: &'a [u8],
126}
127
128impl FastPathUpdatePdu<'_> {
129    const NAME: &'static str = "TS_FP_UPDATE";
130    const FIXED_PART_SIZE: usize = 1 /* header */;
131}
132
133impl Encode for FastPathUpdatePdu<'_> {
134    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
135        ensure_size!(in: dst, size: self.size());
136
137        let data_len = cast_length!("data length", self.data.len())?;
138
139        let mut header = 0u8;
140        header.set_bits(0..4, self.update_code.as_u8());
141        header.set_bits(4..6, self.fragmentation.as_u8());
142
143        dst.write_u8(header);
144
145        if self.compression_flags.is_some() {
146            header.set_bits(6..8, Compression::COMPRESSION_USED.bits());
147            let compression_flags_with_type =
148                self.compression_flags.map(|f| f.bits()).unwrap_or(0) | self.compression_type.map_or(0, |f| f.as_u8());
149            dst.write_u8(compression_flags_with_type);
150        }
151
152        dst.write_u16(data_len);
153        dst.write_slice(self.data);
154
155        Ok(())
156    }
157
158    fn name(&self) -> &'static str {
159        Self::NAME
160    }
161
162    fn size(&self) -> usize {
163        let compression_flags_size = if self.compression_flags.is_some() { 1 } else { 0 };
164
165        Self::FIXED_PART_SIZE + compression_flags_size + 2 /* len */ + self.data.len()
166    }
167}
168
169impl<'de> Decode<'de> for FastPathUpdatePdu<'de> {
170    fn decode(src: &mut ReadCursor<'de>) -> DecodeResult<Self> {
171        ensure_fixed_part_size!(in: src);
172
173        let header = src.read_u8();
174
175        let update_code = header.get_bits(0..4);
176        let update_code = UpdateCode::from_u8(update_code)
177            .ok_or_else(|| invalid_field_err!("updateHeader", "Invalid update code"))?;
178
179        let fragmentation = header.get_bits(4..6);
180        let fragmentation = Fragmentation::from_u8(fragmentation)
181            .ok_or_else(|| invalid_field_err!("updateHeader", "Invalid fragmentation"))?;
182
183        let compression = Compression::from_bits_truncate(header.get_bits(6..8));
184
185        let (compression_flags, compression_type) = if compression.contains(Compression::COMPRESSION_USED) {
186            let expected_size = 1 /* flags_with_type */ + 2 /* len */;
187            ensure_size!(in: src, size: expected_size);
188
189            let compression_flags_with_type = src.read_u8();
190            let compression_flags =
191                CompressionFlags::from_bits_truncate(compression_flags_with_type & !SHARE_DATA_HEADER_COMPRESSION_MASK);
192            let compression_type =
193                CompressionType::from_u8(compression_flags_with_type & SHARE_DATA_HEADER_COMPRESSION_MASK)
194                    .ok_or_else(|| invalid_field_err!("compressionFlags", "invalid compression type"))?;
195
196            (Some(compression_flags), Some(compression_type))
197        } else {
198            let expected_size = 2 /* len */;
199            ensure_size!(in: src, size: expected_size);
200
201            (None, None)
202        };
203
204        let data_length = usize::from(src.read_u16());
205        ensure_size!(in: src, size: data_length);
206        let data = src.read_slice(data_length);
207
208        Ok(Self {
209            fragmentation,
210            update_code,
211            compression_flags,
212            compression_type,
213            data,
214        })
215    }
216}
217
218/// TS_FP_UPDATE data
219#[derive(Debug, Clone, PartialEq, Eq)]
220pub enum FastPathUpdate<'a> {
221    SurfaceCommands(Vec<SurfaceCommand<'a>>),
222    Bitmap(BitmapUpdateData<'a>),
223    Pointer(PointerUpdateData<'a>),
224}
225
226impl<'a> FastPathUpdate<'a> {
227    const NAME: &'static str = "TS_FP_UPDATE data";
228
229    pub fn decode_with_code(src: &'a [u8], code: UpdateCode) -> DecodeResult<Self> {
230        let mut cursor = ReadCursor::<'a>::new(src);
231        Self::decode_cursor_with_code(&mut cursor, code)
232    }
233
234    pub fn decode_cursor_with_code(src: &mut ReadCursor<'a>, code: UpdateCode) -> DecodeResult<Self> {
235        match code {
236            UpdateCode::SurfaceCommands => {
237                let mut commands = Vec::with_capacity(1);
238                while src.len() >= SURFACE_COMMAND_HEADER_SIZE {
239                    commands.push(decode_cursor::<SurfaceCommand<'_>>(src)?);
240                }
241
242                Ok(Self::SurfaceCommands(commands))
243            }
244            UpdateCode::Bitmap => Ok(Self::Bitmap(decode_cursor(src)?)),
245            UpdateCode::HiddenPointer => Ok(Self::Pointer(PointerUpdateData::SetHidden)),
246            UpdateCode::DefaultPointer => Ok(Self::Pointer(PointerUpdateData::SetDefault)),
247            UpdateCode::PositionPointer => Ok(Self::Pointer(PointerUpdateData::SetPosition(decode_cursor(src)?))),
248            UpdateCode::ColorPointer => {
249                let color = decode_cursor(src)?;
250                Ok(Self::Pointer(PointerUpdateData::Color(color)))
251            }
252            UpdateCode::CachedPointer => Ok(Self::Pointer(PointerUpdateData::Cached(decode_cursor(src)?))),
253            UpdateCode::NewPointer => Ok(Self::Pointer(PointerUpdateData::New(decode_cursor(src)?))),
254            UpdateCode::LargePointer => Ok(Self::Pointer(PointerUpdateData::Large(decode_cursor(src)?))),
255            _ => Err(invalid_field_err!("updateCode", "unsupported fast-path update code")),
256        }
257    }
258
259    pub fn as_short_name(&self) -> &str {
260        match self {
261            Self::SurfaceCommands(_) => "Surface Commands",
262            Self::Bitmap(_) => "Bitmap",
263            Self::Pointer(_) => "Pointer",
264        }
265    }
266}
267
268impl Encode for FastPathUpdate<'_> {
269    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
270        ensure_size!(in: dst, size: self.size());
271
272        match self {
273            Self::SurfaceCommands(commands) => {
274                for command in commands {
275                    command.encode(dst)?;
276                }
277            }
278            Self::Bitmap(bitmap) => {
279                bitmap.encode(dst)?;
280            }
281            Self::Pointer(pointer) => match pointer {
282                PointerUpdateData::SetHidden => {}
283                PointerUpdateData::SetDefault => {}
284                PointerUpdateData::SetPosition(inner) => inner.encode(dst)?,
285                PointerUpdateData::Color(inner) => inner.encode(dst)?,
286                PointerUpdateData::Cached(inner) => inner.encode(dst)?,
287                PointerUpdateData::New(inner) => inner.encode(dst)?,
288                PointerUpdateData::Large(inner) => inner.encode(dst)?,
289            },
290        }
291
292        Ok(())
293    }
294
295    fn name(&self) -> &'static str {
296        Self::NAME
297    }
298
299    fn size(&self) -> usize {
300        match self {
301            Self::SurfaceCommands(commands) => commands.iter().map(|c| c.size()).sum::<usize>(),
302            Self::Bitmap(bitmap) => bitmap.size(),
303            Self::Pointer(pointer) => match pointer {
304                PointerUpdateData::SetHidden => 0,
305                PointerUpdateData::SetDefault => 0,
306                PointerUpdateData::SetPosition(inner) => inner.size(),
307                PointerUpdateData::Color(inner) => inner.size(),
308                PointerUpdateData::Cached(inner) => inner.size(),
309                PointerUpdateData::New(inner) => inner.size(),
310                PointerUpdateData::Large(inner) => inner.size(),
311            },
312        }
313    }
314}
315
316#[repr(u8)]
317#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive)]
318pub enum UpdateCode {
319    Orders = 0x0,
320    Bitmap = 0x1,
321    Palette = 0x2,
322    Synchronize = 0x3,
323    SurfaceCommands = 0x4,
324    HiddenPointer = 0x5,
325    DefaultPointer = 0x6,
326    PositionPointer = 0x8,
327    ColorPointer = 0x9,
328    CachedPointer = 0xa,
329    NewPointer = 0xb,
330    LargePointer = 0xc,
331}
332
333impl UpdateCode {
334    #[expect(
335        clippy::as_conversions,
336        reason = "guarantees discriminant layout, and as is the only way to cast enum -> primitive"
337    )]
338    pub fn as_u8(self) -> u8 {
339        self as u8
340    }
341}
342
343impl From<&FastPathUpdate<'_>> for UpdateCode {
344    fn from(update: &FastPathUpdate<'_>) -> Self {
345        match update {
346            FastPathUpdate::SurfaceCommands(_) => Self::SurfaceCommands,
347            FastPathUpdate::Bitmap(_) => Self::Bitmap,
348            FastPathUpdate::Pointer(action) => match action {
349                PointerUpdateData::SetHidden => Self::HiddenPointer,
350                PointerUpdateData::SetDefault => Self::DefaultPointer,
351                PointerUpdateData::SetPosition(_) => Self::PositionPointer,
352                PointerUpdateData::Color(_) => Self::ColorPointer,
353                PointerUpdateData::Cached(_) => Self::CachedPointer,
354                PointerUpdateData::New(_) => Self::NewPointer,
355                PointerUpdateData::Large(_) => Self::LargePointer,
356            },
357        }
358    }
359}
360
361#[repr(u8)]
362#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive)]
363pub enum Fragmentation {
364    Single = 0x0,
365    Last = 0x1,
366    First = 0x2,
367    Next = 0x3,
368}
369
370impl Fragmentation {
371    #[expect(
372        clippy::as_conversions,
373        reason = "guarantees discriminant layout, and as is the only way to cast enum -> primitive"
374    )]
375    pub fn as_u8(self) -> u8 {
376        self as u8
377    }
378}
379
380bitflags! {
381    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
382    pub struct EncryptionFlags: u8 {
383        const SECURE_CHECKSUM = 0x1;
384        const ENCRYPTED = 0x2;
385    }
386}
387
388bitflags! {
389    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
390    pub struct Compression: u8 {
391        const COMPRESSION_USED = 0x2;
392    }
393}