1use core::fmt::Debug;
2use core::ops::Add;
3
4use byteorder::{LittleEndian, ReadBytesExt as _};
5use ironrdp_core::{ensure_size, invalid_field_err, other_err, ReadCursor, WriteCursor};
6use num_derive::{FromPrimitive, ToPrimitive};
7
8use crate::{DecodeResult, EncodeResult};
9
10pub fn split_u64(value: u64) -> (u32, u32) {
11 let bytes = value.to_le_bytes();
12 let (low, high) = bytes.split_at(size_of::<u32>());
13 (
14 u32::from_le_bytes(low.try_into().unwrap()),
15 u32::from_le_bytes(high.try_into().unwrap()),
16 )
17}
18
19pub fn combine_u64(lo: u32, hi: u32) -> u64 {
20 let mut position_bytes = [0u8; size_of::<u64>()];
21 position_bytes[..size_of::<u32>()].copy_from_slice(&lo.to_le_bytes());
22 position_bytes[size_of::<u32>()..].copy_from_slice(&hi.to_le_bytes());
23 u64::from_le_bytes(position_bytes)
24}
25
26pub fn to_utf16_bytes(value: &str) -> Vec<u8> {
27 value
28 .encode_utf16()
29 .flat_map(|i| i.to_le_bytes().to_vec())
30 .collect::<Vec<u8>>()
31}
32
33pub fn from_utf16_bytes(mut value: &[u8]) -> String {
34 let mut value_u16 = vec![0x00; value.len() / 2];
35 value
36 .read_u16_into::<LittleEndian>(value_u16.as_mut())
37 .expect("read_u16_into cannot fail at this point");
38
39 String::from_utf16_lossy(value_u16.as_ref())
40}
41
42#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
43pub enum CharacterSet {
44 Ansi = 1,
45 Unicode = 2,
46}
47
48pub fn read_string_from_cursor(
54 cursor: &mut ReadCursor<'_>,
55 character_set: CharacterSet,
56 read_null_terminator: bool,
57) -> DecodeResult<String> {
58 let size = if character_set == CharacterSet::Unicode {
59 let code_units = if read_null_terminator {
60 cursor
62 .remaining()
63 .chunks_exact(2)
64 .position(|chunk| chunk == [0, 0])
65 .map(|null_terminator_pos| null_terminator_pos + 1) .unwrap_or(cursor.len() / 2)
67 } else {
68 cursor.len() / 2
70 };
71
72 code_units * 2
73 } else if read_null_terminator {
74 cursor
76 .remaining()
77 .iter()
78 .position(|&i| i == 0)
79 .map(|null_terminator_pos| null_terminator_pos + 1) .unwrap_or(cursor.len())
81 } else {
82 cursor.len()
84 };
85
86 if size == 0 {
88 return Ok(String::new());
89 }
90
91 let result = match character_set {
92 CharacterSet::Unicode => {
93 ensure_size!(ctx: "Decode string (UTF-16)", in: cursor, size: size);
94 let mut slice = cursor.read_slice(size);
95
96 let str_buffer = &mut slice;
97 let mut u16_buffer = vec![0u16; str_buffer.len() / 2];
98
99 str_buffer
100 .read_u16_into::<LittleEndian>(u16_buffer.as_mut())
101 .expect("BUG: str_buffer is always even for UTF16");
102
103 String::from_utf16(&u16_buffer)
104 .map_err(|_| invalid_field_err!("UTF16 decode", "buffer", "Failed to decode UTF16 string"))?
105 }
106 CharacterSet::Ansi => {
107 ensure_size!(ctx: "Decode string (UTF-8)", in: cursor, size: size);
108 let slice = cursor.read_slice(size);
109 String::from_utf8(slice.to_vec())
110 .map_err(|_| invalid_field_err!("UTF8 decode", "buffer", "Failed to decode UTF8 string"))?
111 }
112 };
113
114 Ok(result.trim_end_matches('\0').into())
115}
116
117pub fn decode_string(src: &[u8], character_set: CharacterSet, read_null_terminator: bool) -> DecodeResult<String> {
118 read_string_from_cursor(&mut ReadCursor::new(src), character_set, read_null_terminator)
119}
120
121pub fn read_multistring_from_cursor(
122 cursor: &mut ReadCursor<'_>,
123 character_set: CharacterSet,
124) -> DecodeResult<Vec<String>> {
125 let mut strings = Vec::new();
126
127 loop {
128 let string = read_string_from_cursor(cursor, character_set, true)?;
129 if string.is_empty() {
130 break;
133 }
134
135 strings.push(string);
136 }
137
138 Ok(strings)
139}
140
141pub fn encode_string(
142 dst: &mut [u8],
143 value: &str,
144 character_set: CharacterSet,
145 write_null_terminator: bool,
146) -> EncodeResult<usize> {
147 let (buffer, ctx) = match character_set {
148 CharacterSet::Unicode => {
149 let mut buffer = to_utf16_bytes(value);
150 if write_null_terminator {
151 buffer.extend_from_slice(&[0, 0]);
152 }
153 (buffer, "Encode string (UTF-16)")
154 }
155 CharacterSet::Ansi => {
156 let mut buffer = value.as_bytes().to_vec();
157 if write_null_terminator {
158 buffer.push(0);
159 }
160 (buffer, "Encode string (UTF-8)")
161 }
162 };
163
164 let len = buffer.len();
165
166 ensure_size!(ctx: ctx, in: dst, size: len);
167 dst[..len].copy_from_slice(&buffer);
168
169 Ok(len)
170}
171
172pub fn write_string_to_cursor(
173 cursor: &mut WriteCursor<'_>,
174 value: &str,
175 character_set: CharacterSet,
176 write_null_terminator: bool,
177) -> EncodeResult<()> {
178 let len = encode_string(cursor.remaining_mut(), value, character_set, write_null_terminator)?;
179 cursor.advance(len);
180 Ok(())
181}
182
183pub fn write_multistring_to_cursor(
184 cursor: &mut WriteCursor<'_>,
185 strings: &[String],
186 character_set: CharacterSet,
187) -> EncodeResult<()> {
188 for string in strings {
190 write_string_to_cursor(cursor, string, character_set, true)?;
191 }
192
193 match character_set {
195 CharacterSet::Unicode => {
196 ensure_size!(ctx: "Encode multistring (UTF-16)", in: cursor, size: 2);
197 cursor.write_u16(0)
198 }
199 CharacterSet::Ansi => {
200 ensure_size!(ctx: "Encode multistring (UTF-8)", in: cursor, size: 1);
201 cursor.write_u8(0)
202 }
203 }
204
205 Ok(())
206}
207
208pub fn encoded_str_len(value: &str, character_set: CharacterSet, with_null_terminator: bool) -> usize {
211 match character_set {
212 CharacterSet::Ansi => value.len() + if with_null_terminator { 1 } else { 0 },
213 CharacterSet::Unicode => value.encode_utf16().count() * 2 + if with_null_terminator { 2 } else { 0 },
214 }
215}
216
217pub fn encoded_multistring_len(strings: &[String], character_set: CharacterSet) -> usize {
220 strings
221 .iter()
222 .map(|s| encoded_str_len(s, character_set, true))
223 .sum::<usize>()
224 + if character_set == CharacterSet::Unicode { 2 } else { 1 }
225}
226
227pub trait SplitTo {
229 #[must_use]
230 fn split_to(&mut self, n: usize) -> Self;
231}
232
233impl<T> SplitTo for &[T] {
234 fn split_to(&mut self, n: usize) -> Self {
235 assert!(n <= self.len());
236
237 let (a, b) = self.split_at(n);
238 *self = b;
239
240 a
241 }
242}
243
244impl<T> SplitTo for &mut [T] {
245 fn split_to(&mut self, n: usize) -> Self {
246 assert!(n <= self.len());
247
248 let (a, b) = core::mem::take(self).split_at_mut(n);
249 *self = b;
250
251 a
252 }
253}
254
255pub trait CheckedAdd: Sized + Add<Output = Self> {
256 fn checked_add(self, rhs: Self) -> Option<Self>;
257}
258
259impl CheckedAdd for usize {
261 fn checked_add(self, rhs: Self) -> Option<Self> {
262 usize::checked_add(self, rhs)
263 }
264}
265
266impl CheckedAdd for u32 {
267 fn checked_add(self, rhs: Self) -> Option<Self> {
268 u32::checked_add(self, rhs)
269 }
270}
271
272pub fn checked_sum<T>(values: &[T]) -> DecodeResult<T>
274where
275 T: CheckedAdd + Copy + Debug,
276{
277 values.split_first().map_or_else(
278 || Err(other_err!("empty array provided to checked_sum")),
279 |(&first, rest)| {
280 rest.iter().try_fold(first, |acc, &val| {
281 acc.checked_add(val)
282 .ok_or_else(|| other_err!("overflow detected during addition"))
283 })
284 },
285 )
286}
287
288pub fn strict_sum<T>(values: &[T]) -> T
290where
291 T: CheckedAdd + Copy + Debug,
292{
293 checked_sum::<T>(values).expect("overflow detected during addition")
294}