Skip to main content

lexe_common/
byte_str.rs

1use std::{borrow::Borrow, fmt, ops};
2
3use bytes::Bytes;
4use serde::{de, ser};
5use thiserror::Error;
6
7/// `ByteStr` is just a tokio [`Bytes`], but it maintains the internal
8/// invariant that the inner [`Bytes`] must be a valid utf8 string.
9#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
10pub struct ByteStr(Bytes);
11
12#[derive(Debug, Error)]
13#[error("not a valid utf8 string")]
14pub struct Utf8Error;
15
16impl ByteStr {
17    /// Creates a new empty `ByteStr`. This does not allocate.
18    #[inline]
19    pub const fn new() -> Self {
20        Self(Bytes::new())
21    }
22
23    #[inline]
24    pub const fn from_static(s: &'static str) -> Self {
25        // INVARIANT: `s` is a string, so must be valid utf8
26        Self(Bytes::from_static(s.as_bytes()))
27    }
28
29    #[inline]
30    fn from_utf8_unchecked(b: Bytes) -> Self {
31        if cfg!(debug_assertions) {
32            match std::str::from_utf8(b.as_ref()) {
33                Ok(_) => (),
34                Err(err) => {
35                    panic!("input is not valid utf8: err: {err}, bytes: {b:?}")
36                }
37            }
38        }
39
40        Self(b)
41    }
42
43    #[inline]
44    pub fn as_str(&self) -> &str {
45        let b = self.0.as_ref();
46        // SAFETY: the internal invariant guarantees that `b` is valid utf8
47        unsafe { std::str::from_utf8_unchecked(b) }
48    }
49
50    pub fn try_from_bytes(b: Bytes) -> Result<Self, Utf8Error> {
51        if std::str::from_utf8(b.as_ref()).is_ok() {
52            // INVARIANT: we've just verified that `b` is valid utf8
53            Ok(Self::from_utf8_unchecked(b))
54        } else {
55            Err(Utf8Error)
56        }
57    }
58}
59
60impl Default for ByteStr {
61    #[inline]
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl fmt::Display for ByteStr {
68    #[inline]
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        fmt::Display::fmt(self.as_str(), f)
71    }
72}
73
74impl fmt::Debug for ByteStr {
75    #[inline]
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        fmt::Debug::fmt(self.as_str(), f)
78    }
79}
80
81impl ops::Deref for ByteStr {
82    type Target = str;
83
84    #[inline]
85    fn deref(&self) -> &Self::Target {
86        self.as_str()
87    }
88}
89
90impl AsRef<str> for ByteStr {
91    #[inline]
92    fn as_ref(&self) -> &str {
93        self.as_str()
94    }
95}
96
97impl AsRef<[u8]> for ByteStr {
98    #[inline]
99    fn as_ref(&self) -> &[u8] {
100        self.0.as_ref()
101    }
102}
103
104impl Borrow<str> for ByteStr {
105    #[inline]
106    fn borrow(&self) -> &str {
107        self.as_str()
108    }
109}
110
111impl From<String> for ByteStr {
112    #[inline]
113    fn from(s: String) -> Self {
114        // INVARIANT: `s` is a String, so must be valid utf8
115        Self::from_utf8_unchecked(Bytes::from(s))
116    }
117}
118
119impl<'a> From<&'a str> for ByteStr {
120    #[inline]
121    fn from(s: &'a str) -> Self {
122        // INVARIANT: `s` is a &str, so must be valid utf8
123        Self::from_utf8_unchecked(Bytes::copy_from_slice(s.as_bytes()))
124    }
125}
126
127impl From<ByteStr> for Bytes {
128    #[inline]
129    fn from(bs: ByteStr) -> Self {
130        bs.0
131    }
132}
133
134impl ser::Serialize for ByteStr {
135    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
136    where
137        S: serde::Serializer,
138    {
139        serializer.serialize_str(self.as_str())
140    }
141}
142
143impl<'de> de::Deserialize<'de> for ByteStr {
144    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
145    where
146        D: serde::Deserializer<'de>,
147    {
148        struct ByteStrVisitor;
149
150        impl de::Visitor<'_> for ByteStrVisitor {
151            type Value = ByteStr;
152
153            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
154                f.write_str("string")
155            }
156
157            #[inline]
158            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
159            where
160                E: de::Error,
161            {
162                Ok(ByteStr::from(v))
163            }
164
165            #[inline]
166            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
167            where
168                E: de::Error,
169            {
170                Ok(ByteStr::from(v))
171            }
172        }
173
174        deserializer.deserialize_string(ByteStrVisitor)
175    }
176}
177
178#[cfg(any(test, feature = "test-utils"))]
179mod arbitrary_impl {
180    use proptest::{
181        arbitrary::Arbitrary,
182        strategy::{BoxedStrategy, Strategy},
183    };
184
185    use super::*;
186    use crate::test_utils::arbitrary;
187
188    impl Arbitrary for ByteStr {
189        type Parameters = ();
190        type Strategy = BoxedStrategy<Self>;
191
192        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
193            arbitrary::any_string().prop_map(ByteStr::from).boxed()
194        }
195    }
196}
197
198#[cfg(test)]
199mod test {
200    use proptest::{
201        arbitrary::any, prop_assert, prop_assert_eq, prop_oneof, proptest,
202        strategy::Strategy,
203    };
204
205    use super::*;
206    use crate::test_utils::arbitrary;
207
208    /// Generates arbitrary [`Bytes`], but half the time the result is
209    /// guaranteed to be a valid utf8 string.
210    fn arb_bytes() -> impl Strategy<Value = Bytes> {
211        prop_oneof![
212            any::<Vec<u8>>().prop_map(Bytes::from),
213            arbitrary::any_string().prop_map(Bytes::from),
214        ]
215    }
216
217    #[test]
218    fn str_from_utf8_equiv() {
219        proptest!(|(bytes in arb_bytes())| {
220            let res1 = ByteStr::try_from_bytes(bytes.clone());
221            let res2 = std::str::from_utf8(&bytes);
222
223            match (&res1, &res2) {
224                (Ok(s1), Ok(s2)) => {
225                    prop_assert_eq!(&s1.as_str(), s2);
226                }
227                (Err(_), Err(_)) => () /* both reject => ok */,
228                (Ok(_), Err(_)) | (Err(_), Ok(_)) =>
229                    prop_assert!(false, "res1 ({res1:?}) != res2 ({res2:?})"),
230            }
231        })
232    }
233}