#![allow(rustdoc::bare_urls)]
#![doc = include_str!("../README.md")]
use std::{
alloc::{alloc, dealloc, handle_alloc_error, Layout},
borrow::Borrow,
cmp::Ordering,
fmt::{self, Debug},
hash::{Hash, Hasher},
ops::Deref,
ptr::{copy_nonoverlapping, NonNull},
str::FromStr,
};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
pub struct CompactString {
ptr: NonNull<u8>,
}
#[derive(Error, Debug, PartialEq)]
#[error("Invalid string length: {0}")]
pub struct CompactStringLengthError(usize);
#[derive(Error, Debug, PartialEq)]
pub enum ParseCompactStringError {
#[error(transparent)]
LengthError(#[from] CompactStringLengthError),
#[error(transparent)]
Utf8Error(#[from] std::str::Utf8Error),
}
impl CompactString {
pub fn try_new(data: &str) -> Result<CompactString, CompactStringLengthError> {
let data_len = data.len();
if data_len > 255 {
return Err(CompactStringLengthError(data_len));
}
let layout = Self::memory_layout(data_len);
let ptr = unsafe {
let alloc_ptr = alloc(layout);
if alloc_ptr.is_null() {
handle_alloc_error(layout);
}
*alloc_ptr.add(0) = data_len as u8;
copy_nonoverlapping(data.as_ptr(), alloc_ptr.add(1), data_len);
NonNull::new_unchecked(alloc_ptr)
};
Ok(CompactString { ptr })
}
#[inline]
pub fn len(&self) -> usize {
unsafe { *self.ptr.as_ptr() as usize }
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn from_utf8<B: AsRef<[u8]>>(bytes: B) -> Result<Self, ParseCompactStringError> {
let s = std::str::from_utf8(bytes.as_ref())?;
Ok(Self::try_new(s)?)
}
#[inline]
pub fn as_bytes(&self) -> &[u8] {
let data_len = self.len();
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr().add(1), data_len) }
}
#[inline]
pub fn as_str(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
}
#[inline]
fn memory_layout(data_len: usize) -> Layout {
Layout::array::<u8>(data_len + 1).unwrap()
}
}
impl Drop for CompactString {
fn drop(&mut self) {
let data_len = self.len();
unsafe {
dealloc(self.ptr.as_ptr(), Self::memory_layout(data_len));
}
}
}
unsafe impl Send for CompactString {}
unsafe impl Sync for CompactString {}
impl Hash for CompactString {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_str().hash(state);
}
}
impl Borrow<str> for CompactString {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl Clone for CompactString {
#[inline]
fn clone(&self) -> Self {
CompactString::try_new(self.as_str()).unwrap()
}
}
impl FromStr for CompactString {
type Err = CompactStringLengthError;
#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
CompactString::try_new(s)
}
}
impl TryFrom<String> for CompactString {
type Error = CompactStringLengthError;
#[inline]
fn try_from(s: String) -> Result<Self, Self::Error> {
CompactString::try_new(&s)
}
}
impl AsRef<str> for CompactString {
#[inline]
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Deref for CompactString {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl fmt::Display for CompactString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl Debug for CompactString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl PartialEq<str> for CompactString {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<&'_ str> for CompactString {
fn eq(&self, other: &&str) -> bool {
self.as_str() == *other
}
}
impl PartialEq<CompactString> for &'_ str {
fn eq(&self, other: &CompactString) -> bool {
other.eq(*self)
}
}
impl PartialEq<CompactString> for str {
fn eq(&self, other: &CompactString) -> bool {
other.eq(self)
}
}
impl PartialEq<CompactString> for CompactString {
fn eq(&self, other: &CompactString) -> bool {
self.ptr == other.ptr || self.as_bytes() == other.as_bytes()
}
}
impl PartialEq<String> for CompactString {
fn eq(&self, other: &String) -> bool {
self.as_str() == other.as_str()
}
}
impl PartialEq<CompactString> for String {
fn eq(&self, other: &CompactString) -> bool {
other.eq(self.as_str())
}
}
impl Eq for CompactString {}
impl Serialize for CompactString {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(self.as_str())
}
}
impl Ord for CompactString {
fn cmp(&self, other: &Self) -> Ordering {
self.as_str().cmp(other.as_str())
}
}
impl PartialOrd for CompactString {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'de> Deserialize<'de> for CompactString {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
Ok(CompactString::try_new(&s).unwrap())
}
}
#[cfg(test)]
mod test {
use super::*;
use rand::RngCore;
use std::collections::HashMap;
use std::hash::BuildHasher;
use std::hash::RandomState;
#[test]
fn size_check() {
assert_eq!(std::mem::size_of::<CompactString>(), 8);
}
#[test]
fn conversions() {
let s1: CompactString = "Hello world".parse::<CompactString>().unwrap();
let s2: CompactString = s1.clone();
assert_eq!(s1, s2);
assert_eq!(s2.as_str(), "Hello world");
}
#[test]
fn from_to_string() {
let s = "This is a test str";
let compact = CompactString::try_new(s).unwrap();
assert_eq!(compact.as_str(), s);
let s = String::from("This is a test string");
let compact: CompactString = s.clone().parse::<CompactString>().unwrap();
assert_eq!(compact.to_string(), s);
}
#[test]
fn test_length() {
let s: CompactString = "hello world!".parse::<CompactString>().unwrap();
assert_eq!(s.len(), 12);
let s: CompactString = "ABCDEFGHIZKLMNOPQRSTUVWXYZ12345678901234"
.parse::<CompactString>()
.unwrap();
assert_eq!(s.len(), 40);
let s: CompactString = "".parse::<CompactString>().unwrap();
assert_eq!(s.len(), 0);
}
#[test]
fn test_oversize_string() {
let s: [char; 256] = ['k'; 256];
let string = String::from_iter(s);
assert_eq!(CompactString::try_new(&string).is_err(), true);
}
#[test]
fn test_equal_to_str() {
let s = CompactString::try_new("hello world!").unwrap();
assert_eq!(s, "hello world!");
assert_eq!("hello world!", s);
assert_ne!(s, "foo");
assert_ne!("foo", s);
}
#[test]
fn test_equal_compact_string() {
let s1 = CompactString::try_new("test").unwrap();
let s2 = CompactString::try_new("test").unwrap();
assert_eq!(s1, s1);
assert_eq!(s1, s2);
assert_eq!(s2, s1);
let s3 = CompactString::try_new("foo").unwrap();
assert_ne!(s1, s3);
assert_ne!(s3, s1);
}
#[test]
fn test_equal_to_string() {
let s1 = CompactString::try_new("test").unwrap();
let s2 = String::from("test");
assert_eq!(s1, s2);
assert_eq!(s2, s1);
}
#[test]
fn test_ownership() {
let s1 = CompactString::try_new("test").unwrap();
let s2 = s1.clone();
assert_eq!(s1, s2);
drop(s1);
assert_eq!(s2, "test");
}
#[test]
fn test_serde() {
let s = CompactString::try_new("Hello World").unwrap();
let serialized = serde_json::to_string(&s).unwrap();
assert_eq!(serialized, "\"Hello World\"");
let deserialized: CompactString = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, s);
}
#[test]
fn test_hash() {
let mut map = HashMap::new();
let s1 = CompactString::try_new("test").unwrap();
let s2 = CompactString::try_new("test").unwrap();
map.insert(s1, 1);
map.insert(s2, 2);
assert_eq!(map.len(), 1);
}
#[test]
fn test_mix_string_hash() {
let s1 = CompactString::try_new("Test Hash").unwrap();
let s2 = String::from("Test Hash");
let hash_builder = RandomState::new();
let hash1 = hash_builder.hash_one(s1);
let hash2 = hash_builder.hash_one(s2);
assert_eq!(hash1, hash2);
}
#[test]
fn test_search_in_hashmap() {
let mut map = HashMap::<CompactString, i32>::new();
map.insert("aaa".parse::<CompactString>().unwrap(), 17);
assert_eq!(
17,
*map.get(&"aaa".parse::<CompactString>().unwrap()).unwrap()
);
}
#[test]
fn test_search_in_hashmap_with_str() {
let mut map = HashMap::<CompactString, i32>::new();
map.insert("aaa".parse::<CompactString>().unwrap(), 17);
assert_eq!(17, *map.get("aaa").unwrap());
}
#[test]
fn test_debug() {
let s = CompactString::try_new("test").unwrap();
assert_eq!(format!("{:?}", s), "test");
}
#[test]
fn test_asref() {
let s = CompactString::try_new("test").unwrap();
assert_eq!(s.as_ref(), "test");
}
#[test]
fn test_deref() {
let s = CompactString::try_new("test").unwrap();
assert_eq!(&*s, "test");
}
#[test]
fn test_edge_cases() {
let s = CompactString::try_new("").unwrap();
assert_eq!(s.len(), 0);
assert_eq!(s.as_str(), "");
assert_eq!(s.to_string(), "");
assert_eq!(s, "");
let ls = "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345";
let s = CompactString::try_new(ls).unwrap();
assert_eq!(s.len(), 255);
assert_eq!(s.as_str(), ls);
assert_eq!(s.to_string(), ls);
assert_eq!(&s, ls);
}
#[test]
fn test_bytes_roundtrip() {
let bytes = vec![240, 159, 166, 128, 240, 159, 146, 175];
let compact = CompactString::from_utf8(bytes.clone()).unwrap();
assert_eq!(compact.as_bytes(), &bytes);
}
#[test]
fn fuzz_test() {
let mut rng = rand::thread_rng();
let mut bytes = vec![0; 255];
for _ in 0..1000 {
rng.fill_bytes(&mut bytes);
let compact_str = CompactString::from_utf8(bytes.clone());
let string = String::from_utf8(bytes.clone());
match compact_str {
Ok(compact) => assert_eq!(compact, string.unwrap()),
Err(_) => assert!(string.is_err()),
}
}
}
#[test]
fn test_error_type() {
const INVALID_UTF8: &[u8] = b"\xC0\xAF";
const INVALID_LENGTH_BYTES: &[u8] = &[b'0'; 256];
let err = CompactString::from_utf8(INVALID_UTF8).unwrap_err();
assert!(matches!(err, ParseCompactStringError::Utf8Error(_)));
let err = CompactString::from_utf8(INVALID_LENGTH_BYTES).unwrap_err();
assert_eq!(
err,
ParseCompactStringError::LengthError(CompactStringLengthError(256))
);
let invalid_string = "0".repeat(1000);
let err = invalid_string.parse::<CompactString>().unwrap_err();
assert_eq!(err, CompactStringLengthError(1000));
}
}