1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41#[schemars(deny_unknown_fields)]
43pub struct Sz3Codec {
45 #[serde(default = "default_predictor")]
47 pub predictor: Option<Sz3Predictor>,
48 #[serde(flatten)]
50 pub error_bound: Sz3ErrorBound,
51}
52
53#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
55#[serde(tag = "eb_mode")]
56#[serde(deny_unknown_fields)]
57pub enum Sz3ErrorBound {
58 #[serde(rename = "abs-and-rel")]
61 AbsoluteAndRelative {
62 #[serde(rename = "eb_abs")]
64 abs: f64,
65 #[serde(rename = "eb_rel")]
67 rel: f64,
68 },
69 #[serde(rename = "abs-or-rel")]
72 AbsoluteOrRelative {
73 #[serde(rename = "eb_abs")]
75 abs: f64,
76 #[serde(rename = "eb_rel")]
78 rel: f64,
79 },
80 #[serde(rename = "abs")]
82 Absolute {
83 #[serde(rename = "eb_abs")]
85 abs: f64,
86 },
87 #[serde(rename = "rel")]
89 Relative {
90 #[serde(rename = "eb_rel")]
92 rel: f64,
93 },
94 #[serde(rename = "psnr")]
96 PS2NR {
97 #[serde(rename = "eb_psnr")]
99 psnr: f64,
100 },
101 #[serde(rename = "l2")]
103 L2Norm {
104 #[serde(rename = "eb_l2")]
106 l2: f64,
107 },
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
112#[serde(deny_unknown_fields)]
113pub enum Sz3Predictor {
114 #[serde(rename = "linear-interpolation")]
116 LinearInterpolation,
117 #[serde(rename = "cubic-interpolation")]
119 CubicInterpolation,
120 #[serde(rename = "linear-interpolation-lorenzo")]
122 LinearInterpolationLorenzo,
123 #[serde(rename = "cubic-interpolation-lorenzo")]
125 CubicInterpolationLorenzo,
126 #[serde(rename = "lorenzo-regression")]
128 LorenzoRegression,
129}
130
131#[expect(clippy::unnecessary_wraps)]
132const fn default_predictor() -> Option<Sz3Predictor> {
133 Some(Sz3Predictor::CubicInterpolationLorenzo)
134}
135
136impl Codec for Sz3Codec {
137 type Error = Sz3CodecError;
138
139 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
140 match data {
141 AnyCowArray::I32(data) => Ok(AnyArray::U8(
142 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
143 .into_dyn(),
144 )),
145 AnyCowArray::I64(data) => Ok(AnyArray::U8(
146 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
147 .into_dyn(),
148 )),
149 AnyCowArray::F32(data) => Ok(AnyArray::U8(
150 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
151 .into_dyn(),
152 )),
153 AnyCowArray::F64(data) => Ok(AnyArray::U8(
154 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
155 .into_dyn(),
156 )),
157 encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
158 }
159 }
160
161 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
162 let AnyCowArray::U8(encoded) = encoded else {
163 return Err(Sz3CodecError::EncodedDataNotBytes {
164 dtype: encoded.dtype(),
165 });
166 };
167
168 if !matches!(encoded.shape(), [_]) {
169 return Err(Sz3CodecError::EncodedDataNotOneDimensional {
170 shape: encoded.shape().to_vec(),
171 });
172 }
173
174 decompress(&AnyCowArray::U8(encoded).as_bytes())
175 }
176
177 fn decode_into(
178 &self,
179 encoded: AnyArrayView,
180 mut decoded: AnyArrayViewMut,
181 ) -> Result<(), Self::Error> {
182 let decoded_in = self.decode(encoded.cow())?;
183
184 Ok(decoded.assign(&decoded_in)?)
185 }
186}
187
188impl StaticCodec for Sz3Codec {
189 const CODEC_ID: &'static str = "sz3";
190
191 type Config<'de> = Self;
192
193 fn from_config(config: Self::Config<'_>) -> Self {
194 config
195 }
196
197 fn get_config(&self) -> StaticCodecConfig<Self> {
198 StaticCodecConfig::from(self)
199 }
200}
201
202#[derive(Debug, Error)]
203pub enum Sz3CodecError {
205 #[error("Sz3 does not support the dtype {0}")]
207 UnsupportedDtype(AnyArrayDType),
208 #[error("Sz3 failed to encode the header")]
210 HeaderEncodeFailed {
211 source: Sz3HeaderError,
213 },
214 #[error("Sz3 cannot encode an array of shape {shape:?}")]
216 InvalidEncodeShape {
217 source: Sz3CodingError,
219 shape: Vec<usize>,
221 },
222 #[error("Sz3 failed to encode the data")]
224 Sz3EncodeFailed {
225 source: Sz3CodingError,
227 },
228 #[error(
231 "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
232 )]
233 EncodedDataNotBytes {
234 dtype: AnyArrayDType,
236 },
237 #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
240 EncodedDataNotOneDimensional {
241 shape: Vec<usize>,
243 },
244 #[error("Sz3 failed to decode the header")]
246 HeaderDecodeFailed {
247 source: Sz3HeaderError,
249 },
250 #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
253 DecodeInvalidShapeHeader {
254 #[from]
256 source: ShapeError,
257 },
258 #[error("Sz3 cannot decode into the provided array")]
260 MismatchedDecodeIntoArray {
261 #[from]
263 source: AnyArrayAssignError,
264 },
265}
266
267#[derive(Debug, Error)]
268#[error(transparent)]
269pub struct Sz3HeaderError(postcard::Error);
271
272#[derive(Debug, Error)]
273#[error(transparent)]
274pub struct Sz3CodingError(sz3::SZ3Error);
276
277#[expect(clippy::needless_pass_by_value)]
278pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
289 data: ArrayBase<S, D>,
290 predictor: Option<&Sz3Predictor>,
291 error_bound: &Sz3ErrorBound,
292) -> Result<Vec<u8>, Sz3CodecError> {
293 let mut encoded_bytes = postcard::to_extend(
294 &CompressionHeader {
295 dtype: <T as Sz3Element>::DTYPE,
296 shape: Cow::Borrowed(data.shape()),
297 },
298 Vec::new(),
299 )
300 .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
301 source: Sz3HeaderError(err),
302 })?;
303
304 if data.is_empty() {
306 return Ok(encoded_bytes);
307 }
308
309 #[expect(clippy::option_if_let_else)]
310 let data_cow = if let Some(data) = data.as_slice() {
311 Cow::Borrowed(data)
312 } else {
313 Cow::Owned(data.iter().copied().collect())
314 };
315 let mut builder = sz3::DimensionedData::build(&data_cow);
316
317 for length in data.shape() {
318 if *length > 1 {
322 builder = builder
323 .dim(*length)
324 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
325 source: Sz3CodingError(err),
326 shape: data.shape().to_vec(),
327 })?;
328 }
329 }
330
331 if data.len() == 1 {
332 builder = builder
335 .dim(1)
336 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
337 source: Sz3CodingError(err),
338 shape: data.shape().to_vec(),
339 })?;
340 }
341
342 let data = builder
343 .finish()
344 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
345 source: Sz3CodingError(err),
346 shape: data.shape().to_vec(),
347 })?;
348
349 let error_bound = match error_bound {
351 Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
352 absolute_bound: *abs,
353 relative_bound: *rel,
354 },
355 Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
356 absolute_bound: *abs,
357 relative_bound: *rel,
358 },
359 Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
360 Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
361 Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
362 Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
363 };
364 let mut config = sz3::Config::new(error_bound);
365
366 let interpolation = match predictor {
368 Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
369 Some(sz3::InterpolationAlgorithm::Linear)
370 }
371 Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
372 Some(sz3::InterpolationAlgorithm::Cubic)
373 }
374 Some(Sz3Predictor::LorenzoRegression) | None => None,
375 };
376 if let Some(interpolation) = interpolation {
377 config = config.interpolation_algorithm(interpolation);
378 }
379
380 let predictor = match predictor {
382 Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
383 sz3::CompressionAlgorithm::Interpolation
384 }
385 Some(
386 Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
387 ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
388 Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::lorenzo_regression(),
389 None => sz3::CompressionAlgorithm::NoPrediction,
390 };
391 config = config.compression_algorithm(predictor);
392
393 let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
395 Sz3CodecError::Sz3EncodeFailed {
396 source: Sz3CodingError(err),
397 }
398 })?;
399 encoded_bytes.extend_from_slice(&compressed);
400
401 Ok(encoded_bytes)
402}
403
404pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
411 let (header, data) =
412 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
413 Sz3CodecError::HeaderDecodeFailed {
414 source: Sz3HeaderError(err),
415 }
416 })?;
417
418 let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
419 match header.dtype {
420 Sz3DType::I32 => {
421 AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
422 }
423 Sz3DType::I64 => {
424 AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
425 }
426 Sz3DType::F32 => {
427 AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
428 }
429 Sz3DType::F64 => {
430 AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
431 }
432 }
433 } else {
434 match header.dtype {
436 Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
437 &*header.shape,
438 Vec::from(sz3::decompress(data).1.data()),
439 )?),
440 Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
441 &*header.shape,
442 Vec::from(sz3::decompress(data).1.data()),
443 )?),
444 Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
445 &*header.shape,
446 Vec::from(sz3::decompress(data).1.data()),
447 )?),
448 Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
449 &*header.shape,
450 Vec::from(sz3::decompress(data).1.data()),
451 )?),
452 }
453 };
454
455 Ok(decoded)
456}
457
458pub trait Sz3Element: Copy + sz3::SZ3Compressible {
460 const DTYPE: Sz3DType;
462}
463
464impl Sz3Element for i32 {
465 const DTYPE: Sz3DType = Sz3DType::I32;
466}
467
468impl Sz3Element for i64 {
469 const DTYPE: Sz3DType = Sz3DType::I64;
470}
471
472impl Sz3Element for f32 {
473 const DTYPE: Sz3DType = Sz3DType::F32;
474}
475
476impl Sz3Element for f64 {
477 const DTYPE: Sz3DType = Sz3DType::F64;
478}
479
480#[derive(Serialize, Deserialize)]
481struct CompressionHeader<'a> {
482 dtype: Sz3DType,
483 #[serde(borrow)]
484 shape: Cow<'a, [usize]>,
485}
486
487#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
489#[expect(missing_docs)]
490pub enum Sz3DType {
491 #[serde(rename = "i32", alias = "int32")]
492 I32,
493 #[serde(rename = "i64", alias = "int64")]
494 I64,
495 #[serde(rename = "f32", alias = "float32")]
496 F32,
497 #[serde(rename = "f64", alias = "float64")]
498 F64,
499}
500
501impl fmt::Display for Sz3DType {
502 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
503 fmt.write_str(match self {
504 Self::I32 => "i32",
505 Self::I64 => "i64",
506 Self::F32 => "f32",
507 Self::F64 => "f64",
508 })
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use ndarray::ArrayView1;
515
516 use super::*;
517
518 #[test]
519 fn zero_length() -> Result<(), Sz3CodecError> {
520 let encoded = compress(
521 Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
522 default_predictor().as_ref(),
523 &Sz3ErrorBound::L2Norm { l2: 27.0 },
524 )?;
525 let decoded = decompress(&encoded)?;
526
527 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
528 assert!(decoded.is_empty());
529 assert_eq!(decoded.shape(), &[1, 27, 0]);
530
531 Ok(())
532 }
533
534 #[test]
535 fn one_dimension() -> Result<(), Sz3CodecError> {
536 let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
537
538 let encoded = compress(
539 data.view(),
540 default_predictor().as_ref(),
541 &Sz3ErrorBound::Absolute { abs: 0.1 },
542 )?;
543 let decoded = decompress(&encoded)?;
544
545 assert_eq!(decoded, AnyArray::I32(data));
546
547 Ok(())
548 }
549
550 #[test]
551 fn small_state() -> Result<(), Sz3CodecError> {
552 for data in [
553 &[][..],
554 &[0.0],
555 &[0.0, 1.0],
556 &[0.0, 1.0, 0.0],
557 &[0.0, 1.0, 0.0, 1.0],
558 ] {
559 let encoded = compress(
560 ArrayView1::from(data),
561 default_predictor().as_ref(),
562 &Sz3ErrorBound::Absolute { abs: 0.1 },
563 )?;
564 let decoded = decompress(&encoded)?;
565
566 assert_eq!(
567 decoded,
568 AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
569 );
570 }
571
572 Ok(())
573 }
574}