mssf_pal/
strings.rs

1// ------------------------------------------------------------
2// Copyright (c) Microsoft Corporation.  All rights reserved.
3// Licensed under the MIT License (MIT). See License.txt in the repo root for license information.
4// ------------------------------------------------------------
5
6use std::fmt::Write;
7
8#[repr(transparent)]
9#[derive(Clone, Copy, PartialEq, Eq, Debug)]
10pub struct PCWSTR(pub *const u16);
11
12impl AsRef<PCWSTR> for PCWSTR {
13    fn as_ref(&self) -> &Self {
14        self
15    }
16}
17
18impl windows_core::TypeKind for PCWSTR {
19    type TypeKind = windows_core::CopyType;
20}
21
22// Copied minimal impl from windows_core crate which is not available on linux.
23// This is used on windows as well instead of the original defs if you use mssf-pal.
24impl PCWSTR {
25    /// Construct a new `PCWSTR` from a raw pointer
26    pub const fn from_raw(ptr: *const u16) -> Self {
27        Self(ptr)
28    }
29
30    /// Construct a null `PCWSTR`
31    pub const fn null() -> Self {
32        Self(core::ptr::null())
33    }
34
35    /// Returns a raw pointer to the `PCWSTR`
36    pub const fn as_ptr(&self) -> *const u16 {
37        self.0
38    }
39
40    /// Checks whether the `PCWSTR` is null
41    pub fn is_null(&self) -> bool {
42        self.0.is_null()
43    }
44
45    /// String length without the trailing 0
46    ///
47    /// # Safety
48    ///
49    /// The `PCWSTR`'s pointer needs to be valid for reads up until and including the next `\0`.
50    pub unsafe fn len(&self) -> usize {
51        let mut len = 0;
52        let mut ptr = self.0;
53        while unsafe { ptr.read() } != 0 {
54            len += 1;
55            ptr = unsafe { ptr.add(1) };
56        }
57        len
58    }
59
60    /// Returns `true` if the string length is zero, and `false` otherwise.
61    ///
62    /// # Safety
63    ///
64    /// The `PCWSTR`'s pointer needs to be valid for reads up until and including the next `\0`.
65    pub unsafe fn is_empty(&self) -> bool {
66        unsafe { self.len() == 0 }
67    }
68
69    /// String data without the trailing 0
70    ///
71    /// # Safety
72    ///
73    /// The `PCWSTR`'s pointer needs to be valid for reads up until and including the next `\0`.
74    pub unsafe fn as_wide(&self) -> &[u16] {
75        unsafe { core::slice::from_raw_parts(self.0, self.len()) }
76    }
77}
78
79impl Default for PCWSTR {
80    fn default() -> Self {
81        Self::null()
82    }
83}
84
85#[repr(transparent)]
86#[derive(Clone, Copy, PartialEq, Eq, Debug)]
87pub struct PCSTR(pub *const u8);
88
89impl AsRef<PCSTR> for PCSTR {
90    fn as_ref(&self) -> &Self {
91        self
92    }
93}
94
95impl windows_core::TypeKind for PCSTR {
96    type TypeKind = windows_core::CopyType;
97}
98
99/// WString is the utf16 string, similar to std::wstring in cpp.
100/// It is used for passing utf16 string buffers between Rust and COM.
101// The inner buffer is null terminated u16 vec.
102#[derive(Clone, PartialEq, Eq, Default, Hash)]
103pub struct WString(Option<Vec<u16>>);
104const EMPTY: [u16; 1] = [0];
105
106impl WString {
107    /// creates an empty string
108    pub const fn new() -> Self {
109        Self(None)
110    }
111
112    /// returns if the string is empty
113    pub const fn is_empty(&self) -> bool {
114        self.0.is_none()
115    }
116
117    /// len is the utf16 len not including the null terminator bytes
118    pub fn len(&self) -> usize {
119        match self.0.as_ref() {
120            Some(v) => v.len() - 1,
121            None => 0,
122        }
123    }
124
125    /// Get the string as 16-bit wide characters (wchars).
126    pub fn as_wide(&self) -> &[u16] {
127        match self.0.as_ref() {
128            Some(v) => {
129                // remove the last null terminator
130                v.as_slice().split_last().unwrap().1
131            }
132            None => &[],
133        }
134    }
135
136    /// Get the contents of this `WString` as a String lossily.
137    pub fn to_string_lossy(&self) -> String {
138        String::from_utf16_lossy(self.as_wide())
139    }
140
141    /// Returns a raw pointer to the `WString` buffer.
142    pub fn as_ptr(&self) -> *const u16 {
143        match self.0.as_ref() {
144            Some(v) => v.as_ptr(),
145            None => EMPTY.as_ptr(), // This is not null pointer.
146        }
147    }
148
149    /// Returns the `PCWSTR` representation of this `WString` for FFI calls.
150    pub fn as_pcwstr(&self) -> PCWSTR {
151        match self.0.as_ref() {
152            Some(v) => PCWSTR::from_raw(v.as_ptr()),
153            None => PCWSTR::null(),
154        }
155    }
156
157    /// From slice without the null terminator.
158    pub fn from_wide(value: &[u16]) -> Self {
159        // TODO: avoid the clone for the iter.
160        unsafe { Self::from_wide_iter(value.iter().cloned(), value.len()) }
161    }
162
163    unsafe fn from_wide_iter<I: Iterator<Item = u16>>(iter: I, len: usize) -> Self {
164        if len == 0 {
165            return Self::new();
166        }
167        // append a null terminator. collect should allocate efficiently from iter.
168        let iter = iter.chain(EMPTY.as_ref().iter().cloned());
169        let v = iter.collect::<Vec<_>>();
170        Self(Some(v))
171    }
172}
173
174impl From<&str> for WString {
175    fn from(value: &str) -> Self {
176        unsafe { Self::from_wide_iter(value.encode_utf16(), value.len()) }
177    }
178}
179
180impl From<String> for WString {
181    fn from(value: String) -> Self {
182        value.as_str().into()
183    }
184}
185impl From<&String> for WString {
186    fn from(value: &String) -> Self {
187        value.as_str().into()
188    }
189}
190
191impl From<&PCWSTR> for WString {
192    /// Requires value points to valid memory location
193    /// Null is ok.
194    fn from(value: &PCWSTR) -> Self {
195        if value.is_null() {
196            Self::new()
197        } else {
198            Self::from_wide(unsafe { value.as_wide() })
199        }
200    }
201}
202
203impl From<PCWSTR> for WString {
204    fn from(value: PCWSTR) -> Self {
205        Self::from(&value)
206    }
207}
208
209impl core::fmt::Display for WString {
210    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
211        // convert u16 to char gracefully and write to formatter.
212        let wit = core::char::decode_utf16(self.as_wide().iter().cloned());
213        for c in wit {
214            match c {
215                Ok(c) => f.write_char(c)?,
216                Err(_) => f.write_char(core::char::REPLACEMENT_CHARACTER)?,
217            }
218        }
219        Ok(())
220    }
221}
222
223impl core::fmt::Debug for WString {
224    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225        write!(f, "\"{self}\"")
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use crate::PCWSTR;
232
233    use super::WString;
234
235    #[test]
236    fn string_test() {
237        let test_case = |s: &str| {
238            let h = WString::from(s);
239            assert_eq!(s.len(), h.len());
240            assert_eq!(s.is_empty(), h.is_empty());
241            assert_eq!(format!("{h}"), s);
242            assert_eq!(s, h.to_string_lossy());
243            assert_eq!(h.as_wide().len(), s.len());
244            let raw = h.as_ptr();
245            let h2 = WString::from(PCWSTR(raw));
246            assert_eq!(s, h2.to_string_lossy());
247            assert_eq!(h, h2);
248            assert_ne!(h, WString::from("dummy"));
249        };
250
251        test_case("hello");
252        test_case("s");
253        test_case("");
254    }
255}