1use alloc::format;
2use alloc::string::String;
3use alloc::vec::Vec;
4use core::fmt;
5
6#[cfg(feature = "std")]
7use regex::Regex;
8
9#[derive(Debug, Clone, Default)]
30pub struct PathFilter {
31 #[cfg(feature = "std")]
33 regex_patterns: Vec<Regex>,
34
35 exact_paths: Vec<String>,
37
38 predicates: Vec<fn(&str, &str) -> bool>,
41
42 match_all: bool,
44}
45
46impl PathFilter {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn all() -> Self {
54 Self {
55 match_all: true,
56 ..Default::default()
57 }
58 }
59
60 pub fn none() -> Self {
62 Self::default()
63 }
64
65 #[cfg(feature = "std")]
67 pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
68 if let Ok(regex) = Regex::new(pattern.as_ref()) {
69 self.regex_patterns.push(regex);
70 }
71 self
73 }
74
75 #[cfg(feature = "std")]
77 pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
78 where
79 I: IntoIterator<Item = S>,
80 S: AsRef<str>,
81 {
82 for pattern in patterns {
83 if let Ok(regex) = Regex::new(pattern.as_ref()) {
84 self.regex_patterns.push(regex);
85 }
86 }
87 self
88 }
89
90 pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
92 self.exact_paths.push(path.into());
93 self
94 }
95
96 pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
98 where
99 I: IntoIterator<Item = S>,
100 S: Into<String>,
101 {
102 self.exact_paths.extend(paths.into_iter().map(|p| p.into()));
103 self
104 }
105
106 pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
108 self.predicates.push(predicate);
109 self
110 }
111
112 pub fn with_predicates<I>(mut self, predicates: I) -> Self
114 where
115 I: IntoIterator<Item = fn(&str, &str) -> bool>,
116 {
117 self.predicates.extend(predicates);
118 self
119 }
120
121 pub fn match_all(mut self) -> Self {
123 self.match_all = true;
124 self
125 }
126
127 pub fn matches(&self, path: &str) -> bool {
129 self.matches_with_container_path_str(path, "")
130 }
131
132 pub fn matches_with_container(&self, path: &str, container_type: &str) -> bool {
134 self.matches_with_container_path_str(path, container_type)
136 }
137
138 pub fn matches_with_container_path(&self, path: &[String], container_stack: &[String]) -> bool {
140 let path_str = path.join(".");
141 let container_path = container_stack.join(".");
142 self.matches_with_container_path_str(&path_str, &container_path)
143 }
144
145 pub fn matches_with_container_path_str(&self, path: &str, container_path: &str) -> bool {
147 if self.match_all {
149 return true;
150 }
151
152 if self.is_empty() {
154 return false;
155 }
156
157 if self.exact_paths.iter().any(|p| p == path) {
159 return true;
160 }
161
162 #[cfg(feature = "std")]
164 {
165 for regex in &self.regex_patterns {
166 if regex.is_match(path) {
167 return true;
168 }
169 }
170 }
171
172 if self
174 .predicates
175 .iter()
176 .any(|pred| pred(path, container_path))
177 {
178 return true;
179 }
180
181 false
182 }
183
184 pub fn is_empty(&self) -> bool {
186 if self.match_all {
187 return false;
188 }
189
190 #[cfg(feature = "std")]
191 let regex_empty = self.regex_patterns.is_empty();
192 #[cfg(not(feature = "std"))]
193 let regex_empty = true;
194
195 self.exact_paths.is_empty() && self.predicates.is_empty() && regex_empty
196 }
197
198 pub fn criteria_count(&self) -> usize {
200 if self.match_all {
201 return 1;
202 }
203
204 #[allow(unused_mut)]
205 let mut count = self.exact_paths.len() + self.predicates.len();
206
207 #[cfg(feature = "std")]
208 {
209 count += self.regex_patterns.len();
210 }
211
212 count
213 }
214
215 #[cfg(feature = "std")]
217 pub fn clear_regex(&mut self) -> &mut Self {
218 self.regex_patterns.clear();
219 self
220 }
221
222 pub fn clear_paths(&mut self) -> &mut Self {
224 self.exact_paths.clear();
225 self
226 }
227
228 pub fn clear_predicates(&mut self) -> &mut Self {
230 self.predicates.clear();
231 self
232 }
233
234 pub fn clear(&mut self) -> &mut Self {
236 #[cfg(feature = "std")]
237 self.clear_regex();
238
239 self.clear_paths().clear_predicates();
240 self.match_all = false;
241 self
242 }
243
244 #[cfg(feature = "std")]
246 pub fn from_regex_patterns<I, S>(patterns: I) -> Self
247 where
248 I: IntoIterator<Item = S>,
249 S: AsRef<str>,
250 {
251 Self::new().with_regexes(patterns)
252 }
253
254 pub fn from_paths<I, S>(paths: I) -> Self
256 where
257 I: IntoIterator<Item = S>,
258 S: Into<String>,
259 {
260 Self::new().with_full_paths(paths)
261 }
262
263 pub fn from_predicate(predicate: fn(&str, &str) -> bool) -> Self {
265 Self::new().with_predicate(predicate)
266 }
267
268 pub fn or(mut self, other: Self) -> Self {
270 if self.match_all || other.match_all {
271 return Self::all();
272 }
273
274 #[cfg(feature = "std")]
275 {
276 self.regex_patterns.extend(other.regex_patterns);
277 }
278
279 self.exact_paths.extend(other.exact_paths);
280 self.predicates.extend(other.predicates);
281
282 self
283 }
284}
285
286impl fmt::Display for PathFilter {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 if self.match_all {
289 return write!(f, "PathFilter::all()");
290 }
291
292 if self.is_empty() {
293 return write!(f, "PathFilter::none()");
294 }
295
296 write!(f, "PathFilter[")?;
297
298 let mut parts = Vec::new();
299
300 #[cfg(feature = "std")]
301 if !self.regex_patterns.is_empty() {
302 parts.push(format!("regex: {:?}", self.regex_patterns));
303 }
304
305 if !self.exact_paths.is_empty() {
306 parts.push(format!("paths: {:?}", self.exact_paths));
307 }
308
309 if !self.predicates.is_empty() {
310 parts.push(format!("predicates: {}", self.predicates.len()));
311 }
312
313 write!(f, "{}]", parts.join(", "))
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn empty_filter() {
323 let filter = PathFilter::new();
324 assert!(filter.is_empty());
325 assert!(!filter.matches("encoder.weight"));
326 assert!(!filter.matches("decoder.bias"));
327 }
328
329 #[test]
330 fn match_all() {
331 let filter = PathFilter::all();
332 assert!(!filter.is_empty());
333 assert!(filter.matches("encoder.weight"));
334 assert!(filter.matches("decoder.bias"));
335 assert!(filter.matches("anything"));
336 }
337
338 #[test]
339 fn exact_paths() {
340 let filter = PathFilter::new()
341 .with_full_path("encoder.weight")
342 .with_full_path("decoder.bias");
343
344 assert!(filter.matches("encoder.weight"));
345 assert!(filter.matches("decoder.bias"));
346 assert!(!filter.matches("encoder.bias"));
347 assert!(!filter.matches("decoder.weight"));
348 }
349
350 #[test]
351 #[cfg(feature = "std")]
352 fn regex_patterns() {
353 let filter = PathFilter::new()
354 .with_regex(r"^encoder\..*")
355 .with_regex(r".*\.weight$");
356
357 assert!(filter.matches("encoder.layer1.bias"));
358 assert!(filter.matches("decoder.weight"));
359 assert!(filter.matches("encoder.weight"));
360 assert!(!filter.matches("decoder.bias"));
361 }
362
363 #[test]
364 fn predicates() {
365 fn contains_norm(path: &str, _container_path: &str) -> bool {
366 path.contains("norm")
367 }
368
369 fn is_short(path: &str, _container_path: &str) -> bool {
370 path.len() < 10
371 }
372
373 let filter = PathFilter::new()
374 .with_predicate(contains_norm)
375 .with_predicate(is_short);
376
377 assert!(filter.matches("norm.weight"));
378 assert!(filter.matches("layer.norm.bias"));
379 assert!(filter.matches("bias"));
380 assert!(!filter.matches("encoder.decoder.weight.long.name"));
381 }
382
383 #[test]
384 fn combined_filters() {
385 let filter = PathFilter::new()
386 .with_full_path("special.tensor")
387 .with_predicate(|path, _container_path| path.contains("attention"));
388
389 #[cfg(feature = "std")]
390 let filter = filter.with_regex(r"^encoder\..*");
391
392 assert!(filter.matches("special.tensor"));
393 assert!(filter.matches("self_attention.query"));
394
395 #[cfg(feature = "std")]
396 assert!(filter.matches("encoder.anything"));
397
398 assert!(!filter.matches("decoder.weight"));
399 }
400
401 #[test]
402 fn or_combination() {
403 let encoder_filter = PathFilter::new().with_full_path("encoder.weight");
404 let decoder_filter = PathFilter::new().with_full_path("decoder.bias");
405
406 let combined = encoder_filter.or(decoder_filter);
407
408 assert!(combined.matches("encoder.weight"));
409 assert!(combined.matches("decoder.bias"));
410 assert!(!combined.matches("model.head.weight"));
411 }
412
413 #[test]
414 #[cfg(feature = "std")]
415 fn common_patterns() {
416 let encoder = PathFilter::new().with_regex(r"^encoder\..*");
418 assert!(encoder.matches("encoder.weight"));
419 assert!(!encoder.matches("decoder.weight"));
420
421 let weights = PathFilter::new().with_regex(r".*\.weight$");
423 assert!(weights.matches("encoder.weight"));
424 assert!(weights.matches("decoder.weight"));
425 assert!(!weights.matches("encoder.bias"));
426
427 let layers = PathFilter::new()
429 .with_regex(r"(^|.*\.)layers\.0\.")
430 .with_regex(r"(^|.*\.)layers\.2\.")
431 .with_regex(r"(^|.*\.)layers\.4\.");
432 assert!(layers.matches("model.layers.0.weight"));
433 assert!(layers.matches("layers.2.bias"));
434 assert!(!layers.matches("layers.1.weight"));
435 }
436
437 #[test]
438 fn criteria_count() {
439 let filter = PathFilter::new()
440 .with_full_path("path1")
441 .with_full_path("path2")
442 .with_predicate(|_, _| true);
443
444 #[cfg(feature = "std")]
445 let filter = filter.with_regex(".*");
446
447 #[cfg(feature = "std")]
448 assert_eq!(filter.criteria_count(), 4);
449
450 #[cfg(not(feature = "std"))]
451 assert_eq!(filter.criteria_count(), 3);
452 }
453
454 #[test]
455 fn clear_operations() {
456 let mut filter = PathFilter::new().with_full_path("test");
457
458 filter.clear_paths();
459 assert!(!filter.matches("test"));
460
461 filter.clear();
462 assert!(filter.is_empty());
463 }
464
465 #[test]
466 fn container_predicates() {
467 let linear_weights = PathFilter::new().with_predicate(|path, container_path| {
469 container_path.split('.').next_back() == Some("Linear") && path.ends_with(".weight")
470 });
471
472 assert!(linear_weights.matches_with_container("layer1.weight", "Linear"));
473 assert!(!linear_weights.matches_with_container("layer1.weight", "Conv2d"));
474 assert!(!linear_weights.matches_with_container("layer1.bias", "Linear"));
475
476 let conv_only = PathFilter::new().with_predicate(|_path, container_path| {
478 let last = container_path.split('.').next_back();
479 last == Some("Conv2d") || last == Some("ConvTranspose2d")
480 });
481
482 assert!(conv_only.matches_with_container("encoder.weight", "Conv2d"));
483 assert!(conv_only.matches_with_container("decoder.weight", "ConvTranspose2d"));
484 assert!(!conv_only.matches_with_container("fc.weight", "Linear"));
485
486 let combined = PathFilter::new()
488 .with_predicate(|path, _container_path| path.starts_with("encoder."))
489 .with_predicate(|_path, container_path| {
490 container_path.split('.').next_back() == Some("BatchNorm2d")
491 });
492
493 assert!(combined.matches_with_container("encoder.layer1", "Linear"));
495 assert!(combined.matches_with_container("decoder.bn", "BatchNorm2d"));
496 assert!(!combined.matches_with_container("decoder.layer", "Linear"));
497 }
498
499 #[test]
500 fn container_predicate_with_regex() {
501 #[cfg(feature = "std")]
503 {
504 let filter = PathFilter::new()
505 .with_regex(r"^encoder\..*")
506 .with_predicate(|path, container_path| {
507 container_path.split('.').next_back() == Some("Linear")
508 && path.contains(".bias")
509 });
510
511 assert!(filter.matches_with_container("encoder.layer1.weight", "Conv2d"));
513 assert!(filter.matches_with_container("decoder.fc.bias", "Linear"));
515 assert!(!filter.matches_with_container("decoder.conv.weight", "Conv2d"));
517 }
518 }
519
520 #[test]
521 fn container_stack_predicates() {
522 let nested_filter = PathFilter::new().with_predicate(|_path, container_path| {
524 let parts: Vec<&str> = container_path.split('.').collect();
526 parts.len() >= 3
527 && parts[0] == "Model"
528 && parts[1] == "TransformerBlock"
529 && parts[2] == "Linear"
530 });
531
532 assert!(nested_filter.matches_with_container_path_str(
533 "encoder.weight",
534 "Model.TransformerBlock.Linear.Param"
535 ));
536 assert!(
537 !nested_filter
538 .matches_with_container_path_str("decoder.weight", "Model.Decoder.Linear.Param")
539 );
540 assert!(!nested_filter.matches_with_container_path_str(
541 "encoder.weight",
542 "Model.TransformerBlock.Conv2d.Param"
543 ));
544
545 let depth_filter = PathFilter::new().with_predicate(|_path, container_path| {
547 let parts: Vec<&str> = container_path.split('.').collect();
548 parts.len() == 4 && parts.get(2) == Some(&"Linear")
549 });
550
551 assert!(depth_filter.matches_with_container_path_str(
552 "model.layer.weight",
553 "Model.TransformerBlock.Linear.Param"
554 ));
555 assert!(
556 !depth_filter
557 .matches_with_container_path_str("model.weight", "Model.TransformerBlock.Conv2d")
558 ); let any_linear = PathFilter::new()
562 .with_predicate(|_path, container_path| container_path.contains("Linear"));
563
564 assert!(
565 any_linear.matches_with_container_path_str(
566 "some.path",
567 "Model.TransformerBlock.Linear.Param"
568 )
569 );
570 assert!(
571 any_linear.matches_with_container_path_str("other.path", "Model.Decoder.Linear.Param")
572 );
573 assert!(
574 !any_linear.matches_with_container_path_str(
575 "conv.path",
576 "Model.TransformerBlock.Conv2d.Param"
577 )
578 );
579 }
580
581 #[test]
582 fn container_path_dot_notation() {
583 let dot_filter = PathFilter::new().with_predicate(|_path, container_path| {
585 container_path.starts_with("Model.TransformerBlock")
586 });
587
588 assert!(
590 dot_filter.matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
591 );
592 assert!(!dot_filter.matches_with_container_path_str("weight", "Model.Decoder.Linear"));
593
594 let pattern_filter = PathFilter::new().with_predicate(|_path, container_path| {
596 container_path.contains("Block.Linear") || container_path.contains("Block.Conv")
598 });
599
600 assert!(
601 pattern_filter
602 .matches_with_container_path_str("weight", "Model.TransformerBlock.Linear")
603 );
604 assert!(pattern_filter.matches_with_container_path_str("weight", "Model.ResBlock.Conv2d"));
605 assert!(!pattern_filter.matches_with_container_path_str("weight", "Model.Linear.Param"));
606
607 let combined = PathFilter::new().with_predicate(|path, container_path| {
609 path.ends_with(".weight")
611 && container_path.contains("Block")
612 && container_path.split('.').next_back() == Some("Linear")
613 });
614
615 assert!(
616 combined
617 .matches_with_container_path_str("layer.weight", "Model.TransformerBlock.Linear")
618 );
619 assert!(
620 !combined
621 .matches_with_container_path_str("layer.bias", "Model.TransformerBlock.Linear")
622 );
623 assert!(!combined.matches_with_container_path_str("layer.weight", "Model.Decoder.Linear"));
624 }
625}