architect_api/utils/
str.rs1use anyhow::bail;
13use fxhash::FxHashSet;
14use once_cell::sync::Lazy;
15use parking_lot::Mutex;
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use std::{
19 borrow::{Borrow, Cow},
20 collections::HashSet,
21 fmt,
22 hash::Hash,
23 mem,
24 ops::Deref,
25 slice, str,
26};
27
28const TAG_MASK: usize = 0x8000_0000_0000_0000;
29const LEN_MASK: usize = 0x7F00_0000_0000_0000;
30const CHUNK_SIZE: usize = 1024 * 1024;
31
32struct Chunk {
33 data: Vec<u8>,
34 pos: usize,
35}
36
37impl Chunk {
38 #[cfg(target_pointer_width = "64")]
39 fn new() -> &'static mut Self {
40 let res = Box::leak(Box::new(Chunk { data: vec![0; CHUNK_SIZE], pos: 0 }));
41 assert!((res as *mut Self as usize) & TAG_MASK == 0);
42 res
43 }
44
45 fn insert(&mut self, str: &[u8]) -> (*mut Chunk, Str) {
46 let mut t = self;
47 loop {
48 if CHUNK_SIZE - t.pos > str.len() {
49 t.data[t.pos] = str.len() as u8;
50 t.data[t.pos + 1..t.pos + 1 + str.len()].copy_from_slice(str);
51 let res = Str(t.data.as_ptr().wrapping_add(t.pos) as usize);
52 t.pos += 1 + str.len();
53 break (t, res);
54 } else {
55 t = Self::new();
56 }
57 }
58 }
59}
60
61struct Root {
62 all: FxHashSet<Str>,
63 root: *mut Chunk,
64}
65
66unsafe impl Send for Root {}
67unsafe impl Sync for Root {}
68
69static ROOT: Lazy<Mutex<Root>> =
70 Lazy::new(|| Mutex::new(Root { all: HashSet::default(), root: Chunk::new() }));
71
72#[allow(dead_code)]
73struct StrVisitor;
74
75impl serde::de::Visitor<'_> for StrVisitor {
76 type Value = Str;
77
78 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
79 write!(f, "expecting a string")
80 }
81
82 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
83 where
84 E: serde::de::Error,
85 {
86 Str::try_from(s).map_err(|e| E::custom(e.to_string()))
87 }
88}
89
90#[allow(dead_code)]
91#[derive(JsonSchema)]
92struct AsStr(&'static str);
93
94#[derive(Clone, Copy, Deserialize, JsonSchema)]
114#[serde(try_from = "Cow<str>")]
115#[serde(into = "&str")]
116#[repr(transparent)]
117#[cfg_attr(feature = "juniper", derive(juniper::GraphQLScalar))]
118#[cfg_attr(feature = "juniper", graphql(description = "A String type"))]
119pub struct Str(#[schemars(with = "AsStr")] usize);
120
121unsafe impl Send for Str {}
122unsafe impl Sync for Str {}
123
124impl Str {
125 pub fn as_str(&self) -> &str {
126 unsafe {
127 if self.0 & TAG_MASK > 0 {
128 #[cfg(target_endian = "little")]
129 {
130 let len = (self.0 & LEN_MASK) >> 56;
131 let ptr = self as *const Self as *const u8;
132 let slice = slice::from_raw_parts(ptr, len);
133 str::from_utf8_unchecked(slice)
134 }
135 #[cfg(target_endian = "big")]
136 {
137 let len = (self.0 & LEN_MASK) >> 56;
138 let ptr = (self as *const Self as *const u8).wrapping_add(1);
139 let slice = slice::from_raw_parts(ptr, len);
140 str::from_utf8_unchecked(slice)
141 }
142 } else {
143 let t = self.0 as *const u8;
144 let len = *t as usize;
145 let ptr = t.wrapping_add(1);
146 let slice = slice::from_raw_parts(ptr, len);
147 str::from_utf8_unchecked(slice)
148 }
149 }
150 }
151
152 pub fn as_static_str(&self) -> Option<&'static str> {
154 unsafe {
155 if self.0 & TAG_MASK > 0 {
156 None
157 } else {
158 Some(mem::transmute::<&str, &'static str>(self.as_str()))
159 }
160 }
161 }
162
163 pub fn is_immediate(&self) -> bool {
165 self.0 & TAG_MASK > 0
166 }
167}
168
169impl fmt::Debug for Str {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 write!(f, "{}", &**self)
172 }
173}
174
175impl fmt::Display for Str {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 write!(f, "{}", &**self)
178 }
179}
180
181impl Serialize for Str {
182 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
183 where
184 S: serde::Serializer,
185 {
186 serializer.serialize_str(self.as_str())
187 }
188}
189
190impl Deref for Str {
191 type Target = str;
192
193 fn deref(&self) -> &Self::Target {
194 self.as_str()
195 }
196}
197
198impl Borrow<str> for Str {
199 fn borrow(&self) -> &str {
200 self.as_str()
201 }
202}
203
204impl Borrow<str> for &Str {
205 fn borrow(&self) -> &str {
206 self.as_str()
207 }
208}
209
210impl AsRef<str> for Str {
211 fn as_ref(&self) -> &str {
212 self.as_str()
213 }
214}
215
216impl Hash for Str {
217 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
218 self.as_str().hash(state)
219 }
220}
221
222impl PartialEq for Str {
223 fn eq(&self, other: &Self) -> bool {
224 self.0 == other.0
225 }
226}
227
228impl PartialEq<&str> for Str {
229 fn eq(&self, other: &&str) -> bool {
230 self.as_str() == *other
231 }
232}
233
234impl Eq for Str {}
235
236impl PartialOrd for Str {
237 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
238 Some(self.cmp(other))
239 }
240}
241
242impl PartialOrd<&str> for Str {
243 fn partial_cmp(&self, other: &&str) -> Option<std::cmp::Ordering> {
244 self.as_str().partial_cmp(*other)
245 }
246}
247
248impl Ord for Str {
249 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
250 self.as_str().cmp(other.as_str())
251 }
252}
253
254impl TryFrom<String> for Str {
255 type Error = anyhow::Error;
256
257 fn try_from(s: String) -> Result<Self, Self::Error> {
258 s.as_str().try_into()
259 }
260}
261
262impl TryFrom<&str> for Str {
263 type Error = anyhow::Error;
264
265 fn try_from(s: &str) -> Result<Self, Self::Error> {
266 unsafe {
267 let len = s.len();
268 if len > u8::MAX as usize {
269 bail!("string is too long")
270 } else if len < 8 {
271 #[cfg(target_endian = "little")]
272 {
273 let s = s.as_bytes();
274 let mut i = 0;
275 let mut res: usize = TAG_MASK;
276 res |= len << 56;
277 while i < len {
278 res |= (s[i] as usize) << (i << 3);
279 i += 1;
280 }
281 Ok(Str(res))
282 }
283 #[cfg(target_endian = "big")]
284 {
285 let s = s.as_bytes();
286 let mut i = 0;
287 let mut res: usize = TAG_MASK;
288 res |= len << 56;
289 while i < len {
290 res |= (s[i] as usize) << (48 - (i << 3));
291 i += 1;
292 }
293 Ok(Str(res))
294 }
295 } else {
296 let mut root = ROOT.lock();
297 match root.all.get(s) {
298 Some(t) => Ok(*t),
299 None => {
300 let (r, t) = (*root.root).insert(s.as_bytes());
301 root.root = r;
302 root.all.insert(t);
303 Ok(t)
304 }
305 }
306 }
307 }
308 }
309}
310
311impl TryFrom<Cow<'_, str>> for Str {
312 type Error = anyhow::Error;
313
314 fn try_from(s: Cow<str>) -> Result<Self, Self::Error> {
315 match s {
316 Cow::Borrowed(s) => Str::try_from(s),
317 Cow::Owned(s) => Str::try_from(s.as_str()),
318 }
319 }
320}
321
322#[cfg(feature = "juniper")]
323impl Str {
324 #[allow(clippy::wrong_self_convention)]
325 fn to_output<S: juniper::ScalarValue>(&self) -> juniper::Value<S> {
326 juniper::Value::scalar(self.as_str().to_string())
327 }
328
329 fn from_input<S>(v: &juniper::InputValue<S>) -> Result<Self, String>
330 where
331 S: juniper::ScalarValue,
332 {
333 v.as_string_value()
334 .map(Self::try_from)
335 .ok_or_else(|| format!("Expected `String`, found: {v}"))?
336 .map_err(|e| e.to_string())
337 }
338
339 fn parse_token<S>(value: juniper::ScalarToken<'_>) -> juniper::ParseScalarResult<S>
340 where
341 S: juniper::ScalarValue,
342 {
343 <String as juniper::ParseScalarValue<S>>::from_str(value)
344 }
345}
346
347#[cfg(feature = "postgres-types")]
348impl postgres_types::ToSql for Str {
349 postgres_types::to_sql_checked!();
350
351 fn to_sql(
352 &self,
353 ty: &postgres_types::Type,
354 out: &mut bytes::BytesMut,
355 ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
356 self.as_str().to_sql(ty, out)
357 }
358
359 fn accepts(ty: &postgres_types::Type) -> bool {
360 String::accepts(ty)
361 }
362}
363
364#[cfg(test)]
365mod test {
366 use super::*;
367 use rand::{rng, Rng};
368
369 fn rand_ascii(size: usize) -> String {
370 let mut s = String::new();
371 for _ in 0..size {
372 s.push(rng().random_range(' '..'~'))
373 }
374 s
375 }
376
377 fn rand_unicode(size: usize) -> String {
378 let mut s = String::new();
379 for _ in 0..size {
380 s.push(rng().random())
381 }
382 s
383 }
384
385 #[test]
386 fn immediates() {
387 for _ in 0..10000 {
388 let len = rng().random_range(0..8);
389 let s = rand_ascii(len);
390 let t0 = Str::try_from(s.as_str()).unwrap();
391 assert_eq!(&*t0, &*s);
392 let t1 = Str::try_from(s.as_str()).unwrap();
393 assert_eq!(t0.0, t1.0)
394 }
395 }
396
397 #[test]
398 fn mixed() {
399 for _ in 0..10000 {
400 let len = rng().random_range(0..256);
401 let s = rand_ascii(len);
402 let t0 = Str::try_from(s.as_str()).unwrap();
403 assert_eq!(&*t0, &*s);
404 let t1 = Str::try_from(s.as_str()).unwrap();
405 assert_eq!(t0.0, t1.0)
406 }
407 }
408
409 #[test]
410 fn unicode() {
411 for _ in 0..10000 {
412 let s = loop {
413 let len = rng().random_range(0..128);
414 let s = rand_unicode(len);
415 if s.len() < 256 {
416 break s;
417 }
418 };
419 let t0 = Str::try_from(s.as_str()).unwrap();
420 assert_eq!(&*t0, &*s);
421 let t1 = Str::try_from(s.as_str()).unwrap();
422 assert_eq!(t0.0, t1.0)
423 }
424 }
425}