architect_api/utils/
str.rs

1//! Global, permanent, packed, hashconsed, short string storage.
2//!
3//! * supports strings up to 256 bytes
4//! * derefs to a &str, but uses only 1 word on the stack and len + 1 bytes on the heap
5//! * the actual bytes are stored packed into 1 MiB allocations to
6//!   avoid the overhead of lots of small mallocs
7//! * Copy!
8//! * hashconsed, the same &str will always produce a pointer to the same memory
9//!
10//! CAN NEVER BE DEALLOCATED
11
12use anyhow::bail;
13use fxhash::FxHashSet;
14#[cfg(feature = "netidx")]
15use netidx::{
16    chars::Chars,
17    pack::{decode_varint, encode_varint, varint_len, Pack, PackError},
18};
19use once_cell::sync::Lazy;
20use parking_lot::Mutex;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use std::{
24    borrow::{Borrow, Cow},
25    collections::HashSet,
26    fmt,
27    hash::Hash,
28    mem,
29    ops::Deref,
30    slice, str,
31};
32
33const TAG_MASK: usize = 0x8000_0000_0000_0000;
34const LEN_MASK: usize = 0x7F00_0000_0000_0000;
35const CHUNK_SIZE: usize = 1 * 1024 * 1024;
36
37struct Chunk {
38    data: Vec<u8>,
39    pos: usize,
40}
41
42impl Chunk {
43    #[cfg(target_pointer_width = "64")]
44    fn new() -> &'static mut Self {
45        let res = Box::leak(Box::new(Chunk { data: vec![0; CHUNK_SIZE], pos: 0 }));
46        assert!((res as *mut Self as usize) & TAG_MASK == 0);
47        res
48    }
49
50    fn insert(&mut self, str: &[u8]) -> (*mut Chunk, Str) {
51        let mut t = self;
52        loop {
53            if CHUNK_SIZE - t.pos > str.len() {
54                t.data[t.pos] = str.len() as u8;
55                t.data[t.pos + 1..t.pos + 1 + str.len()].copy_from_slice(str);
56                let res = Str(t.data.as_ptr().wrapping_add(t.pos) as usize);
57                t.pos += 1 + str.len();
58                break (t, res);
59            } else {
60                t = Self::new();
61            }
62        }
63    }
64}
65
66struct Root {
67    all: FxHashSet<Str>,
68    root: *mut Chunk,
69}
70
71unsafe impl Send for Root {}
72unsafe impl Sync for Root {}
73
74static ROOT: Lazy<Mutex<Root>> =
75    Lazy::new(|| Mutex::new(Root { all: HashSet::default(), root: Chunk::new() }));
76
77#[allow(dead_code)]
78struct StrVisitor;
79
80impl<'de> serde::de::Visitor<'de> for StrVisitor {
81    type Value = Str;
82
83    fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
84        write!(f, "expecting a string")
85    }
86
87    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
88    where
89        E: serde::de::Error,
90    {
91        Str::try_from(s).map_err(|e| E::custom(e.to_string()))
92    }
93}
94
95#[allow(dead_code)]
96#[derive(JsonSchema)]
97struct AsStr(&'static str);
98
99/// This is either an immediate containing the string data, if the
100/// length is less than 8, or a pointer into static memory that holds
101/// the actual str slice if the data length is greater than 7.
102///
103/// Either way it is 1 word on the stack. In the case of an immediate
104/// the length as well as all the bytes are stored in that word, and
105/// there is no allocation on the heap. Otherwise the length, as well
106/// as the actual bytes of the string are stored on the heap in a
107/// compact allocation along with other strings of this type.
108///
109/// The maximum length of strings of this type is 255
110/// characters. try_from will fail if a larger string is specified.
111///
112/// In either case Deref should be quite cheap, there is no locking to
113/// deref.
114///
115/// In the case of immediates there is never any locking. Otherwise, a
116/// global lock must be taken to hashcons the string and, if it isn't
117/// already present, insert it in the packed allocation.
118#[derive(Clone, Copy, Deserialize, JsonSchema)]
119#[serde(try_from = "Cow<str>")]
120#[serde(into = "&str")]
121#[repr(transparent)]
122#[cfg_attr(feature = "juniper", derive(juniper::GraphQLScalar))]
123#[cfg_attr(feature = "juniper", graphql(description = "A String type"))]
124pub struct Str(#[schemars(with = "AsStr")] usize);
125
126unsafe impl Send for Str {}
127unsafe impl Sync for Str {}
128
129impl Str {
130    pub fn as_str<'a>(&'a self) -> &'a str {
131        unsafe {
132            if self.0 & TAG_MASK > 0 {
133                #[cfg(target_endian = "little")]
134                {
135                    let len = (self.0 & LEN_MASK) >> 56;
136                    let ptr = self as *const Self as *const u8;
137                    let slice = slice::from_raw_parts(ptr, len);
138                    str::from_utf8_unchecked(slice)
139                }
140                #[cfg(target_endian = "big")]
141                {
142                    let len = (self.0 & LEN_MASK) >> 56;
143                    let ptr = (self as *const Self as *const u8).wrapping_add(1);
144                    let slice = slice::from_raw_parts(ptr, len);
145                    str::from_utf8_unchecked(slice)
146                }
147            } else {
148                let t = self.0 as *const u8;
149                let len = *t as usize;
150                let ptr = t.wrapping_add(1);
151                let slice = slice::from_raw_parts(ptr, len);
152                str::from_utf8_unchecked(slice)
153            }
154        }
155    }
156
157    /// return a static str ref unless self is an immediate
158    pub fn as_static_str(&self) -> Option<&'static str> {
159        unsafe {
160            if self.0 & TAG_MASK > 0 {
161                None
162            } else {
163                Some(mem::transmute::<&str, &'static str>(self.as_str()))
164            }
165        }
166    }
167
168    /// return true if this Str is an immediate
169    pub fn is_immediate(&self) -> bool {
170        self.0 & TAG_MASK > 0
171    }
172
173    #[cfg(feature = "netidx")]
174    pub fn as_chars(&self) -> Chars {
175        match self.as_static_str() {
176            Some(s) => Chars::from(s),
177            None => Chars::from(String::from(self.as_str())),
178        }
179    }
180}
181
182impl fmt::Debug for Str {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "{}", &**self)
185    }
186}
187
188impl fmt::Display for Str {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        write!(f, "{}", &**self)
191    }
192}
193
194impl Serialize for Str {
195    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
196    where
197        S: serde::Serializer,
198    {
199        serializer.serialize_str(self.as_str())
200    }
201}
202
203#[cfg(feature = "netidx")]
204impl Pack for Str {
205    fn encoded_len(&self) -> usize {
206        let len = self.len();
207        varint_len(len as u64) + len
208    }
209
210    fn encode(
211        &self,
212        buf: &mut impl bytes::BufMut,
213    ) -> Result<(), netidx::pack::PackError> {
214        let s = &**self;
215        encode_varint(s.len() as u64, buf);
216        Ok(buf.put_slice(s.as_bytes()))
217    }
218
219    fn decode(buf: &mut impl bytes::Buf) -> Result<Self, netidx::pack::PackError> {
220        use std::cell::RefCell;
221        thread_local! {
222            static BUF: RefCell<Vec<u8>> = RefCell::new(Vec::new());
223        }
224        let len = decode_varint(buf)? as usize;
225        if len > u8::MAX as usize {
226            Err(PackError::TooBig)
227        } else {
228            BUF.with(|tmp| {
229                let mut tmp = tmp.borrow_mut();
230                tmp.resize(len, 0);
231                buf.copy_to_slice(&mut *tmp);
232                match str::from_utf8(&*tmp) {
233                    Err(_) => Err(PackError::InvalidFormat),
234                    Ok(s) => Ok(Str::try_from(s).unwrap()),
235                }
236            })
237        }
238    }
239}
240
241impl Deref for Str {
242    type Target = str;
243
244    fn deref(&self) -> &Self::Target {
245        self.as_str()
246    }
247}
248
249impl Borrow<str> for Str {
250    fn borrow(&self) -> &str {
251        self.as_str()
252    }
253}
254
255impl Borrow<str> for &Str {
256    fn borrow(&self) -> &str {
257        self.as_str()
258    }
259}
260
261impl AsRef<str> for Str {
262    fn as_ref(&self) -> &str {
263        self.as_str()
264    }
265}
266
267impl Hash for Str {
268    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
269        self.as_str().hash(state)
270    }
271}
272
273impl PartialEq for Str {
274    fn eq(&self, other: &Self) -> bool {
275        self.0 == other.0
276    }
277}
278
279impl PartialEq<&str> for Str {
280    fn eq(&self, other: &&str) -> bool {
281        self.as_str() == *other
282    }
283}
284
285impl Eq for Str {}
286
287impl PartialOrd for Str {
288    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
289        self.as_str().partial_cmp(other.as_str())
290    }
291}
292
293impl PartialOrd<&str> for Str {
294    fn partial_cmp(&self, other: &&str) -> Option<std::cmp::Ordering> {
295        self.as_str().partial_cmp(*other)
296    }
297}
298
299impl Ord for Str {
300    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
301        self.as_str().cmp(other.as_str())
302    }
303}
304
305impl TryFrom<String> for Str {
306    type Error = anyhow::Error;
307
308    fn try_from(s: String) -> Result<Self, Self::Error> {
309        s.as_str().try_into()
310    }
311}
312
313impl TryFrom<&str> for Str {
314    type Error = anyhow::Error;
315
316    fn try_from(s: &str) -> Result<Self, Self::Error> {
317        unsafe {
318            let len = s.len();
319            if len > u8::MAX as usize {
320                bail!("string is too long")
321            } else if len < 8 {
322                #[cfg(target_endian = "little")]
323                {
324                    let s = s.as_bytes();
325                    let mut i = 0;
326                    let mut res: usize = TAG_MASK;
327                    res |= len << 56;
328                    while i < len {
329                        res |= (s[i] as usize) << (i << 3);
330                        i += 1;
331                    }
332                    Ok(Str(res))
333                }
334                #[cfg(target_endian = "big")]
335                {
336                    let s = s.as_bytes();
337                    let mut i = 0;
338                    let mut res: usize = TAG_MASK;
339                    res |= len << 56;
340                    while i < len {
341                        res |= (s[i] as usize) << (48 - (i << 3));
342                        i += 1;
343                    }
344                    Ok(Str(res))
345                }
346            } else {
347                let mut root = ROOT.lock();
348                match root.all.get(s) {
349                    Some(t) => Ok(*t),
350                    None => {
351                        let (r, t) = (*root.root).insert(s.as_bytes());
352                        root.root = r;
353                        root.all.insert(t);
354                        Ok(t)
355                    }
356                }
357            }
358        }
359    }
360}
361
362impl TryFrom<Cow<'_, str>> for Str {
363    type Error = anyhow::Error;
364
365    fn try_from(s: Cow<str>) -> Result<Self, Self::Error> {
366        match s {
367            Cow::Borrowed(s) => Str::try_from(s),
368            Cow::Owned(s) => Str::try_from(s.as_str()),
369        }
370    }
371}
372
373#[cfg(feature = "juniper")]
374impl Str {
375    fn to_output<S: juniper::ScalarValue>(&self) -> juniper::Value<S> {
376        juniper::Value::scalar(self.as_str().to_string())
377    }
378
379    fn from_input<S>(v: &juniper::InputValue<S>) -> Result<Self, String>
380    where
381        S: juniper::ScalarValue,
382    {
383        v.as_string_value()
384            .map(|s| Self::try_from(s))
385            .ok_or_else(|| format!("Expected `String`, found: {v}"))?
386            .map_err(|e| e.to_string())
387    }
388
389    fn parse_token<S>(value: juniper::ScalarToken<'_>) -> juniper::ParseScalarResult<S>
390    where
391        S: juniper::ScalarValue,
392    {
393        <String as juniper::ParseScalarValue<S>>::from_str(value)
394    }
395}
396
397#[cfg(feature = "postgres-types")]
398impl postgres_types::ToSql for Str {
399    postgres_types::to_sql_checked!();
400
401    fn to_sql(
402        &self,
403        ty: &postgres_types::Type,
404        out: &mut bytes::BytesMut,
405    ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
406        self.as_str().to_sql(ty, out)
407    }
408
409    fn accepts(ty: &postgres_types::Type) -> bool {
410        String::accepts(ty)
411    }
412}
413
414#[cfg(test)]
415mod test {
416    use super::*;
417    use rand::{thread_rng, Rng};
418
419    fn rand_ascii(size: usize) -> String {
420        let mut s = String::new();
421        for _ in 0..size {
422            s.push(thread_rng().gen_range(' '..'~'))
423        }
424        s
425    }
426
427    fn rand_unicode(size: usize) -> String {
428        let mut s = String::new();
429        for _ in 0..size {
430            s.push(thread_rng().gen())
431        }
432        s
433    }
434
435    #[test]
436    fn immediates() {
437        for _ in 0..10000 {
438            let len = thread_rng().gen_range(0..8);
439            let s = rand_ascii(len);
440            let t0 = Str::try_from(s.as_str()).unwrap();
441            assert_eq!(&*t0, &*s);
442            let t1 = Str::try_from(s.as_str()).unwrap();
443            assert_eq!(t0.0, t1.0)
444        }
445    }
446
447    #[test]
448    fn mixed() {
449        for _ in 0..10000 {
450            let len = thread_rng().gen_range(0..256);
451            let s = rand_ascii(len);
452            let t0 = Str::try_from(s.as_str()).unwrap();
453            assert_eq!(&*t0, &*s);
454            let t1 = Str::try_from(s.as_str()).unwrap();
455            assert_eq!(t0.0, t1.0)
456        }
457    }
458
459    #[test]
460    fn unicode() {
461        for _ in 0..10000 {
462            let s = loop {
463                let len = thread_rng().gen_range(0..128);
464                let s = rand_unicode(len);
465                if s.as_bytes().len() < 256 {
466                    break s;
467                }
468            };
469            let t0 = Str::try_from(s.as_str()).unwrap();
470            assert_eq!(&*t0, &*s);
471            let t1 = Str::try_from(s.as_str()).unwrap();
472            assert_eq!(t0.0, t1.0)
473        }
474    }
475}