nt_string/unicode_string/
string.rs1use ::alloc::alloc::{self, Layout};
5use ::alloc::string::String;
6use widestring::{U16CStr, U16Str};
7
8use core::cmp::Ordering;
9use core::iter::once;
10use core::ops::{Add, AddAssign, Deref, DerefMut};
11use core::{fmt, mem, ptr};
12
13use crate::error::{NtStringError, Result};
14use crate::helpers::RawNtString;
15use crate::traits::TryExtend;
16
17use super::{impl_eq, impl_partial_cmp, NtUnicodeStr, NtUnicodeStrMut};
18
19#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
23#[derive(Debug)]
24#[repr(transparent)]
25pub struct NtUnicodeString {
26 raw: RawNtString<*mut u16>,
27}
28
29impl NtUnicodeString {
30 pub fn new() -> Self {
35 Self {
36 raw: RawNtString {
37 length: 0,
38 maximum_length: 0,
39 buffer: ptr::null_mut(),
40 },
41 }
42 }
43
44 pub fn as_unicode_str_mut(&mut self) -> &mut NtUnicodeStrMut<'static> {
46 self.deref_mut()
47 }
48
49 fn layout(&self) -> Layout {
50 Layout::array::<u16>(self.capacity_in_elements()).unwrap()
51 }
52
53 pub fn try_from_u16(buffer: &[u16]) -> Result<Self> {
67 let unicode_str = NtUnicodeStr::try_from_u16(buffer)?;
68 Ok(Self::from(&unicode_str))
69 }
70
71 pub fn try_from_u16_until_nul(buffer: &[u16]) -> Result<Self> {
87 let unicode_str = NtUnicodeStr::try_from_u16_until_nul(buffer)?;
88 Ok(Self::from(&unicode_str))
89 }
90
91 pub fn try_push(&mut self, c: char) -> Result<()> {
102 let encoded_length = c.len_utf16();
107 let additional_elements = encoded_length + 1;
108
109 let additional = (additional_elements * mem::size_of::<u16>()) as u16;
111 self.try_reserve(additional)?;
112
113 let end_index = self.len_in_elements();
115 self.raw.length += additional;
116
117 let dest_slice = &mut self.as_mut_slice()[end_index..];
118 c.encode_utf16(dest_slice);
119
120 dest_slice[encoded_length] = 0;
122
123 self.raw.length -= mem::size_of::<u16>() as u16;
125
126 Ok(())
127 }
128
129 pub fn try_push_str(&mut self, s: &str) -> Result<()> {
140 let additional_elements = s
145 .encode_utf16()
146 .count()
147 .checked_add(1)
148 .ok_or(NtStringError::BufferSizeExceedsU16)?;
149
150 let additional_bytes = additional_elements
152 .checked_mul(mem::size_of::<u16>())
153 .ok_or(NtStringError::BufferSizeExceedsU16)?;
154 let additional =
155 u16::try_from(additional_bytes).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
156 self.try_reserve(additional)?;
157
158 let end_index = self.len_in_elements();
160 self.raw.length += additional;
161
162 for (string_item, utf16_item) in self.as_mut_slice()[end_index..]
163 .iter_mut()
164 .zip(s.encode_utf16().chain(once(0)))
165 {
166 *string_item = utf16_item;
167 }
168
169 self.raw.length -= mem::size_of::<u16>() as u16;
171
172 Ok(())
173 }
174
175 pub fn try_push_u16(&mut self, buffer: &[u16]) -> Result<()> {
195 let additional_elements = buffer
200 .len()
201 .checked_add(1)
202 .ok_or(NtStringError::BufferSizeExceedsU16)?;
203
204 let additional_bytes = additional_elements
206 .checked_mul(mem::size_of::<u16>())
207 .ok_or(NtStringError::BufferSizeExceedsU16)?;
208 let additional =
209 u16::try_from(additional_bytes).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
210 self.try_reserve(additional)?;
211
212 let end_index = self.len_in_elements();
214 self.raw.length += additional;
215
216 let dest_slice = &mut self.as_mut_slice()[end_index..];
217 dest_slice[..buffer.len()].copy_from_slice(buffer);
218
219 dest_slice[buffer.len()] = 0;
221
222 self.raw.length -= mem::size_of::<u16>() as u16;
224
225 Ok(())
226 }
227
228 pub fn try_push_u16_until_nul(&mut self, buffer: &[u16]) -> Result<()> {
252 match buffer.iter().position(|x| *x == 0) {
253 Some(nul_pos) => self.try_push_u16(&buffer[..nul_pos]),
254 None => Err(NtStringError::NulNotFound),
255 }
256 }
257
258 pub fn try_push_u16cstr(&mut self, u16cstr: &U16CStr) -> Result<()> {
271 self.try_push_u16(u16cstr.as_slice())
272 }
273
274 pub fn try_push_u16str(&mut self, u16str: &U16Str) -> Result<()> {
287 self.try_push_u16(u16str.as_slice())
288 }
289
290 pub fn try_reserve(&mut self, additional: u16) -> Result<()> {
297 if self.remaining_capacity() >= additional {
298 return Ok(());
299 }
300
301 let new_capacity = self
302 .len()
303 .checked_add(additional)
304 .ok_or(NtStringError::BufferSizeExceedsU16)?;
305
306 if self.raw.buffer.is_null() {
307 self.raw.maximum_length = new_capacity;
308 let new_layout = self.layout();
309
310 self.raw.buffer = unsafe { alloc::alloc(new_layout) } as *mut u16;
311 } else {
312 let old_layout = self.layout();
313
314 self.raw.buffer = unsafe {
315 alloc::realloc(
316 self.raw.buffer as *mut u8,
317 old_layout,
318 usize::from(new_capacity),
319 )
320 } as *mut u16;
321
322 self.raw.maximum_length = new_capacity;
323 }
324
325 Ok(())
326 }
327
328 pub fn with_capacity(capacity: u16) -> Self {
335 let mut string = Self::new();
336 string.try_reserve(capacity).unwrap();
337 string
338 }
339}
340
341impl Add<&str> for NtUnicodeString {
342 type Output = NtUnicodeString;
343
344 fn add(mut self, rhs: &str) -> Self::Output {
345 if let Err(e) = self.try_push_str(rhs) {
346 panic!("{e}");
347 }
348
349 self
350 }
351}
352
353impl Add<&U16CStr> for NtUnicodeString {
354 type Output = NtUnicodeString;
355
356 fn add(mut self, rhs: &U16CStr) -> Self::Output {
357 if let Err(e) = self.try_push_u16cstr(rhs) {
358 panic!("{e}");
359 }
360
361 self
362 }
363}
364
365impl Add<&U16Str> for NtUnicodeString {
366 type Output = NtUnicodeString;
367
368 fn add(mut self, rhs: &U16Str) -> Self::Output {
369 if let Err(e) = self.try_push_u16str(rhs) {
370 panic!("{e}");
371 }
372
373 self
374 }
375}
376
377impl AddAssign<&str> for NtUnicodeString {
378 fn add_assign(&mut self, rhs: &str) {
379 if let Err(e) = self.try_push_str(rhs) {
380 panic!("{e}");
381 }
382 }
383}
384
385impl AddAssign<&U16CStr> for NtUnicodeString {
386 fn add_assign(&mut self, rhs: &U16CStr) {
387 if let Err(e) = self.try_push_u16cstr(rhs) {
388 panic!("{e}");
389 }
390 }
391}
392
393impl AddAssign<&U16Str> for NtUnicodeString {
394 fn add_assign(&mut self, rhs: &U16Str) {
395 if let Err(e) = self.try_push_u16str(rhs) {
396 panic!("{e}");
397 }
398 }
399}
400
401impl Clone for NtUnicodeString {
402 fn clone(&self) -> Self {
406 NtUnicodeString::from(self.as_unicode_str())
407 }
408}
409
410impl Default for NtUnicodeString {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416impl Deref for NtUnicodeString {
417 type Target = NtUnicodeStrMut<'static>;
420
421 fn deref(&self) -> &Self::Target {
422 unsafe { mem::transmute(self) }
425 }
426}
427
428impl DerefMut for NtUnicodeString {
429 fn deref_mut(&mut self) -> &mut Self::Target {
430 unsafe { mem::transmute(self) }
433 }
434}
435
436impl fmt::Display for NtUnicodeString {
437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438 fmt::Display::fmt(self.deref(), f)
439 }
440}
441
442impl Drop for NtUnicodeString {
443 fn drop(&mut self) {
444 if !self.raw.buffer.is_null() {
445 let layout = self.layout();
446 unsafe { alloc::dealloc(self.raw.buffer as *mut u8, layout) }
447 }
448 }
449}
450
451impl Eq for NtUnicodeString {}
452
453impl From<char> for NtUnicodeString {
454 fn from(c: char) -> Self {
456 let mut string = Self::new();
457 string.try_push(c).unwrap();
458 string
459 }
460}
461
462impl<'a> From<&NtUnicodeStr<'a>> for NtUnicodeString {
463 fn from(unicode_str: &NtUnicodeStr) -> Self {
467 let mut new_string = Self::with_capacity(unicode_str.capacity());
468
469 if !unicode_str.is_empty() {
470 new_string.raw.length = unicode_str.len();
471 new_string
472 .as_mut_slice()
473 .copy_from_slice(unicode_str.as_slice());
474 }
475
476 new_string
477 }
478}
479
480impl Ord for NtUnicodeString {
481 fn cmp(&self, other: &Self) -> Ordering {
482 Ord::cmp(self.deref(), other.deref())
483 }
484}
485
486impl_eq! { NtUnicodeString, NtUnicodeString }
487impl_eq! { NtUnicodeStr<'a>, NtUnicodeString }
488impl_eq! { NtUnicodeString, NtUnicodeStr<'a> }
489impl_eq! { NtUnicodeStrMut<'a>, NtUnicodeString }
490impl_eq! { NtUnicodeString, NtUnicodeStrMut<'a> }
491impl_eq! { NtUnicodeString, str }
492impl_eq! { str, NtUnicodeString }
493impl_eq! { NtUnicodeString, &str }
494impl_eq! { &str, NtUnicodeString }
495
496impl_partial_cmp! { NtUnicodeString, NtUnicodeString }
497impl_partial_cmp! { NtUnicodeStr<'a>, NtUnicodeString }
498impl_partial_cmp! { NtUnicodeString, NtUnicodeStr<'a> }
499impl_partial_cmp! { NtUnicodeStrMut<'a>, NtUnicodeString }
500impl_partial_cmp! { NtUnicodeString, NtUnicodeStrMut<'a> }
501impl_partial_cmp! { NtUnicodeString, str }
502impl_partial_cmp! { str, NtUnicodeString }
503impl_partial_cmp! { NtUnicodeString, &str }
504impl_partial_cmp! { &str, NtUnicodeString }
505
506impl TryExtend<char> for NtUnicodeString {
507 type Error = NtStringError;
508
509 fn try_extend<I: IntoIterator<Item = char>>(&mut self, iter: I) -> Result<()> {
510 let iterator = iter.into_iter();
511 let (lower_bound, _) = iterator.size_hint();
512
513 let additional_elements = lower_bound + 1;
516
517 let additional_bytes = u16::try_from(additional_elements * mem::size_of::<u16>())
519 .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
520 self.try_reserve(additional_bytes)?;
521
522 for ch in iterator {
523 self.try_push(ch)?;
524 }
525
526 Ok(())
527 }
528
529 fn try_extend_one(&mut self, item: char) -> Result<()> {
530 self.try_push(item)
531 }
532}
533
534impl<'a> TryExtend<&'a str> for NtUnicodeString {
535 type Error = NtStringError;
536
537 fn try_extend<I: IntoIterator<Item = &'a str>>(&mut self, iter: I) -> Result<()> {
538 for s in iter.into_iter() {
539 self.try_push_str(s)?;
540 }
541
542 Ok(())
543 }
544}
545
546impl<'a> TryExtend<&'a U16CStr> for NtUnicodeString {
547 type Error = NtStringError;
548
549 fn try_extend<I: IntoIterator<Item = &'a U16CStr>>(&mut self, iter: I) -> Result<()> {
550 for s in iter.into_iter() {
551 self.try_push_u16cstr(s)?;
552 }
553
554 Ok(())
555 }
556}
557
558impl<'a> TryExtend<&'a U16Str> for NtUnicodeString {
559 type Error = NtStringError;
560
561 fn try_extend<I: IntoIterator<Item = &'a U16Str>>(&mut self, iter: I) -> Result<()> {
562 for s in iter.into_iter() {
563 self.try_push_u16str(s)?;
564 }
565
566 Ok(())
567 }
568}
569
570impl TryFrom<&str> for NtUnicodeString {
571 type Error = NtStringError;
572
573 fn try_from(s: &str) -> Result<Self> {
578 let mut string = Self::new();
579 string.try_push_str(s)?;
580 Ok(string)
581 }
582}
583
584impl TryFrom<String> for NtUnicodeString {
585 type Error = NtStringError;
586
587 fn try_from(s: String) -> Result<Self> {
592 NtUnicodeString::try_from(s.as_str())
593 }
594}
595
596impl TryFrom<&String> for NtUnicodeString {
597 type Error = NtStringError;
598
599 fn try_from(s: &String) -> Result<Self> {
604 NtUnicodeString::try_from(s.as_str())
605 }
606}
607
608impl TryFrom<&U16CStr> for NtUnicodeString {
609 type Error = NtStringError;
610
611 fn try_from(value: &U16CStr) -> Result<Self> {
616 let unicode_str = NtUnicodeStr::try_from(value)?;
617 Ok(Self::from(&unicode_str))
618 }
619}
620
621impl TryFrom<&U16Str> for NtUnicodeString {
622 type Error = NtStringError;
623
624 fn try_from(value: &U16Str) -> Result<Self> {
629 let unicode_str = NtUnicodeStr::try_from(value)?;
630 Ok(Self::from(&unicode_str))
631 }
632}
633
634#[cfg(test)]
635mod tests {
636 use alloc::vec::Vec;
637
638 use crate::error::NtStringError;
639 use crate::traits::TryExtend;
640 use crate::unicode_string::NtUnicodeString;
641
642 #[test]
643 fn test_add() {
644 let mut string = NtUnicodeString::new();
645 string += "๐";
646 assert_eq!(string, "๐");
647
648 let string2 = string + "๐";
649 assert_eq!(string2, "๐๐");
650 }
651
652 #[test]
653 fn test_chars() {
654 let moin = NtUnicodeString::try_from("Moin").unwrap();
657 assert_eq!(moin.capacity(), 10);
658 assert_eq!(moin.len(), 8);
659 let vec = moin.chars_lossy().collect::<Vec<char>>();
660 assert_eq!(vec, ['M', 'o', 'i', 'n']);
661
662 let ไปๆฅใฏ = NtUnicodeString::try_from("ไปๆฅใฏ").unwrap();
665 assert_eq!(ไปๆฅใฏ.capacity(), 8);
666 assert_eq!(ไปๆฅใฏ.len(), 6);
667 let vec = ไปๆฅใฏ.chars_lossy().collect::<Vec<char>>();
668 assert_eq!(vec, ['ไป', 'ๆฅ', 'ใฏ']);
669
670 let smile = NtUnicodeString::try_from("๐").unwrap();
673 assert_eq!(smile.capacity(), 6);
674 assert_eq!(smile.len(), 4);
675 let vec = smile.chars_lossy().collect::<Vec<char>>();
676 assert_eq!(vec, ['๐']);
677 }
678
679 #[test]
680 fn test_cmp() {
681 let a = NtUnicodeString::try_from("a").unwrap();
682 let b = NtUnicodeString::try_from("b").unwrap();
683 assert!(a < b);
684 }
685
686 #[test]
687 fn test_eq() {
688 let hello = NtUnicodeString::try_from("Hello").unwrap();
689 let hello_again = NtUnicodeString::try_from("Hello again").unwrap();
690 assert_ne!(hello, hello_again);
691
692 let mut hello_clone = hello.clone();
693 assert_eq!(hello, hello_clone);
694
695 hello_clone.try_reserve(42).unwrap();
696 assert_eq!(hello, hello_clone);
697 }
698
699 #[test]
700 fn test_extend_and_pop() {
701 let a_string = "a".repeat(32766);
704 let mut string = NtUnicodeString::try_from(a_string).unwrap();
705 assert_eq!(string.capacity(), 65534);
706 assert_eq!(string.len(), 65532);
707
708 assert_eq!(
710 string.try_extend(Some('b')),
711 Err(NtStringError::BufferSizeExceedsU16)
712 );
713
714 assert_eq!(string.pop(), Some(Ok('a')));
716 assert_eq!(string.capacity(), 65534);
717 assert_eq!(string.len(), 65530);
718 string.try_extend_one('c').unwrap();
719 assert_eq!(string.capacity(), 65534);
720 assert_eq!(string.len(), 65532);
721
722 assert_eq!(string.pop(), Some(Ok('c')));
724 assert_eq!(string.pop(), Some(Ok('a')));
725 assert_eq!(string.capacity(), 65534);
726 assert_eq!(string.len(), 65528);
727 string.try_extend_one('๐').unwrap();
728 assert_eq!(string.capacity(), 65534);
729 assert_eq!(string.len(), 65532);
730
731 assert_eq!(string.pop(), Some(Ok('๐')));
733 assert_eq!(string.pop(), Some(Ok('a')));
734 assert_eq!(string.capacity(), 65534);
735 assert_eq!(string.len(), 65526);
736 string.try_extend("def".chars()).unwrap();
737 assert_eq!(string.capacity(), 65534);
738 assert_eq!(string.len(), 65532);
739 }
740
741 #[test]
742 fn test_from_u16() {
743 let mut a_vec = "a".repeat(32768).encode_utf16().collect::<Vec<u16>>();
745 assert_eq!(
746 NtUnicodeString::try_from_u16(&a_vec),
747 Err(NtStringError::BufferSizeExceedsU16)
748 );
749
750 a_vec.pop();
752 let string = NtUnicodeString::try_from_u16(&a_vec).unwrap();
753 assert_eq!(string.capacity(), 65534);
754 assert_eq!(string.len(), 65534);
755
756 a_vec[4] = 0;
758 let string = NtUnicodeString::try_from_u16_until_nul(&a_vec).unwrap();
759 assert_eq!(string.capacity(), 10);
760 assert_eq!(string.len(), 8);
761 assert_eq!(string, "aaaa");
762 }
763
764 #[test]
765 fn test_push_str() {
766 let mut string = NtUnicodeString::new();
767 string.try_push_str("Hey").unwrap();
768 assert_eq!(string, "Hey");
769 assert_eq!(string.capacity(), 8);
770 assert_eq!(string.len(), 6);
771
772 string.try_push_str("Ho").unwrap();
773 assert_eq!(string, "HeyHo");
774 assert_eq!(string.capacity(), 12);
775 assert_eq!(string.len(), 10);
776 }
777
778 #[test]
779 fn test_reserve() {
780 let mut string = NtUnicodeString::new();
781 assert_eq!(string.capacity(), 0);
782
783 string.try_reserve(5).unwrap();
784 assert_eq!(string.capacity(), 5);
785
786 string.try_reserve(3).unwrap();
787 assert_eq!(string.capacity(), 5);
788
789 string.try_push_str("a").unwrap();
790 assert_eq!(string, "a");
791 assert_eq!(string.capacity(), 5);
792
793 string.try_push_str("b").unwrap();
794 assert_eq!(string, "ab");
795 assert_eq!(string.capacity(), 6);
796 }
797}