1use std::{borrow::Cow, fmt};
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use numcodecs::{
24 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
25 Codec, StaticCodec, StaticCodecConfig,
26};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use thiserror::Error;
30
31use ::zstd_sys as _;
34
35#[cfg(test)]
36use ::serde_json as _;
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39#[schemars(deny_unknown_fields)]
41pub struct Sz3Codec {
43 #[serde(default = "default_predictor")]
45 pub predictor: Option<Sz3Predictor>,
46 #[serde(flatten)]
48 pub error_bound: Sz3ErrorBound,
49 #[serde(default = "default_encoder")]
51 pub encoder: Option<Sz3Encoder>,
52 #[serde(default = "default_lossless_compressor")]
54 pub lossless: Option<Sz3LosslessCompressor>,
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
59#[serde(tag = "eb_mode")]
60#[serde(deny_unknown_fields)]
61pub enum Sz3ErrorBound {
62 #[serde(rename = "abs-and-rel")]
65 AbsoluteAndRelative {
66 #[serde(rename = "eb_abs")]
68 abs: f64,
69 #[serde(rename = "eb_rel")]
71 rel: f64,
72 },
73 #[serde(rename = "abs-or-rel")]
76 AbsoluteOrRelative {
77 #[serde(rename = "eb_abs")]
79 abs: f64,
80 #[serde(rename = "eb_rel")]
82 rel: f64,
83 },
84 #[serde(rename = "abs")]
86 Absolute {
87 #[serde(rename = "eb_abs")]
89 abs: f64,
90 },
91 #[serde(rename = "rel")]
93 Relative {
94 #[serde(rename = "eb_rel")]
96 rel: f64,
97 },
98 #[serde(rename = "psnr")]
100 PS2NR {
101 #[serde(rename = "eb_psnr")]
103 psnr: f64,
104 },
105 #[serde(rename = "l2")]
107 L2Norm {
108 #[serde(rename = "eb_l2")]
110 l2: f64,
111 },
112}
113
114#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
116#[serde(deny_unknown_fields)]
117pub enum Sz3Predictor {
118 #[serde(rename = "linear-interpolation")]
120 LinearInterpolation,
121 #[serde(rename = "cubic-interpolation")]
123 CubicInterpolation,
124 #[serde(rename = "linear-interpolation-lorenzo")]
126 LinearInterpolationLorenzo,
127 #[serde(rename = "cubic-interpolation-lorenzo")]
129 CubicInterpolationLorenzo,
130 #[serde(rename = "lorenzo-regression")]
132 LorenzoRegression,
133}
134
135#[allow(clippy::unnecessary_wraps)]
136const fn default_predictor() -> Option<Sz3Predictor> {
137 Some(Sz3Predictor::CubicInterpolationLorenzo)
138}
139
140#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
142#[serde(deny_unknown_fields)]
143pub enum Sz3Encoder {
144 #[serde(rename = "huffman")]
146 Huffman,
147 #[serde(rename = "arithmetic")]
149 Arithmetic,
150}
151
152#[allow(clippy::unnecessary_wraps)]
153const fn default_encoder() -> Option<Sz3Encoder> {
154 Some(Sz3Encoder::Huffman)
155}
156
157#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
159#[serde(deny_unknown_fields)]
160pub enum Sz3LosslessCompressor {
161 #[serde(rename = "zstd")]
163 Zstd,
164}
165
166#[allow(clippy::unnecessary_wraps)]
167const fn default_lossless_compressor() -> Option<Sz3LosslessCompressor> {
168 Some(Sz3LosslessCompressor::Zstd)
169}
170
171impl Codec for Sz3Codec {
172 type Error = Sz3CodecError;
173
174 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
175 match data {
176 AnyCowArray::I32(data) => Ok(AnyArray::U8(
177 Array1::from(compress(
178 data,
179 self.predictor.as_ref(),
180 &self.error_bound,
181 self.encoder.as_ref(),
182 self.lossless.as_ref(),
183 )?)
184 .into_dyn(),
185 )),
186 AnyCowArray::I64(data) => Ok(AnyArray::U8(
187 Array1::from(compress(
188 data,
189 self.predictor.as_ref(),
190 &self.error_bound,
191 self.encoder.as_ref(),
192 self.lossless.as_ref(),
193 )?)
194 .into_dyn(),
195 )),
196 AnyCowArray::F32(data) => Ok(AnyArray::U8(
197 Array1::from(compress(
198 data,
199 self.predictor.as_ref(),
200 &self.error_bound,
201 self.encoder.as_ref(),
202 self.lossless.as_ref(),
203 )?)
204 .into_dyn(),
205 )),
206 AnyCowArray::F64(data) => Ok(AnyArray::U8(
207 Array1::from(compress(
208 data,
209 self.predictor.as_ref(),
210 &self.error_bound,
211 self.encoder.as_ref(),
212 self.lossless.as_ref(),
213 )?)
214 .into_dyn(),
215 )),
216 encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
217 }
218 }
219
220 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
221 let AnyCowArray::U8(encoded) = encoded else {
222 return Err(Sz3CodecError::EncodedDataNotBytes {
223 dtype: encoded.dtype(),
224 });
225 };
226
227 if !matches!(encoded.shape(), [_]) {
228 return Err(Sz3CodecError::EncodedDataNotOneDimensional {
229 shape: encoded.shape().to_vec(),
230 });
231 }
232
233 decompress(&AnyCowArray::U8(encoded).as_bytes())
234 }
235
236 fn decode_into(
237 &self,
238 encoded: AnyArrayView,
239 mut decoded: AnyArrayViewMut,
240 ) -> Result<(), Self::Error> {
241 let decoded_in = self.decode(encoded.cow())?;
242
243 Ok(decoded.assign(&decoded_in)?)
244 }
245}
246
247impl StaticCodec for Sz3Codec {
248 const CODEC_ID: &'static str = "sz3";
249
250 type Config<'de> = Self;
251
252 fn from_config(config: Self::Config<'_>) -> Self {
253 config
254 }
255
256 fn get_config(&self) -> StaticCodecConfig<Self> {
257 StaticCodecConfig::from(self)
258 }
259}
260
261#[derive(Debug, Error)]
262pub enum Sz3CodecError {
264 #[error("Sz3 does not support the dtype {0}")]
266 UnsupportedDtype(AnyArrayDType),
267 #[error("Sz3 failed to encode the header")]
269 HeaderEncodeFailed {
270 source: Sz3HeaderError,
272 },
273 #[error("Sz3 cannot encode an array of shape {shape:?}")]
275 InvalidEncodeShape {
276 source: Sz3CodingError,
278 shape: Vec<usize>,
280 },
281 #[error("Sz3 failed to encode the data")]
283 Sz3EncodeFailed {
284 source: Sz3CodingError,
286 },
287 #[error(
290 "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
291 )]
292 EncodedDataNotBytes {
293 dtype: AnyArrayDType,
295 },
296 #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
299 EncodedDataNotOneDimensional {
300 shape: Vec<usize>,
302 },
303 #[error("Sz3 failed to decode the header")]
305 HeaderDecodeFailed {
306 source: Sz3HeaderError,
308 },
309 #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
312 DecodeInvalidShapeHeader {
313 #[from]
315 source: ShapeError,
316 },
317 #[error("Sz3 cannot decode into the provided array")]
319 MismatchedDecodeIntoArray {
320 #[from]
322 source: AnyArrayAssignError,
323 },
324}
325
326#[derive(Debug, Error)]
327#[error(transparent)]
328pub struct Sz3HeaderError(postcard::Error);
330
331#[derive(Debug, Error)]
332#[error(transparent)]
333pub struct Sz3CodingError(sz3::SZ3Error);
335
336#[allow(clippy::needless_pass_by_value)]
337pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
348 data: ArrayBase<S, D>,
349 predictor: Option<&Sz3Predictor>,
350 error_bound: &Sz3ErrorBound,
351 encoder: Option<&Sz3Encoder>,
352 lossless: Option<&Sz3LosslessCompressor>,
353) -> Result<Vec<u8>, Sz3CodecError> {
354 let mut encoded_bytes = postcard::to_extend(
355 &CompressionHeader {
356 dtype: <T as Sz3Element>::DTYPE,
357 shape: Cow::Borrowed(data.shape()),
358 },
359 Vec::new(),
360 )
361 .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
362 source: Sz3HeaderError(err),
363 })?;
364
365 if data.is_empty() {
367 return Ok(encoded_bytes);
368 }
369
370 #[allow(clippy::option_if_let_else)]
371 let data_cow = if let Some(data) = data.as_slice() {
372 Cow::Borrowed(data)
373 } else {
374 Cow::Owned(data.iter().copied().collect())
375 };
376 let mut builder = sz3::DimensionedData::build(&data_cow);
377
378 for length in data.shape() {
379 if *length > 1 {
383 builder = builder
384 .dim(*length)
385 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
386 source: Sz3CodingError(err),
387 shape: data.shape().to_vec(),
388 })?;
389 }
390 }
391
392 if data.len() == 1 {
393 builder = builder
396 .dim(1)
397 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
398 source: Sz3CodingError(err),
399 shape: data.shape().to_vec(),
400 })?;
401 }
402
403 let data = builder
404 .finish()
405 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
406 source: Sz3CodingError(err),
407 shape: data.shape().to_vec(),
408 })?;
409
410 let error_bound = match error_bound {
412 Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
413 absolute_bound: *abs,
414 relative_bound: *rel,
415 },
416 Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
417 absolute_bound: *abs,
418 relative_bound: *rel,
419 },
420 Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
421 Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
422 Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
423 Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
424 };
425 let mut config = sz3::Config::new(error_bound);
426
427 let interpolation = match predictor {
429 Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
430 Some(sz3::InterpolationAlgorithm::Linear)
431 }
432 Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
433 Some(sz3::InterpolationAlgorithm::Cubic)
434 }
435 Some(Sz3Predictor::LorenzoRegression) | None => None,
436 };
437 if let Some(interpolation) = interpolation {
438 config = config.interpolation_algorithm(interpolation);
439 }
440
441 let predictor = match predictor {
443 Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
444 sz3::CompressionAlgorithm::Interpolation
445 }
446 Some(
447 Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
448 ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
449 Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::lorenzo_regression(),
450 None => sz3::CompressionAlgorithm::NoPrediction,
451 };
452 config = config.compression_algorithm(predictor);
453
454 let encoder = match encoder {
456 None => sz3::Encoder::SkipEncoder,
457 Some(Sz3Encoder::Huffman) => sz3::Encoder::HuffmanEncoder,
458 Some(Sz3Encoder::Arithmetic) => sz3::Encoder::ArithmeticEncoder,
459 };
460 config = config.encoder(encoder);
461
462 let lossless = match lossless {
464 None => sz3::LossLess::LossLessBypass,
465 Some(Sz3LosslessCompressor::Zstd) => sz3::LossLess::ZSTD,
466 };
467 config = config.lossless(lossless);
468
469 let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
471 Sz3CodecError::Sz3EncodeFailed {
472 source: Sz3CodingError(err),
473 }
474 })?;
475 encoded_bytes.extend_from_slice(&compressed);
476
477 Ok(encoded_bytes)
478}
479
480pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
487 let (header, data) =
488 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
489 Sz3CodecError::HeaderDecodeFailed {
490 source: Sz3HeaderError(err),
491 }
492 })?;
493
494 let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
495 match header.dtype {
496 Sz3DType::I32 => {
497 AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
498 }
499 Sz3DType::I64 => {
500 AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
501 }
502 Sz3DType::F32 => {
503 AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
504 }
505 Sz3DType::F64 => {
506 AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
507 }
508 }
509 } else {
510 match header.dtype {
512 Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
513 &*header.shape,
514 Vec::from(sz3::decompress(data).1.data()),
515 )?),
516 Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
517 &*header.shape,
518 Vec::from(sz3::decompress(data).1.data()),
519 )?),
520 Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
521 &*header.shape,
522 Vec::from(sz3::decompress(data).1.data()),
523 )?),
524 Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
525 &*header.shape,
526 Vec::from(sz3::decompress(data).1.data()),
527 )?),
528 }
529 };
530
531 Ok(decoded)
532}
533
534pub trait Sz3Element: Copy + sz3::SZ3Compressible {
536 const DTYPE: Sz3DType;
538}
539
540impl Sz3Element for i32 {
541 const DTYPE: Sz3DType = Sz3DType::I32;
542}
543
544impl Sz3Element for i64 {
545 const DTYPE: Sz3DType = Sz3DType::I64;
546}
547
548impl Sz3Element for f32 {
549 const DTYPE: Sz3DType = Sz3DType::F32;
550}
551
552impl Sz3Element for f64 {
553 const DTYPE: Sz3DType = Sz3DType::F64;
554}
555
556#[derive(Serialize, Deserialize)]
557struct CompressionHeader<'a> {
558 dtype: Sz3DType,
559 #[serde(borrow)]
560 shape: Cow<'a, [usize]>,
561}
562
563#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
565#[allow(missing_docs)]
566pub enum Sz3DType {
567 #[serde(rename = "i32", alias = "int32")]
568 I32,
569 #[serde(rename = "i64", alias = "int64")]
570 I64,
571 #[serde(rename = "f32", alias = "float32")]
572 F32,
573 #[serde(rename = "f64", alias = "float64")]
574 F64,
575}
576
577impl fmt::Display for Sz3DType {
578 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
579 fmt.write_str(match self {
580 Self::I32 => "i32",
581 Self::I64 => "i64",
582 Self::F32 => "f32",
583 Self::F64 => "f64",
584 })
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use ndarray::ArrayView1;
591
592 use super::*;
593
594 #[test]
595 fn zero_length() -> Result<(), Sz3CodecError> {
596 let encoded = compress(
597 Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
598 default_predictor().as_ref(),
599 &Sz3ErrorBound::L2Norm { l2: 27.0 },
600 default_encoder().as_ref(),
601 default_lossless_compressor().as_ref(),
602 )?;
603 let decoded = decompress(&encoded)?;
604
605 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
606 assert!(decoded.is_empty());
607 assert_eq!(decoded.shape(), &[1, 27, 0]);
608
609 Ok(())
610 }
611
612 #[test]
613 fn one_dimension() -> Result<(), Sz3CodecError> {
614 let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
615
616 let encoded = compress(
617 data.view(),
618 default_predictor().as_ref(),
619 &Sz3ErrorBound::Absolute { abs: 0.1 },
620 default_encoder().as_ref(),
621 default_lossless_compressor().as_ref(),
622 )?;
623 let decoded = decompress(&encoded)?;
624
625 assert_eq!(decoded, AnyArray::I32(data));
626
627 Ok(())
628 }
629
630 #[test]
631 fn small_state() -> Result<(), Sz3CodecError> {
632 for data in [
633 &[][..],
634 &[0.0],
635 &[0.0, 1.0],
636 &[0.0, 1.0, 0.0],
637 &[0.0, 1.0, 0.0, 1.0],
638 ] {
639 let encoded = compress(
640 ArrayView1::from(data),
641 default_predictor().as_ref(),
642 &Sz3ErrorBound::Absolute { abs: 0.1 },
643 default_encoder().as_ref(),
644 default_lossless_compressor().as_ref(),
645 )?;
646 let decoded = decompress(&encoded)?;
647
648 assert_eq!(
649 decoded,
650 AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
651 );
652 }
653
654 Ok(())
655 }
656}