numcodecs_python/
adapter.rs

1use std::sync::Arc;
2
3use ndarray::{ArrayBase, DataMut, Dimension};
4use numcodecs::{
5    AnyArray, AnyArrayBase, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec,
6    DynCodecType,
7};
8use numpy::{Element, PyArray, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
9use pyo3::{
10    exceptions::{PyTypeError, PyValueError},
11    intern,
12    prelude::*,
13    types::{IntoPyDict, PyDict, PyDictMethods},
14};
15use pythonize::{Depythonizer, Pythonizer};
16use schemars::Schema;
17use serde::{Deserializer, Serializer};
18use serde_transcode::transcode;
19
20use crate::{
21    export::{RustCodec, RustCodecType},
22    schema::schema_from_codec_class,
23    utils::numpy_asarray,
24    PyCodec, PyCodecClass, PyCodecClassMethods, PyCodecMethods, PyCodecRegistry,
25};
26
27/// Wrapper around [`PyCodec`]s to use the [`Codec`] API.
28pub struct PyCodecAdapter {
29    codec: Py<PyCodec>,
30    class: Py<PyCodecClass>,
31    codec_id: Arc<String>,
32    codec_config_schema: Arc<Schema>,
33}
34
35impl PyCodecAdapter {
36    /// Instantiate a codec from the [`PyCodecRegistry`] with a serialized
37    /// `config`uration.
38    ///
39    /// The config *must* include the `id` field with the
40    /// [`PyCodecClassMethods::codec_id`].
41    ///
42    /// # Errors
43    ///
44    /// Errors if no codec with a matching `id` has been registered, or if
45    /// constructing the codec fails.
46    pub fn from_registry_with_config<'de, D: Deserializer<'de>>(
47        config: D,
48    ) -> Result<Self, D::Error> {
49        Python::with_gil(|py| {
50            let config = transcode(config, Pythonizer::new(py))?;
51            let config: Bound<PyDict> = config.extract()?;
52
53            let codec = PyCodecRegistry::get_codec(config.as_borrowed())?;
54
55            Self::from_codec(codec)
56        })
57        .map_err(serde::de::Error::custom)
58    }
59
60    /// Wraps a [`PyCodec`] to use the [`Codec`] API.
61    ///
62    /// # Errors
63    ///
64    /// Errors if the `codec`'s class does not provide an identifier.
65    pub fn from_codec(codec: Bound<PyCodec>) -> Result<Self, PyErr> {
66        let class = codec.class();
67        let codec_id = class.codec_id()?;
68        let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
69            PyTypeError::new_err(format!(
70                "failed to extract the {codec_id} codec config schema: {err}"
71            ))
72        })?;
73
74        Ok(Self {
75            codec: codec.unbind(),
76            class: class.unbind(),
77            codec_id: Arc::new(codec_id),
78            codec_config_schema: Arc::new(codec_config_schema),
79        })
80    }
81
82    /// Access the wrapped [`PyCodec`] to use its [`PyCodecMethods`] API.
83    #[must_use]
84    pub fn as_codec<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodec> {
85        self.codec.bind(py)
86    }
87
88    /// Unwrap the [`PyCodec`] to use its [`PyCodecMethods`] API.
89    #[must_use]
90    pub fn into_codec(self, py: Python) -> Bound<PyCodec> {
91        self.codec.into_bound(py)
92    }
93
94    /// Try to [`clone`][`Clone::clone`] this codec.
95    ///
96    /// # Errors
97    ///
98    /// Errors if extracting this codec's config or creating a new codec from
99    /// the config fails.
100    pub fn try_clone(&self, py: Python) -> Result<Self, PyErr> {
101        let config = self.codec.bind(py).get_config()?;
102
103        // removing the `id` field may fail if the config doesn't contain it
104        let _ = config.del_item(intern!(py, "id"));
105
106        let codec = self
107            .class
108            .bind(py)
109            .codec_from_config(config.as_borrowed())?;
110
111        Ok(Self {
112            codec: codec.unbind(),
113            class: self.class.clone_ref(py),
114            codec_id: self.codec_id.clone(),
115            codec_config_schema: self.codec_config_schema.clone(),
116        })
117    }
118
119    /// If `codec` represents an exported [`DynCodec`] `T`, i.e. its class was
120    /// initially created with [`crate::export_codec_class`], the `with` closure
121    /// provides access to the instance of type `T`.
122    ///
123    /// If `codec` is not an instance of `T`, the `with` closure is *not* run
124    /// and `None` is returned.
125    pub fn with_downcast<T: DynCodec, O>(
126        codec: &Bound<PyCodec>,
127        with: impl for<'a> FnOnce(&'a T) -> O,
128    ) -> Option<O> {
129        let Ok(codec) = codec.downcast::<RustCodec>() else {
130            return None;
131        };
132
133        let codec = codec.get().downcast()?;
134
135        Some(with(codec))
136    }
137}
138
139impl Codec for PyCodecAdapter {
140    type Error = PyErr;
141
142    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
143        Python::with_gil(|py| {
144            self.with_any_array_view_as_ndarray(py, &data.view(), |data| {
145                let encoded = self.codec.bind(py).encode(data.as_borrowed())?;
146
147                Self::any_array_from_ndarray_like(py, encoded.as_borrowed())
148            })
149        })
150    }
151
152    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
153        Python::with_gil(|py| {
154            self.with_any_array_view_as_ndarray(py, &encoded.view(), |encoded| {
155                let decoded = self.codec.bind(py).decode(encoded.as_borrowed(), None)?;
156
157                Self::any_array_from_ndarray_like(py, decoded.as_borrowed())
158            })
159        })
160    }
161
162    fn decode_into(
163        &self,
164        encoded: AnyArrayView,
165        mut decoded: AnyArrayViewMut,
166    ) -> Result<(), Self::Error> {
167        Python::with_gil(|py| {
168            let decoded_out = self.with_any_array_view_as_ndarray(py, &encoded, |encoded| {
169                self.with_any_array_view_mut_as_ndarray(py, &mut decoded, |decoded_in| {
170                    let decoded_out = self
171                        .codec
172                        .bind(py)
173                        .decode(encoded.as_borrowed(), Some(decoded_in.as_borrowed()))?;
174
175                    // Ideally, all codecs should just use the provided out array
176                    if decoded_out.is(decoded_in) {
177                        Ok(Ok(()))
178                    } else {
179                        Ok(Err(decoded_out.unbind()))
180                    }
181                })
182            })?;
183            let decoded_out = match decoded_out {
184                Ok(()) => return Ok(()),
185                Err(decoded_out) => decoded_out.into_bound(py),
186            };
187
188            // Otherwise, we force-copy the output into the decoded array
189            Self::copy_into_any_array_view_mut_from_ndarray_like(
190                py,
191                &mut decoded,
192                decoded_out.as_borrowed(),
193            )
194        })
195    }
196}
197
198impl PyCodecAdapter {
199    fn with_any_array_view_as_ndarray<T>(
200        &self,
201        py: Python,
202        view: &AnyArrayView,
203        with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
204    ) -> Result<T, PyErr> {
205        let this = self.codec.bind(py).clone().into_any();
206
207        #[allow(unsafe_code)] // FIXME: we trust Python code to not store this array
208        let ndarray = unsafe {
209            match &view {
210                AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
211                AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
212                AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
213                AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
214                AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
215                AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
216                AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
217                AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
218                AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
219                AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
220                _ => {
221                    return Err(PyTypeError::new_err(format!(
222                        "unsupported type {} of read-only array view",
223                        view.dtype()
224                    )))
225                }
226            }
227        };
228
229        // create a fully-immutable view of the data that is safe to pass to Python
230        ndarray.call_method(
231            intern!(py, "setflags"),
232            (),
233            Some(&[(intern!(py, "write"), false)].into_py_dict(py)?),
234        )?;
235        let view = ndarray.call_method0(intern!(py, "view"))?;
236
237        with(&view)
238    }
239
240    fn with_any_array_view_mut_as_ndarray<T>(
241        &self,
242        py: Python,
243        view_mut: &mut AnyArrayViewMut,
244        with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
245    ) -> Result<T, PyErr> {
246        let this = self.codec.bind(py).clone().into_any();
247
248        #[allow(unsafe_code)] // FIXME: we trust Python code to not store this array
249        let ndarray = unsafe {
250            match &view_mut {
251                AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
252                AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
253                AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
254                AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
255                AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
256                AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
257                AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
258                AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
259                AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
260                AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
261                _ => {
262                    return Err(PyTypeError::new_err(format!(
263                        "unsupported type {} of read-only array view",
264                        view_mut.dtype()
265                    )))
266                }
267            }
268        };
269
270        with(&ndarray)
271    }
272
273    fn any_array_from_ndarray_like(
274        py: Python,
275        array_like: Borrowed<PyAny>,
276    ) -> Result<AnyArray, PyErr> {
277        let ndarray = numpy_asarray(py, array_like)?;
278
279        let array = if let Ok(e) = ndarray.downcast::<PyArrayDyn<u8>>() {
280            AnyArrayBase::U8(e.try_readonly()?.to_owned_array())
281        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u16>>() {
282            AnyArrayBase::U16(e.try_readonly()?.to_owned_array())
283        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u32>>() {
284            AnyArrayBase::U32(e.try_readonly()?.to_owned_array())
285        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u64>>() {
286            AnyArrayBase::U64(e.try_readonly()?.to_owned_array())
287        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i8>>() {
288            AnyArrayBase::I8(e.try_readonly()?.to_owned_array())
289        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i16>>() {
290            AnyArrayBase::I16(e.try_readonly()?.to_owned_array())
291        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i32>>() {
292            AnyArrayBase::I32(e.try_readonly()?.to_owned_array())
293        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i64>>() {
294            AnyArrayBase::I64(e.try_readonly()?.to_owned_array())
295        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f32>>() {
296            AnyArrayBase::F32(e.try_readonly()?.to_owned_array())
297        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f64>>() {
298            AnyArrayBase::F64(e.try_readonly()?.to_owned_array())
299        } else {
300            return Err(PyTypeError::new_err(format!(
301                "unsupported dtype {} of array-like",
302                ndarray.dtype()
303            )));
304        };
305
306        Ok(array)
307    }
308
309    fn copy_into_any_array_view_mut_from_ndarray_like(
310        py: Python,
311        view_mut: &mut AnyArrayViewMut,
312        array_like: Borrowed<PyAny>,
313    ) -> Result<(), PyErr> {
314        fn shape_checked_assign<
315            T: Copy + Element,
316            S2: DataMut<Elem = T>,
317            D1: Dimension,
318            D2: Dimension,
319        >(
320            src: &Bound<PyArray<T, D1>>,
321            dst: &mut ArrayBase<S2, D2>,
322        ) -> Result<(), PyErr> {
323            #[allow(clippy::unit_arg)]
324            if src.shape() == dst.shape() {
325                Ok(dst.assign(&src.try_readonly()?.as_array()))
326            } else {
327                Err(PyValueError::new_err(format!(
328                    "mismatching shape {:?} of array-like, expected {:?}",
329                    src.shape(),
330                    dst.shape(),
331                )))
332            }
333        }
334
335        let ndarray = numpy_asarray(py, array_like)?;
336
337        #[allow(clippy::unit_arg)]
338        if let Ok(d) = ndarray.downcast::<PyArrayDyn<u8>>() {
339            if let AnyArrayBase::U8(ref mut view_mut) = view_mut {
340                return shape_checked_assign(d, view_mut);
341            }
342        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u16>>() {
343            if let AnyArrayBase::U16(ref mut view_mut) = view_mut {
344                return shape_checked_assign(d, view_mut);
345            }
346        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u32>>() {
347            if let AnyArrayBase::U32(ref mut view_mut) = view_mut {
348                return shape_checked_assign(d, view_mut);
349            }
350        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u64>>() {
351            if let AnyArrayBase::U64(ref mut view_mut) = view_mut {
352                return shape_checked_assign(d, view_mut);
353            }
354        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i8>>() {
355            if let AnyArrayBase::I8(ref mut view_mut) = view_mut {
356                return shape_checked_assign(d, view_mut);
357            }
358        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i16>>() {
359            if let AnyArrayBase::I16(ref mut view_mut) = view_mut {
360                return shape_checked_assign(d, view_mut);
361            }
362        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i32>>() {
363            if let AnyArrayBase::I32(ref mut view_mut) = view_mut {
364                return shape_checked_assign(d, view_mut);
365            }
366        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i64>>() {
367            if let AnyArrayBase::I64(ref mut view_mut) = view_mut {
368                return shape_checked_assign(d, view_mut);
369            }
370        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f32>>() {
371            if let AnyArrayBase::F32(ref mut view_mut) = view_mut {
372                return shape_checked_assign(d, view_mut);
373            }
374        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f64>>() {
375            if let AnyArrayBase::F64(ref mut view_mut) = view_mut {
376                return shape_checked_assign(d, view_mut);
377            }
378        } else {
379            return Err(PyTypeError::new_err(format!(
380                "unsupported dtype {} of array-like",
381                ndarray.dtype()
382            )));
383        };
384
385        Err(PyTypeError::new_err(format!(
386            "mismatching dtype {} of array-like, expected {}",
387            ndarray.dtype(),
388            view_mut.dtype(),
389        )))
390    }
391}
392
393impl Clone for PyCodecAdapter {
394    fn clone(&self) -> Self {
395        #[allow(clippy::expect_used)] // clone is *not* fallible
396        Python::with_gil(|py| {
397            self.try_clone(py)
398                .expect("cloning a PyCodec should not fail")
399        })
400    }
401}
402
403impl DynCodec for PyCodecAdapter {
404    type Type = PyCodecClassAdapter;
405
406    fn ty(&self) -> Self::Type {
407        Python::with_gil(|py| PyCodecClassAdapter {
408            class: self.class.clone_ref(py),
409            codec_id: self.codec_id.clone(),
410            codec_config_schema: self.codec_config_schema.clone(),
411        })
412    }
413
414    fn get_config<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
415        Python::with_gil(|py| {
416            let config = self
417                .codec
418                .bind(py)
419                .get_config()
420                .map_err(serde::ser::Error::custom)?;
421
422            transcode(&mut Depythonizer::from_object(config.as_any()), serializer)
423        })
424    }
425}
426
427/// Wrapper around [`PyCodecClass`]es to use the [`DynCodecType`] API.
428pub struct PyCodecClassAdapter {
429    class: Py<PyCodecClass>,
430    codec_id: Arc<String>,
431    codec_config_schema: Arc<Schema>,
432}
433
434impl PyCodecClassAdapter {
435    /// Wraps a [`PyCodecClass`] to use the [`DynCodecType`] API.
436    ///
437    /// # Errors
438    ///
439    /// Errors if the codec `class` does not provide an identifier.
440    pub fn from_codec_class(class: Bound<PyCodecClass>) -> Result<Self, PyErr> {
441        let codec_id = class.codec_id()?;
442
443        let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
444            PyTypeError::new_err(format!(
445                "failed to extract the {codec_id} codec config schema: {err}"
446            ))
447        })?;
448
449        Ok(Self {
450            class: class.unbind(),
451            codec_id: Arc::new(codec_id),
452            codec_config_schema: Arc::new(codec_config_schema),
453        })
454    }
455
456    /// Access the wrapped [`PyCodecClass`] to use its [`PyCodecClassMethods`]
457    /// API.
458    #[must_use]
459    pub fn as_codec_class<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodecClass> {
460        self.class.bind(py)
461    }
462
463    /// Unwrap the [`PyCodecClass`] to use its [`PyCodecClassMethods`] API.
464    #[must_use]
465    pub fn into_codec_class(self, py: Python) -> Bound<PyCodecClass> {
466        self.class.into_bound(py)
467    }
468
469    /// If `class` represents an exported [`DynCodecType`] `T`, i.e. it was
470    /// initially created with [`crate::export_codec_class`], the `with` closure
471    /// provides access to the instance of type `T`.
472    ///
473    /// If `class` is not an instance of `T`, the `with` closure is *not* run
474    /// and `None` is returned.
475    pub fn with_downcast<T: DynCodecType, O>(
476        class: &Bound<PyCodecClass>,
477        with: impl for<'a> FnOnce(&'a T) -> O,
478    ) -> Option<O> {
479        let Ok(ty) = class.getattr(intern!(class.py(), RustCodec::TYPE_ATTRIBUTE)) else {
480            return None;
481        };
482
483        let Ok(ty) = ty.downcast_into_exact::<RustCodecType>() else {
484            return None;
485        };
486
487        let ty: &T = ty.get().downcast()?;
488
489        Some(with(ty))
490    }
491}
492
493impl DynCodecType for PyCodecClassAdapter {
494    type Codec = PyCodecAdapter;
495
496    fn codec_id(&self) -> &str {
497        &self.codec_id
498    }
499
500    fn codec_config_schema(&self) -> Schema {
501        (*self.codec_config_schema).clone()
502    }
503
504    fn codec_from_config<'de, D: Deserializer<'de>>(
505        &self,
506        config: D,
507    ) -> Result<Self::Codec, D::Error> {
508        Python::with_gil(|py| {
509            let config =
510                transcode(config, Pythonizer::new(py)).map_err(serde::de::Error::custom)?;
511            let config: Bound<PyDict> = config.extract().map_err(serde::de::Error::custom)?;
512
513            let codec = self
514                .class
515                .bind(py)
516                .codec_from_config(config.as_borrowed())
517                .map_err(serde::de::Error::custom)?;
518
519            Ok(PyCodecAdapter {
520                codec: codec.unbind(),
521                class: self.class.clone_ref(py),
522                codec_id: self.codec_id.clone(),
523                codec_config_schema: self.codec_config_schema.clone(),
524            })
525        })
526    }
527}