Skip to main content

anyxml_encoding/
utf16.rs

1use std::iter::once;
2
3use crate::{DecodeError, Decoder, EncodeError, Encoder};
4
5pub const UTF16_NAME: &str = "UTF-16";
6
7#[derive(Debug, Default)]
8pub struct UTF16Encoder {
9    init: bool,
10}
11impl Encoder for UTF16Encoder {
12    fn name(&self) -> &'static str {
13        UTF16_NAME
14    }
15
16    fn encode(
17        &mut self,
18        src: &str,
19        dst: &mut [u8],
20        finish: bool,
21    ) -> Result<(usize, usize), EncodeError> {
22        if src.is_empty() {
23            return Err(EncodeError::InputIsEmpty);
24        }
25        if dst.len() < 4 {
26            return Err(EncodeError::OutputTooShort);
27        }
28
29        if !self.init {
30            self.init = true;
31            // Write BOM as LE
32            dst[0] = 0xFF;
33            dst[1] = 0xFE;
34            return Ok((0, 2));
35        }
36        UTF16LEEncoder.encode(src, dst, finish)
37    }
38}
39
40pub struct UTF16Decoder {
41    read: usize,
42    top: [u8; 2],
43    be: bool,
44}
45impl Decoder for UTF16Decoder {
46    fn name(&self) -> &'static str {
47        UTF16_NAME
48    }
49
50    fn decode(
51        &mut self,
52        mut src: &[u8],
53        dst: &mut String,
54        finish: bool,
55    ) -> Result<(usize, usize), DecodeError> {
56        if src.is_empty() {
57            return Err(DecodeError::InputIsEmpty);
58        }
59        if dst.capacity() - dst.len() < 4 {
60            return Err(DecodeError::OutputTooShort);
61        }
62
63        let mut base = 0;
64        if self.read < 2 {
65            let orig = src.len();
66            while self.read < 2 && !src.is_empty() {
67                self.top[self.read] = src[0];
68                src = &src[1..];
69                self.read += 1;
70            }
71            base = orig - src.len();
72            if self.read == 2 {
73                // If the first 2 bytes of the buffer are 0xFF, 0xFE, it is LE; otherwise, it is BE.
74                if matches!(self.top[..], [0xFF, 0xFE]) {
75                    self.be = false;
76                    return Ok((base, 0));
77                } else if matches!(self.top[..], [0xFE, 0xFF]) {
78                    self.be = true;
79                    return Ok((base, 0));
80                } else {
81                    self.be = true;
82                    // Since the first two bytes were not BOM,
83                    // try decoding using the first two bytes that have already been acquired.
84                };
85            } else {
86                return Ok((base, 0));
87            }
88        }
89
90        if self.be && !matches!(self.top[..], [0xFE, 0xFF]) {
91            let mut read = 0;
92            let mut write = 0;
93            for c in char::decode_utf16(
94                once(((self.top[0] as u16) << 8) | self.top[1] as u16).chain(
95                    src.chunks_exact(2)
96                        .map(|v| ((v[0] as u16) << 8) | v[1] as u16),
97                ),
98            ) {
99                if let Ok(c) = c {
100                    read += c.len_utf16() * 2;
101                    write += c.len_utf8();
102                    dst.push(c);
103                } else {
104                    let rem = src.len() - (read - 2);
105                    if !finish && rem < 4 {
106                        // If this is not the last buffer and the unread buffer is less than 2 bytes,
107                        // return `Ok` because the corresponding surrogate pair may be at the beginning of the next buffer to be input.
108                        break;
109                    } else {
110                        // If this is the last buffer, or if there is sufficient data to form a surrogate pair but an error occurs,
111                        // it is simply an invalid byte sequence.
112                        return Err(DecodeError::Malformed {
113                            read: read + 2,
114                            write,
115                            length: 2,
116                            offset: 0,
117                        });
118                    }
119                }
120
121                if dst.capacity() - dst.len() < 4 {
122                    break;
123                }
124            }
125            return if read > 0 {
126                self.top = [0xFE, 0xFF];
127                read -= 2 - base;
128                Ok((read, write))
129            } else {
130                Ok((base, 0))
131            };
132        }
133
134        if self.be {
135            UTF16BEDecoder.decode(src, dst, finish)
136        } else {
137            UTF16LEDecoder.decode(src, dst, finish)
138        }
139    }
140}
141
142impl Default for UTF16Decoder {
143    fn default() -> Self {
144        Self {
145            read: 0,
146            top: [0; 2],
147            be: true,
148        }
149    }
150}
151
152pub const UTF16BE_NAME: &str = "UTF-16BE";
153
154pub struct UTF16BEEncoder;
155impl Encoder for UTF16BEEncoder {
156    fn name(&self) -> &'static str {
157        UTF16BE_NAME
158    }
159
160    fn encode(
161        &mut self,
162        src: &str,
163        mut dst: &mut [u8],
164        _finish: bool,
165    ) -> Result<(usize, usize), EncodeError> {
166        if src.is_empty() {
167            return Err(EncodeError::InputIsEmpty);
168        }
169        if dst.len() < 4 {
170            return Err(EncodeError::OutputTooShort);
171        }
172
173        let mut buf = [0u16; 2];
174        let mut read = 0;
175        let mut write = 0;
176        for c in src.chars() {
177            read += c.len_utf8();
178            let b = c.encode_utf16(&mut buf);
179            dst[..2].copy_from_slice(&b[0].to_be_bytes());
180            dst = &mut dst[2..];
181            write += 2;
182            if b.len() == 2 {
183                dst[..2].copy_from_slice(&b[1].to_be_bytes());
184                dst = &mut dst[2..];
185                write += 2;
186            }
187            if dst.len() < 4 {
188                break;
189            }
190        }
191        Ok((read, write))
192    }
193}
194
195pub struct UTF16BEDecoder;
196impl Decoder for UTF16BEDecoder {
197    fn name(&self) -> &'static str {
198        UTF16BE_NAME
199    }
200
201    fn decode(
202        &mut self,
203        src: &[u8],
204        dst: &mut String,
205        finish: bool,
206    ) -> Result<(usize, usize), DecodeError> {
207        if src.is_empty() {
208            return Err(DecodeError::InputIsEmpty);
209        }
210        let cap = dst.capacity() - dst.len();
211        if cap < 4 {
212            return Err(DecodeError::OutputTooShort);
213        }
214
215        let mut read = 0;
216        let mut write = 0;
217        for c in char::decode_utf16(
218            src.chunks_exact(2)
219                .map(|v| u16::from_be_bytes([v[0], v[1]])),
220        ) {
221            if let Ok(c) = c {
222                read += c.len_utf16() * 2;
223                write += c.len_utf8();
224                dst.push(c);
225            } else {
226                let rem = src.len() - read;
227                if !finish && rem < 4 {
228                    break;
229                } else {
230                    return Err(DecodeError::Malformed {
231                        read: read + 2,
232                        write,
233                        length: 2,
234                        offset: 0,
235                    });
236                }
237            }
238
239            if dst.capacity() - dst.len() < 4 {
240                break;
241            }
242        }
243
244        Ok((read, write))
245    }
246}
247
248pub const UTF16LE_NAME: &str = "UTF-16LE";
249
250pub struct UTF16LEEncoder;
251impl Encoder for UTF16LEEncoder {
252    fn name(&self) -> &'static str {
253        UTF16LE_NAME
254    }
255
256    fn encode(
257        &mut self,
258        src: &str,
259        mut dst: &mut [u8],
260        _finish: bool,
261    ) -> Result<(usize, usize), EncodeError> {
262        if src.is_empty() {
263            return Err(EncodeError::InputIsEmpty);
264        }
265        if dst.len() < 4 {
266            return Err(EncodeError::OutputTooShort);
267        }
268
269        let mut buf = [0u16; 2];
270        let mut read = 0;
271        let mut write = 0;
272        for c in src.chars() {
273            read += c.len_utf8();
274            let b = c.encode_utf16(&mut buf);
275            dst[..2].copy_from_slice(&b[0].to_le_bytes());
276            dst = &mut dst[2..];
277            write += 2;
278            if b.len() == 2 {
279                dst[..2].copy_from_slice(&b[1].to_le_bytes());
280                dst = &mut dst[2..];
281                write += 2;
282            }
283            if dst.len() < 4 {
284                break;
285            }
286        }
287        Ok((read, write))
288    }
289}
290
291pub struct UTF16LEDecoder;
292impl Decoder for UTF16LEDecoder {
293    fn name(&self) -> &'static str {
294        UTF16LE_NAME
295    }
296
297    fn decode(
298        &mut self,
299        src: &[u8],
300        dst: &mut String,
301        finish: bool,
302    ) -> Result<(usize, usize), DecodeError> {
303        if src.is_empty() {
304            return Err(DecodeError::InputIsEmpty);
305        }
306        let cap = dst.capacity() - dst.len();
307        if cap < 4 {
308            return Err(DecodeError::OutputTooShort);
309        }
310
311        let mut read = 0;
312        let mut write = 0;
313        for c in char::decode_utf16(
314            src.chunks_exact(2)
315                .map(|v| u16::from_le_bytes([v[0], v[1]])),
316        ) {
317            if let Ok(c) = c {
318                read += c.len_utf16() * 2;
319                write += c.len_utf8();
320                dst.push(c);
321            } else {
322                let rem = src.len() - read;
323                if !finish && rem < 4 {
324                    break;
325                } else {
326                    return Err(DecodeError::Malformed {
327                        read: read + 2,
328                        write,
329                        length: 2,
330                        offset: 0,
331                    });
332                }
333            }
334
335            if dst.capacity() - dst.len() < 4 {
336                break;
337            }
338        }
339
340        Ok((read, write))
341    }
342}