ironrdp_pdu/
utils.rs

1use core::fmt::Debug;
2use core::ops::Add;
3
4use byteorder::{LittleEndian, ReadBytesExt as _};
5use ironrdp_core::{ensure_size, invalid_field_err, other_err, ReadCursor, WriteCursor};
6use num_derive::FromPrimitive;
7
8use crate::{DecodeResult, EncodeResult};
9
10pub fn split_u64(value: u64) -> (u32, u32) {
11    #[expect(clippy::missing_panics_doc, reason = "unreachable panic (checked integer downcast)")]
12    let low =
13        u32::try_from(value & 0xFFFF_FFFF).expect("masking with 0xFFFF_FFFF ensures that the value fits into u32");
14
15    #[expect(clippy::missing_panics_doc, reason = "unreachable panic (checked integer downcast)")]
16    let high = u32::try_from(value >> 32).expect("(u64 >> 32) fits into u32");
17
18    (low, high)
19}
20
21pub fn combine_u64(lo: u32, hi: u32) -> u64 {
22    let mut position_bytes = [0u8; size_of::<u64>()];
23    position_bytes[..size_of::<u32>()].copy_from_slice(&lo.to_le_bytes());
24    position_bytes[size_of::<u32>()..].copy_from_slice(&hi.to_le_bytes());
25    u64::from_le_bytes(position_bytes)
26}
27
28pub fn to_utf16_bytes(value: &str) -> Vec<u8> {
29    value
30        .encode_utf16()
31        .flat_map(|i| i.to_le_bytes().to_vec())
32        .collect::<Vec<u8>>()
33}
34
35pub fn from_utf16_bytes(mut value: &[u8]) -> String {
36    let mut value_u16 = vec![0x00; value.len() / 2];
37
38    #[expect(clippy::missing_panics_doc, reason = "unreachable panic (prior constrain)")]
39    value
40        .read_u16_into::<LittleEndian>(value_u16.as_mut())
41        .expect("read_u16_into cannot fail at this point");
42
43    String::from_utf16_lossy(value_u16.as_ref())
44}
45
46#[repr(u16)]
47#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive)]
48pub enum CharacterSet {
49    Ansi = 1,
50    Unicode = 2,
51}
52
53impl CharacterSet {
54    #[expect(
55        clippy::as_conversions,
56        reason = "guarantees discriminant layout, and as is the only way to cast enum -> primitive"
57    )]
58    pub fn as_u16(self) -> u16 {
59        self as u16
60    }
61}
62
63// Read a string from the cursor, using the specified character set.
64//
65// If read_null_terminator is true, the string will be read until a null terminator is found.
66// Otherwise, the string will be read until the end of the cursor. If the next character is a null
67// terminator, an empty string will be returned (without consuming the null terminator).
68pub fn read_string_from_cursor(
69    cursor: &mut ReadCursor<'_>,
70    character_set: CharacterSet,
71    read_null_terminator: bool,
72) -> DecodeResult<String> {
73    let size = if character_set == CharacterSet::Unicode {
74        let code_units = if read_null_terminator {
75            // Find null or read all if null is not found
76            cursor
77                .remaining()
78                .chunks_exact(2)
79                .position(|chunk| chunk == [0, 0])
80                .map(|null_terminator_pos| null_terminator_pos + 1) // Read null code point
81                .unwrap_or(cursor.len() / 2)
82        } else {
83            // UTF16 uses 2 bytes per code unit, so we need to read an even number of bytes
84            cursor.len() / 2
85        };
86
87        code_units * 2
88    } else if read_null_terminator {
89        // Find null or read all if null is not found
90        cursor
91            .remaining()
92            .iter()
93            .position(|&i| i == 0)
94            .map(|null_terminator_pos| null_terminator_pos + 1) // Read null code point
95            .unwrap_or(cursor.len())
96    } else {
97        // Read all
98        cursor.len()
99    };
100
101    // Empty string, nothing to do
102    if size == 0 {
103        return Ok(String::new());
104    }
105
106    let result = match character_set {
107        CharacterSet::Unicode => {
108            ensure_size!(ctx: "Decode string (UTF-16)", in: cursor, size: size);
109            let mut slice = cursor.read_slice(size);
110
111            let str_buffer = &mut slice;
112            let mut u16_buffer = vec![0u16; str_buffer.len() / 2];
113
114            #[expect(clippy::missing_panics_doc, reason = "unreachable panic (prior constrain)")]
115            str_buffer
116                .read_u16_into::<LittleEndian>(u16_buffer.as_mut())
117                .expect("BUG: str_buffer is always even for UTF16");
118
119            String::from_utf16(&u16_buffer)
120                .map_err(|_| invalid_field_err!("UTF16 decode", "buffer", "Failed to decode UTF16 string"))?
121        }
122        CharacterSet::Ansi => {
123            ensure_size!(ctx: "Decode string (UTF-8)", in: cursor, size: size);
124            let slice = cursor.read_slice(size);
125            String::from_utf8(slice.to_vec())
126                .map_err(|_| invalid_field_err!("UTF8 decode", "buffer", "Failed to decode UTF8 string"))?
127        }
128    };
129
130    Ok(result.trim_end_matches('\0').into())
131}
132
133pub fn decode_string(src: &[u8], character_set: CharacterSet, read_null_terminator: bool) -> DecodeResult<String> {
134    read_string_from_cursor(&mut ReadCursor::new(src), character_set, read_null_terminator)
135}
136
137pub fn read_multistring_from_cursor(
138    cursor: &mut ReadCursor<'_>,
139    character_set: CharacterSet,
140) -> DecodeResult<Vec<String>> {
141    let mut strings = Vec::new();
142
143    loop {
144        let string = read_string_from_cursor(cursor, character_set, true)?;
145        if string.is_empty() {
146            // empty string indicates the end of the multi-string array
147            // (we hit two null terminators in a row)
148            break;
149        }
150
151        strings.push(string);
152    }
153
154    Ok(strings)
155}
156
157pub fn encode_string(
158    dst: &mut [u8],
159    value: &str,
160    character_set: CharacterSet,
161    write_null_terminator: bool,
162) -> EncodeResult<usize> {
163    let (buffer, ctx) = match character_set {
164        CharacterSet::Unicode => {
165            let mut buffer = to_utf16_bytes(value);
166            if write_null_terminator {
167                buffer.extend_from_slice(&[0, 0]);
168            }
169            (buffer, "Encode string (UTF-16)")
170        }
171        CharacterSet::Ansi => {
172            let mut buffer = value.as_bytes().to_vec();
173            if write_null_terminator {
174                buffer.push(0);
175            }
176            (buffer, "Encode string (UTF-8)")
177        }
178    };
179
180    let len = buffer.len();
181
182    ensure_size!(ctx: ctx, in: dst, size: len);
183    dst[..len].copy_from_slice(&buffer);
184
185    Ok(len)
186}
187
188pub fn write_string_to_cursor(
189    cursor: &mut WriteCursor<'_>,
190    value: &str,
191    character_set: CharacterSet,
192    write_null_terminator: bool,
193) -> EncodeResult<()> {
194    let len = encode_string(cursor.remaining_mut(), value, character_set, write_null_terminator)?;
195    cursor.advance(len);
196    Ok(())
197}
198
199pub fn write_multistring_to_cursor(
200    cursor: &mut WriteCursor<'_>,
201    strings: &[String],
202    character_set: CharacterSet,
203) -> EncodeResult<()> {
204    // Write each string to cursor, separated by a null terminator
205    for string in strings {
206        write_string_to_cursor(cursor, string, character_set, true)?;
207    }
208
209    // Write final null terminator signifying the end of the multi-string
210    match character_set {
211        CharacterSet::Unicode => {
212            ensure_size!(ctx: "Encode multistring (UTF-16)", in: cursor, size: 2);
213            cursor.write_u16(0)
214        }
215        CharacterSet::Ansi => {
216            ensure_size!(ctx: "Encode multistring (UTF-8)", in: cursor, size: 1);
217            cursor.write_u8(0)
218        }
219    }
220
221    Ok(())
222}
223
224/// Returns the length in bytes of the encoded value
225/// based on the passed CharacterSet and with_null_terminator flag.
226pub fn encoded_str_len(value: &str, character_set: CharacterSet, with_null_terminator: bool) -> usize {
227    match character_set {
228        CharacterSet::Ansi => value.len() + if with_null_terminator { 1 } else { 0 },
229        CharacterSet::Unicode => value.encode_utf16().count() * 2 + if with_null_terminator { 2 } else { 0 },
230    }
231}
232
233/// Returns the length in bytes of the encoded multi-string
234/// based on the passed CharacterSet.
235pub fn encoded_multistring_len(strings: &[String], character_set: CharacterSet) -> usize {
236    strings
237        .iter()
238        .map(|s| encoded_str_len(s, character_set, true))
239        .sum::<usize>()
240        + if character_set == CharacterSet::Unicode { 2 } else { 1 }
241}
242
243// FIXME: legacy trait
244pub trait SplitTo {
245    #[must_use]
246    fn split_to(&mut self, n: usize) -> Self;
247}
248
249impl<T> SplitTo for &[T] {
250    fn split_to(&mut self, n: usize) -> Self {
251        assert!(n <= self.len());
252
253        let (a, b) = self.split_at(n);
254        *self = b;
255
256        a
257    }
258}
259
260impl<T> SplitTo for &mut [T] {
261    fn split_to(&mut self, n: usize) -> Self {
262        assert!(n <= self.len());
263
264        let (a, b) = core::mem::take(self).split_at_mut(n);
265        *self = b;
266
267        a
268    }
269}
270
271pub trait CheckedAdd: Sized + Add<Output = Self> {
272    fn checked_add(self, rhs: Self) -> Option<Self>;
273}
274
275// Implement the trait for usize and u32
276impl CheckedAdd for usize {
277    fn checked_add(self, rhs: Self) -> Option<Self> {
278        usize::checked_add(self, rhs)
279    }
280}
281
282impl CheckedAdd for u32 {
283    fn checked_add(self, rhs: Self) -> Option<Self> {
284        u32::checked_add(self, rhs)
285    }
286}
287
288// Utility function for checked addition that returns a PduResult
289pub fn checked_sum<T>(values: &[T]) -> DecodeResult<T>
290where
291    T: CheckedAdd + Copy + Debug,
292{
293    values.split_first().map_or_else(
294        || Err(other_err!("empty array provided to checked_sum")),
295        |(&first, rest)| {
296            rest.iter().try_fold(first, |acc, &val| {
297                acc.checked_add(val)
298                    .ok_or_else(|| other_err!("overflow detected during addition"))
299            })
300        },
301    )
302}
303
304/// Utility function that panics on overflow
305///
306/// # Panics
307///
308/// Panics if sum of values overflows.
309// FIXME: Is it really something we want to expose from ironrdp-pdu?
310pub fn strict_sum<T>(values: &[T]) -> T
311where
312    T: CheckedAdd + Copy + Debug,
313{
314    checked_sum::<T>(values).expect("overflow detected during addition")
315}