use std::sync::Arc;
use zarrs_plugin::{PluginCreateError, ZarrVersion};
use super::{
BitroundCodecConfiguration, BitroundCodecConfigurationV1, BitroundDataTypeExt,
bitround_codec_partial, round_bytes,
};
use crate::array::{DataType, FillValue};
use std::num::NonZeroU64;
use zarrs_codec::{
ArrayBytes, ArrayCodecTraits, ArrayPartialDecoderTraits, ArrayPartialEncoderTraits,
ArrayToArrayCodecTraits, CodecError, CodecMetadataOptions, CodecOptions, CodecTraits,
PartialDecoderCapability, PartialEncoderCapability, RecommendedConcurrency,
};
#[cfg(feature = "async")]
use zarrs_codec::{AsyncArrayPartialDecoderTraits, AsyncArrayPartialEncoderTraits};
use zarrs_metadata::Configuration;
#[derive(Clone, Debug, Default)]
pub struct BitroundCodec {
keepbits: u32,
}
impl BitroundCodec {
#[must_use]
pub const fn new(keepbits: u32) -> Self {
Self { keepbits }
}
pub fn new_with_configuration(
configuration: &BitroundCodecConfiguration,
) -> Result<Self, PluginCreateError> {
match configuration {
BitroundCodecConfiguration::V1(configuration) => Ok(Self {
keepbits: configuration.keepbits,
}),
_ => Err(PluginCreateError::Other(
"this bitround codec configuration variant is unsupported".to_string(),
)),
}
}
}
impl CodecTraits for BitroundCodec {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn configuration(
&self,
_version: ZarrVersion,
options: &CodecMetadataOptions,
) -> Option<Configuration> {
if options.codec_store_metadata_if_encode_only() {
let configuration = BitroundCodecConfiguration::V1(BitroundCodecConfigurationV1 {
keepbits: self.keepbits,
});
Some(configuration.into())
} else {
None
}
}
fn partial_decoder_capability(&self) -> PartialDecoderCapability {
PartialDecoderCapability {
partial_read: true,
partial_decode: true,
}
}
fn partial_encoder_capability(&self) -> PartialEncoderCapability {
PartialEncoderCapability {
partial_encode: true,
}
}
}
impl ArrayCodecTraits for BitroundCodec {
fn recommended_concurrency(
&self,
_shape: &[NonZeroU64],
_data_type: &DataType,
) -> Result<RecommendedConcurrency, CodecError> {
Ok(RecommendedConcurrency::new_maximum(1))
}
}
#[cfg_attr(
all(feature = "async", not(target_arch = "wasm32")),
async_trait::async_trait
)]
#[cfg_attr(all(feature = "async", target_arch = "wasm32"), async_trait::async_trait(?Send))]
impl ArrayToArrayCodecTraits for BitroundCodec {
fn into_dyn(self: Arc<Self>) -> Arc<dyn ArrayToArrayCodecTraits> {
self as Arc<dyn ArrayToArrayCodecTraits>
}
fn encode<'a>(
&self,
bytes: ArrayBytes<'a>,
_shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<ArrayBytes<'a>, CodecError> {
let mut bytes = bytes.into_fixed()?;
round_bytes(bytes.to_mut(), data_type, self.keepbits)?;
Ok(ArrayBytes::from(bytes))
}
fn decode<'a>(
&self,
bytes: ArrayBytes<'a>,
_shape: &[NonZeroU64],
_data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<ArrayBytes<'a>, CodecError> {
Ok(bytes)
}
fn partial_decoder(
self: Arc<Self>,
input_handle: Arc<dyn ArrayPartialDecoderTraits>,
_shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<Arc<dyn ArrayPartialDecoderTraits>, CodecError> {
Ok(Arc::new(bitround_codec_partial::BitroundCodecPartial::new(
input_handle,
data_type,
self.keepbits,
)?))
}
fn partial_encoder(
self: Arc<Self>,
input_output_handle: Arc<dyn ArrayPartialEncoderTraits>,
_shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<Arc<dyn ArrayPartialEncoderTraits>, CodecError> {
Ok(Arc::new(bitround_codec_partial::BitroundCodecPartial::new(
input_output_handle,
data_type,
self.keepbits,
)?))
}
#[cfg(feature = "async")]
async fn async_partial_decoder(
self: Arc<Self>,
input_handle: Arc<dyn AsyncArrayPartialDecoderTraits>,
_shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<Arc<dyn AsyncArrayPartialDecoderTraits>, CodecError> {
Ok(Arc::new(bitround_codec_partial::BitroundCodecPartial::new(
input_handle,
data_type,
self.keepbits,
)?))
}
#[cfg(feature = "async")]
async fn async_partial_encoder(
self: Arc<Self>,
input_output_handle: Arc<dyn AsyncArrayPartialEncoderTraits>,
_shape: &[NonZeroU64],
data_type: &DataType,
_fill_value: &FillValue,
_options: &CodecOptions,
) -> Result<Arc<dyn AsyncArrayPartialEncoderTraits>, CodecError> {
Ok(Arc::new(bitround_codec_partial::BitroundCodecPartial::new(
input_output_handle,
data_type,
self.keepbits,
)?))
}
fn encoded_data_type(&self, decoded_data_type: &DataType) -> Result<DataType, CodecError> {
decoded_data_type.codec_bitround()?;
Ok(decoded_data_type.clone())
}
}