loro_internal/utils/
string_slice.rs

1use std::{fmt::Debug, ops::Deref};
2
3use append_only_bytes::BytesSlice;
4use generic_btree::rle::{HasLength, Mergeable, Sliceable, TryInsert};
5use rle::Mergable;
6use serde::{Deserialize, Deserializer, Serialize};
7
8use crate::{
9    container::richtext::richtext_state::{unicode_to_utf8_index, utf16_to_utf8_index},
10    delta::DeltaValue,
11};
12
13use super::utf16::{count_unicode_chars, count_utf16_len};
14
15#[derive(Clone)]
16pub struct StringSlice {
17    bytes: Variant,
18}
19
20impl PartialEq for StringSlice {
21    fn eq(&self, other: &Self) -> bool {
22        self.as_str() == other.as_str()
23    }
24}
25
26impl Eq for StringSlice {}
27
28#[derive(Clone, PartialEq, Eq)]
29enum Variant {
30    BytesSlice(BytesSlice),
31    Owned(String),
32}
33
34impl From<String> for StringSlice {
35    fn from(s: String) -> Self {
36        Self {
37            bytes: Variant::Owned(s),
38        }
39    }
40}
41
42impl From<BytesSlice> for StringSlice {
43    fn from(s: BytesSlice) -> Self {
44        Self::new(s)
45    }
46}
47
48impl From<&str> for StringSlice {
49    fn from(s: &str) -> Self {
50        Self {
51            bytes: Variant::Owned(s.to_string()),
52        }
53    }
54}
55
56impl Debug for StringSlice {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("StringSlice")
59            .field("bytes", &self.as_str())
60            .finish()
61    }
62}
63
64impl StringSlice {
65    pub fn new(s: BytesSlice) -> Self {
66        std::str::from_utf8(&s).unwrap();
67        Self {
68            bytes: Variant::BytesSlice(s),
69        }
70    }
71
72    pub fn as_str(&self) -> &str {
73        match &self.bytes {
74            // SAFETY: `bytes` is always valid utf8
75            Variant::BytesSlice(s) => unsafe { std::str::from_utf8_unchecked(s) },
76            Variant::Owned(s) => s,
77        }
78    }
79
80    pub fn len_bytes(&self) -> usize {
81        match &self.bytes {
82            Variant::BytesSlice(s) => s.len(),
83            Variant::Owned(s) => s.len(),
84        }
85    }
86
87    fn bytes(&self) -> &[u8] {
88        match &self.bytes {
89            Variant::BytesSlice(s) => s.deref(),
90            Variant::Owned(s) => s.as_bytes(),
91        }
92    }
93
94    pub fn len_unicode(&self) -> usize {
95        count_unicode_chars(self.bytes())
96    }
97
98    pub fn len_utf16(&self) -> usize {
99        count_utf16_len(self.bytes())
100    }
101
102    pub fn is_empty(&self) -> bool {
103        self.bytes().is_empty()
104    }
105
106    pub fn extend(&mut self, s: &str) {
107        match &mut self.bytes {
108            Variant::BytesSlice(_) => {
109                *self = Self {
110                    bytes: Variant::Owned(format!("{}{}", self.as_str(), s)),
111                }
112            }
113            Variant::Owned(v) => {
114                v.push_str(s);
115            }
116        }
117    }
118}
119
120impl std::fmt::Display for StringSlice {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.write_str(self.as_str())
123    }
124}
125
126impl<'de> Deserialize<'de> for StringSlice {
127    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128    where
129        D: Deserializer<'de>,
130    {
131        let s = String::deserialize(deserializer)?;
132        Ok(s.into())
133    }
134}
135
136impl Serialize for StringSlice {
137    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
138    where
139        S: serde::Serializer,
140    {
141        serializer.serialize_str(self.as_str())
142    }
143}
144
145impl DeltaValue for StringSlice {
146    fn value_extend(&mut self, other: Self) -> Result<(), Self> {
147        match (&mut self.bytes, &other.bytes) {
148            (Variant::BytesSlice(s), Variant::BytesSlice(o)) => match s.try_merge(o) {
149                Ok(_) => Ok(()),
150                Err(_) => Err(other),
151            },
152            (Variant::Owned(s), _) => {
153                s.push_str(other.as_str());
154                Ok(())
155            }
156            _ => Err(other),
157        }
158    }
159
160    fn take(&mut self, length: usize) -> Self {
161        let length = if cfg!(feature = "wasm") {
162            utf16_to_utf8_index(self.as_str(), length).unwrap()
163        } else {
164            unicode_to_utf8_index(self.as_str(), length).unwrap()
165        };
166
167        match &mut self.bytes {
168            Variant::BytesSlice(s) => {
169                let mut other = s.slice_clone(length..);
170                s.slice_(..length);
171                std::mem::swap(s, &mut other);
172                Self {
173                    bytes: Variant::BytesSlice(other),
174                }
175            }
176            Variant::Owned(s) => {
177                let mut other = s.split_off(length);
178                std::mem::swap(s, &mut other);
179                Self {
180                    bytes: Variant::Owned(other),
181                }
182            }
183        }
184    }
185
186    /// Unicode length of the string
187    /// Utf16 length when in WASM
188    fn length(&self) -> usize {
189        if cfg!(feature = "wasm") {
190            count_utf16_len(self.bytes())
191        } else {
192            count_unicode_chars(self.bytes())
193        }
194    }
195}
196
197impl HasLength for StringSlice {
198    fn rle_len(&self) -> usize {
199        if cfg!(feature = "wasm") {
200            count_utf16_len(self.bytes())
201        } else {
202            count_unicode_chars(self.bytes())
203        }
204    }
205}
206
207impl TryInsert for StringSlice {
208    fn try_insert(&mut self, pos: usize, elem: Self) -> Result<(), Self>
209    where
210        Self: Sized,
211    {
212        match &mut self.bytes {
213            Variant::BytesSlice(_) => Err(elem),
214            Variant::Owned(s) => {
215                if s.capacity() >= s.len() + elem.len_bytes() {
216                    let pos = if cfg!(feature = "wasm") {
217                        utf16_to_utf8_index(s.as_str(), pos).unwrap()
218                    } else {
219                        unicode_to_utf8_index(s.as_str(), pos).unwrap()
220                    };
221                    s.insert_str(pos, elem.as_str());
222                    Ok(())
223                } else {
224                    Err(elem)
225                }
226            }
227        }
228
229        // match (&mut self.bytes, &elem.bytes) {
230        //     (Variant::Owned(a), Variant::Owned(b))
231        //         // TODO: Extract magic num
232        //         if a.capacity() >= a.len() + b.len() && a.capacity() < 128 =>
233        //     {
234        //         a.insert_str(pos, b.as_str());
235        //         Ok(())
236        //     }
237        //     _ => Err(elem),
238        // }
239    }
240}
241
242impl Mergeable for StringSlice {
243    fn can_merge(&self, rhs: &Self) -> bool {
244        match (&self.bytes, &rhs.bytes) {
245            (Variant::BytesSlice(a), Variant::BytesSlice(b)) => a.can_merge(b),
246            (Variant::Owned(a), Variant::Owned(b)) => a.len() + b.len() <= a.capacity(),
247            _ => false,
248        }
249    }
250
251    fn merge_right(&mut self, rhs: &Self) {
252        match (&mut self.bytes, &rhs.bytes) {
253            (Variant::BytesSlice(a), Variant::BytesSlice(b)) => a.merge(b, &()),
254            (Variant::Owned(a), Variant::Owned(b)) => a.push_str(b.as_str()),
255            _ => {}
256        }
257    }
258
259    fn merge_left(&mut self, left: &Self) {
260        match (&mut self.bytes, &left.bytes) {
261            (Variant::BytesSlice(a), Variant::BytesSlice(b)) => {
262                let mut new = b.clone();
263                new.merge(a, &());
264                *a = new;
265            }
266            (Variant::Owned(a), Variant::Owned(b)) => {
267                a.insert_str(0, b.as_str());
268            }
269            _ => {}
270        }
271    }
272}
273
274impl Sliceable for StringSlice {
275    fn _slice(&self, range: std::ops::Range<usize>) -> Self {
276        let range = if cfg!(feature = "wasm") {
277            let start = utf16_to_utf8_index(self.as_str(), range.start).unwrap();
278            let end = utf16_to_utf8_index(self.as_str(), range.end).unwrap();
279            start..end
280        } else {
281            let start = unicode_to_utf8_index(self.as_str(), range.start).unwrap();
282            let end = unicode_to_utf8_index(self.as_str(), range.end).unwrap();
283            start..end
284        };
285
286        let bytes = match &self.bytes {
287            Variant::BytesSlice(s) => Variant::BytesSlice(s.slice_clone(range)),
288            Variant::Owned(s) => Variant::Owned(s[range].to_string()),
289        };
290
291        Self { bytes }
292    }
293
294    fn split(&mut self, pos: usize) -> Self {
295        let pos = if cfg!(feature = "wasm") {
296            utf16_to_utf8_index(self.as_str(), pos).unwrap()
297        } else {
298            unicode_to_utf8_index(self.as_str(), pos).unwrap()
299        };
300
301        let bytes = match &mut self.bytes {
302            Variant::BytesSlice(s) => {
303                let other = s.slice_clone(pos..);
304                s.slice_(..pos);
305                Variant::BytesSlice(other)
306            }
307            Variant::Owned(s) => {
308                let other = s.split_off(pos);
309                Variant::Owned(other)
310            }
311        };
312
313        Self { bytes }
314    }
315}
316
317impl Default for StringSlice {
318    fn default() -> Self {
319        StringSlice {
320            bytes: Variant::Owned(String::with_capacity(32)),
321        }
322    }
323}
324
325impl loro_delta::delta_trait::DeltaValue for StringSlice {}
326pub fn unicode_range_to_byte_range(s: &str, start: usize, end: usize) -> (usize, usize) {
327    debug_assert!(start <= end);
328    let start_unicode_index = start;
329    let end_unicode_index = end;
330    let mut current_utf8_index = 0;
331    let mut start_byte = 0;
332    let mut end_byte = s.len();
333    for (current_unicode_index, c) in s.chars().enumerate() {
334        if current_unicode_index == start_unicode_index {
335            start_byte = current_utf8_index;
336        }
337
338        if current_unicode_index == end_unicode_index {
339            end_byte = current_utf8_index;
340            break;
341        }
342
343        current_utf8_index += c.len_utf8();
344    }
345
346    (start_byte, end_byte)
347}