ironrdp_pdu/
utils.rs

1use core::fmt::Debug;
2use core::ops::Add;
3
4use byteorder::{LittleEndian, ReadBytesExt as _};
5use ironrdp_core::{ensure_size, invalid_field_err, other_err, ReadCursor, WriteCursor};
6use num_derive::{FromPrimitive, ToPrimitive};
7
8use crate::{DecodeResult, EncodeResult};
9
10pub fn split_u64(value: u64) -> (u32, u32) {
11    let bytes = value.to_le_bytes();
12    let (low, high) = bytes.split_at(size_of::<u32>());
13    (
14        u32::from_le_bytes(low.try_into().unwrap()),
15        u32::from_le_bytes(high.try_into().unwrap()),
16    )
17}
18
19pub fn combine_u64(lo: u32, hi: u32) -> u64 {
20    let mut position_bytes = [0u8; size_of::<u64>()];
21    position_bytes[..size_of::<u32>()].copy_from_slice(&lo.to_le_bytes());
22    position_bytes[size_of::<u32>()..].copy_from_slice(&hi.to_le_bytes());
23    u64::from_le_bytes(position_bytes)
24}
25
26pub fn to_utf16_bytes(value: &str) -> Vec<u8> {
27    value
28        .encode_utf16()
29        .flat_map(|i| i.to_le_bytes().to_vec())
30        .collect::<Vec<u8>>()
31}
32
33pub fn from_utf16_bytes(mut value: &[u8]) -> String {
34    let mut value_u16 = vec![0x00; value.len() / 2];
35    value
36        .read_u16_into::<LittleEndian>(value_u16.as_mut())
37        .expect("read_u16_into cannot fail at this point");
38
39    String::from_utf16_lossy(value_u16.as_ref())
40}
41
42#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
43pub enum CharacterSet {
44    Ansi = 1,
45    Unicode = 2,
46}
47
48// Read a string from the cursor, using the specified character set.
49//
50// If read_null_terminator is true, the string will be read until a null terminator is found.
51// Otherwise, the string will be read until the end of the cursor. If the next character is a null
52// terminator, an empty string will be returned (without consuming the null terminator).
53pub fn read_string_from_cursor(
54    cursor: &mut ReadCursor<'_>,
55    character_set: CharacterSet,
56    read_null_terminator: bool,
57) -> DecodeResult<String> {
58    let size = if character_set == CharacterSet::Unicode {
59        let code_units = if read_null_terminator {
60            // Find null or read all if null is not found
61            cursor
62                .remaining()
63                .chunks_exact(2)
64                .position(|chunk| chunk == [0, 0])
65                .map(|null_terminator_pos| null_terminator_pos + 1) // Read null code point
66                .unwrap_or(cursor.len() / 2)
67        } else {
68            // UTF16 uses 2 bytes per code unit, so we need to read an even number of bytes
69            cursor.len() / 2
70        };
71
72        code_units * 2
73    } else if read_null_terminator {
74        // Find null or read all if null is not found
75        cursor
76            .remaining()
77            .iter()
78            .position(|&i| i == 0)
79            .map(|null_terminator_pos| null_terminator_pos + 1) // Read null code point
80            .unwrap_or(cursor.len())
81    } else {
82        // Read all
83        cursor.len()
84    };
85
86    // Empty string, nothing to do
87    if size == 0 {
88        return Ok(String::new());
89    }
90
91    let result = match character_set {
92        CharacterSet::Unicode => {
93            ensure_size!(ctx: "Decode string (UTF-16)", in: cursor, size: size);
94            let mut slice = cursor.read_slice(size);
95
96            let str_buffer = &mut slice;
97            let mut u16_buffer = vec![0u16; str_buffer.len() / 2];
98
99            str_buffer
100                .read_u16_into::<LittleEndian>(u16_buffer.as_mut())
101                .expect("BUG: str_buffer is always even for UTF16");
102
103            String::from_utf16(&u16_buffer)
104                .map_err(|_| invalid_field_err!("UTF16 decode", "buffer", "Failed to decode UTF16 string"))?
105        }
106        CharacterSet::Ansi => {
107            ensure_size!(ctx: "Decode string (UTF-8)", in: cursor, size: size);
108            let slice = cursor.read_slice(size);
109            String::from_utf8(slice.to_vec())
110                .map_err(|_| invalid_field_err!("UTF8 decode", "buffer", "Failed to decode UTF8 string"))?
111        }
112    };
113
114    Ok(result.trim_end_matches('\0').into())
115}
116
117pub fn decode_string(src: &[u8], character_set: CharacterSet, read_null_terminator: bool) -> DecodeResult<String> {
118    read_string_from_cursor(&mut ReadCursor::new(src), character_set, read_null_terminator)
119}
120
121pub fn read_multistring_from_cursor(
122    cursor: &mut ReadCursor<'_>,
123    character_set: CharacterSet,
124) -> DecodeResult<Vec<String>> {
125    let mut strings = Vec::new();
126
127    loop {
128        let string = read_string_from_cursor(cursor, character_set, true)?;
129        if string.is_empty() {
130            // empty string indicates the end of the multi-string array
131            // (we hit two null terminators in a row)
132            break;
133        }
134
135        strings.push(string);
136    }
137
138    Ok(strings)
139}
140
141pub fn encode_string(
142    dst: &mut [u8],
143    value: &str,
144    character_set: CharacterSet,
145    write_null_terminator: bool,
146) -> EncodeResult<usize> {
147    let (buffer, ctx) = match character_set {
148        CharacterSet::Unicode => {
149            let mut buffer = to_utf16_bytes(value);
150            if write_null_terminator {
151                buffer.extend_from_slice(&[0, 0]);
152            }
153            (buffer, "Encode string (UTF-16)")
154        }
155        CharacterSet::Ansi => {
156            let mut buffer = value.as_bytes().to_vec();
157            if write_null_terminator {
158                buffer.push(0);
159            }
160            (buffer, "Encode string (UTF-8)")
161        }
162    };
163
164    let len = buffer.len();
165
166    ensure_size!(ctx: ctx, in: dst, size: len);
167    dst[..len].copy_from_slice(&buffer);
168
169    Ok(len)
170}
171
172pub fn write_string_to_cursor(
173    cursor: &mut WriteCursor<'_>,
174    value: &str,
175    character_set: CharacterSet,
176    write_null_terminator: bool,
177) -> EncodeResult<()> {
178    let len = encode_string(cursor.remaining_mut(), value, character_set, write_null_terminator)?;
179    cursor.advance(len);
180    Ok(())
181}
182
183pub fn write_multistring_to_cursor(
184    cursor: &mut WriteCursor<'_>,
185    strings: &[String],
186    character_set: CharacterSet,
187) -> EncodeResult<()> {
188    // Write each string to cursor, separated by a null terminator
189    for string in strings {
190        write_string_to_cursor(cursor, string, character_set, true)?;
191    }
192
193    // Write final null terminator signifying the end of the multi-string
194    match character_set {
195        CharacterSet::Unicode => {
196            ensure_size!(ctx: "Encode multistring (UTF-16)", in: cursor, size: 2);
197            cursor.write_u16(0)
198        }
199        CharacterSet::Ansi => {
200            ensure_size!(ctx: "Encode multistring (UTF-8)", in: cursor, size: 1);
201            cursor.write_u8(0)
202        }
203    }
204
205    Ok(())
206}
207
208/// Returns the length in bytes of the encoded value
209/// based on the passed CharacterSet and with_null_terminator flag.
210pub fn encoded_str_len(value: &str, character_set: CharacterSet, with_null_terminator: bool) -> usize {
211    match character_set {
212        CharacterSet::Ansi => value.len() + if with_null_terminator { 1 } else { 0 },
213        CharacterSet::Unicode => value.encode_utf16().count() * 2 + if with_null_terminator { 2 } else { 0 },
214    }
215}
216
217/// Returns the length in bytes of the encoded multi-string
218/// based on the passed CharacterSet.
219pub fn encoded_multistring_len(strings: &[String], character_set: CharacterSet) -> usize {
220    strings
221        .iter()
222        .map(|s| encoded_str_len(s, character_set, true))
223        .sum::<usize>()
224        + if character_set == CharacterSet::Unicode { 2 } else { 1 }
225}
226
227// FIXME: legacy trait
228pub trait SplitTo {
229    #[must_use]
230    fn split_to(&mut self, n: usize) -> Self;
231}
232
233impl<T> SplitTo for &[T] {
234    fn split_to(&mut self, n: usize) -> Self {
235        assert!(n <= self.len());
236
237        let (a, b) = self.split_at(n);
238        *self = b;
239
240        a
241    }
242}
243
244impl<T> SplitTo for &mut [T] {
245    fn split_to(&mut self, n: usize) -> Self {
246        assert!(n <= self.len());
247
248        let (a, b) = core::mem::take(self).split_at_mut(n);
249        *self = b;
250
251        a
252    }
253}
254
255pub trait CheckedAdd: Sized + Add<Output = Self> {
256    fn checked_add(self, rhs: Self) -> Option<Self>;
257}
258
259// Implement the trait for usize and u32
260impl CheckedAdd for usize {
261    fn checked_add(self, rhs: Self) -> Option<Self> {
262        usize::checked_add(self, rhs)
263    }
264}
265
266impl CheckedAdd for u32 {
267    fn checked_add(self, rhs: Self) -> Option<Self> {
268        u32::checked_add(self, rhs)
269    }
270}
271
272// Utility function for checked addition that returns a PduResult
273pub fn checked_sum<T>(values: &[T]) -> DecodeResult<T>
274where
275    T: CheckedAdd + Copy + Debug,
276{
277    values.split_first().map_or_else(
278        || Err(other_err!("empty array provided to checked_sum")),
279        |(&first, rest)| {
280            rest.iter().try_fold(first, |acc, &val| {
281                acc.checked_add(val)
282                    .ok_or_else(|| other_err!("overflow detected during addition"))
283            })
284        },
285    )
286}
287
288// Utility function that panics on overflow
289pub fn strict_sum<T>(values: &[T]) -> T
290where
291    T: CheckedAdd + Copy + Debug,
292{
293    checked_sum::<T>(values).expect("overflow detected during addition")
294}