Skip to main content

rig_core/providers/anthropic/decoders/
line.rs

1use std::str;
2
3/// A line decoder that handles incrementally reading lines from text.
4/// Ported from JavaScript implementation.
5pub struct LineDecoder {
6    buffer: Vec<u8>,
7    carriage_return_index: Option<usize>,
8}
9
10impl Default for LineDecoder {
11    fn default() -> Self {
12        Self::new()
13    }
14}
15
16impl LineDecoder {
17    /// Create a new LineDecoder
18    pub fn new() -> Self {
19        LineDecoder {
20            buffer: Vec::new(),
21            carriage_return_index: None,
22        }
23    }
24
25    /// Decode a chunk of data into lines
26    pub fn decode(&mut self, chunk: &[u8]) -> Vec<String> {
27        if chunk.is_empty() {
28            return Vec::new();
29        }
30
31        // Append the new chunk to the buffer
32        self.buffer.extend_from_slice(chunk);
33
34        let mut lines = Vec::new();
35
36        // Process lines while we can find newlines
37        while let Some(pattern_index) = find_newline_index(&self.buffer, self.carriage_return_index)
38        {
39            if pattern_index.carriage && self.carriage_return_index.is_none() {
40                // Skip until we either get a corresponding `\n`, a new `\r` or nothing
41                self.carriage_return_index = Some(pattern_index.index);
42                continue;
43            }
44
45            // We got double \r or \rtext\n
46            // TODO: Collapse this if statement (whenever `||` operator is supported in if-let chains).
47            #[allow(clippy::collapsible_if)]
48            if let Some(cr_index) = self.carriage_return_index {
49                if pattern_index.index != cr_index + 1 || pattern_index.carriage {
50                    if cr_index > 0 {
51                        let line = decode_text(
52                            self.buffer.get(..cr_index.saturating_sub(1)).unwrap_or(&[]),
53                        );
54                        lines.push(line);
55                    } else {
56                        // Handle edge case for carriage return at beginning
57                        lines.push(String::new());
58                    }
59
60                    if cr_index < self.buffer.len() {
61                        self.buffer = self.buffer.get(cr_index..).unwrap_or(&[]).to_vec();
62                    } else {
63                        self.buffer.clear();
64                    }
65                    self.carriage_return_index = None;
66                    continue;
67                }
68            }
69
70            let end_index = if self.carriage_return_index.is_some() {
71                pattern_index.preceding - 1
72            } else {
73                pattern_index.preceding
74            };
75
76            if end_index > 0 {
77                let line = decode_text(self.buffer.get(..end_index).unwrap_or(&[]));
78                lines.push(line);
79            } else {
80                lines.push(String::new());
81            }
82
83            if pattern_index.index < self.buffer.len() {
84                self.buffer = self
85                    .buffer
86                    .get(pattern_index.index..)
87                    .unwrap_or(&[])
88                    .to_vec();
89            } else {
90                self.buffer.clear();
91            }
92            self.carriage_return_index = None;
93        }
94
95        lines
96    }
97
98    /// Flush any remaining data in the buffer
99    pub fn flush(&mut self) -> Vec<String> {
100        if self.buffer.is_empty() {
101            return Vec::new();
102        }
103        self.decode("\n".as_bytes())
104    }
105}
106
107/// Helper structure for newline index information
108struct NewlineIndex {
109    preceding: usize,
110    index: usize,
111    carriage: bool,
112}
113
114/// Find the index of the next newline character in the buffer
115fn find_newline_index(buffer: &[u8], start_index: Option<usize>) -> Option<NewlineIndex> {
116    const NEWLINE: u8 = 0x0a; // \n
117    const CARRIAGE: u8 = 0x0d; // \r
118
119    let start = start_index.unwrap_or(0);
120
121    for (i, &byte) in buffer.iter().enumerate().skip(start) {
122        if byte == NEWLINE {
123            return Some(NewlineIndex {
124                preceding: i,
125                index: i + 1,
126                carriage: false,
127            });
128        }
129
130        if byte == CARRIAGE {
131            return Some(NewlineIndex {
132                preceding: i,
133                index: i + 1,
134                carriage: true,
135            });
136        }
137    }
138
139    None
140}
141
142/// Find the index after a double newline pattern in the buffer
143pub fn find_double_newline_index(buffer: &[u8]) -> isize {
144    const NEWLINE: u8 = 0x0a; // \n
145    const CARRIAGE: u8 = 0x0d; // \r
146
147    for (i, window) in buffer.windows(2).enumerate() {
148        if window == [NEWLINE, NEWLINE] {
149            return (i + 2) as isize;
150        }
151
152        if window == [CARRIAGE, CARRIAGE] {
153            return (i + 2) as isize;
154        }
155    }
156
157    for (i, window) in buffer.windows(4).enumerate() {
158        if window == [CARRIAGE, NEWLINE, CARRIAGE, NEWLINE] {
159            return (i + 4) as isize;
160        }
161    }
162
163    -1
164}
165
166/// Decode a byte slice into a UTF-8 string
167fn decode_text(bytes: &[u8]) -> String {
168    match str::from_utf8(bytes) {
169        Ok(s) => s.to_string(),
170        Err(_) => {
171            // Handle invalid UTF-8 by replacing invalid sequences
172            String::from_utf8_lossy(bytes).to_string()
173        }
174    }
175}
176
177/// Decode multiple chunks of data, with an option to flush
178pub fn decode_chunks(chunks: &[&[u8]], flush: bool) -> Vec<String> {
179    let mut decoder = LineDecoder::new();
180    let mut lines = Vec::new();
181
182    for chunk in chunks {
183        lines.extend(decoder.decode(chunk));
184    }
185
186    if flush {
187        lines.extend(decoder.flush());
188    }
189
190    lines
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    fn decode_string_chunks(chunks: &[&str], flush: bool) -> Vec<String> {
198        let byte_chunks: Vec<&[u8]> = chunks.iter().map(|s| s.as_bytes()).collect();
199        decode_chunks(&byte_chunks, flush)
200    }
201
202    #[test]
203    fn test_basic() {
204        // baz is not included because the line hasn't ended yet
205        assert_eq!(
206            decode_string_chunks(&["foo", " bar\nbaz"], false),
207            vec!["foo bar"]
208        );
209    }
210
211    #[test]
212    fn test_basic_with_cr() {
213        assert_eq!(
214            decode_string_chunks(&["foo", " bar\r\nbaz"], false),
215            vec!["foo bar"]
216        );
217        assert_eq!(
218            decode_string_chunks(&["foo", " bar\r\nbaz"], true),
219            vec!["foo bar", "baz"]
220        );
221    }
222
223    #[test]
224    fn test_trailing_new_lines() {
225        assert_eq!(
226            decode_string_chunks(&["foo", " bar", "baz\n", "thing\n"], false),
227            vec!["foo barbaz", "thing"]
228        );
229    }
230
231    #[test]
232    fn test_trailing_new_lines_with_cr() {
233        assert_eq!(
234            decode_string_chunks(&["foo", " bar", "baz\r\n", "thing\r\n"], false),
235            vec!["foo barbaz", "thing"]
236        );
237    }
238
239    #[test]
240    fn test_escaped_new_lines() {
241        assert_eq!(
242            decode_string_chunks(&["foo", " bar\\nbaz\n"], false),
243            vec!["foo bar\\nbaz"]
244        );
245    }
246
247    #[test]
248    fn test_escaped_new_lines_with_cr() {
249        assert_eq!(
250            decode_string_chunks(&["foo", " bar\\r\\nbaz\n"], false),
251            vec!["foo bar\\r\\nbaz"]
252        );
253    }
254
255    #[test]
256    fn test_cr_and_lf_split_across_chunks() {
257        assert_eq!(
258            decode_string_chunks(&["foo\r", "\n", "bar"], true),
259            vec!["foo", "bar"]
260        );
261    }
262
263    #[test]
264    fn test_single_cr() {
265        assert_eq!(
266            decode_string_chunks(&["foo\r", "bar"], true),
267            vec!["foo", "bar"]
268        );
269    }
270
271    #[test]
272    fn test_double_cr() {
273        assert_eq!(
274            decode_string_chunks(&["foo\r", "bar\r"], true),
275            vec!["foo", "bar"]
276        );
277        assert_eq!(
278            decode_string_chunks(&["foo\r", "\r", "bar"], true),
279            vec!["foo", "", "bar"]
280        );
281        // implementation detail that we don't yield the single \r line until a new \r or \n is encountered
282        assert_eq!(
283            decode_string_chunks(&["foo\r", "\r", "bar"], false),
284            vec!["foo"]
285        );
286    }
287
288    #[test]
289    fn test_double_cr_then_crlf() {
290        assert_eq!(
291            decode_string_chunks(&["foo\r", "\r", "\r", "\n", "bar", "\n"], false),
292            vec!["foo", "", "", "bar"]
293        );
294        assert_eq!(
295            decode_string_chunks(&["foo\n", "\n", "\n", "bar", "\n"], false),
296            vec!["foo", "", "", "bar"]
297        );
298    }
299
300    #[test]
301    fn test_double_newline() {
302        assert_eq!(
303            decode_string_chunks(&["foo\n\nbar"], true),
304            vec!["foo", "", "bar"]
305        );
306        assert_eq!(
307            decode_string_chunks(&["foo", "\n", "\nbar"], true),
308            vec!["foo", "", "bar"]
309        );
310        assert_eq!(
311            decode_string_chunks(&["foo\n", "\n", "bar"], true),
312            vec!["foo", "", "bar"]
313        );
314        assert_eq!(
315            decode_string_chunks(&["foo", "\n", "\n", "bar"], true),
316            vec!["foo", "", "bar"]
317        );
318    }
319
320    #[test]
321    fn test_multi_byte_characters_across_chunks() {
322        let mut decoder = LineDecoder::new();
323
324        // bytes taken from the string 'известни' and arbitrarily split
325        // so that some multi-byte characters span multiple chunks
326        assert_eq!(decoder.decode(&[0xd0]), Vec::<String>::new());
327        assert_eq!(
328            decoder.decode(&[0xb8, 0xd0, 0xb7, 0xd0]),
329            Vec::<String>::new()
330        );
331        assert_eq!(
332            decoder.decode(&[
333                0xb2, 0xd0, 0xb5, 0xd1, 0x81, 0xd1, 0x82, 0xd0, 0xbd, 0xd0, 0xb8
334            ]),
335            Vec::<String>::new()
336        );
337
338        let decoded = decoder.decode(&[0xa]);
339        assert_eq!(decoded, vec!["известни"]);
340    }
341
342    #[test]
343    fn test_flushing_trailing_newlines() {
344        assert_eq!(
345            decode_string_chunks(&["foo\n", "\nbar"], true),
346            vec!["foo", "", "bar"]
347        );
348    }
349
350    #[test]
351    fn test_flushing_empty_buffer() {
352        assert_eq!(decode_string_chunks(&[], true), Vec::<String>::new());
353    }
354
355    #[test]
356    fn test_find_double_newline_index() {
357        // Test \n\n patterns
358        assert_eq!(find_double_newline_index("foo\n\nbar".as_bytes()), 5);
359        assert_eq!(find_double_newline_index("\n\nbar".as_bytes()), 2);
360        assert_eq!(find_double_newline_index("foo\n\n".as_bytes()), 5);
361        assert_eq!(find_double_newline_index("\n\n".as_bytes()), 2);
362
363        // Test \r\r patterns
364        assert_eq!(find_double_newline_index("foo\r\rbar".as_bytes()), 5);
365        assert_eq!(find_double_newline_index("\r\rbar".as_bytes()), 2);
366        assert_eq!(find_double_newline_index("foo\r\r".as_bytes()), 5);
367        assert_eq!(find_double_newline_index("\r\r".as_bytes()), 2);
368
369        // Test \r\n\r\n patterns
370        assert_eq!(find_double_newline_index("foo\r\n\r\nbar".as_bytes()), 7);
371        assert_eq!(find_double_newline_index("\r\n\r\nbar".as_bytes()), 4);
372        assert_eq!(find_double_newline_index("foo\r\n\r\n".as_bytes()), 7);
373        assert_eq!(find_double_newline_index("\r\n\r\n".as_bytes()), 4);
374
375        // Test not found cases
376        assert_eq!(find_double_newline_index("foo\nbar".as_bytes()), -1);
377        assert_eq!(find_double_newline_index("foo\rbar".as_bytes()), -1);
378        assert_eq!(find_double_newline_index("foo\r\nbar".as_bytes()), -1);
379        assert_eq!(find_double_newline_index("".as_bytes()), -1);
380
381        // Test incomplete patterns
382        assert_eq!(find_double_newline_index("foo\r\n\r".as_bytes()), -1);
383        assert_eq!(find_double_newline_index("foo\r\n".as_bytes()), -1);
384    }
385}