use std::borrow::Cow;
use std::sync::Arc;
use zarrs_plugin::{PluginCreateError, ZarrVersion};
use zstd::zstd_safe;
use super::{ZstdCodecConfiguration, ZstdCodecConfigurationV1};
use crate::array::{ArrayBytesRaw, BytesRepresentation};
use zarrs_codec::{
BytesToBytesCodecTraits, CodecError, CodecMetadataOptions, CodecOptions, CodecTraits,
PartialDecoderCapability, PartialEncoderCapability, RecommendedConcurrency,
};
use zarrs_metadata::Configuration;
#[derive(Clone, Debug)]
pub struct ZstdCodec {
compression: zstd_safe::CompressionLevel,
checksum: bool,
}
impl ZstdCodec {
#[must_use]
pub const fn new(compression: zstd_safe::CompressionLevel, checksum: bool) -> Self {
Self {
compression,
checksum,
}
}
pub fn new_with_configuration(
configuration: &ZstdCodecConfiguration,
) -> Result<Self, PluginCreateError> {
let (compression, checksum) = match configuration {
ZstdCodecConfiguration::V1(configuration) => {
(configuration.level, configuration.checksum)
}
ZstdCodecConfiguration::Numcodecs(configuration) => (configuration.level, false),
_ => Err(PluginCreateError::Other(
"this zstd codec configuration variant is unsupported".to_string(),
))?,
};
Ok(Self {
compression: compression.into(),
checksum,
})
}
}
impl CodecTraits for ZstdCodec {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn configuration(
&self,
_version: ZarrVersion,
_options: &CodecMetadataOptions,
) -> Option<Configuration> {
let configuration = ZstdCodecConfiguration::V1(ZstdCodecConfigurationV1 {
level: self.compression.into(),
checksum: self.checksum,
});
Some(configuration.into())
}
fn partial_decoder_capability(&self) -> PartialDecoderCapability {
PartialDecoderCapability {
partial_read: false,
partial_decode: false,
}
}
fn partial_encoder_capability(&self) -> PartialEncoderCapability {
PartialEncoderCapability {
partial_encode: false,
}
}
}
#[cfg_attr(
all(feature = "async", not(target_arch = "wasm32")),
async_trait::async_trait
)]
#[cfg_attr(all(feature = "async", target_arch = "wasm32"), async_trait::async_trait(?Send))]
impl BytesToBytesCodecTraits for ZstdCodec {
fn into_dyn(self: Arc<Self>) -> Arc<dyn BytesToBytesCodecTraits> {
self as Arc<dyn BytesToBytesCodecTraits>
}
fn recommended_concurrency(
&self,
_decoded_representation: &BytesRepresentation,
) -> Result<RecommendedConcurrency, CodecError> {
Ok(RecommendedConcurrency::new_maximum(1))
}
fn encode<'a>(
&self,
decoded_value: ArrayBytesRaw<'a>,
_options: &CodecOptions,
) -> Result<ArrayBytesRaw<'a>, CodecError> {
let mut compressor = zstd::bulk::Compressor::new(self.compression)?;
compressor.include_checksum(self.checksum)?;
let result = compressor.compress(&decoded_value)?;
Ok(Cow::Owned(result))
}
fn decode<'a>(
&self,
encoded_value: ArrayBytesRaw<'a>,
_decoded_representation: &BytesRepresentation,
_options: &CodecOptions,
) -> Result<ArrayBytesRaw<'a>, CodecError> {
let upper_bound = zstd::bulk::Decompressor::upper_bound(&encoded_value); if let Some(upper_bound) = upper_bound {
let result = zstd::bulk::decompress(&encoded_value, upper_bound)?;
Ok(Cow::Owned(result))
} else {
zstd::decode_all(std::io::Cursor::new(&encoded_value))
.map_err(CodecError::from)
.map(Cow::Owned)
}
}
fn encoded_representation(
&self,
decoded_representation: &BytesRepresentation,
) -> BytesRepresentation {
decoded_representation
.size()
.map_or(BytesRepresentation::UnboundedSize, |size| {
const HEADER_TRAILER_OVERHEAD: u64 = 4 + 14 + 4;
const MIN_WINDOW_SIZE: u64 = 1000; const BLOCK_OVERHEAD: u64 = 3;
let blocks_overhead = BLOCK_OVERHEAD * size.div_ceil(MIN_WINDOW_SIZE);
BytesRepresentation::BoundedSize(size + HEADER_TRAILER_OVERHEAD + blocks_overhead)
})
}
}