basic_pattern_matcher/
lib.rs

1use std::collections::{HashMap, HashSet};
2
3#[derive(Debug, Default)]
4pub struct TrieNode {
5    // Children for exact segment matches (e.g., "stock", "nyse")
6    children: HashMap<String, TrieNode>,
7    // Child for single-level wildcard '*'
8    star_child: Option<Box<TrieNode>>,
9    // Child for multi-level wildcard '**'
10    // Note: '**' must be the last segment in a pattern branch,
11    // or intermediate, allowing matches further down.
12    double_star_child: Option<Box<TrieNode>>,
13    // Indices into the PatternMatcher's patterns_with_data Vec
14    pattern_indices: Vec<usize>,
15}
16
17#[derive(Default)]
18pub struct PatternMatcher<T> { // Make struct generic over T
19    root: TrieNode,
20    patterns_with_data: Vec<(String, T)>, // Store (pattern_string, associated_data)
21}
22
23// Implement methods for the generic PatternMatcher<T>
24impl<T> PatternMatcher<T> {
25    pub fn new() -> Self {
26        PatternMatcher {
27            root: TrieNode::default(),
28            patterns_with_data: Vec::new(),
29        }
30    }
31
32    /// Adds a subscription pattern and its associated data to the matcher.
33    pub fn add_pattern(&mut self, pattern: &str, data: T) { // Accept data T
34        if pattern.is_empty() {
35            return; // Or handle as needed
36        }
37
38        // Store the pattern and data, get its index
39        let pattern_index = self.patterns_with_data.len();
40        self.patterns_with_data.push((pattern.to_string(), data));
41
42        let segments: Vec<&str> = pattern.split('.').collect();
43        let mut current_node = &mut self.root;
44
45        for (i, segment) in segments.iter().enumerate() {
46            match *segment {
47                "*" => {
48                    current_node = current_node.star_child.get_or_insert_with(Default::default);
49                }
50                "**" => {
51                    if i != segments.len() - 1 {
52                         // Allow intermediate '**' structurally
53                    }
54                    current_node = current_node.double_star_child.get_or_insert_with(Default::default);
55                }
56                exact => {
57                    current_node = current_node.children.entry(exact.to_string()).or_default();
58                }
59            }
60        }
61        // Mark the end of the pattern using its index
62        current_node.pattern_indices.push(pattern_index);
63    }
64
65    /// Finds all patterns that match the given topic and returns pairs of (pattern, data).
66    pub fn match_topic(&self, topic: &str) -> Vec<(&str, &T)> { // Return Vec<(&str, &T)>
67        if topic.is_empty() {
68            return vec![];
69        }
70
71        let segments: Vec<&str> = topic.split('.').collect();
72        let mut matched_indices = HashSet::new(); // Still collect indices
73
74        // Start the recursive search (logic remains the same)
75        self.find_matches_recursive(&self.root, &segments, 0, &mut matched_indices);
76
77        // Convert indices back to (pattern string, data) references
78        matched_indices
79            .into_iter()
80            .map(|index| {
81                let (pattern_str, data) = &self.patterns_with_data[index];
82                (pattern_str.as_str(), data) // Return refs: (&str, &T)
83            })
84            .collect()
85    }
86
87    // Recursive helper function for matching - signature stays the same
88    // It only populates matched_indices (Vec<usize>)
89    fn find_matches_recursive(
90        &self,
91        node: &TrieNode,
92        segments: &[&str],
93        segment_index: usize,
94        matched_indices: &mut HashSet<usize>,
95    ) {
96         // --- Match patterns involving '**' ---
97        if let Some(ds_child) = &node.double_star_child {
98            // 1. '**' matches everything from current segment_index onwards.
99            self.collect_all_terminal_patterns(ds_child, matched_indices);
100
101            // 2. '**' matches zero or more segments, then the rest of the pattern.
102            if segment_index < segments.len() {
103                 self.find_matches_recursive(ds_child, segments, segment_index, matched_indices);
104            }
105             // Case: Pattern like "a.**" matching topic "a"
106             // If the topic ends exactly where '**' begins in the pattern.
107             else if segment_index == segments.len() {
108                 self.collect_all_terminal_patterns(ds_child, matched_indices);
109             }
110        }
111
112        // --- Base Case: End of topic reached ---
113        if segment_index == segments.len() {
114            // Add patterns ending exactly at this node
115            matched_indices.extend(node.pattern_indices.iter().cloned());
116
117            // Also, if a pattern ending in '**' led here, that '**' matches zero
118            // remaining segments. Check the double_star_child's patterns.
119            // This case is subtly handled by the collect_all_terminal_patterns call
120            // at the beginning of the function if the '**' node was reached *before*
121            // exhausting the topic segments. If we arrive *at* the end of the topic
122            // and the current node has a '**' child, that '**' child represents patterns
123            // ending in '**' which should match.
124             if let Some(ds_child) = &node.double_star_child {
125                  // Add patterns ending *exactly* at the double star node itself.
126                  // Patterns deeper within the double_star tree were handled by collect_all_terminal_patterns
127                  // at the top if ds_child existed.
128                 matched_indices.extend(ds_child.pattern_indices.iter().cloned());
129             }
130            return;
131        }
132
133
134        // --- Recursive Step: Match current segment ---
135        let current_segment = segments[segment_index];
136
137        // 1. Match exact segment
138        if let Some(child) = node.children.get(current_segment) {
139            self.find_matches_recursive(child, segments, segment_index + 1, matched_indices);
140        }
141
142        // 2. Match single-level wildcard '*'
143        if let Some(star_child) = &node.star_child {
144            self.find_matches_recursive(star_child, segments, segment_index + 1, matched_indices);
145        }
146
147        // 3. Match multi-level wildcard '**' (already handled at the start of the function)
148        // The logic at the start covers the '**' matching one or more segments.
149    }
150
151
152    // Helper to collect all pattern indices in the subtree rooted at 'node'
153    // Signature stays the same, works with indices.
154    fn collect_all_terminal_patterns(
155        &self,
156        node: &TrieNode,
157        matched_indices: &mut HashSet<usize>,
158    ) {
159        // Add patterns ending at this node
160        matched_indices.extend(node.pattern_indices.iter().cloned());
161
162        // Recursively explore children
163        for child in node.children.values() {
164            self.collect_all_terminal_patterns(child, matched_indices);
165        }
166        if let Some(star_child) = &node.star_child {
167            self.collect_all_terminal_patterns(star_child, matched_indices);
168        }
169         if let Some(ds_child) = &node.double_star_child {
170            self.collect_all_terminal_patterns(ds_child, matched_indices);
171        }
172    }
173}