Skip to main content

hdf5_reader/filters/
mod.rs

1pub mod deflate;
2pub mod fletcher32;
3#[cfg(feature = "lz4")]
4pub mod lz4;
5pub mod nbit;
6pub mod scaleoffset;
7pub mod shuffle;
8
9use std::collections::HashMap;
10
11use crate::error::{Error, Result};
12use crate::messages::filter_pipeline::FilterDescription;
13
14/// Standard HDF5 filter IDs.
15pub const FILTER_DEFLATE: u16 = 1;
16pub const FILTER_SHUFFLE: u16 = 2;
17pub const FILTER_FLETCHER32: u16 = 3;
18pub const FILTER_SZIP: u16 = 4;
19pub const FILTER_NBIT: u16 = 5;
20pub const FILTER_SCALEOFFSET: u16 = 6;
21/// HDF5 registered LZ4 filter.
22pub const FILTER_LZ4: u16 = 32004;
23
24/// A user-supplied filter function.
25///
26/// Takes the filter description, input data, and element size, then returns
27/// the decoded output.
28pub type FilterFn = Box<dyn Fn(&FilterDescription, &[u8], usize) -> Result<Vec<u8>> + Send + Sync>;
29
30enum FilterImplementation {
31    Builtin,
32    Custom(FilterFn),
33}
34
35/// A registry of filter implementations.
36///
37/// Comes pre-loaded with deflate, shuffle, and fletcher32. Users can register
38/// additional filters (e.g., Blosc, LZ4, ZFP) before reading datasets.
39pub struct FilterRegistry {
40    filters: HashMap<u16, FilterImplementation>,
41}
42
43impl FilterRegistry {
44    /// Create a new registry with the built-in filters pre-registered.
45    pub fn new() -> Self {
46        let mut registry = FilterRegistry {
47            filters: HashMap::new(),
48        };
49        registry.register_builtin(FILTER_DEFLATE);
50        registry.register_builtin(FILTER_SHUFFLE);
51        registry.register_builtin(FILTER_FLETCHER32);
52        registry.register_builtin(FILTER_NBIT);
53        registry.register_builtin(FILTER_SCALEOFFSET);
54        #[cfg(feature = "lz4")]
55        registry.register_builtin(FILTER_LZ4);
56        registry
57    }
58
59    fn register_builtin(&mut self, id: u16) {
60        self.filters.insert(id, FilterImplementation::Builtin);
61    }
62
63    /// Register a custom filter implementation for the given filter ID.
64    ///
65    /// Overwrites any previously registered filter with the same ID.
66    pub fn register(&mut self, id: u16, f: FilterFn) {
67        self.filters.insert(id, FilterImplementation::Custom(f));
68    }
69
70    /// Apply a single filter by ID.
71    pub fn apply(
72        &self,
73        filter: &FilterDescription,
74        data: &[u8],
75        element_size: usize,
76    ) -> Result<Vec<u8>> {
77        self.apply_with_limit(filter, data, element_size, None)
78    }
79
80    /// Apply a single filter by ID, passing a maximum decoded output length to
81    /// built-in filters that can enforce it while decoding.
82    pub fn apply_with_limit(
83        &self,
84        filter: &FilterDescription,
85        data: &[u8],
86        element_size: usize,
87        max_output_len: Option<usize>,
88    ) -> Result<Vec<u8>> {
89        match self.filters.get(&filter.id) {
90            Some(FilterImplementation::Builtin) => {
91                apply_builtin_filter_with_limit(filter, data, element_size, max_output_len)
92            }
93            Some(FilterImplementation::Custom(f)) => f(filter, data, element_size),
94            None => Err(Error::UnsupportedFilter(format!("filter id {}", filter.id))),
95        }
96    }
97}
98
99impl Default for FilterRegistry {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105/// Apply the filter pipeline in reverse (decompression direction) to a chunk.
106///
107/// HDF5 stores filters in the order they were applied during writing.
108/// On read, we apply them in reverse order.
109///
110/// If `registry` is `None`, the built-in filter set is used.
111///
112/// `filter_mask` is a bitmask where bit N being set means filter N should be skipped.
113pub fn apply_pipeline(
114    data: &[u8],
115    filters: &[FilterDescription],
116    filter_mask: u32,
117    element_size: usize,
118    registry: Option<&FilterRegistry>,
119) -> Result<Vec<u8>> {
120    apply_pipeline_with_limit(data, filters, filter_mask, element_size, registry, None)
121}
122
123/// Apply the filter pipeline in reverse (decompression direction) to a chunk,
124/// passing a maximum decoded output length to built-in filters that support
125/// bounded decompression.
126pub fn apply_pipeline_with_limit(
127    data: &[u8],
128    filters: &[FilterDescription],
129    filter_mask: u32,
130    element_size: usize,
131    registry: Option<&FilterRegistry>,
132    max_output_len: Option<usize>,
133) -> Result<Vec<u8>> {
134    // Count active filters so an all-skipped pipeline can avoid the loop.
135    let active_count = filters
136        .iter()
137        .enumerate()
138        .rev()
139        .filter(|(i, _)| filter_mask & (1 << i) == 0)
140        .count();
141
142    if active_count == 0 {
143        validate_output_limit("filter pipeline", data.len(), max_output_len)?;
144        return Ok(data.to_vec());
145    }
146
147    // For a single active filter, avoid the double-buffer loop overhead.
148    if active_count == 1 {
149        for (i, filter) in filters.iter().enumerate().rev() {
150            if filter_mask & (1 << i) != 0 {
151                continue;
152            }
153            let output = if let Some(reg) = registry {
154                reg.apply_with_limit(filter, data, element_size, max_output_len)?
155            } else {
156                apply_builtin_filter_with_limit(filter, data, element_size, max_output_len)?
157            };
158            validate_output_limit("filter pipeline", output.len(), max_output_len)?;
159            return Ok(output);
160        }
161    }
162
163    // Multi-filter pipeline: the first stage reads from the borrowed input
164    // slice (avoiding a copy), subsequent stages consume the previous output.
165    // Each filter stage necessarily allocates (output sizes are unpredictable),
166    // but we avoid the initial data.to_vec() copy.
167    let mut owned: Option<Vec<u8>> = None;
168
169    for (i, filter) in filters.iter().enumerate().rev() {
170        if filter_mask & (1 << i) != 0 {
171            continue;
172        }
173
174        let input: &[u8] = match &owned {
175            Some(buf) => buf,
176            None => data,
177        };
178
179        owned = Some(if let Some(reg) = registry {
180            reg.apply_with_limit(filter, input, element_size, max_output_len)?
181        } else {
182            apply_builtin_filter_with_limit(filter, input, element_size, max_output_len)?
183        });
184    }
185
186    let output = owned.unwrap_or_else(|| data.to_vec());
187    validate_output_limit("filter pipeline", output.len(), max_output_len)?;
188    Ok(output)
189}
190
191fn apply_builtin_filter_with_limit(
192    filter: &FilterDescription,
193    data: &[u8],
194    element_size: usize,
195    max_output_len: Option<usize>,
196) -> Result<Vec<u8>> {
197    match filter.id {
198        FILTER_DEFLATE => match max_output_len {
199            Some(max_output_len) => deflate::decompress_with_limit(data, max_output_len),
200            None => deflate::decompress(data),
201        },
202        FILTER_SHUFFLE => Ok(shuffle::unshuffle(data, element_size)),
203        FILTER_FLETCHER32 => fletcher32::verify_and_strip(data),
204        FILTER_SZIP => Err(Error::UnsupportedFilter("szip".into())),
205        FILTER_NBIT => match max_output_len {
206            Some(max_output_len) => {
207                nbit::decompress_with_limit(data, &filter.client_data, max_output_len)
208            }
209            None => nbit::decompress(data, &filter.client_data),
210        },
211        FILTER_SCALEOFFSET => match max_output_len {
212            Some(max_output_len) => {
213                scaleoffset::decompress_with_limit(data, &filter.client_data, max_output_len)
214            }
215            None => scaleoffset::decompress(data, &filter.client_data),
216        },
217        #[cfg(feature = "lz4")]
218        FILTER_LZ4 => match max_output_len {
219            Some(max_output_len) => lz4::decompress_with_limit(data, max_output_len),
220            None => lz4::decompress(data),
221        },
222        id => Err(Error::UnsupportedFilter(format!("filter id {}", id))),
223    }
224}
225
226fn validate_output_limit(context: &str, len: usize, max_output_len: Option<usize>) -> Result<()> {
227    if let Some(max_output_len) = max_output_len {
228        if len > max_output_len {
229            return Err(Error::DecompressionError(format!(
230                "{context} decoded to {len} bytes, limit {max_output_len}"
231            )));
232        }
233    }
234    Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use flate2::write::ZlibEncoder;
241    use flate2::Compression;
242    use std::io::Write;
243
244    #[test]
245    fn filter_registry_default() {
246        let registry = FilterRegistry::new();
247        // Built-in filters should be registered
248        assert!(registry.filters.contains_key(&FILTER_DEFLATE));
249        assert!(registry.filters.contains_key(&FILTER_SHUFFLE));
250        assert!(registry.filters.contains_key(&FILTER_FLETCHER32));
251        assert!(registry.filters.contains_key(&FILTER_NBIT));
252        assert!(registry.filters.contains_key(&FILTER_SCALEOFFSET));
253    }
254
255    #[test]
256    fn filter_registry_custom() {
257        let mut registry = FilterRegistry::new();
258        // Register a no-op custom filter
259        registry.register(32000, Box::new(|_, data, _| Ok(data.to_vec())));
260        let filter = FilterDescription {
261            id: 32000,
262            name: None,
263            client_data: Vec::new(),
264        };
265        let result = registry.apply(&filter, &[1, 2, 3], 1).unwrap();
266        assert_eq!(result, vec![1, 2, 3]);
267    }
268
269    #[test]
270    fn filter_registry_unknown() {
271        let registry = FilterRegistry::new();
272        let filter = FilterDescription {
273            id: 9999,
274            name: None,
275            client_data: Vec::new(),
276        };
277        let err = registry.apply(&filter, &[1, 2, 3], 1).unwrap_err();
278        assert!(matches!(err, Error::UnsupportedFilter(_)));
279    }
280
281    #[test]
282    fn apply_pipeline_with_limit_caps_registry_deflate_output() {
283        let original = vec![0u8; 4096];
284        let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
285        encoder.write_all(&original).unwrap();
286        let compressed = encoder.finish().unwrap();
287        let filter = FilterDescription {
288            id: FILTER_DEFLATE,
289            name: None,
290            client_data: vec![6],
291        };
292        let registry = FilterRegistry::new();
293
294        let decoded =
295            apply_pipeline_with_limit(&compressed, &[filter], 0, 1, Some(&registry), Some(65))
296                .unwrap();
297
298        assert_eq!(decoded.len(), 65);
299        assert_eq!(decoded, original[..65]);
300    }
301
302    #[test]
303    fn apply_pipeline_with_limit_rejects_oversized_final_output() {
304        let filter = FilterDescription {
305            id: FILTER_SHUFFLE,
306            name: None,
307            client_data: Vec::new(),
308        };
309
310        let err =
311            apply_pipeline_with_limit(&[1, 2, 3, 4], &[filter], 0, 1, None, Some(3)).unwrap_err();
312
313        assert!(err.to_string().contains("limit 3"));
314    }
315}