Skip to main content

rustpython_common/
encodings.rs

1use core::ops::{self, Range};
2
3use num_traits::ToPrimitive;
4
5use crate::str::StrKind;
6use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf};
7
8pub trait StrBuffer: AsRef<Wtf8> {
9    fn is_compatible_with(&self, kind: StrKind) -> bool {
10        let s = self.as_ref();
11        match kind {
12            StrKind::Ascii => s.is_ascii(),
13            StrKind::Utf8 => s.is_utf8(),
14            StrKind::Wtf8 => true,
15        }
16    }
17}
18
19pub trait CodecContext: Sized {
20    type Error;
21    type StrBuf: StrBuffer;
22    type BytesBuf: AsRef<[u8]>;
23
24    fn string(&self, s: Wtf8Buf) -> Self::StrBuf;
25    fn bytes(&self, b: Vec<u8>) -> Self::BytesBuf;
26}
27
28pub trait EncodeContext: CodecContext {
29    fn full_data(&self) -> &Wtf8;
30    fn data_len(&self) -> StrSize;
31
32    fn remaining_data(&self) -> &Wtf8;
33    fn position(&self) -> StrSize;
34
35    fn restart_from(&mut self, pos: StrSize) -> Result<(), Self::Error>;
36
37    fn error_encoding(&self, range: Range<StrSize>, reason: Option<&str>) -> Self::Error;
38
39    fn handle_error<E>(
40        &mut self,
41        errors: &E,
42        range: Range<StrSize>,
43        reason: Option<&str>,
44    ) -> Result<EncodeReplace<Self>, Self::Error>
45    where
46        E: EncodeErrorHandler<Self>,
47    {
48        let (replace, restart) = errors.handle_encode_error(self, range, reason)?;
49        self.restart_from(restart)?;
50        Ok(replace)
51    }
52}
53
54pub trait DecodeContext: CodecContext {
55    fn full_data(&self) -> &[u8];
56
57    fn remaining_data(&self) -> &[u8];
58    fn position(&self) -> usize;
59
60    fn advance(&mut self, by: usize);
61
62    fn restart_from(&mut self, pos: usize) -> Result<(), Self::Error>;
63
64    fn error_decoding(&self, byte_range: Range<usize>, reason: Option<&str>) -> Self::Error;
65
66    fn handle_error<E>(
67        &mut self,
68        errors: &E,
69        byte_range: Range<usize>,
70        reason: Option<&str>,
71    ) -> Result<Self::StrBuf, Self::Error>
72    where
73        E: DecodeErrorHandler<Self>,
74    {
75        let (replace, restart) = errors.handle_decode_error(self, byte_range, reason)?;
76        self.restart_from(restart)?;
77        Ok(replace)
78    }
79}
80
81pub trait EncodeErrorHandler<Ctx: EncodeContext> {
82    fn handle_encode_error(
83        &self,
84        ctx: &mut Ctx,
85        range: Range<StrSize>,
86        reason: Option<&str>,
87    ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error>;
88}
89pub trait DecodeErrorHandler<Ctx: DecodeContext> {
90    fn handle_decode_error(
91        &self,
92        ctx: &mut Ctx,
93        byte_range: Range<usize>,
94        reason: Option<&str>,
95    ) -> Result<(Ctx::StrBuf, usize), Ctx::Error>;
96}
97
98pub enum EncodeReplace<Ctx: CodecContext> {
99    Str(Ctx::StrBuf),
100    Bytes(Ctx::BytesBuf),
101}
102
103#[derive(Copy, Clone, Default, Debug)]
104pub struct StrSize {
105    pub bytes: usize,
106    pub chars: usize,
107}
108
109fn iter_code_points(w: &Wtf8) -> impl Iterator<Item = (StrSize, CodePoint)> {
110    w.code_point_indices()
111        .enumerate()
112        .map(|(chars, (bytes, c))| (StrSize { bytes, chars }, c))
113}
114
115impl ops::Add for StrSize {
116    type Output = Self;
117    fn add(self, rhs: Self) -> Self::Output {
118        Self {
119            bytes: self.bytes + rhs.bytes,
120            chars: self.chars + rhs.chars,
121        }
122    }
123}
124
125impl ops::AddAssign for StrSize {
126    fn add_assign(&mut self, rhs: Self) {
127        self.bytes += rhs.bytes;
128        self.chars += rhs.chars;
129    }
130}
131
132struct DecodeError<'a> {
133    valid_prefix: &'a str,
134    rest: &'a [u8],
135    err_len: Option<usize>,
136}
137
138/// # Safety
139/// `v[..valid_up_to]` must be valid utf8
140const unsafe fn make_decode_err(
141    v: &[u8],
142    valid_up_to: usize,
143    err_len: Option<usize>,
144) -> DecodeError<'_> {
145    let (valid_prefix, rest) = unsafe { v.split_at_unchecked(valid_up_to) };
146    let valid_prefix = unsafe { core::str::from_utf8_unchecked(valid_prefix) };
147    DecodeError {
148        valid_prefix,
149        rest,
150        err_len,
151    }
152}
153
154enum HandleResult<'a> {
155    Done,
156    Error {
157        err_len: Option<usize>,
158        reason: &'a str,
159    },
160}
161
162fn decode_utf8_compatible<Ctx, E, DecodeF, ErrF>(
163    mut ctx: Ctx,
164    errors: &E,
165    decode: DecodeF,
166    handle_error: ErrF,
167) -> Result<(Wtf8Buf, usize), Ctx::Error>
168where
169    Ctx: DecodeContext,
170    E: DecodeErrorHandler<Ctx>,
171    DecodeF: Fn(&[u8]) -> Result<&str, DecodeError<'_>>,
172    ErrF: Fn(&[u8], Option<usize>) -> HandleResult<'static>,
173{
174    if ctx.remaining_data().is_empty() {
175        return Ok((Wtf8Buf::new(), 0));
176    }
177    let mut out = Wtf8Buf::with_capacity(ctx.remaining_data().len());
178    loop {
179        match decode(ctx.remaining_data()) {
180            Ok(decoded) => {
181                out.push_str(decoded);
182                ctx.advance(decoded.len());
183                break;
184            }
185            Err(e) => {
186                out.push_str(e.valid_prefix);
187                match handle_error(e.rest, e.err_len) {
188                    HandleResult::Done => {
189                        ctx.advance(e.valid_prefix.len());
190                        break;
191                    }
192                    HandleResult::Error { err_len, reason } => {
193                        let err_start = ctx.position() + e.valid_prefix.len();
194                        let err_end = match err_len {
195                            Some(len) => err_start + len,
196                            None => ctx.full_data().len(),
197                        };
198                        let err_range = err_start..err_end;
199                        let replace = ctx.handle_error(errors, err_range, Some(reason))?;
200                        out.push_wtf8(replace.as_ref());
201                        continue;
202                    }
203                }
204            }
205        }
206    }
207    Ok((out, ctx.position()))
208}
209
210#[inline]
211fn encode_utf8_compatible<Ctx, E>(
212    mut ctx: Ctx,
213    errors: &E,
214    err_reason: &str,
215    target_kind: StrKind,
216) -> Result<Vec<u8>, Ctx::Error>
217where
218    Ctx: EncodeContext,
219    E: EncodeErrorHandler<Ctx>,
220{
221    // let mut data = s.as_ref();
222    // let mut char_data_index = 0;
223    let mut out = Vec::<u8>::with_capacity(ctx.remaining_data().len());
224    loop {
225        let data = ctx.remaining_data();
226        let mut iter = iter_code_points(data);
227        let Some((i, _)) = iter.find(|(_, c)| !target_kind.can_encode(*c)) else {
228            break;
229        };
230
231        out.extend_from_slice(&ctx.remaining_data().as_bytes()[..i.bytes]);
232
233        let err_start = ctx.position() + i;
234        // number of non-compatible chars between the first non-compatible char and the next compatible char
235        let err_end = match { iter }.find(|(_, c)| target_kind.can_encode(*c)) {
236            Some((i, _)) => ctx.position() + i,
237            None => ctx.data_len(),
238        };
239
240        let range = err_start..err_end;
241        let replace = ctx.handle_error(errors, range.clone(), Some(err_reason))?;
242        match replace {
243            EncodeReplace::Str(s) => {
244                if s.is_compatible_with(target_kind) {
245                    out.extend_from_slice(s.as_ref().as_bytes());
246                } else {
247                    return Err(ctx.error_encoding(range, Some(err_reason)));
248                }
249            }
250            EncodeReplace::Bytes(b) => {
251                out.extend_from_slice(b.as_ref());
252            }
253        }
254    }
255    out.extend_from_slice(ctx.remaining_data().as_bytes());
256    Ok(out)
257}
258
259pub mod errors {
260    use crate::str::UnicodeEscapeCodepoint;
261
262    use super::*;
263    use core::fmt::Write;
264
265    #[derive(Clone, Copy)]
266    pub struct Strict;
267
268    impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for Strict {
269        fn handle_encode_error(
270            &self,
271            ctx: &mut Ctx,
272            range: Range<StrSize>,
273            reason: Option<&str>,
274        ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
275            Err(ctx.error_encoding(range, reason))
276        }
277    }
278
279    impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for Strict {
280        fn handle_decode_error(
281            &self,
282            ctx: &mut Ctx,
283            byte_range: Range<usize>,
284            reason: Option<&str>,
285        ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
286            Err(ctx.error_decoding(byte_range, reason))
287        }
288    }
289
290    #[derive(Clone, Copy)]
291    pub struct Ignore;
292
293    impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for Ignore {
294        fn handle_encode_error(
295            &self,
296            ctx: &mut Ctx,
297            range: Range<StrSize>,
298            _reason: Option<&str>,
299        ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
300            Ok((EncodeReplace::Bytes(ctx.bytes(b"".into())), range.end))
301        }
302    }
303
304    impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for Ignore {
305        fn handle_decode_error(
306            &self,
307            ctx: &mut Ctx,
308            byte_range: Range<usize>,
309            _reason: Option<&str>,
310        ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
311            Ok((ctx.string("".into()), byte_range.end))
312        }
313    }
314
315    #[derive(Clone, Copy)]
316    pub struct Replace;
317
318    impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for Replace {
319        fn handle_encode_error(
320            &self,
321            ctx: &mut Ctx,
322            range: Range<StrSize>,
323            _reason: Option<&str>,
324        ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
325            let replace = "?".repeat(range.end.chars - range.start.chars);
326            Ok((EncodeReplace::Str(ctx.string(replace.into())), range.end))
327        }
328    }
329
330    impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for Replace {
331        fn handle_decode_error(
332            &self,
333            ctx: &mut Ctx,
334            byte_range: Range<usize>,
335            _reason: Option<&str>,
336        ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
337            Ok((
338                ctx.string(char::REPLACEMENT_CHARACTER.to_string().into()),
339                byte_range.end,
340            ))
341        }
342    }
343
344    #[derive(Clone, Copy)]
345    pub struct XmlCharRefReplace;
346
347    impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for XmlCharRefReplace {
348        fn handle_encode_error(
349            &self,
350            ctx: &mut Ctx,
351            range: Range<StrSize>,
352            _reason: Option<&str>,
353        ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
354            let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
355            let num_chars = range.end.chars - range.start.chars;
356            // capacity rough guess; assuming that the codepoints are 3 digits in decimal + the &#;
357            let mut out = String::with_capacity(num_chars * 6);
358            for c in err_str.code_points() {
359                write!(out, "&#{};", c.to_u32()).unwrap()
360            }
361            Ok((EncodeReplace::Str(ctx.string(out.into())), range.end))
362        }
363    }
364
365    #[derive(Clone, Copy)]
366    pub struct BackslashReplace;
367
368    impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for BackslashReplace {
369        fn handle_encode_error(
370            &self,
371            ctx: &mut Ctx,
372            range: Range<StrSize>,
373            _reason: Option<&str>,
374        ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
375            let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
376            let num_chars = range.end.chars - range.start.chars;
377            // minimum 4 output bytes per char: \xNN
378            let mut out = String::with_capacity(num_chars * 4);
379            for c in err_str.code_points() {
380                write!(out, "{}", UnicodeEscapeCodepoint(c)).unwrap();
381            }
382            Ok((EncodeReplace::Str(ctx.string(out.into())), range.end))
383        }
384    }
385
386    impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for BackslashReplace {
387        fn handle_decode_error(
388            &self,
389            ctx: &mut Ctx,
390            byte_range: Range<usize>,
391            _reason: Option<&str>,
392        ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
393            let err_bytes = &ctx.full_data()[byte_range.clone()];
394            let mut replace = String::with_capacity(4 * err_bytes.len());
395            for &c in err_bytes {
396                write!(replace, "\\x{c:02x}").unwrap();
397            }
398            Ok((ctx.string(replace.into()), byte_range.end))
399        }
400    }
401
402    #[derive(Clone, Copy)]
403    pub struct NameReplace;
404
405    impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for NameReplace {
406        fn handle_encode_error(
407            &self,
408            ctx: &mut Ctx,
409            range: Range<StrSize>,
410            _reason: Option<&str>,
411        ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
412            let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
413            let num_chars = range.end.chars - range.start.chars;
414            let mut out = String::with_capacity(num_chars * 4);
415            for c in err_str.code_points() {
416                let c_u32 = c.to_u32();
417                if let Some(c_name) = c.to_char().and_then(unicode_names2::name) {
418                    write!(out, "\\N{{{c_name}}}").unwrap();
419                } else if c_u32 >= 0x10000 {
420                    write!(out, "\\U{c_u32:08x}").unwrap();
421                } else if c_u32 >= 0x100 {
422                    write!(out, "\\u{c_u32:04x}").unwrap();
423                } else {
424                    write!(out, "\\x{c_u32:02x}").unwrap();
425                }
426            }
427            Ok((EncodeReplace::Str(ctx.string(out.into())), range.end))
428        }
429    }
430
431    #[derive(Clone, Copy)]
432    pub struct SurrogateEscape;
433
434    impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for SurrogateEscape {
435        fn handle_encode_error(
436            &self,
437            ctx: &mut Ctx,
438            range: Range<StrSize>,
439            reason: Option<&str>,
440        ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
441            let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
442            let num_chars = range.end.chars - range.start.chars;
443            let mut out = Vec::with_capacity(num_chars);
444            let mut pos = range.start;
445            for ch in err_str.code_points() {
446                let ch_u32 = ch.to_u32();
447                if !(0xdc80..=0xdcff).contains(&ch_u32) {
448                    if out.is_empty() {
449                        // Can't handle even the first character
450                        return Err(ctx.error_encoding(range, reason));
451                    }
452                    // Return partial result, restart from this character
453                    return Ok((EncodeReplace::Bytes(ctx.bytes(out)), pos));
454                }
455                out.push((ch_u32 - 0xdc00) as u8);
456                pos += StrSize {
457                    bytes: ch.len_wtf8(),
458                    chars: 1,
459                };
460            }
461            Ok((EncodeReplace::Bytes(ctx.bytes(out)), range.end))
462        }
463    }
464
465    impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for SurrogateEscape {
466        fn handle_decode_error(
467            &self,
468            ctx: &mut Ctx,
469            byte_range: Range<usize>,
470            reason: Option<&str>,
471        ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
472            let err_bytes = &ctx.full_data()[byte_range.clone()];
473            let mut consumed = 0;
474            let mut replace = Wtf8Buf::with_capacity(4 * byte_range.len());
475            while consumed < 4 && consumed < byte_range.len() {
476                let c = err_bytes[consumed] as u16;
477                // Refuse to escape ASCII bytes
478                if c < 128 {
479                    break;
480                }
481                replace.push(CodePoint::from(0xdc00 + c));
482                consumed += 1;
483            }
484            if consumed == 0 {
485                return Err(ctx.error_decoding(byte_range, reason));
486            }
487            Ok((ctx.string(replace), byte_range.start + consumed))
488        }
489    }
490}
491
492pub mod utf8 {
493    use super::*;
494
495    pub const ENCODING_NAME: &str = "utf-8";
496
497    #[inline]
498    pub fn encode<Ctx, E>(ctx: Ctx, errors: &E) -> Result<Vec<u8>, Ctx::Error>
499    where
500        Ctx: EncodeContext,
501        E: EncodeErrorHandler<Ctx>,
502    {
503        encode_utf8_compatible(ctx, errors, "surrogates not allowed", StrKind::Utf8)
504    }
505
506    pub fn decode<Ctx: DecodeContext, E: DecodeErrorHandler<Ctx>>(
507        ctx: Ctx,
508        errors: &E,
509        final_decode: bool,
510    ) -> Result<(Wtf8Buf, usize), Ctx::Error> {
511        decode_utf8_compatible(
512            ctx,
513            errors,
514            |v| {
515                core::str::from_utf8(v).map_err(|e| {
516                    // SAFETY: as specified in valid_up_to's documentation, input[..e.valid_up_to()]
517                    //         is valid utf8
518                    unsafe { make_decode_err(v, e.valid_up_to(), e.error_len()) }
519                })
520            },
521            |rest, err_len| {
522                let first_err = rest[0];
523                if matches!(first_err, 0x80..=0xc1 | 0xf5..=0xff) {
524                    HandleResult::Error {
525                        err_len: Some(1),
526                        reason: "invalid start byte",
527                    }
528                } else if err_len.is_none() {
529                    // error_len() == None means unexpected eof
530                    if final_decode {
531                        HandleResult::Error {
532                            err_len,
533                            reason: "unexpected end of data",
534                        }
535                    } else {
536                        HandleResult::Done
537                    }
538                } else if !final_decode && matches!(rest, [0xed, 0xa0..=0xbf]) {
539                    // truncated surrogate
540                    HandleResult::Done
541                } else {
542                    HandleResult::Error {
543                        err_len,
544                        reason: "invalid continuation byte",
545                    }
546                }
547            },
548        )
549    }
550}
551
552pub mod latin_1 {
553    use super::*;
554
555    pub const ENCODING_NAME: &str = "latin-1";
556
557    const ERR_REASON: &str = "ordinal not in range(256)";
558
559    #[inline]
560    pub fn encode<Ctx, E>(mut ctx: Ctx, errors: &E) -> Result<Vec<u8>, Ctx::Error>
561    where
562        Ctx: EncodeContext,
563        E: EncodeErrorHandler<Ctx>,
564    {
565        let mut out = Vec::<u8>::new();
566        loop {
567            let data = ctx.remaining_data();
568            let mut iter = iter_code_points(ctx.remaining_data());
569            let Some((i, ch)) = iter.find(|(_, c)| !c.is_ascii()) else {
570                break;
571            };
572            out.extend_from_slice(&data.as_bytes()[..i.bytes]);
573            let err_start = ctx.position() + i;
574            if let Some(byte) = ch.to_u32().to_u8() {
575                drop(iter);
576                out.push(byte);
577                // if the codepoint is between 128..=255, it's utf8-length is 2
578                ctx.restart_from(err_start + StrSize { bytes: 2, chars: 1 })?;
579            } else {
580                // number of non-latin_1 chars between the first non-latin_1 char and the next latin_1 char
581                let err_end = match { iter }.find(|(_, c)| c.to_u32() <= 255) {
582                    Some((i, _)) => ctx.position() + i,
583                    None => ctx.data_len(),
584                };
585                let err_range = err_start..err_end;
586                let replace = ctx.handle_error(errors, err_range.clone(), Some(ERR_REASON))?;
587                match replace {
588                    EncodeReplace::Str(s) => {
589                        if s.as_ref().code_points().any(|c| c.to_u32() > 255) {
590                            return Err(ctx.error_encoding(err_range, Some(ERR_REASON)));
591                        }
592                        out.extend(s.as_ref().code_points().map(|c| c.to_u32() as u8));
593                    }
594                    EncodeReplace::Bytes(b) => {
595                        out.extend_from_slice(b.as_ref());
596                    }
597                }
598            }
599        }
600        out.extend_from_slice(ctx.remaining_data().as_bytes());
601        Ok(out)
602    }
603
604    pub fn decode<Ctx: DecodeContext, E: DecodeErrorHandler<Ctx>>(
605        ctx: Ctx,
606        _errors: &E,
607    ) -> Result<(Wtf8Buf, usize), Ctx::Error> {
608        let out: String = ctx.remaining_data().iter().map(|c| *c as char).collect();
609        let out_len = out.len();
610        Ok((out.into(), out_len))
611    }
612}
613
614pub mod ascii {
615    use super::*;
616    use ::ascii::AsciiStr;
617
618    pub const ENCODING_NAME: &str = "ascii";
619
620    const ERR_REASON: &str = "ordinal not in range(128)";
621
622    #[inline]
623    pub fn encode<Ctx, E>(ctx: Ctx, errors: &E) -> Result<Vec<u8>, Ctx::Error>
624    where
625        Ctx: EncodeContext,
626        E: EncodeErrorHandler<Ctx>,
627    {
628        encode_utf8_compatible(ctx, errors, ERR_REASON, StrKind::Ascii)
629    }
630
631    pub fn decode<Ctx: DecodeContext, E: DecodeErrorHandler<Ctx>>(
632        ctx: Ctx,
633        errors: &E,
634    ) -> Result<(Wtf8Buf, usize), Ctx::Error> {
635        decode_utf8_compatible(
636            ctx,
637            errors,
638            |v| {
639                AsciiStr::from_ascii(v).map(|s| s.as_str()).map_err(|e| {
640                    // SAFETY: as specified in valid_up_to's documentation, input[..e.valid_up_to()]
641                    //         is valid ascii & therefore valid utf8
642                    unsafe { make_decode_err(v, e.valid_up_to(), Some(1)) }
643                })
644            },
645            |_rest, err_len| HandleResult::Error {
646                err_len,
647                reason: ERR_REASON,
648            },
649        )
650    }
651}