Skip to main content

rustpython_vm/
codecs.rs

1use rustpython_common::{
2    borrow::BorrowedValue,
3    encodings::{
4        CodecContext, DecodeContext, DecodeErrorHandler, EncodeContext, EncodeErrorHandler,
5        EncodeReplace, StrBuffer, StrSize, errors,
6    },
7    str::StrKind,
8    wtf8::{CodePoint, Wtf8, Wtf8Buf},
9};
10
11use crate::common::lock::OnceCell;
12use crate::{
13    AsObject, Context, Py, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject,
14    VirtualMachine,
15    builtins::{
16        PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyUtf8Str,
17        PyUtf8StrRef,
18    },
19    common::{ascii, lock::PyRwLock},
20    convert::ToPyObject,
21    function::{ArgBytesLike, PyMethodDef},
22};
23use alloc::borrow::Cow;
24use core::ops::{self, Range};
25use std::collections::HashMap;
26
27pub struct CodecsRegistry {
28    inner: PyRwLock<RegistryInner>,
29}
30
31struct RegistryInner {
32    search_path: Vec<PyObjectRef>,
33    search_cache: HashMap<String, PyCodec>,
34    errors: HashMap<String, PyObjectRef>,
35}
36
37pub const DEFAULT_ENCODING: &str = "utf-8";
38
39#[derive(Clone)]
40#[repr(transparent)]
41pub struct PyCodec(PyTupleRef);
42impl PyCodec {
43    #[inline]
44    pub fn from_tuple(tuple: PyTupleRef) -> Result<Self, PyTupleRef> {
45        if tuple.len() == 4 {
46            Ok(Self(tuple))
47        } else {
48            Err(tuple)
49        }
50    }
51    #[inline]
52    pub fn into_tuple(self) -> PyTupleRef {
53        self.0
54    }
55    #[inline]
56    pub fn as_tuple(&self) -> &Py<PyTuple> {
57        &self.0
58    }
59
60    #[inline]
61    pub fn get_encode_func(&self) -> &PyObject {
62        &self.0[0]
63    }
64    #[inline]
65    pub fn get_decode_func(&self) -> &PyObject {
66        &self.0[1]
67    }
68
69    pub fn is_text_codec(&self, vm: &VirtualMachine) -> PyResult<bool> {
70        let is_text = vm.get_attribute_opt(self.0.clone().into(), "_is_text_encoding")?;
71        is_text.map_or(Ok(true), |is_text| is_text.try_to_bool(vm))
72    }
73
74    pub fn encode(
75        &self,
76        obj: PyObjectRef,
77        errors: Option<PyUtf8StrRef>,
78        vm: &VirtualMachine,
79    ) -> PyResult {
80        let args = match errors {
81            Some(errors) => vec![obj, errors.into_wtf8().into()],
82            None => vec![obj],
83        };
84        let res = self.get_encode_func().call(args, vm)?;
85        let res = res
86            .downcast::<PyTuple>()
87            .ok()
88            .filter(|tuple| tuple.len() == 2)
89            .ok_or_else(|| vm.new_type_error("encoder must return a tuple (object, integer)"))?;
90        // we don't actually care about the integer
91        Ok(res[0].clone())
92    }
93
94    pub fn decode(
95        &self,
96        obj: PyObjectRef,
97        errors: Option<PyUtf8StrRef>,
98        vm: &VirtualMachine,
99    ) -> PyResult {
100        let args = match errors {
101            Some(errors) => vec![obj, errors.into_wtf8().into()],
102            None => vec![obj],
103        };
104        let res = self.get_decode_func().call(args, vm)?;
105        let res = res
106            .downcast::<PyTuple>()
107            .ok()
108            .filter(|tuple| tuple.len() == 2)
109            .ok_or_else(|| vm.new_type_error("decoder must return a tuple (object,integer)"))?;
110        // we don't actually care about the integer
111        Ok(res[0].clone())
112    }
113
114    pub fn get_incremental_encoder(
115        &self,
116        errors: Option<PyStrRef>,
117        vm: &VirtualMachine,
118    ) -> PyResult {
119        let args = match errors {
120            Some(e) => vec![e.into()],
121            None => vec![],
122        };
123        vm.call_method(self.0.as_object(), "incrementalencoder", args)
124    }
125
126    pub fn get_incremental_decoder(
127        &self,
128        errors: Option<PyStrRef>,
129        vm: &VirtualMachine,
130    ) -> PyResult {
131        let args = match errors {
132            Some(e) => vec![e.into()],
133            None => vec![],
134        };
135        vm.call_method(self.0.as_object(), "incrementaldecoder", args)
136    }
137}
138
139impl TryFromObject for PyCodec {
140    fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
141        obj.downcast::<PyTuple>()
142            .ok()
143            .and_then(|tuple| Self::from_tuple(tuple).ok())
144            .ok_or_else(|| vm.new_type_error("codec search functions must return 4-tuples"))
145    }
146}
147
148impl ToPyObject for PyCodec {
149    #[inline]
150    fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef {
151        self.0.into()
152    }
153}
154
155impl CodecsRegistry {
156    /// Reset the inner RwLock to unlocked state after fork().
157    ///
158    /// # Safety
159    /// Must only be called after fork() in the child process when no other
160    /// threads exist.
161    #[cfg(all(unix, feature = "threading"))]
162    pub(crate) unsafe fn reinit_after_fork(&self) {
163        unsafe { crate::common::lock::reinit_rwlock_after_fork(&self.inner) };
164    }
165
166    pub(crate) fn new(ctx: &Context) -> Self {
167        ::rustpython_vm::common::static_cell! {
168            static METHODS: Box<[PyMethodDef]>;
169        }
170
171        let methods = METHODS.get_or_init(|| {
172            crate::define_methods![
173                "strict_errors" => strict_errors as EMPTY,
174                "ignore_errors" => ignore_errors as EMPTY,
175                "replace_errors" => replace_errors as EMPTY,
176                "xmlcharrefreplace_errors" => xmlcharrefreplace_errors as EMPTY,
177                "backslashreplace_errors" => backslashreplace_errors as EMPTY,
178                "namereplace_errors" => namereplace_errors as EMPTY,
179                "surrogatepass_errors" => surrogatepass_errors as EMPTY,
180                "surrogateescape_errors" => surrogateescape_errors as EMPTY
181            ]
182            .into_boxed_slice()
183        });
184
185        let errors = [
186            ("strict", methods[0].build_function(ctx)),
187            ("ignore", methods[1].build_function(ctx)),
188            ("replace", methods[2].build_function(ctx)),
189            ("xmlcharrefreplace", methods[3].build_function(ctx)),
190            ("backslashreplace", methods[4].build_function(ctx)),
191            ("namereplace", methods[5].build_function(ctx)),
192            ("surrogatepass", methods[6].build_function(ctx)),
193            ("surrogateescape", methods[7].build_function(ctx)),
194        ];
195        let errors = errors
196            .into_iter()
197            .map(|(name, f)| (name.to_owned(), f.into()))
198            .collect();
199        let inner = RegistryInner {
200            search_path: Vec::new(),
201            search_cache: HashMap::new(),
202            errors,
203        };
204        Self {
205            inner: PyRwLock::new(inner),
206        }
207    }
208
209    pub fn register(&self, search_function: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
210        if !search_function.is_callable() {
211            return Err(vm.new_type_error("argument must be callable"));
212        }
213        self.inner.write().search_path.push(search_function);
214        Ok(())
215    }
216
217    pub fn unregister(&self, search_function: PyObjectRef) -> PyResult<()> {
218        let mut inner = self.inner.write();
219        // Do nothing if search_path is not created yet or was cleared.
220        if inner.search_path.is_empty() {
221            return Ok(());
222        }
223        for (i, item) in inner.search_path.iter().enumerate() {
224            if item.get_id() == search_function.get_id() {
225                if !inner.search_cache.is_empty() {
226                    inner.search_cache.clear();
227                }
228                inner.search_path.remove(i);
229                return Ok(());
230            }
231        }
232        Ok(())
233    }
234
235    pub(crate) fn register_manual(&self, name: &str, codec: PyCodec) -> PyResult<()> {
236        let name = normalize_encoding_name(name);
237        self.inner
238            .write()
239            .search_cache
240            .insert(name.into_owned(), codec);
241        Ok(())
242    }
243
244    pub fn lookup(&self, encoding: &str, vm: &VirtualMachine) -> PyResult<PyCodec> {
245        let encoding = normalize_encoding_name(encoding);
246        let search_path = {
247            let inner = self.inner.read();
248            if let Some(codec) = inner.search_cache.get(encoding.as_ref()) {
249                // hit cache
250                return Ok(codec.clone());
251            }
252            inner.search_path.clone()
253        };
254        let encoding: PyUtf8StrRef = vm.ctx.new_utf8_str(encoding.as_ref());
255        for func in search_path {
256            let res = func.call((encoding.clone(),), vm)?;
257            let res: Option<PyCodec> = res.try_into_value(vm)?;
258            if let Some(codec) = res {
259                let mut inner = self.inner.write();
260                // someone might have raced us to this, so use theirs
261                let codec = inner
262                    .search_cache
263                    .entry(encoding.as_str().to_owned())
264                    .or_insert(codec);
265                return Ok(codec.clone());
266            }
267        }
268        Err(vm.new_lookup_error(format!("unknown encoding: {encoding}")))
269    }
270
271    fn _lookup_text_encoding(
272        &self,
273        encoding: &str,
274        generic_func: &str,
275        vm: &VirtualMachine,
276    ) -> PyResult<PyCodec> {
277        let codec = self.lookup(encoding, vm)?;
278        if codec.is_text_codec(vm)? {
279            Ok(codec)
280        } else {
281            Err(vm.new_lookup_error(format!(
282                "'{encoding}' is not a text encoding; use {generic_func} to handle arbitrary codecs"
283            )))
284        }
285    }
286
287    pub fn forget(&self, encoding: &str) -> Option<PyCodec> {
288        let encoding = normalize_encoding_name(encoding);
289        self.inner.write().search_cache.remove(encoding.as_ref())
290    }
291
292    pub fn encode(
293        &self,
294        obj: PyObjectRef,
295        encoding: &str,
296        errors: Option<PyUtf8StrRef>,
297        vm: &VirtualMachine,
298    ) -> PyResult {
299        let codec = self.lookup(encoding, vm)?;
300        codec.encode(obj, errors, vm).inspect_err(|exc| {
301            Self::add_codec_note(exc, "encoding", encoding, vm);
302        })
303    }
304
305    pub fn decode(
306        &self,
307        obj: PyObjectRef,
308        encoding: &str,
309        errors: Option<PyUtf8StrRef>,
310        vm: &VirtualMachine,
311    ) -> PyResult {
312        let codec = self.lookup(encoding, vm)?;
313        codec.decode(obj, errors, vm).inspect_err(|exc| {
314            Self::add_codec_note(exc, "decoding", encoding, vm);
315        })
316    }
317
318    pub fn encode_text(
319        &self,
320        obj: PyStrRef,
321        encoding: &str,
322        errors: Option<PyUtf8StrRef>,
323        vm: &VirtualMachine,
324    ) -> PyResult<PyBytesRef> {
325        let codec = self._lookup_text_encoding(encoding, "codecs.encode()", vm)?;
326        codec
327            .encode(obj.into(), errors, vm)
328            .inspect_err(|exc| {
329                Self::add_codec_note(exc, "encoding", encoding, vm);
330            })?
331            .downcast()
332            .map_err(|obj| {
333                vm.new_type_error(format!(
334                    "'{}' encoder returned '{}' instead of 'bytes'; use codecs.encode() to \
335                     encode to arbitrary types",
336                    encoding,
337                    obj.class().name(),
338                ))
339            })
340    }
341
342    pub fn decode_text(
343        &self,
344        obj: PyObjectRef,
345        encoding: &str,
346        errors: Option<PyUtf8StrRef>,
347        vm: &VirtualMachine,
348    ) -> PyResult<PyStrRef> {
349        let codec = self._lookup_text_encoding(encoding, "codecs.decode()", vm)?;
350        codec
351            .decode(obj, errors, vm)
352            .inspect_err(|exc| {
353                Self::add_codec_note(exc, "decoding", encoding, vm);
354            })?
355            .downcast()
356            .map_err(|obj| {
357                vm.new_type_error(format!(
358                    "'{}' decoder returned '{}' instead of 'str'; use codecs.decode() to \
359                 decode to arbitrary types",
360                    encoding,
361                    obj.class().name(),
362                ))
363            })
364    }
365
366    fn add_codec_note(
367        exc: &crate::builtins::PyBaseExceptionRef,
368        operation: &str,
369        encoding: &str,
370        vm: &VirtualMachine,
371    ) {
372        let note = format!("{operation} with '{encoding}' codec failed");
373        let _ = vm.call_method(exc.as_object(), "add_note", (vm.ctx.new_str(note),));
374    }
375
376    pub fn register_error(&self, name: String, handler: PyObjectRef) -> Option<PyObjectRef> {
377        self.inner.write().errors.insert(name, handler)
378    }
379
380    pub fn unregister_error(&self, name: &str, vm: &VirtualMachine) -> PyResult<bool> {
381        const BUILTIN_ERROR_HANDLERS: &[&str] = &[
382            "strict",
383            "ignore",
384            "replace",
385            "xmlcharrefreplace",
386            "backslashreplace",
387            "namereplace",
388            "surrogatepass",
389            "surrogateescape",
390        ];
391        if BUILTIN_ERROR_HANDLERS.contains(&name) {
392            return Err(vm.new_value_error(format!(
393                "cannot un-register built-in error handler '{name}'"
394            )));
395        }
396        Ok(self.inner.write().errors.remove(name).is_some())
397    }
398
399    pub fn lookup_error_opt(&self, name: &str) -> Option<PyObjectRef> {
400        self.inner.read().errors.get(name).cloned()
401    }
402
403    pub fn lookup_error(&self, name: &str, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
404        self.lookup_error_opt(name)
405            .ok_or_else(|| vm.new_lookup_error(format!("unknown error handler name '{name}'")))
406    }
407}
408
409fn normalize_encoding_name(encoding: &str) -> Cow<'_, str> {
410    // _Py_normalize_encoding: collapse non-alphanumeric/non-dot chars into
411    // single underscore, strip non-ASCII, lowercase ASCII letters.
412    let needs_transform = encoding
413        .bytes()
414        .any(|b| b.is_ascii_uppercase() || !b.is_ascii_alphanumeric() && b != b'.');
415    if !needs_transform {
416        return encoding.into();
417    }
418    let mut out = String::with_capacity(encoding.len());
419    let mut punct = false;
420    for c in encoding.chars() {
421        if c.is_ascii_alphanumeric() || c == '.' {
422            if punct && !out.is_empty() {
423                out.push('_');
424            }
425            out.push(c.to_ascii_lowercase());
426            punct = false;
427        } else {
428            punct = true;
429        }
430    }
431    out.into()
432}
433
434#[derive(Eq, PartialEq)]
435enum StandardEncoding {
436    Utf8,
437    Utf16Be,
438    Utf16Le,
439    Utf32Be,
440    Utf32Le,
441}
442
443impl StandardEncoding {
444    #[cfg(target_endian = "little")]
445    const UTF_16_NE: Self = Self::Utf16Le;
446    #[cfg(target_endian = "big")]
447    const UTF_16_NE: Self = Self::Utf16Be;
448
449    #[cfg(target_endian = "little")]
450    const UTF_32_NE: Self = Self::Utf32Le;
451    #[cfg(target_endian = "big")]
452    const UTF_32_NE: Self = Self::Utf32Be;
453
454    fn parse(encoding: &str) -> Option<Self> {
455        if let Some(encoding) = encoding.to_lowercase().strip_prefix("utf") {
456            let encoding = encoding
457                .strip_prefix(|c| ['-', '_'].contains(&c))
458                .unwrap_or(encoding);
459            if encoding == "8" {
460                Some(Self::Utf8)
461            } else if let Some(encoding) = encoding.strip_prefix("16") {
462                if encoding.is_empty() {
463                    return Some(Self::UTF_16_NE);
464                }
465                let encoding = encoding.strip_prefix(['-', '_']).unwrap_or(encoding);
466                match encoding {
467                    "be" => Some(Self::Utf16Be),
468                    "le" => Some(Self::Utf16Le),
469                    _ => None,
470                }
471            } else if let Some(encoding) = encoding.strip_prefix("32") {
472                if encoding.is_empty() {
473                    return Some(Self::UTF_32_NE);
474                }
475                let encoding = encoding.strip_prefix(['-', '_']).unwrap_or(encoding);
476                match encoding {
477                    "be" => Some(Self::Utf32Be),
478                    "le" => Some(Self::Utf32Le),
479                    _ => None,
480                }
481            } else {
482                None
483            }
484        } else if encoding == "cp65001" {
485            Some(Self::Utf8)
486        } else {
487            None
488        }
489    }
490}
491
492struct SurrogatePass;
493
494impl<'a> EncodeErrorHandler<PyEncodeContext<'a>> for SurrogatePass {
495    fn handle_encode_error(
496        &self,
497        ctx: &mut PyEncodeContext<'a>,
498        range: Range<StrSize>,
499        reason: Option<&str>,
500    ) -> PyResult<(EncodeReplace<PyEncodeContext<'a>>, StrSize)> {
501        let standard_encoding = StandardEncoding::parse(ctx.encoding)
502            .ok_or_else(|| ctx.error_encoding(range.clone(), reason))?;
503        let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
504        let num_chars = range.end.chars - range.start.chars;
505        let mut out: Vec<u8> = Vec::with_capacity(num_chars * 4);
506        for ch in err_str.code_points() {
507            let c = ch.to_u32();
508            let 0xd800..=0xdfff = c else {
509                // Not a surrogate, fail with original exception
510                return Err(ctx.error_encoding(range, reason));
511            };
512            match standard_encoding {
513                StandardEncoding::Utf8 => out.extend(ch.encode_wtf8(&mut [0; 4]).as_bytes()),
514                StandardEncoding::Utf16Le => out.extend((c as u16).to_le_bytes()),
515                StandardEncoding::Utf16Be => out.extend((c as u16).to_be_bytes()),
516                StandardEncoding::Utf32Le => out.extend(c.to_le_bytes()),
517                StandardEncoding::Utf32Be => out.extend(c.to_be_bytes()),
518            }
519        }
520        Ok((EncodeReplace::Bytes(ctx.bytes(out)), range.end))
521    }
522}
523
524impl<'a> DecodeErrorHandler<PyDecodeContext<'a>> for SurrogatePass {
525    fn handle_decode_error(
526        &self,
527        ctx: &mut PyDecodeContext<'a>,
528        byte_range: Range<usize>,
529        reason: Option<&str>,
530    ) -> PyResult<(PyStrRef, usize)> {
531        let standard_encoding = StandardEncoding::parse(ctx.encoding)
532            .ok_or_else(|| ctx.error_decoding(byte_range.clone(), reason))?;
533
534        let s = ctx.full_data();
535        debug_assert!(byte_range.start <= 0.max(s.len() - 1));
536        debug_assert!(byte_range.end >= 1.min(s.len()));
537        debug_assert!(byte_range.end <= s.len());
538
539        // Try decoding a single surrogate character. If there are more,
540        // let the codec call us again.
541        let p = &s[byte_range.start..];
542
543        fn slice<const N: usize>(p: &[u8]) -> Option<[u8; N]> {
544            p.first_chunk().copied()
545        }
546
547        let c = match standard_encoding {
548            StandardEncoding::Utf8 => {
549                // it's a three-byte code
550                slice::<3>(p)
551                    .filter(|&[a, b, c]| {
552                        (u32::from(a) & 0xf0) == 0xe0
553                            && (u32::from(b) & 0xc0) == 0x80
554                            && (u32::from(c) & 0xc0) == 0x80
555                    })
556                    .map(|[a, b, c]| {
557                        ((u32::from(a) & 0x0f) << 12)
558                            + ((u32::from(b) & 0x3f) << 6)
559                            + (u32::from(c) & 0x3f)
560                    })
561            }
562            StandardEncoding::Utf16Le => slice(p).map(u16::from_le_bytes).map(u32::from),
563            StandardEncoding::Utf16Be => slice(p).map(u16::from_be_bytes).map(u32::from),
564            StandardEncoding::Utf32Le => slice(p).map(u32::from_le_bytes),
565            StandardEncoding::Utf32Be => slice(p).map(u32::from_be_bytes),
566        };
567        let byte_length = match standard_encoding {
568            StandardEncoding::Utf8 => 3,
569            StandardEncoding::Utf16Be | StandardEncoding::Utf16Le => 2,
570            StandardEncoding::Utf32Be | StandardEncoding::Utf32Le => 4,
571        };
572
573        // !Py_UNICODE_IS_SURROGATE
574        let c = c
575            .and_then(CodePoint::from_u32)
576            .filter(|c| matches!(c.to_u32(), 0xd800..=0xdfff))
577            .ok_or_else(|| ctx.error_decoding(byte_range.clone(), reason))?;
578
579        Ok((ctx.string(c.into()), byte_range.start + byte_length))
580    }
581}
582
583pub struct PyEncodeContext<'a> {
584    vm: &'a VirtualMachine,
585    encoding: &'a str,
586    data: &'a Py<PyStr>,
587    pos: StrSize,
588    exception: OnceCell<PyBaseExceptionRef>,
589}
590
591impl<'a> PyEncodeContext<'a> {
592    pub fn new(encoding: &'a str, data: &'a Py<PyStr>, vm: &'a VirtualMachine) -> Self {
593        Self {
594            vm,
595            encoding,
596            data,
597            pos: StrSize::default(),
598            exception: OnceCell::new(),
599        }
600    }
601}
602
603impl CodecContext for PyEncodeContext<'_> {
604    type Error = PyBaseExceptionRef;
605    type StrBuf = PyStrRef;
606    type BytesBuf = PyBytesRef;
607
608    fn string(&self, s: Wtf8Buf) -> Self::StrBuf {
609        self.vm.ctx.new_str(s)
610    }
611
612    fn bytes(&self, b: Vec<u8>) -> Self::BytesBuf {
613        self.vm.ctx.new_bytes(b)
614    }
615}
616impl EncodeContext for PyEncodeContext<'_> {
617    fn full_data(&self) -> &Wtf8 {
618        self.data.as_wtf8()
619    }
620
621    fn data_len(&self) -> StrSize {
622        StrSize {
623            bytes: self.data.byte_len(),
624            chars: self.data.char_len(),
625        }
626    }
627
628    fn remaining_data(&self) -> &Wtf8 {
629        &self.full_data()[self.pos.bytes..]
630    }
631
632    fn position(&self) -> StrSize {
633        self.pos
634    }
635
636    fn restart_from(&mut self, pos: StrSize) -> Result<(), Self::Error> {
637        if pos.chars > self.data.char_len() {
638            return Err(self.vm.new_index_error(format!(
639                "position {} from error handler out of bounds",
640                pos.chars
641            )));
642        }
643        assert!(
644            self.data.as_wtf8().is_code_point_boundary(pos.bytes),
645            "invalid pos {pos:?} for {:?}",
646            self.data.as_wtf8()
647        );
648        self.pos = pos;
649        Ok(())
650    }
651
652    fn error_encoding(&self, range: Range<StrSize>, reason: Option<&str>) -> Self::Error {
653        let vm = self.vm;
654        match self.exception.get() {
655            Some(exc) => {
656                match update_unicode_error_attrs(
657                    exc.as_object(),
658                    range.start.chars,
659                    range.end.chars,
660                    reason,
661                    vm,
662                ) {
663                    Ok(()) => exc.clone(),
664                    Err(e) => e,
665                }
666            }
667            None => self
668                .exception
669                .get_or_init(|| {
670                    let reason = reason.expect(
671                        "should only ever pass reason: None if an exception is already set",
672                    );
673                    vm.new_unicode_encode_error_real(
674                        vm.ctx.new_str(self.encoding),
675                        self.data.to_owned(),
676                        range.start.chars,
677                        range.end.chars,
678                        vm.ctx.new_str(reason),
679                    )
680                })
681                .clone(),
682        }
683    }
684}
685
686pub struct PyDecodeContext<'a> {
687    vm: &'a VirtualMachine,
688    encoding: &'a str,
689    data: PyDecodeData<'a>,
690    orig_bytes: Option<&'a Py<PyBytes>>,
691    pos: usize,
692    exception: OnceCell<PyBaseExceptionRef>,
693}
694enum PyDecodeData<'a> {
695    Original(BorrowedValue<'a, [u8]>),
696    Modified(PyBytesRef),
697}
698impl ops::Deref for PyDecodeData<'_> {
699    type Target = [u8];
700    fn deref(&self) -> &Self::Target {
701        match self {
702            PyDecodeData::Original(data) => data,
703            PyDecodeData::Modified(data) => data,
704        }
705    }
706}
707
708impl<'a> PyDecodeContext<'a> {
709    pub fn new(encoding: &'a str, data: &'a ArgBytesLike, vm: &'a VirtualMachine) -> Self {
710        Self {
711            vm,
712            encoding,
713            data: PyDecodeData::Original(data.borrow_buf()),
714            orig_bytes: data.as_object().downcast_ref(),
715            pos: 0,
716            exception: OnceCell::new(),
717        }
718    }
719}
720
721impl CodecContext for PyDecodeContext<'_> {
722    type Error = PyBaseExceptionRef;
723    type StrBuf = PyStrRef;
724    type BytesBuf = PyBytesRef;
725
726    fn string(&self, s: Wtf8Buf) -> Self::StrBuf {
727        self.vm.ctx.new_str(s)
728    }
729
730    fn bytes(&self, b: Vec<u8>) -> Self::BytesBuf {
731        self.vm.ctx.new_bytes(b)
732    }
733}
734impl DecodeContext for PyDecodeContext<'_> {
735    fn full_data(&self) -> &[u8] {
736        &self.data
737    }
738
739    fn remaining_data(&self) -> &[u8] {
740        &self.data[self.pos..]
741    }
742
743    fn position(&self) -> usize {
744        self.pos
745    }
746
747    fn advance(&mut self, by: usize) {
748        self.pos += by;
749    }
750
751    fn restart_from(&mut self, pos: usize) -> Result<(), Self::Error> {
752        if pos > self.data.len() {
753            return Err(self
754                .vm
755                .new_index_error(format!("position {pos} from error handler out of bounds",)));
756        }
757        self.pos = pos;
758        Ok(())
759    }
760
761    fn error_decoding(&self, byte_range: Range<usize>, reason: Option<&str>) -> Self::Error {
762        let vm = self.vm;
763
764        match self.exception.get() {
765            Some(exc) => {
766                match update_unicode_error_attrs(
767                    exc.as_object(),
768                    byte_range.start,
769                    byte_range.end,
770                    reason,
771                    vm,
772                ) {
773                    Ok(()) => exc.clone(),
774                    Err(e) => e,
775                }
776            }
777            None => self
778                .exception
779                .get_or_init(|| {
780                    let reason = reason.expect(
781                        "should only ever pass reason: None if an exception is already set",
782                    );
783                    let data = if let Some(bytes) = self.orig_bytes {
784                        bytes.to_owned()
785                    } else {
786                        vm.ctx.new_bytes(self.data.to_vec())
787                    };
788                    vm.new_unicode_decode_error_real(
789                        vm.ctx.new_str(self.encoding),
790                        data,
791                        byte_range.start,
792                        byte_range.end,
793                        vm.ctx.new_str(reason),
794                    )
795                })
796                .clone(),
797        }
798    }
799}
800
801#[derive(strum_macros::EnumString)]
802#[strum(serialize_all = "lowercase")]
803enum StandardError {
804    Strict,
805    Ignore,
806    Replace,
807    XmlCharRefReplace,
808    BackslashReplace,
809    SurrogatePass,
810    SurrogateEscape,
811}
812
813impl<'a> EncodeErrorHandler<PyEncodeContext<'a>> for StandardError {
814    fn handle_encode_error(
815        &self,
816        ctx: &mut PyEncodeContext<'a>,
817        range: Range<StrSize>,
818        reason: Option<&str>,
819    ) -> PyResult<(EncodeReplace<PyEncodeContext<'a>>, StrSize)> {
820        use StandardError::*;
821        // use errors::*;
822        match self {
823            Strict => errors::Strict.handle_encode_error(ctx, range, reason),
824            Ignore => errors::Ignore.handle_encode_error(ctx, range, reason),
825            Replace => errors::Replace.handle_encode_error(ctx, range, reason),
826            XmlCharRefReplace => errors::XmlCharRefReplace.handle_encode_error(ctx, range, reason),
827            BackslashReplace => errors::BackslashReplace.handle_encode_error(ctx, range, reason),
828            SurrogatePass => SurrogatePass.handle_encode_error(ctx, range, reason),
829            SurrogateEscape => errors::SurrogateEscape.handle_encode_error(ctx, range, reason),
830        }
831    }
832}
833
834impl<'a> DecodeErrorHandler<PyDecodeContext<'a>> for StandardError {
835    fn handle_decode_error(
836        &self,
837        ctx: &mut PyDecodeContext<'a>,
838        byte_range: Range<usize>,
839        reason: Option<&str>,
840    ) -> PyResult<(PyStrRef, usize)> {
841        use StandardError::*;
842        match self {
843            Strict => errors::Strict.handle_decode_error(ctx, byte_range, reason),
844            Ignore => errors::Ignore.handle_decode_error(ctx, byte_range, reason),
845            Replace => errors::Replace.handle_decode_error(ctx, byte_range, reason),
846            XmlCharRefReplace => Err(ctx
847                .vm
848                .new_type_error("don't know how to handle UnicodeDecodeError in error callback")),
849            BackslashReplace => {
850                errors::BackslashReplace.handle_decode_error(ctx, byte_range, reason)
851            }
852            SurrogatePass => self::SurrogatePass.handle_decode_error(ctx, byte_range, reason),
853            SurrogateEscape => errors::SurrogateEscape.handle_decode_error(ctx, byte_range, reason),
854        }
855    }
856}
857
858pub struct ErrorsHandler<'a> {
859    errors: &'a Py<PyUtf8Str>,
860    resolved: OnceCell<ResolvedError>,
861}
862enum ResolvedError {
863    Standard(StandardError),
864    Handler(PyObjectRef),
865}
866
867impl<'a> ErrorsHandler<'a> {
868    #[inline]
869    pub fn new(errors: Option<&'a Py<PyUtf8Str>>, vm: &VirtualMachine) -> Self {
870        match errors {
871            Some(errors) => Self {
872                errors,
873                resolved: OnceCell::new(),
874            },
875            None => Self {
876                errors: identifier_utf8!(vm, strict),
877                resolved: OnceCell::from(ResolvedError::Standard(StandardError::Strict)),
878            },
879        }
880    }
881    #[inline]
882    fn resolve(&self, vm: &VirtualMachine) -> PyResult<&ResolvedError> {
883        if let Some(val) = self.resolved.get() {
884            return Ok(val);
885        }
886        let errors_str = self.errors.as_str();
887        let val = if let Ok(standard) = errors_str.parse() {
888            ResolvedError::Standard(standard)
889        } else {
890            vm.state
891                .codec_registry
892                .lookup_error(errors_str, vm)
893                .map(ResolvedError::Handler)?
894        };
895        let _ = self.resolved.set(val);
896        Ok(self.resolved.get().unwrap())
897    }
898}
899impl StrBuffer for PyStrRef {
900    fn is_compatible_with(&self, kind: StrKind) -> bool {
901        self.kind() <= kind
902    }
903}
904impl<'a> EncodeErrorHandler<PyEncodeContext<'a>> for ErrorsHandler<'_> {
905    fn handle_encode_error(
906        &self,
907        ctx: &mut PyEncodeContext<'a>,
908        range: Range<StrSize>,
909        reason: Option<&str>,
910    ) -> PyResult<(EncodeReplace<PyEncodeContext<'a>>, StrSize)> {
911        let vm = ctx.vm;
912        let handler = match self.resolve(vm)? {
913            ResolvedError::Standard(standard) => {
914                return standard.handle_encode_error(ctx, range, reason);
915            }
916            ResolvedError::Handler(handler) => handler,
917        };
918        let encode_exc = ctx.error_encoding(range.clone(), reason);
919        let res = handler.call((encode_exc.clone(),), vm)?;
920        let tuple_err =
921            || vm.new_type_error("encoding error handler must return (str/bytes, int) tuple");
922        let (replace, restart) = match res.downcast_ref::<PyTuple>().map(|tup| tup.as_slice()) {
923            Some([replace, restart]) => (replace.clone(), restart),
924            _ => return Err(tuple_err()),
925        };
926        let replace = match_class!(match replace {
927            s @ PyStr => EncodeReplace::Str(s),
928            b @ PyBytes => EncodeReplace::Bytes(b),
929            _ => return Err(tuple_err()),
930        });
931        let restart = isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?;
932        let restart = if restart < 0 {
933            // will still be out of bounds if it underflows ¯\_(ツ)_/¯
934            ctx.data.char_len().wrapping_sub(restart.unsigned_abs())
935        } else {
936            restart as usize
937        };
938        let restart = if restart == range.end.chars {
939            range.end
940        } else {
941            StrSize {
942                chars: restart,
943                bytes: ctx
944                    .data
945                    .as_wtf8()
946                    .code_point_indices()
947                    .nth(restart)
948                    .map_or(ctx.data.byte_len(), |(i, _)| i),
949            }
950        };
951        Ok((replace, restart))
952    }
953}
954impl<'a> DecodeErrorHandler<PyDecodeContext<'a>> for ErrorsHandler<'_> {
955    fn handle_decode_error(
956        &self,
957        ctx: &mut PyDecodeContext<'a>,
958        byte_range: Range<usize>,
959        reason: Option<&str>,
960    ) -> PyResult<(PyStrRef, usize)> {
961        let vm = ctx.vm;
962        let handler = match self.resolve(vm)? {
963            ResolvedError::Standard(standard) => {
964                return standard.handle_decode_error(ctx, byte_range, reason);
965            }
966            ResolvedError::Handler(handler) => handler,
967        };
968        let decode_exc = ctx.error_decoding(byte_range.clone(), reason);
969        let data_bytes: PyObjectRef = decode_exc.as_object().get_attr("object", vm)?;
970        let res = handler.call((decode_exc.clone(),), vm)?;
971        let new_data = decode_exc.as_object().get_attr("object", vm)?;
972        if !new_data.is(&data_bytes) {
973            let new_data: PyBytesRef = new_data
974                .downcast()
975                .map_err(|_| vm.new_type_error("object attribute must be bytes"))?;
976            ctx.data = PyDecodeData::Modified(new_data);
977        }
978        let data = &*ctx.data;
979        let tuple_err = || vm.new_type_error("decoding error handler must return (str, int) tuple");
980        match res.downcast_ref::<PyTuple>().map(|tup| tup.as_slice()) {
981            Some([replace, restart]) => {
982                let replace = replace
983                    .downcast_ref::<PyStr>()
984                    .ok_or_else(tuple_err)?
985                    .to_owned();
986                let restart =
987                    isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?;
988                let restart = if restart < 0 {
989                    // will still be out of bounds if it underflows ¯\_(ツ)_/¯
990                    data.len().wrapping_sub(restart.unsigned_abs())
991                } else {
992                    restart as usize
993                };
994                Ok((replace, restart))
995            }
996            _ => Err(tuple_err()),
997        }
998    }
999}
1000
1001fn call_native_encode_error<E>(
1002    handler: E,
1003    err: PyObjectRef,
1004    vm: &VirtualMachine,
1005) -> PyResult<(PyObjectRef, usize)>
1006where
1007    for<'a> E: EncodeErrorHandler<PyEncodeContext<'a>>,
1008{
1009    // let err = err.
1010    let range = extract_unicode_error_range(&err, vm)?;
1011    let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
1012    let s_encoding = PyUtf8StrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?;
1013    let mut ctx = PyEncodeContext {
1014        vm,
1015        encoding: s_encoding.as_str(),
1016        data: &s,
1017        pos: StrSize::default(),
1018        exception: OnceCell::from(err.downcast().unwrap()),
1019    };
1020    let mut iter = s.as_wtf8().code_point_indices();
1021    let start = StrSize {
1022        chars: range.start,
1023        bytes: iter.nth(range.start).unwrap().0,
1024    };
1025    let end = StrSize {
1026        chars: range.end,
1027        bytes: if let Some(n) = range.len().checked_sub(1) {
1028            iter.nth(n).map_or(s.byte_len(), |(i, _)| i)
1029        } else {
1030            start.bytes
1031        },
1032    };
1033    let (replace, restart) = handler.handle_encode_error(&mut ctx, start..end, None)?;
1034    let replace = match replace {
1035        EncodeReplace::Str(s) => s.into(),
1036        EncodeReplace::Bytes(b) => b.into(),
1037    };
1038    Ok((replace, restart.chars))
1039}
1040
1041fn call_native_decode_error<E>(
1042    handler: E,
1043    err: PyObjectRef,
1044    vm: &VirtualMachine,
1045) -> PyResult<(PyObjectRef, usize)>
1046where
1047    for<'a> E: DecodeErrorHandler<PyDecodeContext<'a>>,
1048{
1049    let range = extract_unicode_error_range(&err, vm)?;
1050    let s = ArgBytesLike::try_from_object(vm, err.get_attr("object", vm)?)?;
1051    let s_encoding = PyUtf8StrRef::try_from_object(vm, err.get_attr("encoding", vm)?)?;
1052    let mut ctx = PyDecodeContext {
1053        vm,
1054        encoding: s_encoding.as_str(),
1055        data: PyDecodeData::Original(s.borrow_buf()),
1056        orig_bytes: s.as_object().downcast_ref(),
1057        pos: 0,
1058        exception: OnceCell::from(err.downcast().unwrap()),
1059    };
1060    let (replace, restart) = handler.handle_decode_error(&mut ctx, range, None)?;
1061    Ok((replace.into(), restart))
1062}
1063
1064// this is a hack, for now
1065fn call_native_translate_error<E>(
1066    handler: E,
1067    err: PyObjectRef,
1068    vm: &VirtualMachine,
1069) -> PyResult<(PyObjectRef, usize)>
1070where
1071    for<'a> E: EncodeErrorHandler<PyEncodeContext<'a>>,
1072{
1073    // let err = err.
1074    let range = extract_unicode_error_range(&err, vm)?;
1075    let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?;
1076    let mut ctx = PyEncodeContext {
1077        vm,
1078        encoding: "",
1079        data: &s,
1080        pos: StrSize::default(),
1081        exception: OnceCell::from(err.downcast().unwrap()),
1082    };
1083    let mut iter = s.as_wtf8().code_point_indices();
1084    let start = StrSize {
1085        chars: range.start,
1086        bytes: iter.nth(range.start).unwrap().0,
1087    };
1088    let end = StrSize {
1089        chars: range.end,
1090        bytes: if let Some(n) = range.len().checked_sub(1) {
1091            iter.nth(n).map_or(s.byte_len(), |(i, _)| i)
1092        } else {
1093            start.bytes
1094        },
1095    };
1096    let (replace, restart) = handler.handle_encode_error(&mut ctx, start..end, None)?;
1097    let replace = match replace {
1098        EncodeReplace::Str(s) => s.into(),
1099        EncodeReplace::Bytes(b) => b.into(),
1100    };
1101    Ok((replace, restart.chars))
1102}
1103
1104// TODO: exceptions with custom payloads
1105fn extract_unicode_error_range(err: &PyObject, vm: &VirtualMachine) -> PyResult<Range<usize>> {
1106    let start = err.get_attr("start", vm)?;
1107    let start = start.try_into_value(vm)?;
1108    let end = err.get_attr("end", vm)?;
1109    let end = end.try_into_value(vm)?;
1110    Ok(Range { start, end })
1111}
1112
1113fn update_unicode_error_attrs(
1114    err: &PyObject,
1115    start: usize,
1116    end: usize,
1117    reason: Option<&str>,
1118    vm: &VirtualMachine,
1119) -> PyResult<()> {
1120    err.set_attr("start", start.to_pyobject(vm), vm)?;
1121    err.set_attr("end", end.to_pyobject(vm), vm)?;
1122    if let Some(reason) = reason {
1123        err.set_attr("reason", reason.to_pyobject(vm), vm)?;
1124    }
1125    Ok(())
1126}
1127
1128#[inline]
1129fn is_encode_err(err: &PyObject, vm: &VirtualMachine) -> bool {
1130    err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error)
1131}
1132#[inline]
1133fn is_decode_err(err: &PyObject, vm: &VirtualMachine) -> bool {
1134    err.fast_isinstance(vm.ctx.exceptions.unicode_decode_error)
1135}
1136#[inline]
1137fn is_translate_err(err: &PyObject, vm: &VirtualMachine) -> bool {
1138    err.fast_isinstance(vm.ctx.exceptions.unicode_translate_error)
1139}
1140
1141fn bad_err_type(err: PyObjectRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
1142    vm.new_type_error(format!(
1143        "don't know how to handle {} in error callback",
1144        err.class().name()
1145    ))
1146}
1147
1148fn strict_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult {
1149    let err = err
1150        .downcast()
1151        .unwrap_or_else(|_| vm.new_type_error("codec must pass exception instance"));
1152    Err(err)
1153}
1154
1155fn ignore_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1156    if is_encode_err(&err, vm) || is_decode_err(&err, vm) || is_translate_err(&err, vm) {
1157        let range = extract_unicode_error_range(&err, vm)?;
1158        Ok((vm.ctx.new_str(ascii!("")).into(), range.end))
1159    } else {
1160        Err(bad_err_type(err, vm))
1161    }
1162}
1163
1164fn replace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1165    if is_encode_err(&err, vm) {
1166        call_native_encode_error(errors::Replace, err, vm)
1167    } else if is_decode_err(&err, vm) {
1168        call_native_decode_error(errors::Replace, err, vm)
1169    } else if is_translate_err(&err, vm) {
1170        // char::REPLACEMENT_CHARACTER as a str
1171        let replacement_char = "\u{FFFD}";
1172        let range = extract_unicode_error_range(&err, vm)?;
1173        let replace = replacement_char.repeat(range.end - range.start);
1174        Ok((replace.to_pyobject(vm), range.end))
1175    } else {
1176        Err(bad_err_type(err, vm))
1177    }
1178}
1179
1180fn xmlcharrefreplace_errors(
1181    err: PyObjectRef,
1182    vm: &VirtualMachine,
1183) -> PyResult<(PyObjectRef, usize)> {
1184    if is_encode_err(&err, vm) {
1185        call_native_encode_error(errors::XmlCharRefReplace, err, vm)
1186    } else {
1187        Err(bad_err_type(err, vm))
1188    }
1189}
1190
1191fn backslashreplace_errors(
1192    err: PyObjectRef,
1193    vm: &VirtualMachine,
1194) -> PyResult<(PyObjectRef, usize)> {
1195    if is_decode_err(&err, vm) {
1196        call_native_decode_error(errors::BackslashReplace, err, vm)
1197    } else if is_encode_err(&err, vm) {
1198        call_native_encode_error(errors::BackslashReplace, err, vm)
1199    } else if is_translate_err(&err, vm) {
1200        call_native_translate_error(errors::BackslashReplace, err, vm)
1201    } else {
1202        Err(bad_err_type(err, vm))
1203    }
1204}
1205
1206fn namereplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1207    if is_encode_err(&err, vm) {
1208        call_native_encode_error(errors::NameReplace, err, vm)
1209    } else {
1210        Err(bad_err_type(err, vm))
1211    }
1212}
1213
1214fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1215    if is_encode_err(&err, vm) {
1216        call_native_encode_error(SurrogatePass, err, vm)
1217    } else if is_decode_err(&err, vm) {
1218        call_native_decode_error(SurrogatePass, err, vm)
1219    } else {
1220        Err(bad_err_type(err, vm))
1221    }
1222}
1223
1224fn surrogateescape_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
1225    if is_encode_err(&err, vm) {
1226        call_native_encode_error(errors::SurrogateEscape, err, vm)
1227    } else if is_decode_err(&err, vm) {
1228        call_native_decode_error(errors::SurrogateEscape, err, vm)
1229    } else {
1230        Err(bad_err_type(err, vm))
1231    }
1232}