use std::borrow::Cow;
use std::sync::Arc;
use zarrs_plugin::{PluginCreateError, ZarrVersion};
use zfp_sys::{
zfp_compress,
zfp_stream_maximum_size,
zfp_stream_rewind,
zfp_stream_set_bit_stream,
zfp_write_header,
};
use super::zfp_bitstream::ZfpBitstream;
use super::zfp_field::ZfpField;
use super::zfp_stream::ZfpStream;
use super::{
ZfpCodecConfiguration, ZfpCodecConfigurationV1, ZfpDataTypeExt, promote_before_zfp_encoding,
zfp_decode, zfp_native_type_to_sys,
};
use crate::array::{BytesRepresentation, DataType, FillValue};
use std::num::NonZeroU64;
use zarrs_codec::{
ArrayBytes, ArrayBytesRaw, ArrayCodecTraits, ArrayToBytesCodecTraits, CodecError,
CodecMetadataOptions, CodecOptions, CodecTraits, PartialDecoderCapability,
PartialEncoderCapability, RecommendedConcurrency,
};
use zarrs_metadata::Configuration;
use zarrs_metadata_ext::codec::zfp::ZfpMode;
#[derive(Clone, Copy, Debug)]
pub struct ZfpCodec {
mode: ZfpMode,
write_header: bool,
}
impl ZfpCodec {
#[must_use]
pub const fn new_expert(minbits: u32, maxbits: u32, maxprec: u32, minexp: i32) -> Self {
Self {
mode: ZfpMode::Expert {
minbits,
maxbits,
maxprec,
minexp,
},
write_header: false,
}
}
#[must_use]
pub const fn new_fixed_rate(rate: f64) -> Self {
Self {
mode: ZfpMode::FixedRate { rate },
write_header: false,
}
}
#[must_use]
pub const fn new_fixed_precision(precision: u32) -> Self {
Self {
mode: ZfpMode::FixedPrecision { precision },
write_header: false,
}
}
#[must_use]
pub const fn new_fixed_accuracy(tolerance: f64) -> Self {
Self {
mode: ZfpMode::FixedAccuracy { tolerance },
write_header: false,
}
}
#[must_use]
pub const fn new_reversible() -> Self {
Self {
mode: ZfpMode::Reversible,
write_header: false,
}
}
#[must_use]
pub(crate) const fn mode(&self) -> ZfpMode {
self.mode
}
#[must_use]
pub(crate) const fn with_write_header(mut self, write_header: bool) -> Self {
self.write_header = write_header;
self
}
pub fn new_with_configuration(
configuration: &ZfpCodecConfiguration,
) -> Result<Self, PluginCreateError> {
let configuration = match configuration {
ZfpCodecConfiguration::V1(configuration) => configuration.clone(),
_ => Err(PluginCreateError::Other(
"this zfp codec configuration variant is unsupported".to_string(),
))?,
};
Ok(match configuration.mode {
ZfpMode::Expert {
minbits,
maxbits,
maxprec,
minexp,
} => Self::new_expert(minbits, maxbits, maxprec, minexp),
ZfpMode::FixedRate { rate } => Self::new_fixed_rate(rate),
ZfpMode::FixedPrecision { precision } => Self::new_fixed_precision(precision),
ZfpMode::FixedAccuracy { tolerance } => Self::new_fixed_accuracy(tolerance),
ZfpMode::Reversible => Self::new_reversible(),
})
}
}
impl CodecTraits for ZfpCodec {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn configuration(
&self,
_version: ZarrVersion,
_options: &CodecMetadataOptions,
) -> Option<Configuration> {
Some(ZfpCodecConfiguration::V1(ZfpCodecConfigurationV1 { mode: self.mode }).into())
}
fn partial_decoder_capability(&self) -> PartialDecoderCapability {
PartialDecoderCapability {
partial_read: false,
partial_decode: false,
}
}
fn partial_encoder_capability(&self) -> PartialEncoderCapability {
PartialEncoderCapability {
partial_encode: false,
}
}
}
impl ArrayCodecTraits for ZfpCodec {
fn recommended_concurrency(
&self,
_shape: &[NonZeroU64],
_data_type: &DataType,
) -> Result<RecommendedConcurrency, CodecError> {
Ok(RecommendedConcurrency::new_maximum(1))
}
}
#[cfg_attr(
all(feature = "async", not(target_arch = "wasm32")),
async_trait::async_trait
)]
#[cfg_attr(all(feature = "async", target_arch = "wasm32"), async_trait::async_trait(?Send))]
impl ArrayToBytesCodecTraits for ZfpCodec {
fn into_dyn(self: Arc<Self>) -> Arc<dyn ArrayToBytesCodecTraits> {
self as Arc<dyn ArrayToBytesCodecTraits>
}
fn encode<'a>(
&self,
bytes: ArrayBytes<'a>,
shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<ArrayBytesRaw<'a>, CodecError> {
let bytes = bytes.into_fixed()?;
let mut bytes_promoted = promote_before_zfp_encoding(&bytes, data_type)?;
let zfp_type = bytes_promoted.zfp_type();
let field = ZfpField::new(
&mut bytes_promoted,
&shape
.iter()
.map(|u| usize::try_from(u.get()).unwrap())
.collect::<Vec<usize>>(),
)
.ok_or_else(|| CodecError::from("failed to create zfp field"))?;
let stream = ZfpStream::new(&self.mode, zfp_type)
.ok_or_else(|| CodecError::from("failed to create zfp stream"))?;
let bufsize = unsafe {
zfp_stream_maximum_size(stream.as_zfp_stream(), field.as_zfp_field())
};
let mut encoded_value: Vec<u8> = vec![0; bufsize];
let bitstream = ZfpBitstream::new(&mut encoded_value)
.ok_or_else(|| CodecError::from("failed to create zfp field"))?;
unsafe {
zfp_stream_set_bit_stream(stream.as_zfp_stream(), bitstream.as_bitstream());
zfp_stream_rewind(stream.as_zfp_stream()); }
if self.write_header {
unsafe {
zfp_write_header(
stream.as_zfp_stream(),
field.as_zfp_field(),
zfp_sys::ZFP_HEADER_FULL,
);
};
}
let size = unsafe {
zfp_compress(stream.as_zfp_stream(), field.as_zfp_field())
};
if size == 0 {
Err(CodecError::from("zfp compression failed"))
} else {
encoded_value.truncate(size);
Ok(Cow::Owned(encoded_value))
}
}
fn decode<'a>(
&self,
bytes: ArrayBytesRaw<'a>,
shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<ArrayBytes<'a>, CodecError> {
zfp_decode(
&self.mode,
self.write_header,
&mut bytes.to_vec(), shape,
data_type,
false, )
.map(ArrayBytes::from)
}
fn encoded_representation(
&self,
shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
) -> Result<BytesRepresentation, CodecError> {
let encoding = data_type.codec_zfp()?.zfp_encoding();
let zfp_type = zfp_native_type_to_sys(encoding.native_type());
let bufsize = {
let field = unsafe {
ZfpField::new_empty(
zfp_type,
&shape
.iter()
.map(|u| usize::try_from(u.get()).unwrap())
.collect::<Vec<usize>>(),
)
}
.ok_or_else(|| CodecError::from("failed to create zfp field"))?;
let stream = ZfpStream::new(&self.mode, zfp_type)
.ok_or_else(|| CodecError::from("failed to create zfp stream"))?;
unsafe {
zfp_stream_maximum_size(stream.as_zfp_stream(), field.as_zfp_field())
}
};
#[allow(clippy::cast_possible_truncation)]
Ok(BytesRepresentation::BoundedSize(bufsize as u64))
}
}