numcodecs_python/
export.rs

1use std::{any::Any, ffi::CString};
2
3use ndarray::{ArrayViewD, ArrayViewMutD, CowArray};
4use numcodecs::{
5    AnyArray, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec, DynCodecType,
6};
7use numpy::{
8    IxDyn, PyArray, PyArrayDescrMethods, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods,
9};
10use pyo3::{
11    exceptions::PyTypeError,
12    intern,
13    prelude::*,
14    types::{IntoPyDict, PyDict, PyString, PyType},
15    PyTypeInfo,
16};
17use pyo3_error::PyErrChain;
18use pythonize::{pythonize, Depythonizer, Pythonizer};
19
20use crate::{
21    schema::{docs_from_schema, signature_from_schema},
22    utils::numpy_asarray,
23    PyCodec, PyCodecClass, PyCodecClassAdapter, PyCodecRegistry,
24};
25
26/// Export the [`DynCodecType`] `ty` to Python by generating a fresh
27/// [`PyCodecClass`] inside `module` and registering it with the
28/// [`PyCodecRegistry`].
29///
30/// # Errors
31///
32/// Errors if generating or exporting the fresh [`PyCodecClass`] fails.
33pub fn export_codec_class<'py, T: DynCodecType>(
34    py: Python<'py>,
35    ty: T,
36    module: Borrowed<'_, 'py, PyModule>,
37) -> Result<Bound<'py, PyCodecClass>, PyErr> {
38    let codec_id = String::from(ty.codec_id());
39    let codec_class_name = convert_case::Casing::to_case(&codec_id, convert_case::Case::Pascal);
40
41    let codec_class: Bound<PyCodecClass> =
42        // re-exporting a Python codec class should roundtrip
43        if let Some(adapter) = (&ty as &dyn Any).downcast_ref::<PyCodecClassAdapter>() {
44            adapter.as_codec_class(py).clone()
45        } else {
46            let codec_config_schema = ty.codec_config_schema();
47
48            let codec_class_bases = (
49                RustCodec::type_object(py),
50                PyCodec::type_object(py),
51            );
52
53            let codec_class_namespace = [
54                (intern!(py, "__module__"), module.name()?.into_any()),
55                (
56                    intern!(py, "__doc__"),
57                    docs_from_schema(&codec_config_schema, &codec_id).into_pyobject(py)?,
58                ),
59                (
60                    intern!(py, RustCodec::TYPE_ATTRIBUTE),
61                    Bound::new(py, RustCodecType { ty: Box::new(ty) })?.into_any(),
62                ),
63                (
64                    intern!(py, "codec_id"),
65                    PyString::new(py, &codec_id).into_any(),
66                ),
67                (
68                    intern!(py, RustCodec::SCHEMA_ATTRIBUTE),
69                    pythonize(py, &codec_config_schema)?,
70                ),
71                (
72                    intern!(py, "__init__"),
73                    py.eval(&CString::new(format!(
74                        "lambda {}: None",
75                        signature_from_schema(&codec_config_schema),
76                    ))?, None, None)?,
77                ),
78            ]
79            .into_py_dict(py)?;
80
81            PyType::type_object(py)
82                .call1((&codec_class_name, codec_class_bases, codec_class_namespace))?
83                .extract()?
84        };
85
86    module.add(codec_class_name.as_str(), &codec_class)?;
87
88    PyCodecRegistry::register_codec(codec_class.as_borrowed(), None)?;
89
90    Ok(codec_class)
91}
92
93#[allow(clippy::redundant_pub_crate)]
94#[pyclass(frozen)]
95pub(crate) struct RustCodecType {
96    ty: Box<dyn 'static + Send + Sync + AnyCodecType>,
97}
98
99impl RustCodecType {
100    pub fn downcast<T: DynCodecType>(&self) -> Option<&T> {
101        self.ty.as_any().downcast_ref()
102    }
103}
104
105trait AnyCodec {
106    fn encode(&self, py: Python, data: AnyCowArray) -> Result<AnyArray, PyErr>;
107
108    fn decode(&self, py: Python, encoded: AnyCowArray) -> Result<AnyArray, PyErr>;
109
110    fn decode_into(
111        &self,
112        py: Python,
113        encoded: AnyArrayView,
114        decoded: AnyArrayViewMut,
115    ) -> Result<(), PyErr>;
116
117    fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr>;
118
119    fn as_any(&self) -> &dyn Any;
120}
121
122impl<T: DynCodec> AnyCodec for T {
123    fn encode(&self, py: Python, data: AnyCowArray) -> Result<AnyArray, PyErr> {
124        <T as Codec>::encode(self, data).map_err(|err| PyErrChain::pyerr_from_err(py, err))
125    }
126
127    fn decode(&self, py: Python, encoded: AnyCowArray) -> Result<AnyArray, PyErr> {
128        <T as Codec>::decode(self, encoded).map_err(|err| PyErrChain::pyerr_from_err(py, err))
129    }
130
131    fn decode_into(
132        &self,
133        py: Python,
134        encoded: AnyArrayView,
135        decoded: AnyArrayViewMut,
136    ) -> Result<(), PyErr> {
137        <T as Codec>::decode_into(self, encoded, decoded)
138            .map_err(|err| PyErrChain::pyerr_from_err(py, err))
139    }
140
141    fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr> {
142        <T as DynCodec>::get_config(self, Pythonizer::new(py))?.extract()
143    }
144
145    fn as_any(&self) -> &dyn Any {
146        self
147    }
148}
149
150trait AnyCodecType {
151    fn codec_from_config<'py>(
152        &self,
153        config: Bound<'py, PyDict>,
154    ) -> Result<Box<dyn 'static + Send + Sync + AnyCodec>, PyErr>;
155
156    fn as_any(&self) -> &dyn Any;
157}
158
159impl<T: DynCodecType> AnyCodecType for T {
160    fn codec_from_config<'py>(
161        &self,
162        config: Bound<'py, PyDict>,
163    ) -> Result<Box<dyn 'static + Send + Sync + AnyCodec>, PyErr> {
164        match <T as DynCodecType>::codec_from_config(
165            self,
166            &mut Depythonizer::from_object(config.as_any()),
167        ) {
168            Ok(codec) => Ok(Box::new(codec)),
169            Err(err) => Err(err.into()),
170        }
171    }
172
173    fn as_any(&self) -> &dyn Any {
174        self
175    }
176}
177
178#[allow(clippy::redundant_pub_crate)]
179#[pyclass(subclass, frozen)]
180pub(crate) struct RustCodec {
181    cls_module: String,
182    cls_name: String,
183    codec: Box<dyn 'static + Send + Sync + AnyCodec>,
184}
185
186impl RustCodec {
187    pub const SCHEMA_ATTRIBUTE: &'static str = "__schema__";
188    pub const TYPE_ATTRIBUTE: &'static str = "_ty";
189
190    pub fn downcast<T: DynCodec>(&self) -> Option<&T> {
191        self.codec.as_any().downcast_ref()
192    }
193}
194
195#[pymethods]
196impl RustCodec {
197    #[new]
198    #[classmethod]
199    #[pyo3(signature = (**kwargs))]
200    fn new<'py>(
201        cls: &Bound<'py, PyType>,
202        py: Python<'py>,
203        kwargs: Option<Bound<'py, PyDict>>,
204    ) -> Result<Self, PyErr> {
205        let cls: &Bound<PyCodecClass> = cls.downcast()?;
206        let cls_module: String = cls.getattr(intern!(py, "__module__"))?.extract()?;
207        let cls_name: String = cls.getattr(intern!(py, "__name__"))?.extract()?;
208
209        let ty: Bound<RustCodecType> = cls
210            .getattr(intern!(py, RustCodec::TYPE_ATTRIBUTE))
211            .map_err(|_| {
212                PyTypeError::new_err(format!(
213                    "{cls_module}.{cls_name} is not linked to a Rust codec type"
214                ))
215            })?
216            .extract()?;
217        let ty: PyRef<RustCodecType> = ty.try_borrow()?;
218
219        let codec = ty
220            .ty
221            .codec_from_config(kwargs.unwrap_or_else(|| PyDict::new(py)))?;
222
223        Ok(Self {
224            cls_module,
225            cls_name,
226            codec,
227        })
228    }
229
230    fn encode<'py>(
231        &self,
232        py: Python<'py>,
233        buf: &Bound<'py, PyAny>,
234    ) -> Result<Bound<'py, PyAny>, PyErr> {
235        self.process(
236            py,
237            buf.as_borrowed(),
238            AnyCodec::encode,
239            &format!("{}.{}::encode", self.cls_module, self.cls_name),
240        )
241    }
242
243    #[pyo3(signature = (buf, out=None))]
244    fn decode<'py>(
245        &self,
246        py: Python<'py>,
247        buf: &Bound<'py, PyAny>,
248        out: Option<Bound<'py, PyAny>>,
249    ) -> Result<Bound<'py, PyAny>, PyErr> {
250        let class_method = &format!("{}.{}::decode", self.cls_module, self.cls_name);
251        if let Some(out) = out {
252            self.process_into(
253                py,
254                buf.as_borrowed(),
255                out.as_borrowed(),
256                AnyCodec::decode_into,
257                class_method,
258            )?;
259            Ok(out)
260        } else {
261            self.process(py, buf.as_borrowed(), AnyCodec::decode, class_method)
262        }
263    }
264
265    fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr> {
266        self.codec.get_config(py)
267    }
268
269    #[classmethod]
270    fn from_config<'py>(
271        cls: &Bound<'py, PyType>,
272        config: &Bound<'py, PyDict>,
273    ) -> Result<Bound<'py, PyCodec>, PyErr> {
274        let cls: Bound<PyCodecClass> = cls.extract()?;
275
276        // Ensures that cls(**config) is called and an instance of cls is returned
277        cls.call((), Some(config))?.extract()
278    }
279
280    fn __repr__(this: PyRef<Self>, py: Python) -> Result<String, PyErr> {
281        let config = this.get_config(py)?;
282        // FIXME: let Ok(..) is sufficient with MSRV 1.82
283        let py_this = this.into_pyobject(py)?;
284
285        let mut repr = py_this.get_type().name()?.to_cow()?.into_owned();
286        repr.push('(');
287
288        let mut first = true;
289
290        for (name, value) in config.iter() {
291            let name: String = name.extract()?;
292
293            if name == "id" {
294                // Exclude the id config parameter from the repr
295                continue;
296            }
297
298            let value_repr: String = value.repr()?.extract()?;
299
300            if !first {
301                repr.push_str(", ");
302            }
303            first = false;
304
305            repr.push_str(&name);
306            repr.push('=');
307            repr.push_str(&value_repr);
308        }
309
310        repr.push(')');
311
312        Ok(repr)
313    }
314}
315
316impl RustCodec {
317    fn process<'py>(
318        &self,
319        py: Python<'py>,
320        buf: Borrowed<'_, 'py, PyAny>,
321        process: impl FnOnce(
322            &(dyn 'static + Send + Sync + AnyCodec),
323            Python,
324            AnyCowArray,
325        ) -> Result<AnyArray, PyErr>,
326        class_method: &str,
327    ) -> Result<Bound<'py, PyAny>, PyErr> {
328        Self::with_pyarraylike_as_cow(py, buf, class_method, |data| {
329            let processed = process(&*self.codec, py, data)?;
330            Self::any_array_into_pyarray(py, processed, class_method)
331        })
332    }
333
334    fn process_into<'py>(
335        &self,
336        py: Python<'py>,
337        buf: Borrowed<'_, 'py, PyAny>,
338        out: Borrowed<'_, 'py, PyAny>,
339        process: impl FnOnce(
340            &(dyn 'static + Send + Sync + AnyCodec),
341            Python,
342            AnyArrayView,
343            AnyArrayViewMut,
344        ) -> Result<(), PyErr>,
345        class_method: &str,
346    ) -> Result<(), PyErr> {
347        Self::with_pyarraylike_as_view(py, buf, class_method, |data| {
348            Self::with_pyarraylike_as_view_mut(py, out, class_method, |data_out| {
349                process(&*self.codec, py, data, data_out)
350            })
351        })
352    }
353
354    fn with_pyarraylike_as_cow<'py, O>(
355        py: Python<'py>,
356        buf: Borrowed<'_, 'py, PyAny>,
357        class_method: &str,
358        with: impl for<'a> FnOnce(AnyCowArray<'a>) -> Result<O, PyErr>,
359    ) -> Result<O, PyErr> {
360        fn with_pyarraylike_as_cow_inner<T: numpy::Element, O>(
361            data: Borrowed<PyArrayDyn<T>>,
362            with: impl for<'a> FnOnce(CowArray<'a, T, IxDyn>) -> Result<O, PyErr>,
363        ) -> Result<O, PyErr> {
364            let readonly_data = data.try_readonly()?;
365            with(readonly_data.as_array().into())
366        }
367
368        let data = numpy_asarray(py, buf)?;
369        let dtype = data.dtype();
370
371        if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
372            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
373                with(AnyCowArray::U8(a))
374            })
375        } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
376            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
377                with(AnyCowArray::U16(a))
378            })
379        } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
380            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
381                with(AnyCowArray::U32(a))
382            })
383        } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
384            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
385                with(AnyCowArray::U64(a))
386            })
387        } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
388            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
389                with(AnyCowArray::I8(a))
390            })
391        } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
392            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
393                with(AnyCowArray::I16(a))
394            })
395        } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
396            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
397                with(AnyCowArray::I32(a))
398            })
399        } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
400            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
401                with(AnyCowArray::I64(a))
402            })
403        } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
404            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
405                with(AnyCowArray::F32(a))
406            })
407        } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
408            with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
409                with(AnyCowArray::F64(a))
410            })
411        } else {
412            Err(PyTypeError::new_err(format!(
413                "{class_method} received buffer of unsupported dtype `{dtype}`",
414            )))
415        }
416    }
417
418    fn with_pyarraylike_as_view<'py, O>(
419        py: Python<'py>,
420        buf: Borrowed<'_, 'py, PyAny>,
421        class_method: &str,
422        with: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, PyErr>,
423    ) -> Result<O, PyErr> {
424        fn with_pyarraylike_as_view_inner<T: numpy::Element, O>(
425            data: Borrowed<PyArrayDyn<T>>,
426            with: impl for<'a> FnOnce(ArrayViewD<'a, T>) -> Result<O, PyErr>,
427        ) -> Result<O, PyErr> {
428            let readonly_data = data.try_readonly()?;
429            with(readonly_data.as_array())
430        }
431
432        let data = numpy_asarray(py, buf)?;
433        let dtype = data.dtype();
434
435        if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
436            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
437                with(AnyArrayView::U8(a))
438            })
439        } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
440            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
441                with(AnyArrayView::U16(a))
442            })
443        } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
444            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
445                with(AnyArrayView::U32(a))
446            })
447        } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
448            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
449                with(AnyArrayView::U64(a))
450            })
451        } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
452            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
453                with(AnyArrayView::I8(a))
454            })
455        } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
456            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
457                with(AnyArrayView::I16(a))
458            })
459        } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
460            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
461                with(AnyArrayView::I32(a))
462            })
463        } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
464            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
465                with(AnyArrayView::I64(a))
466            })
467        } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
468            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
469                with(AnyArrayView::F32(a))
470            })
471        } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
472            with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
473                with(AnyArrayView::F64(a))
474            })
475        } else {
476            Err(PyTypeError::new_err(format!(
477                "{class_method} received buffer of unsupported dtype `{dtype}`",
478            )))
479        }
480    }
481
482    fn with_pyarraylike_as_view_mut<'py, O>(
483        py: Python<'py>,
484        buf: Borrowed<'_, 'py, PyAny>,
485        class_method: &str,
486        with: impl for<'a> FnOnce(AnyArrayViewMut<'a>) -> Result<O, PyErr>,
487    ) -> Result<O, PyErr> {
488        fn with_pyarraylike_as_view_mut_inner<T: numpy::Element, O>(
489            data: Borrowed<PyArrayDyn<T>>,
490            with: impl for<'a> FnOnce(ArrayViewMutD<'a, T>) -> Result<O, PyErr>,
491        ) -> Result<O, PyErr> {
492            let mut readwrite_data = data.try_readwrite()?;
493            with(readwrite_data.as_array_mut())
494        }
495
496        let data = numpy_asarray(py, buf)?;
497        let dtype = data.dtype();
498
499        if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
500            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
501                with(AnyArrayViewMut::U8(a))
502            })
503        } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
504            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
505                with(AnyArrayViewMut::U16(a))
506            })
507        } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
508            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
509                with(AnyArrayViewMut::U32(a))
510            })
511        } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
512            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
513                with(AnyArrayViewMut::U64(a))
514            })
515        } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
516            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
517                with(AnyArrayViewMut::I8(a))
518            })
519        } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
520            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
521                with(AnyArrayViewMut::I16(a))
522            })
523        } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
524            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
525                with(AnyArrayViewMut::I32(a))
526            })
527        } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
528            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
529                with(AnyArrayViewMut::I64(a))
530            })
531        } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
532            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
533                with(AnyArrayViewMut::F32(a))
534            })
535        } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
536            with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
537                with(AnyArrayViewMut::F64(a))
538            })
539        } else {
540            Err(PyTypeError::new_err(format!(
541                "{class_method} received buffer of unsupported dtype `{dtype}`",
542            )))
543        }
544    }
545
546    fn any_array_into_pyarray<'py>(
547        py: Python<'py>,
548        array: AnyArray,
549        class_method: &str,
550    ) -> Result<Bound<'py, PyAny>, PyErr> {
551        match array {
552            AnyArray::U8(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
553            AnyArray::U16(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
554            AnyArray::U32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
555            AnyArray::U64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
556            AnyArray::I8(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
557            AnyArray::I16(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
558            AnyArray::I32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
559            AnyArray::I64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
560            AnyArray::F32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
561            AnyArray::F64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
562            array => Err(PyTypeError::new_err(format!(
563                "{class_method} returned unsupported dtype `{}`",
564                array.dtype(),
565            ))),
566        }
567    }
568}