Skip to main content

jacquard_common/
cowstr.rs

1use alloc::borrow::Cow;
2use alloc::boxed::Box;
3use alloc::string::String;
4use core::fmt;
5use core::hash::{Hash, Hasher};
6use core::ops::Deref;
7
8use serde::{Deserialize, Serialize};
9use smol_str::SmolStr;
10
11use crate::IntoStatic;
12
13/// A copy-on-write immutable string type that uses [`SmolStr`] for
14/// the "owned" variant.
15///
16/// The standard [`Cow`] type cannot be used, since
17/// `<str as ToOwned>::Owned` is `String`, and not `SmolStr`.
18///
19/// Shamelessly ported from [merde](https://github.com/bearcove/merde)
20#[derive(Clone)]
21pub enum CowStr<'s> {
22    /// &str varaiant
23    Borrowed(&'s str),
24    /// Smolstr variant
25    Owned(SmolStr),
26}
27
28impl CowStr<'static> {
29    /// Create a new `CowStr` by copying from a `&str` — this might allocate
30    /// if the string is longer than `MAX_INLINE_SIZE`.
31    pub fn copy_from_str(s: &str) -> Self {
32        Self::Owned(SmolStr::from(s))
33    }
34
35    /// Create a new owned `CowStr` from a static &str without allocating
36    pub fn new_static(s: &'static str) -> Self {
37        Self::Owned(SmolStr::new_static(s))
38    }
39}
40
41impl<'s> CowStr<'s> {
42    #[inline]
43    /// Borrow and decode a byte slice as utf8 into a CowStr
44    pub fn from_utf8(s: &'s [u8]) -> Result<Self, core::str::Utf8Error> {
45        Ok(Self::Borrowed(core::str::from_utf8(s)?))
46    }
47
48    #[inline]
49    /// Take bytes and decode them as utf8 into an owned CowStr. Might allocate.
50    pub fn from_utf8_owned(s: impl AsRef<[u8]>) -> Result<Self, core::str::Utf8Error> {
51        Ok(Self::Owned(SmolStr::new(core::str::from_utf8(s.as_ref())?)))
52    }
53
54    #[inline]
55    /// Take bytes and decode them as utf8, skipping invalid characters, taking ownership.
56    /// Will allocate, uses String::from_utf8_lossy() internally for now.
57    pub fn from_utf8_lossy(s: &'s [u8]) -> Self {
58        Self::Owned(String::from_utf8_lossy(&s).into())
59    }
60
61    /// # Safety
62    ///
63    /// This function is unsafe because it does not check that the bytes are valid UTF-8.
64    #[inline]
65    pub unsafe fn from_utf8_unchecked(s: &'s [u8]) -> Self {
66        unsafe { Self::Owned(SmolStr::new(core::str::from_utf8_unchecked(s))) }
67    }
68
69    /// Returns a reference to the underlying string slice.
70    #[inline]
71    pub fn as_str(&self) -> &str {
72        match self {
73            CowStr::Borrowed(s) => s,
74            CowStr::Owned(s) => s.as_str(),
75        }
76    }
77}
78
79impl AsRef<str> for CowStr<'_> {
80    #[inline]
81    fn as_ref(&self) -> &str {
82        match self {
83            CowStr::Borrowed(s) => s,
84            CowStr::Owned(s) => s.as_str(),
85        }
86    }
87}
88
89impl Deref for CowStr<'_> {
90    type Target = str;
91
92    #[inline]
93    fn deref(&self) -> &Self::Target {
94        match self {
95            CowStr::Borrowed(s) => s,
96            CowStr::Owned(s) => s.as_str(),
97        }
98    }
99}
100
101impl<'a> From<Cow<'a, str>> for CowStr<'a> {
102    #[inline]
103    fn from(s: Cow<'a, str>) -> Self {
104        match s {
105            Cow::Borrowed(s) => CowStr::Borrowed(s),
106            #[allow(clippy::useless_conversion)]
107            Cow::Owned(s) => CowStr::Owned(s.into()),
108        }
109    }
110}
111
112impl<'s> From<&'s str> for CowStr<'s> {
113    #[inline]
114    fn from(s: &'s str) -> Self {
115        CowStr::Borrowed(s)
116    }
117}
118
119impl Default for CowStr<'_> {
120    #[inline]
121    fn default() -> Self {
122        CowStr::new_static("")
123    }
124}
125
126impl From<String> for CowStr<'_> {
127    #[inline]
128    fn from(s: String) -> Self {
129        #[allow(clippy::useless_conversion)]
130        CowStr::Owned(s.into())
131    }
132}
133
134impl From<Box<str>> for CowStr<'_> {
135    #[inline]
136    fn from(s: Box<str>) -> Self {
137        CowStr::Owned(s.into())
138    }
139}
140
141impl<'s> From<&'s String> for CowStr<'s> {
142    #[inline]
143    fn from(s: &'s String) -> Self {
144        CowStr::Borrowed(s.as_str())
145    }
146}
147
148impl From<CowStr<'_>> for String {
149    #[inline]
150    fn from(s: CowStr<'_>) -> Self {
151        match s {
152            CowStr::Borrowed(s) => s.into(),
153            #[allow(clippy::useless_conversion)]
154            CowStr::Owned(s) => s.into(),
155        }
156    }
157}
158
159impl From<CowStr<'_>> for SmolStr {
160    #[inline]
161    fn from(s: CowStr<'_>) -> Self {
162        match s {
163            CowStr::Borrowed(s) => SmolStr::new(s),
164            CowStr::Owned(s) => SmolStr::new(s),
165        }
166    }
167}
168
169impl From<SmolStr> for CowStr<'_> {
170    #[inline]
171    fn from(s: SmolStr) -> Self {
172        CowStr::Owned(s)
173    }
174}
175
176impl From<CowStr<'_>> for Box<str> {
177    #[inline]
178    fn from(s: CowStr<'_>) -> Self {
179        match s {
180            CowStr::Borrowed(s) => s.into(),
181            CowStr::Owned(s) => String::from(s).into_boxed_str(),
182        }
183    }
184}
185
186impl<'a> PartialEq<CowStr<'a>> for CowStr<'_> {
187    #[inline]
188    fn eq(&self, other: &CowStr<'a>) -> bool {
189        self.deref() == other.deref()
190    }
191}
192
193impl PartialEq<&str> for CowStr<'_> {
194    #[inline]
195    fn eq(&self, other: &&str) -> bool {
196        self.deref() == *other
197    }
198}
199
200impl PartialEq<CowStr<'_>> for &str {
201    #[inline]
202    fn eq(&self, other: &CowStr<'_>) -> bool {
203        *self == other.deref()
204    }
205}
206
207impl PartialEq<String> for CowStr<'_> {
208    #[inline]
209    fn eq(&self, other: &String) -> bool {
210        self.deref() == other.as_str()
211    }
212}
213
214impl PartialEq<CowStr<'_>> for String {
215    #[inline]
216    fn eq(&self, other: &CowStr<'_>) -> bool {
217        self.as_str() == other.deref()
218    }
219}
220
221impl PartialOrd for CowStr<'_> {
222    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
223        Some(match (self, other) {
224            (CowStr::Borrowed(s1), CowStr::Borrowed(s2)) => s1.cmp(s2),
225            (CowStr::Borrowed(s1), CowStr::Owned(s2)) => s1.cmp(&s2.as_ref()),
226            (CowStr::Owned(s1), CowStr::Borrowed(s2)) => s1.as_str().cmp(s2),
227            (CowStr::Owned(s1), CowStr::Owned(s2)) => s1.cmp(s2),
228        })
229    }
230}
231
232impl Ord for CowStr<'_> {
233    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
234        match (self, other) {
235            (CowStr::Borrowed(s1), CowStr::Borrowed(s2)) => s1.cmp(s2),
236            (CowStr::Borrowed(s1), CowStr::Owned(s2)) => s1.cmp(&s2.as_ref()),
237            (CowStr::Owned(s1), CowStr::Borrowed(s2)) => s1.as_str().cmp(s2),
238            (CowStr::Owned(s1), CowStr::Owned(s2)) => s1.cmp(s2),
239        }
240    }
241}
242
243impl Eq for CowStr<'_> {}
244
245impl Hash for CowStr<'_> {
246    #[inline]
247    fn hash<H: Hasher>(&self, state: &mut H) {
248        self.deref().hash(state)
249    }
250}
251
252impl fmt::Debug for CowStr<'_> {
253    #[inline]
254    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255        self.deref().fmt(f)
256    }
257}
258
259impl fmt::Display for CowStr<'_> {
260    #[inline]
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        self.deref().fmt(f)
263    }
264}
265
266// TODO(bos-migration): Change Output to SmolStr once types are parameterised by S: Bos<str>.
267impl IntoStatic for CowStr<'_> {
268    type Output = CowStr<'static>;
269
270    #[inline]
271    fn into_static(self) -> Self::Output {
272        match self {
273            CowStr::Borrowed(s) => CowStr::Owned((*s).into()),
274            CowStr::Owned(s) => CowStr::Owned(s),
275        }
276    }
277}
278
279impl Serialize for CowStr<'_> {
280    #[inline]
281    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
282    where
283        S: serde::Serializer,
284    {
285        serializer.serialize_str(self)
286    }
287}
288
289/// Deserialization helper for things that wrap a CowStr
290pub struct CowStrVisitor;
291
292impl<'de> serde::de::Visitor<'de> for CowStrVisitor {
293    type Value = CowStr<'de>;
294
295    #[inline]
296    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
297        write!(formatter, "a string")
298    }
299
300    #[inline]
301    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
302    where
303        E: serde::de::Error,
304    {
305        Ok(CowStr::copy_from_str(v))
306    }
307
308    #[inline]
309    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
310    where
311        E: serde::de::Error,
312    {
313        Ok(CowStr::Borrowed(v))
314    }
315
316    #[inline]
317    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
318    where
319        E: serde::de::Error,
320    {
321        Ok(v.into())
322    }
323}
324
325impl<'de, 'a> Deserialize<'de> for CowStr<'a>
326where
327    'de: 'a,
328{
329    #[inline]
330    fn deserialize<D>(deserializer: D) -> Result<CowStr<'a>, D::Error>
331    where
332        D: serde::Deserializer<'de>,
333    {
334        deserializer.deserialize_str(CowStrVisitor)
335    }
336}
337
338impl core::str::FromStr for CowStr<'_> {
339    type Err = core::convert::Infallible;
340
341    fn from_str(s: &str) -> Result<Self, Self::Err> {
342        Ok(CowStr::copy_from_str(s))
343    }
344}
345
346/// Convert to a CowStr.
347pub trait ToCowStr {
348    /// Convert to a CowStr.
349    fn to_cowstr(&self) -> CowStr<'_>;
350}
351
352impl<T> ToCowStr for T
353where
354    T: fmt::Display + ?Sized,
355{
356    fn to_cowstr(&self) -> CowStr<'_> {
357        CowStr::Owned(smol_str::format_smolstr!("{}", self))
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_partialeq_with_str() {
367        let cow_str1 = CowStr::Borrowed("hello");
368        let cow_str2 = CowStr::Borrowed("hello");
369        let cow_str3 = CowStr::Borrowed("world");
370
371        assert_eq!(cow_str1, "hello");
372        assert_eq!("hello", cow_str1);
373        assert_eq!(cow_str1, cow_str2);
374        assert_ne!(cow_str1, "world");
375        assert_ne!("world", cow_str1);
376        assert_ne!(cow_str1, cow_str3);
377    }
378}