nt_string/unicode_string/str.rs
1// Copyright 2023 Colin Finck <colin@reactos.org>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use core::cmp::Ordering;
5use core::iter::Copied;
6use core::marker::PhantomData;
7use core::slice::Iter;
8use core::{fmt, mem, slice};
9
10use widestring::{U16CStr, U16Str};
11
12use crate::error::{NtStringError, Result};
13use crate::helpers::{cmp_iter, RawNtString};
14
15use super::iter::{Chars, CharsLossy};
16
17/// An immutable reference to a `UNICODE_STRING` (equivalent of `&str`).
18///
19/// See the [module-level documentation](super) for more details.
20#[derive(Clone, Copy, Debug)]
21#[repr(transparent)]
22pub struct NtUnicodeStr<'a> {
23 raw: RawNtString<*const u16>,
24 _lifetime: PhantomData<&'a ()>,
25}
26
27impl<'a> NtUnicodeStr<'a> {
28 /// Returns a `*const NtUnicodeStr` pointer
29 /// (mainly for non-Rust interfaces that expect an immutable `UNICODE_STRING*`).
30 pub fn as_ptr(&self) -> *const Self {
31 self as *const Self
32 }
33
34 /// Returns a slice to the raw [`u16`] codepoints of the string.
35 pub fn as_slice(&self) -> &'a [u16] {
36 unsafe { slice::from_raw_parts(self.raw.buffer, self.len_in_elements()) }
37 }
38
39 /// Returns a [`U16Str`] reference for this string.
40 ///
41 /// The [`U16Str`] will only contain the characters and not the NUL terminator.
42 pub fn as_u16str(&self) -> &'a U16Str {
43 U16Str::from_slice(self.as_slice())
44 }
45
46 /// Returns the capacity (also known as "maximum length") of this string, in bytes.
47 pub fn capacity(&self) -> u16 {
48 self.raw.maximum_length
49 }
50
51 /// Returns the capacity (also known as "maximum length") of this string, in elements.
52 #[allow(unused)]
53 pub(crate) fn capacity_in_elements(&self) -> usize {
54 usize::from(self.raw.maximum_length) / mem::size_of::<u16>()
55 }
56
57 /// Returns an iterator over the [`char`]s of this string.
58 ///
59 /// As the string may contain invalid UTF-16 characters (unpaired surrogates), the returned iterator is an
60 /// iterator over `Result<char>`.
61 /// Unpaired surrogates are returned as an [`NtStringError::UnpairedUtf16Surrogate`] error.
62 /// If you would like a lossy iterator over [`char`]s directly, use [`chars_lossy`] instead.
63 ///
64 /// [`chars_lossy`]: Self::chars_lossy
65 pub fn chars(&self) -> Chars {
66 Chars::new(self)
67 }
68
69 /// Returns an iterator over the [`char`]s of this string.
70 ///
71 /// Any invalid UTF-16 characters (unpaired surrogates) are automatically replaced by
72 /// [`U+FFFD REPLACEMENT CHARACTER`] (�).
73 /// If you would like to treat them differently, use [`chars`] instead.
74 ///
75 /// [`chars`]: Self::chars
76 /// [`U+FFFD REPLACEMENT CHARACTER`]: std::char::REPLACEMENT_CHARACTER
77 pub fn chars_lossy(&self) -> CharsLossy {
78 CharsLossy::new(self)
79 }
80
81 /// Creates an [`NtUnicodeStr`] from a [`u16`] string buffer, a byte length of the string,
82 /// and a buffer capacity in bytes (also known as "maximum length").
83 ///
84 /// The string is expected to consist of valid UTF-16 characters.
85 /// The buffer may or may not be NUL-terminated.
86 /// In any case, `length` does NOT include the terminating NUL character.
87 ///
88 /// This function is `unsafe` and you are advised to use any of the safe `try_from_*`
89 /// functions over this one if possible.
90 ///
91 /// # Safety
92 ///
93 /// Behavior is undefined if any of the following conditions are violated:
94 ///
95 /// * `length` must be less than or equal to `maximum_length`.
96 /// * `buffer` must be valid for at least `maximum_length` bytes.
97 /// * `buffer` must point to `length` consecutive properly initialized bytes.
98 /// * `buffer` must be valid for the duration of lifetime `'a`.
99 ///
100 /// [`try_from_u16`]: Self::try_from_u16
101 /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
102 pub const unsafe fn from_raw_parts(
103 buffer: *const u16,
104 length: u16,
105 maximum_length: u16,
106 ) -> Self {
107 debug_assert!(length <= maximum_length);
108
109 Self {
110 raw: RawNtString {
111 length,
112 maximum_length,
113 buffer,
114 },
115 _lifetime: PhantomData,
116 }
117 }
118
119 /// Returns `true` if this string has a length of zero, and `false` otherwise.
120 pub fn is_empty(&self) -> bool {
121 self.raw.length == 0
122 }
123
124 /// Returns the length of this string, in bytes.
125 ///
126 /// Note that a single character may occupy more than one byte.
127 /// In other words, the returned value might not be what a human considers the length of the string.
128 pub fn len(&self) -> u16 {
129 self.raw.length
130 }
131
132 /// Returns the length of this string, in elements.
133 ///
134 /// Note that a single character may occupy more than one element.
135 /// In other words, the returned value might not be what a human considers the length of the string.
136 pub(crate) fn len_in_elements(&self) -> usize {
137 usize::from(self.raw.length) / mem::size_of::<u16>()
138 }
139
140 /// Returns the remaining capacity of this string, in bytes.
141 #[allow(unused)]
142 pub(crate) fn remaining_capacity(&self) -> u16 {
143 debug_assert!(self.raw.maximum_length >= self.raw.length);
144 self.raw.maximum_length - self.raw.length
145 }
146
147 /// Creates an [`NtUnicodeStr`] from an existing [`u16`] string buffer without a terminating NUL character.
148 ///
149 /// The string is expected to consist of valid UTF-16 characters.
150 ///
151 /// The given buffer becomes the internal buffer of the [`NtUnicodeStr`] and therefore won't be NUL-terminated.
152 /// See the [module-level documentation](super) for the implications of that.
153 ///
154 /// This function has *O*(1) complexity.
155 ///
156 /// If you have a NUL-terminated buffer, either use [`try_from_u16_until_nul`] or convert from a [`U16CStr`]
157 /// using the corresponding [`TryFrom`] implementation.
158 ///
159 /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
160 pub fn try_from_u16(buffer: &'a [u16]) -> Result<Self> {
161 let elements = buffer.len();
162 let length_usize = elements
163 .checked_mul(mem::size_of::<u16>())
164 .ok_or(NtStringError::BufferSizeExceedsU16)?;
165 let length =
166 u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
167
168 Ok(Self {
169 raw: RawNtString {
170 length,
171 maximum_length: length,
172 buffer: buffer.as_ptr(),
173 },
174 _lifetime: PhantomData,
175 })
176 }
177
178 /// Creates an [`NtUnicodeStr`] from an existing [`u16`] string buffer that contains at least one NUL character.
179 ///
180 /// The string is expected to consist of valid UTF-16 characters.
181 ///
182 /// The string will be terminated at the NUL character.
183 /// An [`NtStringError::NulNotFound`] error is returned if no NUL character could be found.
184 /// As a consequence, this function has *O*(*n*) complexity.
185 ///
186 /// The resulting internal `buffer` of [`NtUnicodeStr`] will be NUL-terminated.
187 /// See the [module-level documentation](super) for the implications of that.
188 ///
189 /// Use [`try_from_u16`] if you have a buffer that is not NUL-terminated.
190 /// You can also convert from a NUL-terminated [`U16CStr`] in *O*(1) via the corresponding [`TryFrom`] implementation.
191 ///
192 /// [`try_from_u16`]: Self::try_from_u16
193 pub fn try_from_u16_until_nul(buffer: &'a [u16]) -> Result<Self> {
194 let length;
195 let maximum_length;
196
197 match buffer.iter().position(|x| *x == 0) {
198 Some(nul_pos) => {
199 // Include the terminating NUL character in `maximum_length` ...
200 let maximum_elements = nul_pos
201 .checked_add(1)
202 .ok_or(NtStringError::BufferSizeExceedsU16)?;
203 let maximum_length_usize = maximum_elements
204 .checked_mul(mem::size_of::<u16>())
205 .ok_or(NtStringError::BufferSizeExceedsU16)?;
206 maximum_length = u16::try_from(maximum_length_usize)
207 .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
208
209 // ... but not in `length`
210 length = maximum_length - mem::size_of::<u16>() as u16;
211 }
212 None => return Err(NtStringError::NulNotFound),
213 };
214
215 Ok(Self {
216 raw: RawNtString {
217 length,
218 maximum_length,
219 buffer: buffer.as_ptr(),
220 },
221 _lifetime: PhantomData,
222 })
223 }
224
225 pub(crate) fn u16_iter(&'a self) -> Copied<Iter<'a, u16>> {
226 self.as_slice().iter().copied()
227 }
228}
229
230impl<'a> fmt::Display for NtUnicodeStr<'a> {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 for single_char in self.chars_lossy() {
233 single_char.fmt(f)?;
234 }
235
236 Ok(())
237 }
238}
239
240impl<'a> Eq for NtUnicodeStr<'a> {}
241
242impl<'a> Ord for NtUnicodeStr<'a> {
243 fn cmp(&self, other: &Self) -> Ordering {
244 cmp_iter(self.u16_iter(), other.u16_iter())
245 }
246}
247
248impl<'a, 'b> PartialEq<NtUnicodeStr<'a>> for NtUnicodeStr<'b> {
249 /// Checks that two strings are a (case-sensitive!) match.
250 fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
251 self.as_slice() == other.as_slice()
252 }
253}
254
255impl<'a> PartialEq<str> for NtUnicodeStr<'a> {
256 fn eq(&self, other: &str) -> bool {
257 cmp_iter(self.u16_iter(), other.encode_utf16()) == Ordering::Equal
258 }
259}
260
261impl<'a> PartialEq<NtUnicodeStr<'a>> for str {
262 fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
263 cmp_iter(self.encode_utf16(), other.u16_iter()) == Ordering::Equal
264 }
265}
266
267impl<'a> PartialEq<&str> for NtUnicodeStr<'a> {
268 fn eq(&self, other: &&str) -> bool {
269 cmp_iter(self.u16_iter(), other.encode_utf16()) == Ordering::Equal
270 }
271}
272
273impl<'a> PartialEq<NtUnicodeStr<'a>> for &str {
274 fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
275 cmp_iter(self.encode_utf16(), other.u16_iter()) == Ordering::Equal
276 }
277}
278
279impl<'a> PartialOrd for NtUnicodeStr<'a> {
280 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
281 Some(self.cmp(other))
282 }
283}
284
285impl<'a> PartialOrd<str> for NtUnicodeStr<'a> {
286 fn partial_cmp(&self, other: &str) -> Option<Ordering> {
287 Some(cmp_iter(self.u16_iter(), other.encode_utf16()))
288 }
289}
290
291impl<'a> PartialOrd<NtUnicodeStr<'a>> for str {
292 fn partial_cmp(&self, other: &NtUnicodeStr<'a>) -> Option<Ordering> {
293 Some(cmp_iter(self.encode_utf16(), other.u16_iter()))
294 }
295}
296
297impl<'a> PartialOrd<&str> for NtUnicodeStr<'a> {
298 fn partial_cmp(&self, other: &&str) -> Option<Ordering> {
299 Some(cmp_iter(self.u16_iter(), other.encode_utf16()))
300 }
301}
302
303impl<'a> PartialOrd<NtUnicodeStr<'a>> for &str {
304 fn partial_cmp(&self, other: &NtUnicodeStr<'a>) -> Option<Ordering> {
305 Some(cmp_iter(self.encode_utf16(), other.u16_iter()))
306 }
307}
308
309impl<'a> TryFrom<&'a U16CStr> for NtUnicodeStr<'a> {
310 type Error = NtStringError;
311
312 /// Converts a [`U16CStr`] reference into an [`NtUnicodeStr`].
313 ///
314 /// The internal buffer will be NUL-terminated.
315 /// See the [module-level documentation](super) for the implications of that.
316 fn try_from(value: &'a U16CStr) -> Result<Self> {
317 let buffer = value.as_slice_with_nul();
318
319 // Include the terminating NUL character in `maximum_length` ...
320 let maximum_length_in_elements = buffer.len();
321 let maximum_length_in_bytes = maximum_length_in_elements
322 .checked_mul(mem::size_of::<u16>())
323 .ok_or(NtStringError::BufferSizeExceedsU16)?;
324 let maximum_length = u16::try_from(maximum_length_in_bytes)
325 .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
326
327 // ... but not in `length`
328 debug_assert!(maximum_length >= mem::size_of::<u16>() as u16);
329 let length = maximum_length - mem::size_of::<u16>() as u16;
330
331 Ok(Self {
332 raw: RawNtString {
333 length,
334 maximum_length,
335 buffer: buffer.as_ptr(),
336 },
337 _lifetime: PhantomData,
338 })
339 }
340}
341
342impl<'a> TryFrom<&'a U16Str> for NtUnicodeStr<'a> {
343 type Error = NtStringError;
344
345 /// Converts a [`U16Str`] reference into an [`NtUnicodeStr`].
346 ///
347 /// The internal buffer will NOT be NUL-terminated.
348 /// See the [module-level documentation](super) for the implications of that.
349 fn try_from(value: &'a U16Str) -> Result<Self> {
350 Self::try_from_u16(value.as_slice())
351 }
352}