nt_string/unicode_string/
str.rs

1// Copyright 2023 Colin Finck <colin@reactos.org>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use core::cmp::Ordering;
5use core::iter::Copied;
6use core::marker::PhantomData;
7use core::slice::Iter;
8use core::{fmt, mem, slice};
9
10use widestring::{U16CStr, U16Str};
11
12use crate::error::{NtStringError, Result};
13use crate::helpers::{cmp_iter, RawNtString};
14
15use super::iter::{Chars, CharsLossy};
16
17/// An immutable reference to a `UNICODE_STRING` (equivalent of `&str`).
18///
19/// See the [module-level documentation](super) for more details.
20#[derive(Clone, Copy, Debug)]
21#[repr(transparent)]
22pub struct NtUnicodeStr<'a> {
23    raw: RawNtString<*const u16>,
24    _lifetime: PhantomData<&'a ()>,
25}
26
27impl<'a> NtUnicodeStr<'a> {
28    /// Returns a `*const NtUnicodeStr` pointer
29    /// (mainly for non-Rust interfaces that expect an immutable `UNICODE_STRING*`).
30    pub fn as_ptr(&self) -> *const Self {
31        self as *const Self
32    }
33
34    /// Returns a slice to the raw [`u16`] codepoints of the string.
35    pub fn as_slice(&self) -> &'a [u16] {
36        unsafe { slice::from_raw_parts(self.raw.buffer, self.len_in_elements()) }
37    }
38
39    /// Returns a [`U16Str`] reference for this string.
40    ///
41    /// The [`U16Str`] will only contain the characters and not the NUL terminator.
42    pub fn as_u16str(&self) -> &'a U16Str {
43        U16Str::from_slice(self.as_slice())
44    }
45
46    /// Returns the capacity (also known as "maximum length") of this string, in bytes.
47    pub fn capacity(&self) -> u16 {
48        self.raw.maximum_length
49    }
50
51    /// Returns the capacity (also known as "maximum length") of this string, in elements.
52    #[allow(unused)]
53    pub(crate) fn capacity_in_elements(&self) -> usize {
54        usize::from(self.raw.maximum_length) / mem::size_of::<u16>()
55    }
56
57    /// Returns an iterator over the [`char`]s of this string.
58    ///
59    /// As the string may contain invalid UTF-16 characters (unpaired surrogates), the returned iterator is an
60    /// iterator over `Result<char>`.
61    /// Unpaired surrogates are returned as an [`NtStringError::UnpairedUtf16Surrogate`] error.
62    /// If you would like a lossy iterator over [`char`]s directly, use [`chars_lossy`] instead.
63    ///
64    /// [`chars_lossy`]: Self::chars_lossy
65    pub fn chars(&self) -> Chars {
66        Chars::new(self)
67    }
68
69    /// Returns an iterator over the [`char`]s of this string.
70    ///
71    /// Any invalid UTF-16 characters (unpaired surrogates) are automatically replaced by
72    /// [`U+FFFD REPLACEMENT CHARACTER`] (�).
73    /// If you would like to treat them differently, use [`chars`] instead.
74    ///
75    /// [`chars`]: Self::chars
76    /// [`U+FFFD REPLACEMENT CHARACTER`]: std::char::REPLACEMENT_CHARACTER
77    pub fn chars_lossy(&self) -> CharsLossy {
78        CharsLossy::new(self)
79    }
80
81    /// Creates an [`NtUnicodeStr`] from a [`u16`] string buffer, a byte length of the string,
82    /// and a buffer capacity in bytes (also known as "maximum length").
83    ///
84    /// The string is expected to consist of valid UTF-16 characters.
85    /// The buffer may or may not be NUL-terminated.
86    /// In any case, `length` does NOT include the terminating NUL character.
87    ///
88    /// This function is `unsafe` and you are advised to use any of the safe `try_from_*`
89    /// functions over this one if possible.
90    ///
91    /// # Safety
92    ///
93    /// Behavior is undefined if any of the following conditions are violated:
94    ///
95    /// * `length` must be less than or equal to `maximum_length`.
96    /// * `buffer` must be valid for at least `maximum_length` bytes.
97    /// * `buffer` must point to `length` consecutive properly initialized bytes.
98    /// * `buffer` must be valid for the duration of lifetime `'a`.
99    ///
100    /// [`try_from_u16`]: Self::try_from_u16
101    /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
102    pub const unsafe fn from_raw_parts(
103        buffer: *const u16,
104        length: u16,
105        maximum_length: u16,
106    ) -> Self {
107        debug_assert!(length <= maximum_length);
108
109        Self {
110            raw: RawNtString {
111                length,
112                maximum_length,
113                buffer,
114            },
115            _lifetime: PhantomData,
116        }
117    }
118
119    /// Returns `true` if this string has a length of zero, and `false` otherwise.
120    pub fn is_empty(&self) -> bool {
121        self.raw.length == 0
122    }
123
124    /// Returns the length of this string, in bytes.
125    ///
126    /// Note that a single character may occupy more than one byte.
127    /// In other words, the returned value might not be what a human considers the length of the string.
128    pub fn len(&self) -> u16 {
129        self.raw.length
130    }
131
132    /// Returns the length of this string, in elements.
133    ///
134    /// Note that a single character may occupy more than one element.
135    /// In other words, the returned value might not be what a human considers the length of the string.
136    pub(crate) fn len_in_elements(&self) -> usize {
137        usize::from(self.raw.length) / mem::size_of::<u16>()
138    }
139
140    /// Returns the remaining capacity of this string, in bytes.
141    #[allow(unused)]
142    pub(crate) fn remaining_capacity(&self) -> u16 {
143        debug_assert!(self.raw.maximum_length >= self.raw.length);
144        self.raw.maximum_length - self.raw.length
145    }
146
147    /// Creates an [`NtUnicodeStr`] from an existing [`u16`] string buffer without a terminating NUL character.
148    ///
149    /// The string is expected to consist of valid UTF-16 characters.
150    ///
151    /// The given buffer becomes the internal buffer of the [`NtUnicodeStr`] and therefore won't be NUL-terminated.
152    /// See the [module-level documentation](super) for the implications of that.
153    ///
154    /// This function has *O*(1) complexity.
155    ///
156    /// If you have a NUL-terminated buffer, either use [`try_from_u16_until_nul`] or convert from a [`U16CStr`]
157    /// using the corresponding [`TryFrom`] implementation.
158    ///
159    /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
160    pub fn try_from_u16(buffer: &'a [u16]) -> Result<Self> {
161        let elements = buffer.len();
162        let length_usize = elements
163            .checked_mul(mem::size_of::<u16>())
164            .ok_or(NtStringError::BufferSizeExceedsU16)?;
165        let length =
166            u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
167
168        Ok(Self {
169            raw: RawNtString {
170                length,
171                maximum_length: length,
172                buffer: buffer.as_ptr(),
173            },
174            _lifetime: PhantomData,
175        })
176    }
177
178    /// Creates an [`NtUnicodeStr`] from an existing [`u16`] string buffer that contains at least one NUL character.
179    ///
180    /// The string is expected to consist of valid UTF-16 characters.
181    ///
182    /// The string will be terminated at the NUL character.
183    /// An [`NtStringError::NulNotFound`] error is returned if no NUL character could be found.
184    /// As a consequence, this function has *O*(*n*) complexity.
185    ///
186    /// The resulting internal `buffer` of [`NtUnicodeStr`] will be NUL-terminated.
187    /// See the [module-level documentation](super) for the implications of that.
188    ///
189    /// Use [`try_from_u16`] if you have a buffer that is not NUL-terminated.
190    /// You can also convert from a NUL-terminated [`U16CStr`] in *O*(1) via the corresponding [`TryFrom`] implementation.
191    ///
192    /// [`try_from_u16`]: Self::try_from_u16
193    pub fn try_from_u16_until_nul(buffer: &'a [u16]) -> Result<Self> {
194        let length;
195        let maximum_length;
196
197        match buffer.iter().position(|x| *x == 0) {
198            Some(nul_pos) => {
199                // Include the terminating NUL character in `maximum_length` ...
200                let maximum_elements = nul_pos
201                    .checked_add(1)
202                    .ok_or(NtStringError::BufferSizeExceedsU16)?;
203                let maximum_length_usize = maximum_elements
204                    .checked_mul(mem::size_of::<u16>())
205                    .ok_or(NtStringError::BufferSizeExceedsU16)?;
206                maximum_length = u16::try_from(maximum_length_usize)
207                    .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
208
209                // ... but not in `length`
210                length = maximum_length - mem::size_of::<u16>() as u16;
211            }
212            None => return Err(NtStringError::NulNotFound),
213        };
214
215        Ok(Self {
216            raw: RawNtString {
217                length,
218                maximum_length,
219                buffer: buffer.as_ptr(),
220            },
221            _lifetime: PhantomData,
222        })
223    }
224
225    pub(crate) fn u16_iter(&'a self) -> Copied<Iter<'a, u16>> {
226        self.as_slice().iter().copied()
227    }
228}
229
230impl<'a> fmt::Display for NtUnicodeStr<'a> {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        for single_char in self.chars_lossy() {
233            single_char.fmt(f)?;
234        }
235
236        Ok(())
237    }
238}
239
240impl<'a> Eq for NtUnicodeStr<'a> {}
241
242impl<'a> Ord for NtUnicodeStr<'a> {
243    fn cmp(&self, other: &Self) -> Ordering {
244        cmp_iter(self.u16_iter(), other.u16_iter())
245    }
246}
247
248impl<'a, 'b> PartialEq<NtUnicodeStr<'a>> for NtUnicodeStr<'b> {
249    /// Checks that two strings are a (case-sensitive!) match.
250    fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
251        self.as_slice() == other.as_slice()
252    }
253}
254
255impl<'a> PartialEq<str> for NtUnicodeStr<'a> {
256    fn eq(&self, other: &str) -> bool {
257        cmp_iter(self.u16_iter(), other.encode_utf16()) == Ordering::Equal
258    }
259}
260
261impl<'a> PartialEq<NtUnicodeStr<'a>> for str {
262    fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
263        cmp_iter(self.encode_utf16(), other.u16_iter()) == Ordering::Equal
264    }
265}
266
267impl<'a> PartialEq<&str> for NtUnicodeStr<'a> {
268    fn eq(&self, other: &&str) -> bool {
269        cmp_iter(self.u16_iter(), other.encode_utf16()) == Ordering::Equal
270    }
271}
272
273impl<'a> PartialEq<NtUnicodeStr<'a>> for &str {
274    fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
275        cmp_iter(self.encode_utf16(), other.u16_iter()) == Ordering::Equal
276    }
277}
278
279impl<'a> PartialOrd for NtUnicodeStr<'a> {
280    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
281        Some(self.cmp(other))
282    }
283}
284
285impl<'a> PartialOrd<str> for NtUnicodeStr<'a> {
286    fn partial_cmp(&self, other: &str) -> Option<Ordering> {
287        Some(cmp_iter(self.u16_iter(), other.encode_utf16()))
288    }
289}
290
291impl<'a> PartialOrd<NtUnicodeStr<'a>> for str {
292    fn partial_cmp(&self, other: &NtUnicodeStr<'a>) -> Option<Ordering> {
293        Some(cmp_iter(self.encode_utf16(), other.u16_iter()))
294    }
295}
296
297impl<'a> PartialOrd<&str> for NtUnicodeStr<'a> {
298    fn partial_cmp(&self, other: &&str) -> Option<Ordering> {
299        Some(cmp_iter(self.u16_iter(), other.encode_utf16()))
300    }
301}
302
303impl<'a> PartialOrd<NtUnicodeStr<'a>> for &str {
304    fn partial_cmp(&self, other: &NtUnicodeStr<'a>) -> Option<Ordering> {
305        Some(cmp_iter(self.encode_utf16(), other.u16_iter()))
306    }
307}
308
309impl<'a> TryFrom<&'a U16CStr> for NtUnicodeStr<'a> {
310    type Error = NtStringError;
311
312    /// Converts a [`U16CStr`] reference into an [`NtUnicodeStr`].
313    ///
314    /// The internal buffer will be NUL-terminated.
315    /// See the [module-level documentation](super) for the implications of that.
316    fn try_from(value: &'a U16CStr) -> Result<Self> {
317        let buffer = value.as_slice_with_nul();
318
319        // Include the terminating NUL character in `maximum_length` ...
320        let maximum_length_in_elements = buffer.len();
321        let maximum_length_in_bytes = maximum_length_in_elements
322            .checked_mul(mem::size_of::<u16>())
323            .ok_or(NtStringError::BufferSizeExceedsU16)?;
324        let maximum_length = u16::try_from(maximum_length_in_bytes)
325            .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
326
327        // ... but not in `length`
328        debug_assert!(maximum_length >= mem::size_of::<u16>() as u16);
329        let length = maximum_length - mem::size_of::<u16>() as u16;
330
331        Ok(Self {
332            raw: RawNtString {
333                length,
334                maximum_length,
335                buffer: buffer.as_ptr(),
336            },
337            _lifetime: PhantomData,
338        })
339    }
340}
341
342impl<'a> TryFrom<&'a U16Str> for NtUnicodeStr<'a> {
343    type Error = NtStringError;
344
345    /// Converts a [`U16Str`] reference into an [`NtUnicodeStr`].
346    ///
347    /// The internal buffer will NOT be NUL-terminated.
348    /// See the [module-level documentation](super) for the implications of that.
349    fn try_from(value: &'a U16Str) -> Result<Self> {
350        Self::try_from_u16(value.as_slice())
351    }
352}