architect_api/utils/
str.rs1use anyhow::bail;
13use fxhash::FxHashSet;
14use once_cell::sync::Lazy;
15use parking_lot::Mutex;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use std::{
19    borrow::{Borrow, Cow},
20    collections::HashSet,
21    fmt,
22    hash::Hash,
23    mem,
24    ops::Deref,
25    slice, str,
26};
27
28const TAG_MASK: usize = 0x8000_0000_0000_0000;
29const LEN_MASK: usize = 0x7F00_0000_0000_0000;
30const CHUNK_SIZE: usize = 1024 * 1024;
31
32struct Chunk {
33    data: Vec<u8>,
34    pos: usize,
35}
36
37impl Chunk {
38    #[cfg(target_pointer_width = "64")]
39    fn new() -> &'static mut Self {
40        let res = Box::leak(Box::new(Chunk { data: vec![0; CHUNK_SIZE], pos: 0 }));
41        assert!((res as *mut Self as usize) & TAG_MASK == 0);
42        res
43    }
44
45    fn insert(&mut self, str: &[u8]) -> (*mut Chunk, Str) {
46        let mut t = self;
47        loop {
48            if CHUNK_SIZE - t.pos > str.len() {
49                t.data[t.pos] = str.len() as u8;
50                t.data[t.pos + 1..t.pos + 1 + str.len()].copy_from_slice(str);
51                let res = Str(t.data.as_ptr().wrapping_add(t.pos) as usize);
52                t.pos += 1 + str.len();
53                break (t, res);
54            } else {
55                t = Self::new();
56            }
57        }
58    }
59}
60
61struct Root {
62    all: FxHashSet<Str>,
63    root: *mut Chunk,
64}
65
66unsafe impl Send for Root {}
67unsafe impl Sync for Root {}
68
69static ROOT: Lazy<Mutex<Root>> =
70    Lazy::new(|| Mutex::new(Root { all: HashSet::default(), root: Chunk::new() }));
71
72#[allow(dead_code)]
73struct StrVisitor;
74
75impl serde::de::Visitor<'_> for StrVisitor {
76    type Value = Str;
77
78    fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
79        write!(f, "expecting a string")
80    }
81
82    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
83    where
84        E: serde::de::Error,
85    {
86        Str::try_from(s).map_err(|e| E::custom(e.to_string()))
87    }
88}
89
90#[allow(dead_code)]
91#[derive(JsonSchema)]
92struct AsStr(&'static str);
93
94#[derive(Clone, Copy, Deserialize, JsonSchema)]
114#[serde(try_from = "Cow<str>")]
115#[serde(into = "&str")]
116#[repr(transparent)]
117#[cfg_attr(feature = "juniper", derive(juniper::GraphQLScalar))]
118#[cfg_attr(feature = "juniper", graphql(description = "A String type"))]
119pub struct Str(#[schemars(with = "AsStr")] usize);
120
121unsafe impl Send for Str {}
122unsafe impl Sync for Str {}
123
124impl Str {
125    pub fn as_str(&self) -> &str {
126        unsafe {
127            if self.0 & TAG_MASK > 0 {
128                #[cfg(target_endian = "little")]
129                {
130                    let len = (self.0 & LEN_MASK) >> 56;
131                    let ptr = self as *const Self as *const u8;
132                    let slice = slice::from_raw_parts(ptr, len);
133                    str::from_utf8_unchecked(slice)
134                }
135                #[cfg(target_endian = "big")]
136                {
137                    let len = (self.0 & LEN_MASK) >> 56;
138                    let ptr = (self as *const Self as *const u8).wrapping_add(1);
139                    let slice = slice::from_raw_parts(ptr, len);
140                    str::from_utf8_unchecked(slice)
141                }
142            } else {
143                let t = self.0 as *const u8;
144                let len = *t as usize;
145                let ptr = t.wrapping_add(1);
146                let slice = slice::from_raw_parts(ptr, len);
147                str::from_utf8_unchecked(slice)
148            }
149        }
150    }
151
152    pub fn as_static_str(&self) -> Option<&'static str> {
154        unsafe {
155            if self.0 & TAG_MASK > 0 {
156                None
157            } else {
158                Some(mem::transmute::<&str, &'static str>(self.as_str()))
159            }
160        }
161    }
162
163    pub fn is_immediate(&self) -> bool {
165        self.0 & TAG_MASK > 0
166    }
167}
168
169impl fmt::Debug for Str {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        write!(f, "{}", &**self)
172    }
173}
174
175impl fmt::Display for Str {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        write!(f, "{}", &**self)
178    }
179}
180
181impl Serialize for Str {
182    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
183    where
184        S: serde::Serializer,
185    {
186        serializer.serialize_str(self.as_str())
187    }
188}
189
190impl Deref for Str {
191    type Target = str;
192
193    fn deref(&self) -> &Self::Target {
194        self.as_str()
195    }
196}
197
198impl Borrow<str> for Str {
199    fn borrow(&self) -> &str {
200        self.as_str()
201    }
202}
203
204impl Borrow<str> for &Str {
205    fn borrow(&self) -> &str {
206        self.as_str()
207    }
208}
209
210impl AsRef<str> for Str {
211    fn as_ref(&self) -> &str {
212        self.as_str()
213    }
214}
215
216impl Hash for Str {
217    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
218        self.as_str().hash(state)
219    }
220}
221
222impl PartialEq for Str {
223    fn eq(&self, other: &Self) -> bool {
224        self.0 == other.0
225    }
226}
227
228impl PartialEq<&str> for Str {
229    fn eq(&self, other: &&str) -> bool {
230        self.as_str() == *other
231    }
232}
233
234impl Eq for Str {}
235
236impl PartialOrd for Str {
237    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
238        Some(self.cmp(other))
239    }
240}
241
242impl PartialOrd<&str> for Str {
243    fn partial_cmp(&self, other: &&str) -> Option<std::cmp::Ordering> {
244        self.as_str().partial_cmp(*other)
245    }
246}
247
248impl Ord for Str {
249    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
250        self.as_str().cmp(other.as_str())
251    }
252}
253
254impl TryFrom<String> for Str {
255    type Error = anyhow::Error;
256
257    fn try_from(s: String) -> Result<Self, Self::Error> {
258        s.as_str().try_into()
259    }
260}
261
262impl TryFrom<&str> for Str {
263    type Error = anyhow::Error;
264
265    fn try_from(s: &str) -> Result<Self, Self::Error> {
266        unsafe {
267            let len = s.len();
268            if len > u8::MAX as usize {
269                bail!("string is too long")
270            } else if len < 8 {
271                #[cfg(target_endian = "little")]
272                {
273                    let s = s.as_bytes();
274                    let mut i = 0;
275                    let mut res: usize = TAG_MASK;
276                    res |= len << 56;
277                    while i < len {
278                        res |= (s[i] as usize) << (i << 3);
279                        i += 1;
280                    }
281                    Ok(Str(res))
282                }
283                #[cfg(target_endian = "big")]
284                {
285                    let s = s.as_bytes();
286                    let mut i = 0;
287                    let mut res: usize = TAG_MASK;
288                    res |= len << 56;
289                    while i < len {
290                        res |= (s[i] as usize) << (48 - (i << 3));
291                        i += 1;
292                    }
293                    Ok(Str(res))
294                }
295            } else {
296                let mut root = ROOT.lock();
297                match root.all.get(s) {
298                    Some(t) => Ok(*t),
299                    None => {
300                        let (r, t) = (*root.root).insert(s.as_bytes());
301                        root.root = r;
302                        root.all.insert(t);
303                        Ok(t)
304                    }
305                }
306            }
307        }
308    }
309}
310
311impl TryFrom<Cow<'_, str>> for Str {
312    type Error = anyhow::Error;
313
314    fn try_from(s: Cow<str>) -> Result<Self, Self::Error> {
315        match s {
316            Cow::Borrowed(s) => Str::try_from(s),
317            Cow::Owned(s) => Str::try_from(s.as_str()),
318        }
319    }
320}
321
322#[cfg(feature = "juniper")]
323impl Str {
324    #[allow(clippy::wrong_self_convention)]
325    fn to_output<S: juniper::ScalarValue>(&self) -> juniper::Value<S> {
326        juniper::Value::scalar(self.as_str().to_string())
327    }
328
329    fn from_input<S>(v: &juniper::InputValue<S>) -> Result<Self, String>
330    where
331        S: juniper::ScalarValue,
332    {
333        v.as_string_value()
334            .map(Self::try_from)
335            .ok_or_else(|| format!("Expected `String`, found: {v}"))?
336            .map_err(|e| e.to_string())
337    }
338
339    fn parse_token<S>(value: juniper::ScalarToken<'_>) -> juniper::ParseScalarResult<S>
340    where
341        S: juniper::ScalarValue,
342    {
343        <String as juniper::ParseScalarValue<S>>::from_str(value)
344    }
345}
346
347#[cfg(feature = "postgres-types")]
348impl postgres_types::ToSql for Str {
349    postgres_types::to_sql_checked!();
350
351    fn to_sql(
352        &self,
353        ty: &postgres_types::Type,
354        out: &mut bytes::BytesMut,
355    ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
356        self.as_str().to_sql(ty, out)
357    }
358
359    fn accepts(ty: &postgres_types::Type) -> bool {
360        String::accepts(ty)
361    }
362}
363
364#[cfg(test)]
365mod test {
366    use super::*;
367    use rand::{rng, Rng};
368
369    fn rand_ascii(size: usize) -> String {
370        let mut s = String::new();
371        for _ in 0..size {
372            s.push(rng().random_range(' '..'~'))
373        }
374        s
375    }
376
377    fn rand_unicode(size: usize) -> String {
378        let mut s = String::new();
379        for _ in 0..size {
380            s.push(rng().random())
381        }
382        s
383    }
384
385    #[test]
386    fn immediates() {
387        for _ in 0..10000 {
388            let len = rng().random_range(0..8);
389            let s = rand_ascii(len);
390            let t0 = Str::try_from(s.as_str()).unwrap();
391            assert_eq!(&*t0, &*s);
392            let t1 = Str::try_from(s.as_str()).unwrap();
393            assert_eq!(t0.0, t1.0)
394        }
395    }
396
397    #[test]
398    fn mixed() {
399        for _ in 0..10000 {
400            let len = rng().random_range(0..256);
401            let s = rand_ascii(len);
402            let t0 = Str::try_from(s.as_str()).unwrap();
403            assert_eq!(&*t0, &*s);
404            let t1 = Str::try_from(s.as_str()).unwrap();
405            assert_eq!(t0.0, t1.0)
406        }
407    }
408
409    #[test]
410    fn unicode() {
411        for _ in 0..10000 {
412            let s = loop {
413                let len = rng().random_range(0..128);
414                let s = rand_unicode(len);
415                if s.len() < 256 {
416                    break s;
417                }
418            };
419            let t0 = Str::try_from(s.as_str()).unwrap();
420            assert_eq!(&*t0, &*s);
421            let t1 = Str::try_from(s.as_str()).unwrap();
422            assert_eq!(t0.0, t1.0)
423        }
424    }
425}