Skip to main content

nt_string/unicode_string/
str.rs

1// Copyright 2023-2026 Colin Finck <colin@reactos.org>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use core::cmp::Ordering;
5use core::iter::Copied;
6use core::marker::PhantomData;
7use core::slice::Iter;
8use core::{fmt, mem, slice};
9
10use widestring::{U16CStr, U16Str};
11
12use crate::error::{NtStringError, Result};
13use crate::helpers::{cmp_iter, RawNtString};
14
15use super::iter::{Chars, CharsLossy};
16
17/// An immutable reference to a `UNICODE_STRING` (equivalent of `&str`).
18///
19/// See the [module-level documentation](super) for more details.
20#[derive(Clone, Copy, Debug)]
21#[repr(transparent)]
22pub struct NtUnicodeStr<'a> {
23    raw: RawNtString<*const u16>,
24    _lifetime: PhantomData<&'a ()>,
25}
26
27impl<'a> NtUnicodeStr<'a> {
28    /// Returns a `*const NtUnicodeStr` pointer
29    /// (mainly for non-Rust interfaces that expect an immutable `UNICODE_STRING*`).
30    pub fn as_ptr(&self) -> *const Self {
31        self as *const Self
32    }
33
34    /// Returns a slice to the raw [`u16`] codepoints of the string.
35    pub fn as_slice(&self) -> &'a [u16] {
36        unsafe { slice::from_raw_parts(self.raw.buffer, self.len_in_elements()) }
37    }
38
39    /// Returns a [`U16Str`] reference for this string.
40    ///
41    /// The [`U16Str`] will only contain the characters and not the NUL terminator.
42    pub fn as_u16str(&self) -> &'a U16Str {
43        U16Str::from_slice(self.as_slice())
44    }
45
46    /// Returns the capacity (also known as "maximum length") of this string, in bytes.
47    pub fn capacity(&self) -> u16 {
48        self.raw.maximum_length
49    }
50
51    /// Returns the capacity (also known as "maximum length") of this string, in elements.
52    #[allow(unused)]
53    pub(crate) fn capacity_in_elements(&self) -> usize {
54        usize::from(self.raw.maximum_length) / mem::size_of::<u16>()
55    }
56
57    /// Returns an iterator over the [`char`]s of this string.
58    ///
59    /// As the string may contain invalid UTF-16 characters (unpaired surrogates), the returned iterator is an
60    /// iterator over `Result<char>`.
61    /// Unpaired surrogates are returned as an [`NtStringError::UnpairedUtf16Surrogate`] error.
62    /// If you would like a lossy iterator over [`char`]s directly, use [`chars_lossy`] instead.
63    ///
64    /// [`chars_lossy`]: Self::chars_lossy
65    pub fn chars(&self) -> Chars<'_> {
66        Chars::new(self)
67    }
68
69    /// Returns an iterator over the [`char`]s of this string.
70    ///
71    /// Any invalid UTF-16 characters (unpaired surrogates) are automatically replaced by
72    /// [`U+FFFD REPLACEMENT CHARACTER`] (�).
73    /// If you would like to treat them differently, use [`chars`] instead.
74    ///
75    /// [`chars`]: Self::chars
76    /// [`U+FFFD REPLACEMENT CHARACTER`]: std::char::REPLACEMENT_CHARACTER
77    pub fn chars_lossy(&self) -> CharsLossy<'_> {
78        CharsLossy::new(self)
79    }
80
81    /// Creates an [`NtUnicodeStr`] from a [`u16`] string buffer, a byte length of the string,
82    /// and a buffer capacity in bytes (also known as "maximum length").
83    ///
84    /// The string is expected to consist of valid UTF-16 characters.
85    /// The buffer may or may not be NUL-terminated.
86    /// In any case, `length` does NOT include the terminating NUL character.
87    ///
88    /// This function is `unsafe` and you are advised to use any of the safe `try_from_*`
89    /// functions over this one if possible.
90    ///
91    /// # Safety
92    ///
93    /// Behavior is undefined if any of the following conditions are violated:
94    ///
95    /// * `length` must be less than or equal to `maximum_length`.
96    /// * `buffer` must be valid for at least `maximum_length` bytes.
97    /// * `buffer` must point to `length` consecutive properly initialized bytes.
98    /// * `buffer` must be valid for the duration of lifetime `'a`.
99    ///
100    /// [`try_from_u16`]: Self::try_from_u16
101    /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
102    pub const unsafe fn from_raw_parts(
103        buffer: *const u16,
104        length: u16,
105        maximum_length: u16,
106    ) -> Self {
107        debug_assert!(length <= maximum_length);
108
109        Self {
110            raw: RawNtString {
111                length,
112                maximum_length,
113                buffer,
114            },
115            _lifetime: PhantomData,
116        }
117    }
118
119    /// Returns `true` if this string has a length of zero, and `false` otherwise.
120    pub fn is_empty(&self) -> bool {
121        self.raw.length == 0
122    }
123
124    /// Returns the length of this string, in bytes.
125    ///
126    /// Note that a single character may occupy more than one byte.
127    /// In other words, the returned value might not be what a human considers the length of the string.
128    pub fn len(&self) -> u16 {
129        self.raw.length
130    }
131
132    /// Returns the length of this string, in elements.
133    ///
134    /// Note that a single character may occupy more than one element.
135    /// In other words, the returned value might not be what a human considers the length of the string.
136    pub(crate) fn len_in_elements(&self) -> usize {
137        usize::from(self.raw.length) / mem::size_of::<u16>()
138    }
139
140    /// Returns the remaining capacity of this string, in bytes.
141    #[allow(unused)]
142    pub(crate) fn remaining_capacity(&self) -> u16 {
143        debug_assert!(self.raw.maximum_length >= self.raw.length);
144        self.raw.maximum_length - self.raw.length
145    }
146
147    /// Creates an [`NtUnicodeStr`] from an existing [`u16`] string buffer without a terminating NUL character.
148    ///
149    /// The string is expected to consist of valid UTF-16 characters.
150    ///
151    /// The given buffer becomes the internal buffer of the [`NtUnicodeStr`] and therefore won't be NUL-terminated.
152    /// See the [module-level documentation](super) for the implications of that.
153    ///
154    /// This function has *O*(1) complexity.
155    ///
156    /// If you have a NUL-terminated buffer, either use [`try_from_u16_until_nul`] or convert from a [`U16CStr`]
157    /// using the corresponding [`TryFrom`] implementation.
158    ///
159    /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
160    pub fn try_from_u16(buffer: &'a [u16]) -> Result<Self> {
161        let length = Self::try_length_from_u16(buffer)?;
162
163        Ok(Self {
164            raw: RawNtString {
165                length,
166                maximum_length: length,
167                buffer: buffer.as_ptr(),
168            },
169            _lifetime: PhantomData,
170        })
171    }
172
173    /// Creates an [`NtUnicodeStr`] from an existing [`u16`] string buffer that contains at least one NUL character.
174    ///
175    /// The string is expected to consist of valid UTF-16 characters.
176    ///
177    /// The string will be terminated at the NUL character.
178    /// An [`NtStringError::NulNotFound`] error is returned if no NUL character could be found.
179    /// As a consequence, this function has *O*(*n*) complexity.
180    ///
181    /// The resulting internal `buffer` of [`NtUnicodeStr`] will be NUL-terminated.
182    /// See the [module-level documentation](super) for the implications of that.
183    ///
184    /// Use [`try_from_u16`] if you have a buffer that is not NUL-terminated.
185    /// You can also convert from a NUL-terminated [`U16CStr`] in *O*(1) via the corresponding [`TryFrom`] implementation.
186    ///
187    /// [`try_from_u16`]: Self::try_from_u16
188    pub fn try_from_u16_until_nul(buffer: &'a [u16]) -> Result<Self> {
189        let (length, maximum_length) = Self::try_length_from_u16_until_nul(buffer)?;
190
191        Ok(Self {
192            raw: RawNtString {
193                length,
194                maximum_length,
195                buffer: buffer.as_ptr(),
196            },
197            _lifetime: PhantomData,
198        })
199    }
200
201    pub(crate) fn try_length_from_u16(buffer: &[u16]) -> Result<u16> {
202        let elements = buffer.len();
203        let length_usize = elements
204            .checked_mul(mem::size_of::<u16>())
205            .ok_or(NtStringError::BufferSizeExceedsU16)?;
206        let length =
207            u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
208
209        Ok(length)
210    }
211
212    pub(crate) fn try_length_from_u16_cstr(u16cstr: &U16CStr) -> Result<(u16, u16)> {
213        let buffer = u16cstr.as_slice_with_nul();
214
215        // Include the terminating NUL character in `maximum_length` ...
216        let maximum_length_in_elements = buffer.len();
217        let maximum_length_in_bytes = maximum_length_in_elements
218            .checked_mul(mem::size_of::<u16>())
219            .ok_or(NtStringError::BufferSizeExceedsU16)?;
220        let maximum_length = u16::try_from(maximum_length_in_bytes)
221            .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
222
223        // ... but not in `length`
224        debug_assert!(maximum_length >= mem::size_of::<u16>() as u16);
225        let length = maximum_length - mem::size_of::<u16>() as u16;
226
227        Ok((length, maximum_length))
228    }
229
230    pub(crate) fn try_length_from_u16_until_nul(buffer: &[u16]) -> Result<(u16, u16)> {
231        match buffer.iter().position(|x| *x == 0) {
232            Some(nul_pos) => {
233                // Include the terminating NUL character in `maximum_length` ...
234                let maximum_elements = nul_pos
235                    .checked_add(1)
236                    .ok_or(NtStringError::BufferSizeExceedsU16)?;
237                let maximum_length_usize = maximum_elements
238                    .checked_mul(mem::size_of::<u16>())
239                    .ok_or(NtStringError::BufferSizeExceedsU16)?;
240                let maximum_length = u16::try_from(maximum_length_usize)
241                    .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
242
243                // ... but not in `length`
244                let length = maximum_length - mem::size_of::<u16>() as u16;
245
246                Ok((length, maximum_length))
247            }
248            None => Err(NtStringError::NulNotFound),
249        }
250    }
251
252    pub(crate) fn u16_iter(&'a self) -> Copied<Iter<'a, u16>> {
253        self.as_slice().iter().copied()
254    }
255}
256
257impl<'a> fmt::Display for NtUnicodeStr<'a> {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        for single_char in self.chars_lossy() {
260            single_char.fmt(f)?;
261        }
262
263        Ok(())
264    }
265}
266
267impl<'a> Eq for NtUnicodeStr<'a> {}
268
269impl<'a> Ord for NtUnicodeStr<'a> {
270    fn cmp(&self, other: &Self) -> Ordering {
271        cmp_iter(self.u16_iter(), other.u16_iter())
272    }
273}
274
275impl<'a, 'b> PartialEq<NtUnicodeStr<'a>> for NtUnicodeStr<'b> {
276    /// Checks that two strings are a (case-sensitive!) match.
277    fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
278        self.as_slice() == other.as_slice()
279    }
280}
281
282impl<'a> PartialEq<str> for NtUnicodeStr<'a> {
283    fn eq(&self, other: &str) -> bool {
284        cmp_iter(self.u16_iter(), other.encode_utf16()) == Ordering::Equal
285    }
286}
287
288impl<'a> PartialEq<NtUnicodeStr<'a>> for str {
289    fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
290        cmp_iter(self.encode_utf16(), other.u16_iter()) == Ordering::Equal
291    }
292}
293
294impl<'a> PartialEq<&str> for NtUnicodeStr<'a> {
295    fn eq(&self, other: &&str) -> bool {
296        cmp_iter(self.u16_iter(), other.encode_utf16()) == Ordering::Equal
297    }
298}
299
300impl<'a> PartialEq<NtUnicodeStr<'a>> for &str {
301    fn eq(&self, other: &NtUnicodeStr<'a>) -> bool {
302        cmp_iter(self.encode_utf16(), other.u16_iter()) == Ordering::Equal
303    }
304}
305
306impl<'a> PartialOrd for NtUnicodeStr<'a> {
307    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
308        Some(self.cmp(other))
309    }
310}
311
312impl<'a> PartialOrd<str> for NtUnicodeStr<'a> {
313    fn partial_cmp(&self, other: &str) -> Option<Ordering> {
314        Some(cmp_iter(self.u16_iter(), other.encode_utf16()))
315    }
316}
317
318impl<'a> PartialOrd<NtUnicodeStr<'a>> for str {
319    fn partial_cmp(&self, other: &NtUnicodeStr<'a>) -> Option<Ordering> {
320        Some(cmp_iter(self.encode_utf16(), other.u16_iter()))
321    }
322}
323
324impl<'a> PartialOrd<&str> for NtUnicodeStr<'a> {
325    fn partial_cmp(&self, other: &&str) -> Option<Ordering> {
326        Some(cmp_iter(self.u16_iter(), other.encode_utf16()))
327    }
328}
329
330impl<'a> PartialOrd<NtUnicodeStr<'a>> for &str {
331    fn partial_cmp(&self, other: &NtUnicodeStr<'a>) -> Option<Ordering> {
332        Some(cmp_iter(self.encode_utf16(), other.u16_iter()))
333    }
334}
335
336impl<'a> TryFrom<&'a U16CStr> for NtUnicodeStr<'a> {
337    type Error = NtStringError;
338
339    /// Converts a [`U16CStr`] reference into an [`NtUnicodeStr`].
340    ///
341    /// The internal buffer will be NUL-terminated.
342    /// See the [module-level documentation](super) for the implications of that.
343    fn try_from(value: &'a U16CStr) -> Result<Self> {
344        let (length, maximum_length) = Self::try_length_from_u16_cstr(value)?;
345
346        Ok(Self {
347            raw: RawNtString {
348                length,
349                maximum_length,
350                buffer: value.as_ptr(),
351            },
352            _lifetime: PhantomData,
353        })
354    }
355}
356
357impl<'a> TryFrom<&'a U16Str> for NtUnicodeStr<'a> {
358    type Error = NtStringError;
359
360    /// Converts a [`U16Str`] reference into an [`NtUnicodeStr`].
361    ///
362    /// The internal buffer will NOT be NUL-terminated.
363    /// See the [module-level documentation](super) for the implications of that.
364    fn try_from(value: &'a U16Str) -> Result<Self> {
365        Self::try_from_u16(value.as_slice())
366    }
367}