burn_store/
filter.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::vec::Vec;
4use core::fmt;
5
6#[cfg(feature = "std")]
7use regex::Regex;
8
9/// A sophisticated path filter that supports multiple matching strategies.
10///
11/// The filter uses an OR logic - a path is included if it matches ANY of the configured criteria.
12/// This allows for flexible and powerful filtering configurations.
13///
14/// # Examples
15///
16/// ```rust,no_run
17/// # use burn_store::PathFilter;
18/// // Create a filter that matches encoder paths or any weight path
19/// let filter = PathFilter::new()
20///     .with_regex(r"^encoder\..*")
21///     .with_regex(r".*\.weight$")
22///     .with_full_path("special_tensor");
23///
24/// // Check if a path should be included
25/// if filter.matches("encoder.layer1.weight") {
26///     // This will match due to both regex patterns
27/// }
28/// ```
29#[derive(Debug, Clone, Default)]
30pub struct PathFilter {
31    /// Compiled regex patterns for matching paths
32    #[cfg(feature = "std")]
33    regex_patterns: Vec<Regex>,
34
35    /// Exact full paths to match
36    exact_paths: Vec<String>,
37
38    /// Predicate functions for custom matching logic based on path and container path
39    /// Note: These cannot be cloned, so we store them separately
40    predicates: Vec<fn(&str, &str) -> bool>,
41
42    /// If true, matches all paths (overrides other filters)
43    match_all: bool,
44}
45
46impl PathFilter {
47    /// Create a new empty filter (matches nothing by default)
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Create a filter that matches all paths
53    pub fn all() -> Self {
54        Self {
55            match_all: true,
56            ..Default::default()
57        }
58    }
59
60    /// Create a filter that matches nothing
61    pub fn none() -> Self {
62        Self::default()
63    }
64
65    /// Add a regex pattern for matching paths
66    #[cfg(feature = "std")]
67    pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
68        if let Ok(regex) = Regex::new(pattern.as_ref()) {
69            self.regex_patterns.push(regex);
70        }
71        // TODO: Consider returning Result to handle regex compilation errors
72        self
73    }
74
75    /// Add multiple regex patterns
76    #[cfg(feature = "std")]
77    pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
78    where
79        I: IntoIterator<Item = S>,
80        S: AsRef<str>,
81    {
82        for pattern in patterns {
83            if let Ok(regex) = Regex::new(pattern.as_ref()) {
84                self.regex_patterns.push(regex);
85            }
86        }
87        self
88    }
89
90    /// Add an exact full path to match
91    pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
92        self.exact_paths.push(path.into());
93        self
94    }
95
96    /// Add multiple exact full paths
97    pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
98    where
99        I: IntoIterator<Item = S>,
100        S: Into<String>,
101    {
102        self.exact_paths.extend(paths.into_iter().map(|p| p.into()));
103        self
104    }
105
106    /// Add a predicate function for custom matching based on path and container path
107    pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
108        self.predicates.push(predicate);
109        self
110    }
111
112    /// Add multiple predicates
113    pub fn with_predicates<I>(mut self, predicates: I) -> Self
114    where
115        I: IntoIterator<Item = fn(&str, &str) -> bool>,
116    {
117        self.predicates.extend(predicates);
118        self
119    }
120
121    /// Set to match all paths
122    pub fn match_all(mut self) -> Self {
123        self.match_all = true;
124        self
125    }
126
127    /// Check if a path matches this filter (assumes empty container path for backward compatibility)
128    pub fn matches(&self, path: &str) -> bool {
129        self.matches_with_container_path_str(path, "")
130    }
131
132    /// Check if a path and container type match this filter (for backward compatibility)
133    pub fn matches_with_container(&self, path: &str, container_type: &str) -> bool {
134        // For backward compatibility, treat single container type as the full path
135        self.matches_with_container_path_str(path, container_type)
136    }
137
138    /// Check if a path and container path match this filter
139    pub fn matches_with_container_path(&self, path: &[String], container_stack: &[String]) -> bool {
140        let path_str = path.join(".");
141        let container_path = container_stack.join(".");
142        self.matches_with_container_path_str(&path_str, &container_path)
143    }
144
145    /// Check if a path and container path (dot-notated strings) match this filter
146    pub fn matches_with_container_path_str(&self, path: &str, container_path: &str) -> bool {
147        // If match_all is set, always return true
148        if self.match_all {
149            return true;
150        }
151
152        // If no filters are configured, match nothing
153        if self.is_empty() {
154            return false;
155        }
156
157        // Check exact path matches
158        if self.exact_paths.iter().any(|p| p == path) {
159            return true;
160        }
161
162        // Check regex patterns (on the path)
163        #[cfg(feature = "std")]
164        {
165            for regex in &self.regex_patterns {
166                if regex.is_match(path) {
167                    return true;
168                }
169            }
170        }
171
172        // Check predicates with container path
173        if self
174            .predicates
175            .iter()
176            .any(|pred| pred(path, container_path))
177        {
178            return true;
179        }
180
181        false
182    }
183
184    /// Check if the filter is empty (matches nothing)
185    pub fn is_empty(&self) -> bool {
186        if self.match_all {
187            return false;
188        }
189
190        #[cfg(feature = "std")]
191        let regex_empty = self.regex_patterns.is_empty();
192        #[cfg(not(feature = "std"))]
193        let regex_empty = true;
194
195        self.exact_paths.is_empty() && self.predicates.is_empty() && regex_empty
196    }
197
198    /// Get the number of filter criteria configured
199    pub fn criteria_count(&self) -> usize {
200        if self.match_all {
201            return 1;
202        }
203
204        #[allow(unused_mut)]
205        let mut count = self.exact_paths.len() + self.predicates.len();
206
207        #[cfg(feature = "std")]
208        {
209            count += self.regex_patterns.len();
210        }
211
212        count
213    }
214
215    /// Clear all regex patterns
216    #[cfg(feature = "std")]
217    pub fn clear_regex(&mut self) -> &mut Self {
218        self.regex_patterns.clear();
219        self
220    }
221
222    /// Clear all exact paths
223    pub fn clear_paths(&mut self) -> &mut Self {
224        self.exact_paths.clear();
225        self
226    }
227
228    /// Clear all predicates
229    pub fn clear_predicates(&mut self) -> &mut Self {
230        self.predicates.clear();
231        self
232    }
233
234    /// Clear all filters
235    pub fn clear(&mut self) -> &mut Self {
236        #[cfg(feature = "std")]
237        self.clear_regex();
238
239        self.clear_paths().clear_predicates();
240        self.match_all = false;
241        self
242    }
243
244    /// Create a filter from regex patterns only
245    #[cfg(feature = "std")]
246    pub fn from_regex_patterns<I, S>(patterns: I) -> Self
247    where
248        I: IntoIterator<Item = S>,
249        S: AsRef<str>,
250    {
251        Self::new().with_regexes(patterns)
252    }
253
254    /// Create a filter from exact paths only
255    pub fn from_paths<I, S>(paths: I) -> Self
256    where
257        I: IntoIterator<Item = S>,
258        S: Into<String>,
259    {
260        Self::new().with_full_paths(paths)
261    }
262
263    /// Create a filter from a single predicate
264    pub fn from_predicate(predicate: fn(&str, &str) -> bool) -> Self {
265        Self::new().with_predicate(predicate)
266    }
267
268    /// Combine with another filter using OR logic
269    pub fn or(mut self, other: Self) -> Self {
270        if self.match_all || other.match_all {
271            return Self::all();
272        }
273
274        #[cfg(feature = "std")]
275        {
276            self.regex_patterns.extend(other.regex_patterns);
277        }
278
279        self.exact_paths.extend(other.exact_paths);
280        self.predicates.extend(other.predicates);
281
282        self
283    }
284}
285
286impl fmt::Display for PathFilter {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        if self.match_all {
289            return write!(f, "PathFilter::all()");
290        }
291
292        if self.is_empty() {
293            return write!(f, "PathFilter::none()");
294        }
295
296        write!(f, "PathFilter[")?;
297
298        let mut parts = Vec::new();
299
300        #[cfg(feature = "std")]
301        if !self.regex_patterns.is_empty() {
302            parts.push(format!("regex: {:?}", self.regex_patterns));
303        }
304
305        if !self.exact_paths.is_empty() {
306            parts.push(format!("paths: {:?}", self.exact_paths));
307        }
308
309        if !self.predicates.is_empty() {
310            parts.push(format!("predicates: {}", self.predicates.len()));
311        }
312
313        write!(f, "{}]", parts.join(", "))
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn empty_filter() {
323        let filter = PathFilter::new();
324        assert!(filter.is_empty());
325        assert!(!filter.matches("encoder.weight"));
326        assert!(!filter.matches("decoder.bias"));
327    }
328
329    #[test]
330    fn match_all() {
331        let filter = PathFilter::all();
332        assert!(!filter.is_empty());
333        assert!(filter.matches("encoder.weight"));
334        assert!(filter.matches("decoder.bias"));
335        assert!(filter.matches("anything"));
336    }
337
338    #[test]
339    fn exact_paths() {
340        let filter = PathFilter::new()
341            .with_full_path("encoder.weight")
342            .with_full_path("decoder.bias");
343
344        assert!(filter.matches("encoder.weight"));
345        assert!(filter.matches("decoder.bias"));
346        assert!(!filter.matches("encoder.bias"));
347        assert!(!filter.matches("decoder.weight"));
348    }
349
350    #[test]
351    #[cfg(feature = "std")]
352    fn regex_patterns() {
353        let filter = PathFilter::new()
354            .with_regex(r"^encoder\..*")
355            .with_regex(r".*\.weight$");
356
357        assert!(filter.matches("encoder.layer1.bias"));
358        assert!(filter.matches("decoder.weight"));
359        assert!(filter.matches("encoder.weight"));
360        assert!(!filter.matches("decoder.bias"));
361    }
362
363    #[test]
364    fn predicates() {
365        fn contains_norm(path: &str, _container_path: &str) -> bool {
366            path.contains("norm")
367        }
368
369        fn is_short(path: &str, _container_path: &str) -> bool {
370            path.len() < 10
371        }
372
373        let filter = PathFilter::new()
374            .with_predicate(contains_norm)
375            .with_predicate(is_short);
376
377        assert!(filter.matches("norm.weight"));
378        assert!(filter.matches("layer.norm.bias"));
379        assert!(filter.matches("bias"));
380        assert!(!filter.matches("encoder.decoder.weight.long.name"));
381    }
382
383    #[test]
384    fn combined_filters() {
385        let filter = PathFilter::new()
386            .with_full_path("special.tensor")
387            .with_predicate(|path, _container_path| path.contains("attention"));
388
389        #[cfg(feature = "std")]
390        let filter = filter.with_regex(r"^encoder\..*");
391
392        assert!(filter.matches("special.tensor"));
393        assert!(filter.matches("self_attention.query"));
394
395        #[cfg(feature = "std")]
396        assert!(filter.matches("encoder.anything"));
397
398        assert!(!filter.matches("decoder.weight"));
399    }
400
401    #[test]
402    fn or_combination() {
403        let encoder_filter = PathFilter::new().with_full_path("encoder.weight");
404        let decoder_filter = PathFilter::new().with_full_path("decoder.bias");
405
406        let combined = encoder_filter.or(decoder_filter);
407
408        assert!(combined.matches("encoder.weight"));
409        assert!(combined.matches("decoder.bias"));
410        assert!(!combined.matches("model.head.weight"));
411    }
412
413    #[test]
414    #[cfg(feature = "std")]
415    fn common_patterns() {
416        // Test encoder pattern
417        let encoder = PathFilter::new().with_regex(r"^encoder\..*");
418        assert!(encoder.matches("encoder.weight"));
419        assert!(!encoder.matches("decoder.weight"));
420
421        // Test weights-only pattern
422        let weights = PathFilter::new().with_regex(r".*\.weight$");
423        assert!(weights.matches("encoder.weight"));
424        assert!(weights.matches("decoder.weight"));
425        assert!(!weights.matches("encoder.bias"));
426
427        // Test layer-specific patterns
428        let layers = PathFilter::new()
429            .with_regex(r"(^|.*\.)layers\.0\.")
430            .with_regex(r"(^|.*\.)layers\.2\.")
431            .with_regex(r"(^|.*\.)layers\.4\.");
432        assert!(layers.matches("model.layers.0.weight"));
433        assert!(layers.matches("layers.2.bias"));
434        assert!(!layers.matches("layers.1.weight"));
435    }
436
437    #[test]
438    fn criteria_count() {
439        let filter = PathFilter::new()
440            .with_full_path("path1")
441            .with_full_path("path2")
442            .with_predicate(|_, _| true);
443
444        #[cfg(feature = "std")]
445        let filter = filter.with_regex(".*");
446
447        #[cfg(feature = "std")]
448        assert_eq!(filter.criteria_count(), 4);
449
450        #[cfg(not(feature = "std"))]
451        assert_eq!(filter.criteria_count(), 3);
452    }
453
454    #[test]
455    fn clear_operations() {
456        let mut filter = PathFilter::new().with_full_path("test");
457
458        filter.clear_paths();
459        assert!(!filter.matches("test"));
460
461        filter.clear();
462        assert!(filter.is_empty());
463    }
464
465    #[test]
466    fn container_predicates() {
467        // Filter that matches only Linear module weights
468        let linear_weights = PathFilter::new().with_predicate(|path, container_path| {
469            container_path.split('.').next_back() == Some("Linear") && path.ends_with(".weight")
470        });
471
472        assert!(linear_weights.matches_with_container("layer1.weight", "Linear"));
473        assert!(!linear_weights.matches_with_container("layer1.weight", "Conv2d"));
474        assert!(!linear_weights.matches_with_container("layer1.bias", "Linear"));
475
476        // Filter for specific container types
477        let conv_only = PathFilter::new().with_predicate(|_path, container_path| {
478            let last = container_path.split('.').next_back();
479            last == Some("Conv2d") || last == Some("ConvTranspose2d")
480        });
481
482        assert!(conv_only.matches_with_container("encoder.weight", "Conv2d"));
483        assert!(conv_only.matches_with_container("decoder.weight", "ConvTranspose2d"));
484        assert!(!conv_only.matches_with_container("fc.weight", "Linear"));
485
486        // Combine path and container predicates
487        let combined = PathFilter::new()
488            .with_predicate(|path, _container_path| path.starts_with("encoder."))
489            .with_predicate(|_path, container_path| {
490                container_path.split('.').next_back() == Some("BatchNorm2d")
491            });
492
493        // Should match either condition (OR logic)
494        assert!(combined.matches_with_container("encoder.layer1", "Linear"));
495        assert!(combined.matches_with_container("decoder.bn", "BatchNorm2d"));
496        assert!(!combined.matches_with_container("decoder.layer", "Linear"));
497    }
498
499    #[test]
500    fn container_predicate_with_regex() {
501        // Combine regex patterns with container predicates
502        #[cfg(feature = "std")]
503        {
504            let filter = PathFilter::new()
505                .with_regex(r"^encoder\..*")
506                .with_predicate(|path, container_path| {
507                    container_path.split('.').next_back() == Some("Linear")
508                        && path.contains(".bias")
509                });
510
511            // Matches due to regex
512            assert!(filter.matches_with_container("encoder.layer1.weight", "Conv2d"));
513            // Matches due to container predicate
514            assert!(filter.matches_with_container("decoder.fc.bias", "Linear"));
515            // Doesn't match either
516            assert!(!filter.matches_with_container("decoder.conv.weight", "Conv2d"));
517        }
518    }
519
520    #[test]
521    fn container_stack_predicates() {
522        // Filter using full container path - only tensors nested in a specific hierarchy
523        let nested_filter = PathFilter::new().with_predicate(|_path, container_path| {
524            // Check if tensor is nested within: Model -> TransformerBlock -> Linear
525            let parts: Vec<&str> = container_path.split('.').collect();
526            parts.len() >= 3
527                && parts[0] == "Model"
528                && parts[1] == "TransformerBlock"
529                && parts[2] == "Linear"
530        });
531
532        assert!(nested_filter.matches_with_container_path_str(
533            "encoder.weight",
534            "Model.TransformerBlock.Linear.Param"
535        ));
536        assert!(
537            !nested_filter
538                .matches_with_container_path_str("decoder.weight", "Model.Decoder.Linear.Param")
539        );
540        assert!(!nested_filter.matches_with_container_path_str(
541            "encoder.weight",
542            "Model.TransformerBlock.Conv2d.Param"
543        ));
544
545        // Filter that checks for specific depth in hierarchy
546        let depth_filter = PathFilter::new().with_predicate(|_path, container_path| {
547            let parts: Vec<&str> = container_path.split('.').collect();
548            parts.len() == 4 && parts.get(2) == Some(&"Linear")
549        });
550
551        assert!(depth_filter.matches_with_container_path_str(
552            "model.layer.weight",
553            "Model.TransformerBlock.Linear.Param"
554        ));
555        assert!(
556            !depth_filter
557                .matches_with_container_path_str("model.weight", "Model.TransformerBlock.Conv2d")
558        ); // Too shallow
559
560        // Filter that checks any Linear in the path (not just the last)
561        let any_linear = PathFilter::new()
562            .with_predicate(|_path, container_path| container_path.contains("Linear"));
563
564        assert!(
565            any_linear.matches_with_container_path_str(
566                "some.path",
567                "Model.TransformerBlock.Linear.Param"
568            )
569        );
570        assert!(
571            any_linear.matches_with_container_path_str("other.path", "Model.Decoder.Linear.Param")
572        );
573        assert!(
574            !any_linear.matches_with_container_path_str(
575                "conv.path",
576                "Model.TransformerBlock.Conv2d.Param"
577            )
578        );
579    }
580
581    #[test]
582    fn container_path_dot_notation() {
583        // Filter using dot-notated container path
584        let dot_filter = PathFilter::new().with_predicate(|_path, container_path| {
585            container_path.starts_with("Model.TransformerBlock")
586        });
587
588        // Test with matches_with_container_path
589        assert!(
590            dot_filter.matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
591        );
592        assert!(!dot_filter.matches_with_container_path_str("weight", "Model.Decoder.Linear"));
593
594        // Filter that checks for specific patterns in container path
595        let pattern_filter = PathFilter::new().with_predicate(|_path, container_path| {
596            // Match any path that has Linear after a block
597            container_path.contains("Block.Linear") || container_path.contains("Block.Conv")
598        });
599
600        assert!(
601            pattern_filter
602                .matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
603        );
604        assert!(pattern_filter.matches_with_container_path_str("weight", "Model.ResBlock.Conv2d"));
605        assert!(!pattern_filter.matches_with_container_path_str("weight", "Model.Linear.Param"));
606
607        // Filter combining path and container path patterns
608        let combined = PathFilter::new().with_predicate(|path, container_path| {
609            // Only weights in Linear layers that are inside blocks
610            path.ends_with(".weight")
611                && container_path.contains("Block")
612                && container_path.split('.').next_back() == Some("Linear")
613        });
614
615        assert!(
616            combined
617                .matches_with_container_path_str("layer.weight", "Model.TransformerBlock.Linear")
618        );
619        assert!(
620            !combined
621                .matches_with_container_path_str("layer.bias", "Model.TransformerBlock.Linear")
622        );
623        assert!(!combined.matches_with_container_path_str("layer.weight", "Model.Decoder.Linear"));
624    }
625}