nt_string/unicode_string/
string.rs1use ::alloc::alloc::{self, Layout};
5use ::alloc::string::String;
6use widestring::{U16CStr, U16Str};
7
8use core::cmp::Ordering;
9use core::iter::once;
10use core::ops::{Add, AddAssign, Deref, DerefMut};
11use core::{fmt, mem, ptr};
12
13use crate::error::{NtStringError, Result};
14use crate::helpers::RawNtString;
15use crate::traits::TryExtend;
16
17use super::{impl_eq, impl_partial_cmp, NtUnicodeStr, NtUnicodeStrMut};
18
19#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
23#[derive(Debug)]
24#[repr(transparent)]
25pub struct NtUnicodeString {
26 raw: RawNtString<*mut u16>,
27}
28
29impl NtUnicodeString {
30 pub fn new() -> Self {
35 Self {
36 raw: RawNtString {
37 length: 0,
38 maximum_length: 0,
39 buffer: ptr::null_mut(),
40 },
41 }
42 }
43
44 pub fn as_unicode_str_mut(&mut self) -> &mut NtUnicodeStrMut<'static> {
46 self.deref_mut()
47 }
48
49 fn layout(&self) -> Layout {
50 Layout::array::<u16>(self.capacity_in_elements()).unwrap()
51 }
52
53 pub fn try_from_u16(buffer: &[u16]) -> Result<Self> {
67 let unicode_str = NtUnicodeStr::try_from_u16(buffer)?;
68 Ok(Self::from(&unicode_str))
69 }
70
71 pub fn try_from_u16_until_nul(buffer: &[u16]) -> Result<Self> {
87 let unicode_str = NtUnicodeStr::try_from_u16_until_nul(buffer)?;
88 Ok(Self::from(&unicode_str))
89 }
90
91 pub fn try_push(&mut self, c: char) -> Result<()> {
102 let encoded_length = c.len_utf16();
107 let additional_elements = encoded_length + 1;
108
109 let additional = (additional_elements * mem::size_of::<u16>()) as u16;
111 self.try_reserve(additional)?;
112
113 let end_index = self.len_in_elements();
115 self.raw.length += additional;
116
117 let dest_slice = &mut self.as_mut_slice()[end_index..];
118 c.encode_utf16(dest_slice);
119
120 dest_slice[encoded_length] = 0;
122
123 self.raw.length -= mem::size_of::<u16>() as u16;
125
126 Ok(())
127 }
128
129 pub fn try_push_str(&mut self, s: &str) -> Result<()> {
140 let additional_elements = s
145 .encode_utf16()
146 .count()
147 .checked_add(1)
148 .ok_or(NtStringError::BufferSizeExceedsU16)?;
149
150 let additional_bytes = additional_elements
152 .checked_mul(mem::size_of::<u16>())
153 .ok_or(NtStringError::BufferSizeExceedsU16)?;
154 let additional =
155 u16::try_from(additional_bytes).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
156 self.try_reserve(additional)?;
157
158 let end_index = self.len_in_elements();
160 self.raw.length += additional;
161
162 for (string_item, utf16_item) in self.as_mut_slice()[end_index..]
163 .iter_mut()
164 .zip(s.encode_utf16().chain(once(0)))
165 {
166 *string_item = utf16_item;
167 }
168
169 self.raw.length -= mem::size_of::<u16>() as u16;
171
172 Ok(())
173 }
174
175 pub fn try_push_u16(&mut self, buffer: &[u16]) -> Result<()> {
195 let additional_elements = buffer
200 .len()
201 .checked_add(1)
202 .ok_or(NtStringError::BufferSizeExceedsU16)?;
203
204 let additional_bytes = additional_elements
206 .checked_mul(mem::size_of::<u16>())
207 .ok_or(NtStringError::BufferSizeExceedsU16)?;
208 let additional =
209 u16::try_from(additional_bytes).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
210 self.try_reserve(additional)?;
211
212 let end_index = self.len_in_elements();
214 self.raw.length += additional;
215
216 let dest_slice = &mut self.as_mut_slice()[end_index..];
217 dest_slice[..buffer.len()].copy_from_slice(buffer);
218
219 dest_slice[buffer.len()] = 0;
221
222 self.raw.length -= mem::size_of::<u16>() as u16;
224
225 Ok(())
226 }
227
228 pub fn try_push_u16_until_nul(&mut self, buffer: &[u16]) -> Result<()> {
252 match buffer.iter().position(|x| *x == 0) {
253 Some(nul_pos) => self.try_push_u16(&buffer[..nul_pos]),
254 None => Err(NtStringError::NulNotFound),
255 }
256 }
257
258 pub fn try_push_u16cstr(&mut self, u16cstr: &U16CStr) -> Result<()> {
271 self.try_push_u16(u16cstr.as_slice())
272 }
273
274 pub fn try_push_u16str(&mut self, u16str: &U16Str) -> Result<()> {
287 self.try_push_u16(u16str.as_slice())
288 }
289
290 pub fn try_reserve(&mut self, additional: u16) -> Result<()> {
297 if self.remaining_capacity() >= additional {
298 return Ok(());
299 }
300
301 let new_capacity = self
302 .len()
303 .checked_add(additional)
304 .ok_or(NtStringError::BufferSizeExceedsU16)?;
305
306 if self.raw.buffer.is_null() {
307 self.raw.maximum_length = new_capacity;
308 let new_layout = self.layout();
309
310 self.raw.buffer = unsafe { alloc::alloc(new_layout) } as *mut u16;
311 } else {
312 let old_layout = self.layout();
313
314 self.raw.buffer = unsafe {
315 alloc::realloc(
316 self.raw.buffer as *mut u8,
317 old_layout,
318 usize::from(new_capacity),
319 )
320 } as *mut u16;
321
322 self.raw.maximum_length = new_capacity;
323 }
324
325 Ok(())
326 }
327
328 pub fn with_capacity(capacity: u16) -> Self {
335 let mut string = Self::new();
336 string.try_reserve(capacity).unwrap();
337 string
338 }
339}
340
341impl Add<&str> for NtUnicodeString {
342 type Output = NtUnicodeString;
343
344 fn add(mut self, rhs: &str) -> Self::Output {
345 if let Err(e) = self.try_push_str(rhs) {
346 panic!("{e}");
347 }
348
349 self
350 }
351}
352
353impl Add<&U16CStr> for NtUnicodeString {
354 type Output = NtUnicodeString;
355
356 fn add(mut self, rhs: &U16CStr) -> Self::Output {
357 if let Err(e) = self.try_push_u16cstr(rhs) {
358 panic!("{e}");
359 }
360
361 self
362 }
363}
364
365impl Add<&U16Str> for NtUnicodeString {
366 type Output = NtUnicodeString;
367
368 fn add(mut self, rhs: &U16Str) -> Self::Output {
369 if let Err(e) = self.try_push_u16str(rhs) {
370 panic!("{e}");
371 }
372
373 self
374 }
375}
376
377impl AddAssign<&str> for NtUnicodeString {
378 fn add_assign(&mut self, rhs: &str) {
379 if let Err(e) = self.try_push_str(rhs) {
380 panic!("{e}");
381 }
382 }
383}
384
385impl AddAssign<&U16CStr> for NtUnicodeString {
386 fn add_assign(&mut self, rhs: &U16CStr) {
387 if let Err(e) = self.try_push_u16cstr(rhs) {
388 panic!("{e}");
389 }
390 }
391}
392
393impl AddAssign<&U16Str> for NtUnicodeString {
394 fn add_assign(&mut self, rhs: &U16Str) {
395 if let Err(e) = self.try_push_u16str(rhs) {
396 panic!("{e}");
397 }
398 }
399}
400
401impl Clone for NtUnicodeString {
402 fn clone(&self) -> Self {
406 NtUnicodeString::from(self.as_unicode_str())
407 }
408}
409
410impl Default for NtUnicodeString {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416impl Deref for NtUnicodeString {
417 type Target = NtUnicodeStrMut<'static>;
420
421 fn deref(&self) -> &Self::Target {
422 unsafe { mem::transmute(self) }
425 }
426}
427
428impl DerefMut for NtUnicodeString {
429 fn deref_mut(&mut self) -> &mut Self::Target {
430 unsafe { mem::transmute(self) }
433 }
434}
435
436impl fmt::Display for NtUnicodeString {
437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438 fmt::Display::fmt(self.deref(), f)
439 }
440}
441
442impl Drop for NtUnicodeString {
443 fn drop(&mut self) {
444 if !self.raw.buffer.is_null() {
445 let layout = self.layout();
446 unsafe { alloc::dealloc(self.raw.buffer as *mut u8, layout) }
447 }
448 }
449}
450
451impl Eq for NtUnicodeString {}
452
453impl From<char> for NtUnicodeString {
454 fn from(c: char) -> Self {
456 let mut string = Self::new();
457 string.try_push(c).unwrap();
458 string
459 }
460}
461
462impl<'a> From<&NtUnicodeStr<'a>> for NtUnicodeString {
463 fn from(unicode_str: &NtUnicodeStr) -> Self {
467 let mut new_string = Self::with_capacity(unicode_str.capacity());
468
469 if !unicode_str.is_empty() {
470 new_string.raw.length = unicode_str.len();
471 new_string
472 .as_mut_slice()
473 .copy_from_slice(unicode_str.as_slice());
474 }
475
476 new_string
477 }
478}
479
480impl Ord for NtUnicodeString {
481 fn cmp(&self, other: &Self) -> Ordering {
482 Ord::cmp(self.deref(), other.deref())
483 }
484}
485
486impl PartialOrd<NtUnicodeString> for NtUnicodeString {
487 fn partial_cmp(&self, other: &NtUnicodeString) -> Option<Ordering> {
488 Some(self.cmp(other))
489 }
490}
491
492impl_eq! { NtUnicodeString, NtUnicodeString }
493impl_eq! { NtUnicodeStr<'a>, NtUnicodeString }
494impl_eq! { NtUnicodeString, NtUnicodeStr<'a> }
495impl_eq! { NtUnicodeStrMut<'a>, NtUnicodeString }
496impl_eq! { NtUnicodeString, NtUnicodeStrMut<'a> }
497impl_eq! { NtUnicodeString, str }
498impl_eq! { str, NtUnicodeString }
499impl_eq! { NtUnicodeString, &str }
500impl_eq! { &str, NtUnicodeString }
501
502impl_partial_cmp! { NtUnicodeStr<'a>, NtUnicodeString }
503impl_partial_cmp! { NtUnicodeString, NtUnicodeStr<'a> }
504impl_partial_cmp! { NtUnicodeStrMut<'a>, NtUnicodeString }
505impl_partial_cmp! { NtUnicodeString, NtUnicodeStrMut<'a> }
506impl_partial_cmp! { NtUnicodeString, str }
507impl_partial_cmp! { str, NtUnicodeString }
508impl_partial_cmp! { NtUnicodeString, &str }
509impl_partial_cmp! { &str, NtUnicodeString }
510
511impl TryExtend<char> for NtUnicodeString {
512 type Error = NtStringError;
513
514 fn try_extend<I: IntoIterator<Item = char>>(&mut self, iter: I) -> Result<()> {
515 let iterator = iter.into_iter();
516 let (lower_bound, _) = iterator.size_hint();
517
518 let additional_elements = lower_bound + 1;
521
522 let additional_bytes = u16::try_from(additional_elements * mem::size_of::<u16>())
524 .map_err(|_| NtStringError::BufferSizeExceedsU16)?;
525 self.try_reserve(additional_bytes)?;
526
527 for ch in iterator {
528 self.try_push(ch)?;
529 }
530
531 Ok(())
532 }
533
534 fn try_extend_one(&mut self, item: char) -> Result<()> {
535 self.try_push(item)
536 }
537}
538
539impl<'a> TryExtend<&'a str> for NtUnicodeString {
540 type Error = NtStringError;
541
542 fn try_extend<I: IntoIterator<Item = &'a str>>(&mut self, iter: I) -> Result<()> {
543 for s in iter.into_iter() {
544 self.try_push_str(s)?;
545 }
546
547 Ok(())
548 }
549}
550
551impl<'a> TryExtend<&'a U16CStr> for NtUnicodeString {
552 type Error = NtStringError;
553
554 fn try_extend<I: IntoIterator<Item = &'a U16CStr>>(&mut self, iter: I) -> Result<()> {
555 for s in iter.into_iter() {
556 self.try_push_u16cstr(s)?;
557 }
558
559 Ok(())
560 }
561}
562
563impl<'a> TryExtend<&'a U16Str> for NtUnicodeString {
564 type Error = NtStringError;
565
566 fn try_extend<I: IntoIterator<Item = &'a U16Str>>(&mut self, iter: I) -> Result<()> {
567 for s in iter.into_iter() {
568 self.try_push_u16str(s)?;
569 }
570
571 Ok(())
572 }
573}
574
575impl TryFrom<&str> for NtUnicodeString {
576 type Error = NtStringError;
577
578 fn try_from(s: &str) -> Result<Self> {
583 let mut string = Self::new();
584 string.try_push_str(s)?;
585 Ok(string)
586 }
587}
588
589impl TryFrom<String> for NtUnicodeString {
590 type Error = NtStringError;
591
592 fn try_from(s: String) -> Result<Self> {
597 NtUnicodeString::try_from(s.as_str())
598 }
599}
600
601impl TryFrom<&String> for NtUnicodeString {
602 type Error = NtStringError;
603
604 fn try_from(s: &String) -> Result<Self> {
609 NtUnicodeString::try_from(s.as_str())
610 }
611}
612
613impl TryFrom<&U16CStr> for NtUnicodeString {
614 type Error = NtStringError;
615
616 fn try_from(value: &U16CStr) -> Result<Self> {
621 let unicode_str = NtUnicodeStr::try_from(value)?;
622 Ok(Self::from(&unicode_str))
623 }
624}
625
626impl TryFrom<&U16Str> for NtUnicodeString {
627 type Error = NtStringError;
628
629 fn try_from(value: &U16Str) -> Result<Self> {
634 let unicode_str = NtUnicodeStr::try_from(value)?;
635 Ok(Self::from(&unicode_str))
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use alloc::vec::Vec;
642
643 use crate::error::NtStringError;
644 use crate::traits::TryExtend;
645 use crate::unicode_string::NtUnicodeString;
646
647 #[test]
648 fn test_add() {
649 let mut string = NtUnicodeString::new();
650 string += "๐";
651 assert_eq!(string, "๐");
652
653 let string2 = string + "๐";
654 assert_eq!(string2, "๐๐");
655 }
656
657 #[test]
658 fn test_chars() {
659 let moin = NtUnicodeString::try_from("Moin").unwrap();
662 assert_eq!(moin.capacity(), 10);
663 assert_eq!(moin.len(), 8);
664 let vec = moin.chars_lossy().collect::<Vec<char>>();
665 assert_eq!(vec, ['M', 'o', 'i', 'n']);
666
667 let ไปๆฅใฏ = NtUnicodeString::try_from("ไปๆฅใฏ").unwrap();
670 assert_eq!(ไปๆฅใฏ.capacity(), 8);
671 assert_eq!(ไปๆฅใฏ.len(), 6);
672 let vec = ไปๆฅใฏ.chars_lossy().collect::<Vec<char>>();
673 assert_eq!(vec, ['ไป', 'ๆฅ', 'ใฏ']);
674
675 let smile = NtUnicodeString::try_from("๐").unwrap();
678 assert_eq!(smile.capacity(), 6);
679 assert_eq!(smile.len(), 4);
680 let vec = smile.chars_lossy().collect::<Vec<char>>();
681 assert_eq!(vec, ['๐']);
682 }
683
684 #[test]
685 fn test_cmp() {
686 let a = NtUnicodeString::try_from("a").unwrap();
687 let b = NtUnicodeString::try_from("b").unwrap();
688 assert!(a < b);
689 }
690
691 #[test]
692 fn test_eq() {
693 let hello = NtUnicodeString::try_from("Hello").unwrap();
694 let hello_again = NtUnicodeString::try_from("Hello again").unwrap();
695 assert_ne!(hello, hello_again);
696
697 let mut hello_clone = hello.clone();
698 assert_eq!(hello, hello_clone);
699
700 hello_clone.try_reserve(42).unwrap();
701 assert_eq!(hello, hello_clone);
702 }
703
704 #[test]
705 fn test_extend_and_pop() {
706 let a_string = "a".repeat(32766);
709 let mut string = NtUnicodeString::try_from(a_string).unwrap();
710 assert_eq!(string.capacity(), 65534);
711 assert_eq!(string.len(), 65532);
712
713 assert_eq!(
715 string.try_extend(Some('b')),
716 Err(NtStringError::BufferSizeExceedsU16)
717 );
718
719 assert_eq!(string.pop(), Some(Ok('a')));
721 assert_eq!(string.capacity(), 65534);
722 assert_eq!(string.len(), 65530);
723 string.try_extend_one('c').unwrap();
724 assert_eq!(string.capacity(), 65534);
725 assert_eq!(string.len(), 65532);
726
727 assert_eq!(string.pop(), Some(Ok('c')));
729 assert_eq!(string.pop(), Some(Ok('a')));
730 assert_eq!(string.capacity(), 65534);
731 assert_eq!(string.len(), 65528);
732 string.try_extend_one('๐').unwrap();
733 assert_eq!(string.capacity(), 65534);
734 assert_eq!(string.len(), 65532);
735
736 assert_eq!(string.pop(), Some(Ok('๐')));
738 assert_eq!(string.pop(), Some(Ok('a')));
739 assert_eq!(string.capacity(), 65534);
740 assert_eq!(string.len(), 65526);
741 string.try_extend("def".chars()).unwrap();
742 assert_eq!(string.capacity(), 65534);
743 assert_eq!(string.len(), 65532);
744 }
745
746 #[test]
747 fn test_from_u16() {
748 let mut a_vec = "a".repeat(32768).encode_utf16().collect::<Vec<u16>>();
750 assert_eq!(
751 NtUnicodeString::try_from_u16(&a_vec),
752 Err(NtStringError::BufferSizeExceedsU16)
753 );
754
755 a_vec.pop();
757 let string = NtUnicodeString::try_from_u16(&a_vec).unwrap();
758 assert_eq!(string.capacity(), 65534);
759 assert_eq!(string.len(), 65534);
760
761 a_vec[4] = 0;
763 let string = NtUnicodeString::try_from_u16_until_nul(&a_vec).unwrap();
764 assert_eq!(string.capacity(), 10);
765 assert_eq!(string.len(), 8);
766 assert_eq!(string, "aaaa");
767 }
768
769 #[test]
770 fn test_push_str() {
771 let mut string = NtUnicodeString::new();
772 string.try_push_str("Hey").unwrap();
773 assert_eq!(string, "Hey");
774 assert_eq!(string.capacity(), 8);
775 assert_eq!(string.len(), 6);
776
777 string.try_push_str("Ho").unwrap();
778 assert_eq!(string, "HeyHo");
779 assert_eq!(string.capacity(), 12);
780 assert_eq!(string.len(), 10);
781 }
782
783 #[test]
784 fn test_reserve() {
785 let mut string = NtUnicodeString::new();
786 assert_eq!(string.capacity(), 0);
787
788 string.try_reserve(5).unwrap();
789 assert_eq!(string.capacity(), 5);
790
791 string.try_reserve(3).unwrap();
792 assert_eq!(string.capacity(), 5);
793
794 string.try_push_str("a").unwrap();
795 assert_eq!(string, "a");
796 assert_eq!(string.capacity(), 5);
797
798 string.try_push_str("b").unwrap();
799 assert_eq!(string, "ab");
800 assert_eq!(string.capacity(), 6);
801 }
802}