1use std::{any::Any, ffi::CString};
2
3use ndarray::{ArrayViewD, ArrayViewMutD, CowArray};
4use numcodecs::{
5 AnyArray, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec, DynCodecType,
6};
7use numpy::{
8 IxDyn, PyArray, PyArrayDescrMethods, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods,
9};
10use pyo3::{
11 exceptions::PyTypeError,
12 intern,
13 prelude::*,
14 types::{IntoPyDict, PyDict, PyString, PyType},
15 PyTypeInfo,
16};
17use pyo3_error::PyErrChain;
18use pythonize::{pythonize, Depythonizer, Pythonizer};
19
20use crate::{
21 schema::{docs_from_schema, signature_from_schema},
22 utils::numpy_asarray,
23 PyCodec, PyCodecClass, PyCodecClassAdapter, PyCodecRegistry,
24};
25
26pub fn export_codec_class<'py, T: DynCodecType>(
34 py: Python<'py>,
35 ty: T,
36 module: Borrowed<'_, 'py, PyModule>,
37) -> Result<Bound<'py, PyCodecClass>, PyErr> {
38 let codec_id = String::from(ty.codec_id());
39 let codec_class_name = convert_case::Casing::to_case(&codec_id, convert_case::Case::Pascal);
40
41 let codec_class: Bound<PyCodecClass> =
42 if let Some(adapter) = (&ty as &dyn Any).downcast_ref::<PyCodecClassAdapter>() {
44 adapter.as_codec_class(py).clone()
45 } else {
46 let codec_config_schema = ty.codec_config_schema();
47
48 let codec_class_bases = (
49 RustCodec::type_object(py),
50 PyCodec::type_object(py),
51 );
52
53 let codec_class_namespace = [
54 (intern!(py, "__module__"), module.name()?.into_any()),
55 (
56 intern!(py, "__doc__"),
57 docs_from_schema(&codec_config_schema, &codec_id).into_pyobject(py)?,
58 ),
59 (
60 intern!(py, RustCodec::TYPE_ATTRIBUTE),
61 Bound::new(py, RustCodecType { ty: Box::new(ty) })?.into_any(),
62 ),
63 (
64 intern!(py, "codec_id"),
65 PyString::new(py, &codec_id).into_any(),
66 ),
67 (
68 intern!(py, RustCodec::SCHEMA_ATTRIBUTE),
69 pythonize(py, &codec_config_schema)?,
70 ),
71 (
72 intern!(py, "__init__"),
73 py.eval(&CString::new(format!(
74 "lambda {}: None",
75 signature_from_schema(&codec_config_schema),
76 ))?, None, None)?,
77 ),
78 ]
79 .into_py_dict(py)?;
80
81 PyType::type_object(py)
82 .call1((&codec_class_name, codec_class_bases, codec_class_namespace))?
83 .extract()?
84 };
85
86 module.add(codec_class_name.as_str(), &codec_class)?;
87
88 PyCodecRegistry::register_codec(codec_class.as_borrowed(), None)?;
89
90 Ok(codec_class)
91}
92
93#[allow(clippy::redundant_pub_crate)]
94#[pyclass(frozen)]
95pub(crate) struct RustCodecType {
96 ty: Box<dyn 'static + Send + Sync + AnyCodecType>,
97}
98
99impl RustCodecType {
100 pub fn downcast<T: DynCodecType>(&self) -> Option<&T> {
101 self.ty.as_any().downcast_ref()
102 }
103}
104
105trait AnyCodec {
106 fn encode(&self, py: Python, data: AnyCowArray) -> Result<AnyArray, PyErr>;
107
108 fn decode(&self, py: Python, encoded: AnyCowArray) -> Result<AnyArray, PyErr>;
109
110 fn decode_into(
111 &self,
112 py: Python,
113 encoded: AnyArrayView,
114 decoded: AnyArrayViewMut,
115 ) -> Result<(), PyErr>;
116
117 fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr>;
118
119 fn as_any(&self) -> &dyn Any;
120}
121
122impl<T: DynCodec> AnyCodec for T {
123 fn encode(&self, py: Python, data: AnyCowArray) -> Result<AnyArray, PyErr> {
124 <T as Codec>::encode(self, data).map_err(|err| PyErrChain::pyerr_from_err(py, err))
125 }
126
127 fn decode(&self, py: Python, encoded: AnyCowArray) -> Result<AnyArray, PyErr> {
128 <T as Codec>::decode(self, encoded).map_err(|err| PyErrChain::pyerr_from_err(py, err))
129 }
130
131 fn decode_into(
132 &self,
133 py: Python,
134 encoded: AnyArrayView,
135 decoded: AnyArrayViewMut,
136 ) -> Result<(), PyErr> {
137 <T as Codec>::decode_into(self, encoded, decoded)
138 .map_err(|err| PyErrChain::pyerr_from_err(py, err))
139 }
140
141 fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr> {
142 <T as DynCodec>::get_config(self, Pythonizer::new(py))?.extract()
143 }
144
145 fn as_any(&self) -> &dyn Any {
146 self
147 }
148}
149
150trait AnyCodecType {
151 fn codec_from_config<'py>(
152 &self,
153 config: Bound<'py, PyDict>,
154 ) -> Result<Box<dyn 'static + Send + Sync + AnyCodec>, PyErr>;
155
156 fn as_any(&self) -> &dyn Any;
157}
158
159impl<T: DynCodecType> AnyCodecType for T {
160 fn codec_from_config<'py>(
161 &self,
162 config: Bound<'py, PyDict>,
163 ) -> Result<Box<dyn 'static + Send + Sync + AnyCodec>, PyErr> {
164 match <T as DynCodecType>::codec_from_config(
165 self,
166 &mut Depythonizer::from_object(config.as_any()),
167 ) {
168 Ok(codec) => Ok(Box::new(codec)),
169 Err(err) => Err(err.into()),
170 }
171 }
172
173 fn as_any(&self) -> &dyn Any {
174 self
175 }
176}
177
178#[allow(clippy::redundant_pub_crate)]
179#[pyclass(subclass, frozen)]
180pub(crate) struct RustCodec {
181 cls_module: String,
182 cls_name: String,
183 codec: Box<dyn 'static + Send + Sync + AnyCodec>,
184}
185
186impl RustCodec {
187 pub const SCHEMA_ATTRIBUTE: &'static str = "__schema__";
188 pub const TYPE_ATTRIBUTE: &'static str = "_ty";
189
190 pub fn downcast<T: DynCodec>(&self) -> Option<&T> {
191 self.codec.as_any().downcast_ref()
192 }
193}
194
195#[pymethods]
196impl RustCodec {
197 #[new]
198 #[classmethod]
199 #[pyo3(signature = (**kwargs))]
200 fn new<'py>(
201 cls: &Bound<'py, PyType>,
202 py: Python<'py>,
203 kwargs: Option<Bound<'py, PyDict>>,
204 ) -> Result<Self, PyErr> {
205 let cls: &Bound<PyCodecClass> = cls.downcast()?;
206 let cls_module: String = cls.getattr(intern!(py, "__module__"))?.extract()?;
207 let cls_name: String = cls.getattr(intern!(py, "__name__"))?.extract()?;
208
209 let ty: Bound<RustCodecType> = cls
210 .getattr(intern!(py, RustCodec::TYPE_ATTRIBUTE))
211 .map_err(|_| {
212 PyTypeError::new_err(format!(
213 "{cls_module}.{cls_name} is not linked to a Rust codec type"
214 ))
215 })?
216 .extract()?;
217 let ty: PyRef<RustCodecType> = ty.try_borrow()?;
218
219 let codec = ty
220 .ty
221 .codec_from_config(kwargs.unwrap_or_else(|| PyDict::new(py)))?;
222
223 Ok(Self {
224 cls_module,
225 cls_name,
226 codec,
227 })
228 }
229
230 fn encode<'py>(
231 &self,
232 py: Python<'py>,
233 buf: &Bound<'py, PyAny>,
234 ) -> Result<Bound<'py, PyAny>, PyErr> {
235 self.process(
236 py,
237 buf.as_borrowed(),
238 AnyCodec::encode,
239 &format!("{}.{}::encode", self.cls_module, self.cls_name),
240 )
241 }
242
243 #[pyo3(signature = (buf, out=None))]
244 fn decode<'py>(
245 &self,
246 py: Python<'py>,
247 buf: &Bound<'py, PyAny>,
248 out: Option<Bound<'py, PyAny>>,
249 ) -> Result<Bound<'py, PyAny>, PyErr> {
250 let class_method = &format!("{}.{}::decode", self.cls_module, self.cls_name);
251 if let Some(out) = out {
252 self.process_into(
253 py,
254 buf.as_borrowed(),
255 out.as_borrowed(),
256 AnyCodec::decode_into,
257 class_method,
258 )?;
259 Ok(out)
260 } else {
261 self.process(py, buf.as_borrowed(), AnyCodec::decode, class_method)
262 }
263 }
264
265 fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr> {
266 self.codec.get_config(py)
267 }
268
269 #[classmethod]
270 fn from_config<'py>(
271 cls: &Bound<'py, PyType>,
272 config: &Bound<'py, PyDict>,
273 ) -> Result<Bound<'py, PyCodec>, PyErr> {
274 let cls: Bound<PyCodecClass> = cls.extract()?;
275
276 cls.call((), Some(config))?.extract()
278 }
279
280 fn __repr__(this: PyRef<Self>, py: Python) -> Result<String, PyErr> {
281 let config = this.get_config(py)?;
282 let py_this = this.into_pyobject(py)?;
284
285 let mut repr = py_this.get_type().name()?.to_cow()?.into_owned();
286 repr.push('(');
287
288 let mut first = true;
289
290 for (name, value) in config.iter() {
291 let name: String = name.extract()?;
292
293 if name == "id" {
294 continue;
296 }
297
298 let value_repr: String = value.repr()?.extract()?;
299
300 if !first {
301 repr.push_str(", ");
302 }
303 first = false;
304
305 repr.push_str(&name);
306 repr.push('=');
307 repr.push_str(&value_repr);
308 }
309
310 repr.push(')');
311
312 Ok(repr)
313 }
314}
315
316impl RustCodec {
317 fn process<'py>(
318 &self,
319 py: Python<'py>,
320 buf: Borrowed<'_, 'py, PyAny>,
321 process: impl FnOnce(
322 &(dyn 'static + Send + Sync + AnyCodec),
323 Python,
324 AnyCowArray,
325 ) -> Result<AnyArray, PyErr>,
326 class_method: &str,
327 ) -> Result<Bound<'py, PyAny>, PyErr> {
328 Self::with_pyarraylike_as_cow(py, buf, class_method, |data| {
329 let processed = process(&*self.codec, py, data)?;
330 Self::any_array_into_pyarray(py, processed, class_method)
331 })
332 }
333
334 fn process_into<'py>(
335 &self,
336 py: Python<'py>,
337 buf: Borrowed<'_, 'py, PyAny>,
338 out: Borrowed<'_, 'py, PyAny>,
339 process: impl FnOnce(
340 &(dyn 'static + Send + Sync + AnyCodec),
341 Python,
342 AnyArrayView,
343 AnyArrayViewMut,
344 ) -> Result<(), PyErr>,
345 class_method: &str,
346 ) -> Result<(), PyErr> {
347 Self::with_pyarraylike_as_view(py, buf, class_method, |data| {
348 Self::with_pyarraylike_as_view_mut(py, out, class_method, |data_out| {
349 process(&*self.codec, py, data, data_out)
350 })
351 })
352 }
353
354 fn with_pyarraylike_as_cow<'py, O>(
355 py: Python<'py>,
356 buf: Borrowed<'_, 'py, PyAny>,
357 class_method: &str,
358 with: impl for<'a> FnOnce(AnyCowArray<'a>) -> Result<O, PyErr>,
359 ) -> Result<O, PyErr> {
360 fn with_pyarraylike_as_cow_inner<T: numpy::Element, O>(
361 data: Borrowed<PyArrayDyn<T>>,
362 with: impl for<'a> FnOnce(CowArray<'a, T, IxDyn>) -> Result<O, PyErr>,
363 ) -> Result<O, PyErr> {
364 let readonly_data = data.try_readonly()?;
365 with(readonly_data.as_array().into())
366 }
367
368 let data = numpy_asarray(py, buf)?;
369 let dtype = data.dtype();
370
371 if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
372 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
373 with(AnyCowArray::U8(a))
374 })
375 } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
376 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
377 with(AnyCowArray::U16(a))
378 })
379 } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
380 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
381 with(AnyCowArray::U32(a))
382 })
383 } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
384 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
385 with(AnyCowArray::U64(a))
386 })
387 } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
388 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
389 with(AnyCowArray::I8(a))
390 })
391 } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
392 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
393 with(AnyCowArray::I16(a))
394 })
395 } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
396 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
397 with(AnyCowArray::I32(a))
398 })
399 } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
400 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
401 with(AnyCowArray::I64(a))
402 })
403 } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
404 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
405 with(AnyCowArray::F32(a))
406 })
407 } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
408 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
409 with(AnyCowArray::F64(a))
410 })
411 } else {
412 Err(PyTypeError::new_err(format!(
413 "{class_method} received buffer of unsupported dtype `{dtype}`",
414 )))
415 }
416 }
417
418 fn with_pyarraylike_as_view<'py, O>(
419 py: Python<'py>,
420 buf: Borrowed<'_, 'py, PyAny>,
421 class_method: &str,
422 with: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, PyErr>,
423 ) -> Result<O, PyErr> {
424 fn with_pyarraylike_as_view_inner<T: numpy::Element, O>(
425 data: Borrowed<PyArrayDyn<T>>,
426 with: impl for<'a> FnOnce(ArrayViewD<'a, T>) -> Result<O, PyErr>,
427 ) -> Result<O, PyErr> {
428 let readonly_data = data.try_readonly()?;
429 with(readonly_data.as_array())
430 }
431
432 let data = numpy_asarray(py, buf)?;
433 let dtype = data.dtype();
434
435 if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
436 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
437 with(AnyArrayView::U8(a))
438 })
439 } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
440 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
441 with(AnyArrayView::U16(a))
442 })
443 } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
444 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
445 with(AnyArrayView::U32(a))
446 })
447 } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
448 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
449 with(AnyArrayView::U64(a))
450 })
451 } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
452 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
453 with(AnyArrayView::I8(a))
454 })
455 } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
456 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
457 with(AnyArrayView::I16(a))
458 })
459 } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
460 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
461 with(AnyArrayView::I32(a))
462 })
463 } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
464 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
465 with(AnyArrayView::I64(a))
466 })
467 } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
468 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
469 with(AnyArrayView::F32(a))
470 })
471 } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
472 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
473 with(AnyArrayView::F64(a))
474 })
475 } else {
476 Err(PyTypeError::new_err(format!(
477 "{class_method} received buffer of unsupported dtype `{dtype}`",
478 )))
479 }
480 }
481
482 fn with_pyarraylike_as_view_mut<'py, O>(
483 py: Python<'py>,
484 buf: Borrowed<'_, 'py, PyAny>,
485 class_method: &str,
486 with: impl for<'a> FnOnce(AnyArrayViewMut<'a>) -> Result<O, PyErr>,
487 ) -> Result<O, PyErr> {
488 fn with_pyarraylike_as_view_mut_inner<T: numpy::Element, O>(
489 data: Borrowed<PyArrayDyn<T>>,
490 with: impl for<'a> FnOnce(ArrayViewMutD<'a, T>) -> Result<O, PyErr>,
491 ) -> Result<O, PyErr> {
492 let mut readwrite_data = data.try_readwrite()?;
493 with(readwrite_data.as_array_mut())
494 }
495
496 let data = numpy_asarray(py, buf)?;
497 let dtype = data.dtype();
498
499 if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
500 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
501 with(AnyArrayViewMut::U8(a))
502 })
503 } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
504 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
505 with(AnyArrayViewMut::U16(a))
506 })
507 } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
508 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
509 with(AnyArrayViewMut::U32(a))
510 })
511 } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
512 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
513 with(AnyArrayViewMut::U64(a))
514 })
515 } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
516 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
517 with(AnyArrayViewMut::I8(a))
518 })
519 } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
520 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
521 with(AnyArrayViewMut::I16(a))
522 })
523 } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
524 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
525 with(AnyArrayViewMut::I32(a))
526 })
527 } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
528 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
529 with(AnyArrayViewMut::I64(a))
530 })
531 } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
532 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
533 with(AnyArrayViewMut::F32(a))
534 })
535 } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
536 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
537 with(AnyArrayViewMut::F64(a))
538 })
539 } else {
540 Err(PyTypeError::new_err(format!(
541 "{class_method} received buffer of unsupported dtype `{dtype}`",
542 )))
543 }
544 }
545
546 fn any_array_into_pyarray<'py>(
547 py: Python<'py>,
548 array: AnyArray,
549 class_method: &str,
550 ) -> Result<Bound<'py, PyAny>, PyErr> {
551 match array {
552 AnyArray::U8(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
553 AnyArray::U16(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
554 AnyArray::U32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
555 AnyArray::U64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
556 AnyArray::I8(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
557 AnyArray::I16(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
558 AnyArray::I32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
559 AnyArray::I64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
560 AnyArray::F32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
561 AnyArray::F64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
562 array => Err(PyTypeError::new_err(format!(
563 "{class_method} returned unsupported dtype `{}`",
564 array.dtype(),
565 ))),
566 }
567 }
568}