1use std::collections::HashMap;
12use std::sync::Arc;
13
14use bytes::Bytes;
15
16use crate::{ChunkManifest, Codec, CodecError, CodecKind};
17
18pub struct CodecRegistry {
20 codecs: HashMap<CodecKind, Arc<dyn Codec>>,
21 default: CodecKind,
22}
23
24impl CodecRegistry {
25 pub fn new(default: CodecKind) -> Self {
28 Self {
29 codecs: HashMap::new(),
30 default,
31 }
32 }
33
34 pub fn register(&mut self, codec: Arc<dyn Codec>) -> &mut Self {
36 self.codecs.insert(codec.kind(), codec);
37 self
38 }
39
40 #[must_use]
42 pub fn with(mut self, codec: Arc<dyn Codec>) -> Self {
43 self.register(codec);
44 self
45 }
46
47 pub fn kinds(&self) -> impl Iterator<Item = CodecKind> + '_ {
49 self.codecs.keys().copied()
50 }
51
52 pub fn default_kind(&self) -> CodecKind {
54 self.default
55 }
56
57 fn lookup(&self, kind: CodecKind) -> Result<&Arc<dyn Codec>, CodecError> {
58 self.codecs
59 .get(&kind)
60 .ok_or(CodecError::UnregisteredCodec(kind))
61 }
62
63 pub async fn compress(
65 &self,
66 input: Bytes,
67 kind: CodecKind,
68 ) -> Result<(Bytes, ChunkManifest), CodecError> {
69 let codec = self.lookup(kind)?;
70 codec.compress(input).await
71 }
72
73 pub async fn decompress(
75 &self,
76 input: Bytes,
77 manifest: &ChunkManifest,
78 ) -> Result<Bytes, CodecError> {
79 let codec = self.lookup(manifest.codec)?;
80 codec.decompress(input, manifest).await
81 }
82}
83
84impl std::fmt::Debug for CodecRegistry {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 let mut kinds: Vec<&CodecKind> = self.codecs.keys().collect();
87 kinds.sort_unstable_by_key(|k| k.as_str());
88 f.debug_struct("CodecRegistry")
89 .field("default", &self.default)
90 .field("registered", &kinds)
91 .finish()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::cpu_zstd::CpuZstd;
99 use crate::passthrough::Passthrough;
100
101 fn registry() -> CodecRegistry {
102 CodecRegistry::new(CodecKind::CpuZstd)
103 .with(Arc::new(Passthrough))
104 .with(Arc::new(CpuZstd::default()))
105 }
106
107 #[tokio::test]
108 async fn dispatches_compress_by_kind() {
109 let r = registry();
110 let input = Bytes::from(vec![b'a'; 1024]);
111
112 let (compressed_pt, manifest_pt) = r
113 .compress(input.clone(), CodecKind::Passthrough)
114 .await
115 .unwrap();
116 assert_eq!(manifest_pt.codec, CodecKind::Passthrough);
117 assert_eq!(compressed_pt.len(), input.len());
118
119 let (compressed_zstd, manifest_zstd) =
120 r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
121 assert_eq!(manifest_zstd.codec, CodecKind::CpuZstd);
122 assert!(compressed_zstd.len() < input.len() / 5);
123 }
124
125 #[tokio::test]
126 async fn dispatches_decompress_by_manifest() {
127 let r = registry();
128 let input = Bytes::from(vec![b'a'; 1024]);
129 let (compressed, manifest) = r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
130 let decompressed = r.decompress(compressed, &manifest).await.unwrap();
131 assert_eq!(decompressed, input);
132 }
133
134 #[tokio::test]
135 async fn unregistered_codec_yields_error() {
136 let r = registry();
137 let bogus_manifest = ChunkManifest {
138 codec: CodecKind::NvcompBitcomp,
139 original_size: 10,
140 compressed_size: 10,
141 crc32c: 0,
142 };
143 let err = r
144 .decompress(Bytes::from_static(b"0123456789"), &bogus_manifest)
145 .await
146 .unwrap_err();
147 assert!(matches!(
148 err,
149 CodecError::UnregisteredCodec(CodecKind::NvcompBitcomp)
150 ));
151 }
152}