use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use crate::{ChunkManifest, Codec, CodecError, CodecKind, CompressTelemetry, looks_like_oom};
fn is_gpu_kind(kind: CodecKind) -> bool {
matches!(
kind,
CodecKind::NvcompZstd
| CodecKind::NvcompBitcomp
| CodecKind::NvcompGans
| CodecKind::NvcompGDeflate
| CodecKind::DietGpuAns
)
}
pub struct CodecRegistry {
codecs: HashMap<CodecKind, Arc<dyn Codec>>,
default: CodecKind,
}
impl CodecRegistry {
pub fn new(default: CodecKind) -> Self {
Self {
codecs: HashMap::new(),
default,
}
}
pub fn register(&mut self, codec: Arc<dyn Codec>) -> &mut Self {
self.codecs.insert(codec.kind(), codec);
self
}
#[must_use]
pub fn with(mut self, codec: Arc<dyn Codec>) -> Self {
self.register(codec);
self
}
pub fn kinds(&self) -> impl Iterator<Item = CodecKind> + '_ {
self.codecs.keys().copied()
}
pub fn default_kind(&self) -> CodecKind {
self.default
}
fn lookup(&self, kind: CodecKind) -> Result<&Arc<dyn Codec>, CodecError> {
self.codecs
.get(&kind)
.ok_or(CodecError::UnregisteredCodec(kind))
}
pub async fn compress(
&self,
input: Bytes,
kind: CodecKind,
) -> Result<(Bytes, ChunkManifest), CodecError> {
let codec = self.lookup(kind)?;
codec.compress(input).await
}
pub async fn decompress(
&self,
input: Bytes,
manifest: &ChunkManifest,
) -> Result<Bytes, CodecError> {
let codec = self.lookup(manifest.codec)?;
codec.decompress(input, manifest).await
}
pub async fn compress_with_telemetry(
&self,
input: Bytes,
kind: CodecKind,
) -> (
Result<(Bytes, ChunkManifest), CodecError>,
CompressTelemetry,
) {
let bytes_in = input.len() as u64;
let codec = match self.lookup(kind) {
Ok(c) => c,
Err(e) => {
let tel = CompressTelemetry {
codec: kind.as_str(),
bytes_in,
bytes_out: 0,
gpu_seconds: None,
oom: false,
};
return (Err(e), tel);
}
};
let is_gpu = is_gpu_kind(kind);
let started = Instant::now();
let result = codec.compress(input).await;
let elapsed = started.elapsed().as_secs_f64();
match &result {
Ok((out, _manifest)) => {
let bytes_out = out.len() as u64;
let tel = if is_gpu {
CompressTelemetry::gpu(kind.as_str(), bytes_in, bytes_out, elapsed)
} else {
CompressTelemetry::cpu(kind.as_str(), bytes_in, bytes_out)
};
(result, tel)
}
Err(e) => {
let mut tel = if is_gpu {
CompressTelemetry::gpu(kind.as_str(), bytes_in, 0, elapsed)
} else {
CompressTelemetry::cpu(kind.as_str(), bytes_in, 0)
};
if looks_like_oom(e) {
tel = tel.with_oom();
}
(result, tel)
}
}
}
pub async fn decompress_with_telemetry(
&self,
input: Bytes,
manifest: &ChunkManifest,
) -> (Result<Bytes, CodecError>, CompressTelemetry) {
let bytes_in = input.len() as u64;
let kind = manifest.codec;
let codec = match self.lookup(kind) {
Ok(c) => c,
Err(e) => {
let tel = CompressTelemetry {
codec: kind.as_str(),
bytes_in,
bytes_out: 0,
gpu_seconds: None,
oom: false,
};
return (Err(e), tel);
}
};
let is_gpu = is_gpu_kind(kind);
let started = Instant::now();
let result = codec.decompress(input, manifest).await;
let elapsed = started.elapsed().as_secs_f64();
match &result {
Ok(out) => {
let bytes_out = out.len() as u64;
let tel = if is_gpu {
CompressTelemetry::gpu(kind.as_str(), bytes_in, bytes_out, elapsed)
} else {
CompressTelemetry::cpu(kind.as_str(), bytes_in, bytes_out)
};
(result, tel)
}
Err(e) => {
let mut tel = if is_gpu {
CompressTelemetry::gpu(kind.as_str(), bytes_in, 0, elapsed)
} else {
CompressTelemetry::cpu(kind.as_str(), bytes_in, 0)
};
if looks_like_oom(e) {
tel = tel.with_oom();
}
(result, tel)
}
}
}
}
impl std::fmt::Debug for CodecRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut kinds: Vec<&CodecKind> = self.codecs.keys().collect();
kinds.sort_unstable_by_key(|k| k.as_str());
f.debug_struct("CodecRegistry")
.field("default", &self.default)
.field("registered", &kinds)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cpu_zstd::CpuZstd;
use crate::passthrough::Passthrough;
fn registry() -> CodecRegistry {
CodecRegistry::new(CodecKind::CpuZstd)
.with(Arc::new(Passthrough))
.with(Arc::new(CpuZstd::default()))
}
#[tokio::test]
async fn dispatches_compress_by_kind() {
let r = registry();
let input = Bytes::from(vec![b'a'; 1024]);
let (compressed_pt, manifest_pt) = r
.compress(input.clone(), CodecKind::Passthrough)
.await
.unwrap();
assert_eq!(manifest_pt.codec, CodecKind::Passthrough);
assert_eq!(compressed_pt.len(), input.len());
let (compressed_zstd, manifest_zstd) =
r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
assert_eq!(manifest_zstd.codec, CodecKind::CpuZstd);
assert!(compressed_zstd.len() < input.len() / 5);
}
#[tokio::test]
async fn dispatches_decompress_by_manifest() {
let r = registry();
let input = Bytes::from(vec![b'a'; 1024]);
let (compressed, manifest) = r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
let decompressed = r.decompress(compressed, &manifest).await.unwrap();
assert_eq!(decompressed, input);
}
#[tokio::test]
async fn unregistered_codec_yields_error() {
let r = registry();
let bogus_manifest = ChunkManifest {
codec: CodecKind::NvcompBitcomp,
original_size: 10,
compressed_size: 10,
crc32c: 0,
};
let err = r
.decompress(Bytes::from_static(b"0123456789"), &bogus_manifest)
.await
.unwrap_err();
assert!(matches!(
err,
CodecError::UnregisteredCodec(CodecKind::NvcompBitcomp)
));
}
#[tokio::test]
async fn compress_with_telemetry_cpu_marks_gpu_seconds_none() {
let r = registry();
let input = Bytes::from(vec![b'a'; 1024]);
let (res, tel) = r
.compress_with_telemetry(input.clone(), CodecKind::CpuZstd)
.await;
let (out, _manifest) = res.expect("compress ok");
assert_eq!(tel.codec, "cpu-zstd");
assert_eq!(tel.bytes_in, input.len() as u64);
assert_eq!(tel.bytes_out, out.len() as u64);
assert!(
tel.gpu_seconds.is_none(),
"CPU codec must report gpu_seconds=None, got {:?}",
tel.gpu_seconds
);
assert!(!tel.oom);
}
#[tokio::test]
async fn compress_with_telemetry_unregistered_returns_telemetry() {
let r = registry();
let input = Bytes::from(vec![b'b'; 32]);
let (res, tel) = r
.compress_with_telemetry(input.clone(), CodecKind::NvcompBitcomp)
.await;
assert!(matches!(
res,
Err(CodecError::UnregisteredCodec(CodecKind::NvcompBitcomp))
));
assert_eq!(tel.codec, "nvcomp-bitcomp");
assert_eq!(tel.bytes_in, input.len() as u64);
assert_eq!(tel.bytes_out, 0);
assert!(tel.gpu_seconds.is_none());
assert!(!tel.oom);
}
#[tokio::test]
async fn decompress_with_telemetry_cpu_reports_output_size() {
let r = registry();
let input = Bytes::from(vec![b'c'; 1024]);
let (compressed, manifest) = r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
let (res, tel) = r
.decompress_with_telemetry(compressed.clone(), &manifest)
.await;
let out = res.expect("decompress ok");
assert_eq!(out, input);
assert_eq!(tel.codec, "cpu-zstd");
assert_eq!(tel.bytes_in, compressed.len() as u64);
assert_eq!(tel.bytes_out, input.len() as u64);
assert!(tel.gpu_seconds.is_none());
}
}