1use anyhow::bail;
13use fxhash::FxHashSet;
14#[cfg(feature = "netidx")]
15use netidx::{
16 chars::Chars,
17 pack::{decode_varint, encode_varint, varint_len, Pack, PackError},
18};
19use once_cell::sync::Lazy;
20use parking_lot::Mutex;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use std::{
24 borrow::{Borrow, Cow},
25 collections::HashSet,
26 fmt,
27 hash::Hash,
28 mem,
29 ops::Deref,
30 slice, str,
31};
32
33const TAG_MASK: usize = 0x8000_0000_0000_0000;
34const LEN_MASK: usize = 0x7F00_0000_0000_0000;
35const CHUNK_SIZE: usize = 1 * 1024 * 1024;
36
37struct Chunk {
38 data: Vec<u8>,
39 pos: usize,
40}
41
42impl Chunk {
43 #[cfg(target_pointer_width = "64")]
44 fn new() -> &'static mut Self {
45 let res = Box::leak(Box::new(Chunk { data: vec![0; CHUNK_SIZE], pos: 0 }));
46 assert!((res as *mut Self as usize) & TAG_MASK == 0);
47 res
48 }
49
50 fn insert(&mut self, str: &[u8]) -> (*mut Chunk, Str) {
51 let mut t = self;
52 loop {
53 if CHUNK_SIZE - t.pos > str.len() {
54 t.data[t.pos] = str.len() as u8;
55 t.data[t.pos + 1..t.pos + 1 + str.len()].copy_from_slice(str);
56 let res = Str(t.data.as_ptr().wrapping_add(t.pos) as usize);
57 t.pos += 1 + str.len();
58 break (t, res);
59 } else {
60 t = Self::new();
61 }
62 }
63 }
64}
65
66struct Root {
67 all: FxHashSet<Str>,
68 root: *mut Chunk,
69}
70
71unsafe impl Send for Root {}
72unsafe impl Sync for Root {}
73
74static ROOT: Lazy<Mutex<Root>> =
75 Lazy::new(|| Mutex::new(Root { all: HashSet::default(), root: Chunk::new() }));
76
77#[allow(dead_code)]
78struct StrVisitor;
79
80impl<'de> serde::de::Visitor<'de> for StrVisitor {
81 type Value = Str;
82
83 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
84 write!(f, "expecting a string")
85 }
86
87 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
88 where
89 E: serde::de::Error,
90 {
91 Str::try_from(s).map_err(|e| E::custom(e.to_string()))
92 }
93}
94
95#[allow(dead_code)]
96#[derive(JsonSchema)]
97struct AsStr(&'static str);
98
99#[derive(Clone, Copy, Deserialize, JsonSchema)]
119#[serde(try_from = "Cow<str>")]
120#[serde(into = "&str")]
121#[repr(transparent)]
122#[cfg_attr(feature = "juniper", derive(juniper::GraphQLScalar))]
123#[cfg_attr(feature = "juniper", graphql(description = "A String type"))]
124pub struct Str(#[schemars(with = "AsStr")] usize);
125
126unsafe impl Send for Str {}
127unsafe impl Sync for Str {}
128
129impl Str {
130 pub fn as_str<'a>(&'a self) -> &'a str {
131 unsafe {
132 if self.0 & TAG_MASK > 0 {
133 #[cfg(target_endian = "little")]
134 {
135 let len = (self.0 & LEN_MASK) >> 56;
136 let ptr = self as *const Self as *const u8;
137 let slice = slice::from_raw_parts(ptr, len);
138 str::from_utf8_unchecked(slice)
139 }
140 #[cfg(target_endian = "big")]
141 {
142 let len = (self.0 & LEN_MASK) >> 56;
143 let ptr = (self as *const Self as *const u8).wrapping_add(1);
144 let slice = slice::from_raw_parts(ptr, len);
145 str::from_utf8_unchecked(slice)
146 }
147 } else {
148 let t = self.0 as *const u8;
149 let len = *t as usize;
150 let ptr = t.wrapping_add(1);
151 let slice = slice::from_raw_parts(ptr, len);
152 str::from_utf8_unchecked(slice)
153 }
154 }
155 }
156
157 pub fn as_static_str(&self) -> Option<&'static str> {
159 unsafe {
160 if self.0 & TAG_MASK > 0 {
161 None
162 } else {
163 Some(mem::transmute::<&str, &'static str>(self.as_str()))
164 }
165 }
166 }
167
168 pub fn is_immediate(&self) -> bool {
170 self.0 & TAG_MASK > 0
171 }
172
173 #[cfg(feature = "netidx")]
174 pub fn as_chars(&self) -> Chars {
175 match self.as_static_str() {
176 Some(s) => Chars::from(s),
177 None => Chars::from(String::from(self.as_str())),
178 }
179 }
180}
181
182impl fmt::Debug for Str {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 write!(f, "{}", &**self)
185 }
186}
187
188impl fmt::Display for Str {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 write!(f, "{}", &**self)
191 }
192}
193
194impl Serialize for Str {
195 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
196 where
197 S: serde::Serializer,
198 {
199 serializer.serialize_str(self.as_str())
200 }
201}
202
203#[cfg(feature = "netidx")]
204impl Pack for Str {
205 fn encoded_len(&self) -> usize {
206 let len = self.len();
207 varint_len(len as u64) + len
208 }
209
210 fn encode(
211 &self,
212 buf: &mut impl bytes::BufMut,
213 ) -> Result<(), netidx::pack::PackError> {
214 let s = &**self;
215 encode_varint(s.len() as u64, buf);
216 Ok(buf.put_slice(s.as_bytes()))
217 }
218
219 fn decode(buf: &mut impl bytes::Buf) -> Result<Self, netidx::pack::PackError> {
220 use std::cell::RefCell;
221 thread_local! {
222 static BUF: RefCell<Vec<u8>> = RefCell::new(Vec::new());
223 }
224 let len = decode_varint(buf)? as usize;
225 if len > u8::MAX as usize {
226 Err(PackError::TooBig)
227 } else {
228 BUF.with(|tmp| {
229 let mut tmp = tmp.borrow_mut();
230 tmp.resize(len, 0);
231 buf.copy_to_slice(&mut *tmp);
232 match str::from_utf8(&*tmp) {
233 Err(_) => Err(PackError::InvalidFormat),
234 Ok(s) => Ok(Str::try_from(s).unwrap()),
235 }
236 })
237 }
238 }
239}
240
241impl Deref for Str {
242 type Target = str;
243
244 fn deref(&self) -> &Self::Target {
245 self.as_str()
246 }
247}
248
249impl Borrow<str> for Str {
250 fn borrow(&self) -> &str {
251 self.as_str()
252 }
253}
254
255impl Borrow<str> for &Str {
256 fn borrow(&self) -> &str {
257 self.as_str()
258 }
259}
260
261impl AsRef<str> for Str {
262 fn as_ref(&self) -> &str {
263 self.as_str()
264 }
265}
266
267impl Hash for Str {
268 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
269 self.as_str().hash(state)
270 }
271}
272
273impl PartialEq for Str {
274 fn eq(&self, other: &Self) -> bool {
275 self.0 == other.0
276 }
277}
278
279impl PartialEq<&str> for Str {
280 fn eq(&self, other: &&str) -> bool {
281 self.as_str() == *other
282 }
283}
284
285impl Eq for Str {}
286
287impl PartialOrd for Str {
288 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
289 self.as_str().partial_cmp(other.as_str())
290 }
291}
292
293impl PartialOrd<&str> for Str {
294 fn partial_cmp(&self, other: &&str) -> Option<std::cmp::Ordering> {
295 self.as_str().partial_cmp(*other)
296 }
297}
298
299impl Ord for Str {
300 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
301 self.as_str().cmp(other.as_str())
302 }
303}
304
305impl TryFrom<String> for Str {
306 type Error = anyhow::Error;
307
308 fn try_from(s: String) -> Result<Self, Self::Error> {
309 s.as_str().try_into()
310 }
311}
312
313impl TryFrom<&str> for Str {
314 type Error = anyhow::Error;
315
316 fn try_from(s: &str) -> Result<Self, Self::Error> {
317 unsafe {
318 let len = s.len();
319 if len > u8::MAX as usize {
320 bail!("string is too long")
321 } else if len < 8 {
322 #[cfg(target_endian = "little")]
323 {
324 let s = s.as_bytes();
325 let mut i = 0;
326 let mut res: usize = TAG_MASK;
327 res |= len << 56;
328 while i < len {
329 res |= (s[i] as usize) << (i << 3);
330 i += 1;
331 }
332 Ok(Str(res))
333 }
334 #[cfg(target_endian = "big")]
335 {
336 let s = s.as_bytes();
337 let mut i = 0;
338 let mut res: usize = TAG_MASK;
339 res |= len << 56;
340 while i < len {
341 res |= (s[i] as usize) << (48 - (i << 3));
342 i += 1;
343 }
344 Ok(Str(res))
345 }
346 } else {
347 let mut root = ROOT.lock();
348 match root.all.get(s) {
349 Some(t) => Ok(*t),
350 None => {
351 let (r, t) = (*root.root).insert(s.as_bytes());
352 root.root = r;
353 root.all.insert(t);
354 Ok(t)
355 }
356 }
357 }
358 }
359 }
360}
361
362impl TryFrom<Cow<'_, str>> for Str {
363 type Error = anyhow::Error;
364
365 fn try_from(s: Cow<str>) -> Result<Self, Self::Error> {
366 match s {
367 Cow::Borrowed(s) => Str::try_from(s),
368 Cow::Owned(s) => Str::try_from(s.as_str()),
369 }
370 }
371}
372
373#[cfg(feature = "juniper")]
374impl Str {
375 fn to_output<S: juniper::ScalarValue>(&self) -> juniper::Value<S> {
376 juniper::Value::scalar(self.as_str().to_string())
377 }
378
379 fn from_input<S>(v: &juniper::InputValue<S>) -> Result<Self, String>
380 where
381 S: juniper::ScalarValue,
382 {
383 v.as_string_value()
384 .map(|s| Self::try_from(s))
385 .ok_or_else(|| format!("Expected `String`, found: {v}"))?
386 .map_err(|e| e.to_string())
387 }
388
389 fn parse_token<S>(value: juniper::ScalarToken<'_>) -> juniper::ParseScalarResult<S>
390 where
391 S: juniper::ScalarValue,
392 {
393 <String as juniper::ParseScalarValue<S>>::from_str(value)
394 }
395}
396
397#[cfg(feature = "postgres-types")]
398impl postgres_types::ToSql for Str {
399 postgres_types::to_sql_checked!();
400
401 fn to_sql(
402 &self,
403 ty: &postgres_types::Type,
404 out: &mut bytes::BytesMut,
405 ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
406 self.as_str().to_sql(ty, out)
407 }
408
409 fn accepts(ty: &postgres_types::Type) -> bool {
410 String::accepts(ty)
411 }
412}
413
414#[cfg(test)]
415mod test {
416 use super::*;
417 use rand::{thread_rng, Rng};
418
419 fn rand_ascii(size: usize) -> String {
420 let mut s = String::new();
421 for _ in 0..size {
422 s.push(thread_rng().gen_range(' '..'~'))
423 }
424 s
425 }
426
427 fn rand_unicode(size: usize) -> String {
428 let mut s = String::new();
429 for _ in 0..size {
430 s.push(thread_rng().gen())
431 }
432 s
433 }
434
435 #[test]
436 fn immediates() {
437 for _ in 0..10000 {
438 let len = thread_rng().gen_range(0..8);
439 let s = rand_ascii(len);
440 let t0 = Str::try_from(s.as_str()).unwrap();
441 assert_eq!(&*t0, &*s);
442 let t1 = Str::try_from(s.as_str()).unwrap();
443 assert_eq!(t0.0, t1.0)
444 }
445 }
446
447 #[test]
448 fn mixed() {
449 for _ in 0..10000 {
450 let len = thread_rng().gen_range(0..256);
451 let s = rand_ascii(len);
452 let t0 = Str::try_from(s.as_str()).unwrap();
453 assert_eq!(&*t0, &*s);
454 let t1 = Str::try_from(s.as_str()).unwrap();
455 assert_eq!(t0.0, t1.0)
456 }
457 }
458
459 #[test]
460 fn unicode() {
461 for _ in 0..10000 {
462 let s = loop {
463 let len = thread_rng().gen_range(0..128);
464 let s = rand_unicode(len);
465 if s.as_bytes().len() < 256 {
466 break s;
467 }
468 };
469 let t0 = Str::try_from(s.as_str()).unwrap();
470 assert_eq!(&*t0, &*s);
471 let t1 = Str::try_from(s.as_str()).unwrap();
472 assert_eq!(t0.0, t1.0)
473 }
474 }
475}