1pub mod deflate;
2pub mod fletcher32;
3#[cfg(feature = "lz4")]
4pub mod lz4;
5pub mod nbit;
6pub mod scaleoffset;
7pub mod shuffle;
8
9use std::collections::HashMap;
10
11use crate::error::{Error, Result};
12use crate::messages::filter_pipeline::FilterDescription;
13
14pub const FILTER_DEFLATE: u16 = 1;
16pub const FILTER_SHUFFLE: u16 = 2;
17pub const FILTER_FLETCHER32: u16 = 3;
18pub const FILTER_SZIP: u16 = 4;
19pub const FILTER_NBIT: u16 = 5;
20pub const FILTER_SCALEOFFSET: u16 = 6;
21pub const FILTER_LZ4: u16 = 32004;
23
24pub type FilterFn = Box<dyn Fn(&FilterDescription, &[u8], usize) -> Result<Vec<u8>> + Send + Sync>;
29
30enum FilterImplementation {
31 Builtin,
32 Custom(FilterFn),
33}
34
35pub struct FilterRegistry {
40 filters: HashMap<u16, FilterImplementation>,
41}
42
43impl FilterRegistry {
44 pub fn new() -> Self {
46 let mut registry = FilterRegistry {
47 filters: HashMap::new(),
48 };
49 registry.register_builtin(FILTER_DEFLATE);
50 registry.register_builtin(FILTER_SHUFFLE);
51 registry.register_builtin(FILTER_FLETCHER32);
52 registry.register_builtin(FILTER_NBIT);
53 registry.register_builtin(FILTER_SCALEOFFSET);
54 #[cfg(feature = "lz4")]
55 registry.register_builtin(FILTER_LZ4);
56 registry
57 }
58
59 fn register_builtin(&mut self, id: u16) {
60 self.filters.insert(id, FilterImplementation::Builtin);
61 }
62
63 pub fn register(&mut self, id: u16, f: FilterFn) {
67 self.filters.insert(id, FilterImplementation::Custom(f));
68 }
69
70 pub fn apply(
72 &self,
73 filter: &FilterDescription,
74 data: &[u8],
75 element_size: usize,
76 ) -> Result<Vec<u8>> {
77 self.apply_with_limit(filter, data, element_size, None)
78 }
79
80 pub fn apply_with_limit(
83 &self,
84 filter: &FilterDescription,
85 data: &[u8],
86 element_size: usize,
87 max_output_len: Option<usize>,
88 ) -> Result<Vec<u8>> {
89 match self.filters.get(&filter.id) {
90 Some(FilterImplementation::Builtin) => {
91 apply_builtin_filter_with_limit(filter, data, element_size, max_output_len)
92 }
93 Some(FilterImplementation::Custom(f)) => f(filter, data, element_size),
94 None => Err(Error::UnsupportedFilter(format!("filter id {}", filter.id))),
95 }
96 }
97}
98
99impl Default for FilterRegistry {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105pub fn apply_pipeline(
114 data: &[u8],
115 filters: &[FilterDescription],
116 filter_mask: u32,
117 element_size: usize,
118 registry: Option<&FilterRegistry>,
119) -> Result<Vec<u8>> {
120 apply_pipeline_with_limit(data, filters, filter_mask, element_size, registry, None)
121}
122
123pub fn apply_pipeline_with_limit(
127 data: &[u8],
128 filters: &[FilterDescription],
129 filter_mask: u32,
130 element_size: usize,
131 registry: Option<&FilterRegistry>,
132 max_output_len: Option<usize>,
133) -> Result<Vec<u8>> {
134 let active_count = filters
136 .iter()
137 .enumerate()
138 .rev()
139 .filter(|(i, _)| filter_mask & (1 << i) == 0)
140 .count();
141
142 if active_count == 0 {
143 validate_output_limit("filter pipeline", data.len(), max_output_len)?;
144 return Ok(data.to_vec());
145 }
146
147 if active_count == 1 {
149 for (i, filter) in filters.iter().enumerate().rev() {
150 if filter_mask & (1 << i) != 0 {
151 continue;
152 }
153 let output = if let Some(reg) = registry {
154 reg.apply_with_limit(filter, data, element_size, max_output_len)?
155 } else {
156 apply_builtin_filter_with_limit(filter, data, element_size, max_output_len)?
157 };
158 validate_output_limit("filter pipeline", output.len(), max_output_len)?;
159 return Ok(output);
160 }
161 }
162
163 let mut owned: Option<Vec<u8>> = None;
168
169 for (i, filter) in filters.iter().enumerate().rev() {
170 if filter_mask & (1 << i) != 0 {
171 continue;
172 }
173
174 let input: &[u8] = match &owned {
175 Some(buf) => buf,
176 None => data,
177 };
178
179 owned = Some(if let Some(reg) = registry {
180 reg.apply_with_limit(filter, input, element_size, max_output_len)?
181 } else {
182 apply_builtin_filter_with_limit(filter, input, element_size, max_output_len)?
183 });
184 }
185
186 let output = owned.unwrap_or_else(|| data.to_vec());
187 validate_output_limit("filter pipeline", output.len(), max_output_len)?;
188 Ok(output)
189}
190
191fn apply_builtin_filter_with_limit(
192 filter: &FilterDescription,
193 data: &[u8],
194 element_size: usize,
195 max_output_len: Option<usize>,
196) -> Result<Vec<u8>> {
197 match filter.id {
198 FILTER_DEFLATE => match max_output_len {
199 Some(max_output_len) => deflate::decompress_with_limit(data, max_output_len),
200 None => deflate::decompress(data),
201 },
202 FILTER_SHUFFLE => Ok(shuffle::unshuffle(data, element_size)),
203 FILTER_FLETCHER32 => fletcher32::verify_and_strip(data),
204 FILTER_SZIP => Err(Error::UnsupportedFilter("szip".into())),
205 FILTER_NBIT => match max_output_len {
206 Some(max_output_len) => {
207 nbit::decompress_with_limit(data, &filter.client_data, max_output_len)
208 }
209 None => nbit::decompress(data, &filter.client_data),
210 },
211 FILTER_SCALEOFFSET => match max_output_len {
212 Some(max_output_len) => {
213 scaleoffset::decompress_with_limit(data, &filter.client_data, max_output_len)
214 }
215 None => scaleoffset::decompress(data, &filter.client_data),
216 },
217 #[cfg(feature = "lz4")]
218 FILTER_LZ4 => match max_output_len {
219 Some(max_output_len) => lz4::decompress_with_limit(data, max_output_len),
220 None => lz4::decompress(data),
221 },
222 id => Err(Error::UnsupportedFilter(format!("filter id {}", id))),
223 }
224}
225
226fn validate_output_limit(context: &str, len: usize, max_output_len: Option<usize>) -> Result<()> {
227 if let Some(max_output_len) = max_output_len {
228 if len > max_output_len {
229 return Err(Error::DecompressionError(format!(
230 "{context} decoded to {len} bytes, limit {max_output_len}"
231 )));
232 }
233 }
234 Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use flate2::write::ZlibEncoder;
241 use flate2::Compression;
242 use std::io::Write;
243
244 #[test]
245 fn filter_registry_default() {
246 let registry = FilterRegistry::new();
247 assert!(registry.filters.contains_key(&FILTER_DEFLATE));
249 assert!(registry.filters.contains_key(&FILTER_SHUFFLE));
250 assert!(registry.filters.contains_key(&FILTER_FLETCHER32));
251 assert!(registry.filters.contains_key(&FILTER_NBIT));
252 assert!(registry.filters.contains_key(&FILTER_SCALEOFFSET));
253 }
254
255 #[test]
256 fn filter_registry_custom() {
257 let mut registry = FilterRegistry::new();
258 registry.register(32000, Box::new(|_, data, _| Ok(data.to_vec())));
260 let filter = FilterDescription {
261 id: 32000,
262 name: None,
263 client_data: Vec::new(),
264 };
265 let result = registry.apply(&filter, &[1, 2, 3], 1).unwrap();
266 assert_eq!(result, vec![1, 2, 3]);
267 }
268
269 #[test]
270 fn filter_registry_unknown() {
271 let registry = FilterRegistry::new();
272 let filter = FilterDescription {
273 id: 9999,
274 name: None,
275 client_data: Vec::new(),
276 };
277 let err = registry.apply(&filter, &[1, 2, 3], 1).unwrap_err();
278 assert!(matches!(err, Error::UnsupportedFilter(_)));
279 }
280
281 #[test]
282 fn apply_pipeline_with_limit_caps_registry_deflate_output() {
283 let original = vec![0u8; 4096];
284 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
285 encoder.write_all(&original).unwrap();
286 let compressed = encoder.finish().unwrap();
287 let filter = FilterDescription {
288 id: FILTER_DEFLATE,
289 name: None,
290 client_data: vec![6],
291 };
292 let registry = FilterRegistry::new();
293
294 let decoded =
295 apply_pipeline_with_limit(&compressed, &[filter], 0, 1, Some(®istry), Some(65))
296 .unwrap();
297
298 assert_eq!(decoded.len(), 65);
299 assert_eq!(decoded, original[..65]);
300 }
301
302 #[test]
303 fn apply_pipeline_with_limit_rejects_oversized_final_output() {
304 let filter = FilterDescription {
305 id: FILTER_SHUFFLE,
306 name: None,
307 client_data: Vec::new(),
308 };
309
310 let err =
311 apply_pipeline_with_limit(&[1, 2, 3, 4], &[filter], 0, 1, None, Some(3)).unwrap_err();
312
313 assert!(err.to_string().contains("limit 3"));
314 }
315}