basic_pattern_matcher/lib.rs
1use std::collections::{HashMap, HashSet};
2
3#[derive(Debug, Default)]
4pub struct TrieNode {
5 // Children for exact segment matches (e.g., "stock", "nyse")
6 children: HashMap<String, TrieNode>,
7 // Child for single-level wildcard '*'
8 star_child: Option<Box<TrieNode>>,
9 // Child for multi-level wildcard '**'
10 // Note: '**' must be the last segment in a pattern branch,
11 // or intermediate, allowing matches further down.
12 double_star_child: Option<Box<TrieNode>>,
13 // Indices into the PatternMatcher's patterns_with_data Vec
14 pattern_indices: Vec<usize>,
15}
16
17#[derive(Default)]
18pub struct PatternMatcher<T> { // Make struct generic over T
19 root: TrieNode,
20 patterns_with_data: Vec<(String, T)>, // Store (pattern_string, associated_data)
21}
22
23// Implement methods for the generic PatternMatcher<T>
24impl<T> PatternMatcher<T> {
25 pub fn new() -> Self {
26 PatternMatcher {
27 root: TrieNode::default(),
28 patterns_with_data: Vec::new(),
29 }
30 }
31
32 /// Adds a subscription pattern and its associated data to the matcher.
33 pub fn add_pattern(&mut self, pattern: &str, data: T) { // Accept data T
34 if pattern.is_empty() {
35 return; // Or handle as needed
36 }
37
38 // Store the pattern and data, get its index
39 let pattern_index = self.patterns_with_data.len();
40 self.patterns_with_data.push((pattern.to_string(), data));
41
42 let segments: Vec<&str> = pattern.split('.').collect();
43 let mut current_node = &mut self.root;
44
45 for (i, segment) in segments.iter().enumerate() {
46 match *segment {
47 "*" => {
48 current_node = current_node.star_child.get_or_insert_with(Default::default);
49 }
50 "**" => {
51 if i != segments.len() - 1 {
52 // Allow intermediate '**' structurally
53 }
54 current_node = current_node.double_star_child.get_or_insert_with(Default::default);
55 }
56 exact => {
57 current_node = current_node.children.entry(exact.to_string()).or_default();
58 }
59 }
60 }
61 // Mark the end of the pattern using its index
62 current_node.pattern_indices.push(pattern_index);
63 }
64
65 /// Finds all patterns that match the given topic and returns pairs of (pattern, data).
66 pub fn match_topic(&self, topic: &str) -> Vec<(&str, &T)> { // Return Vec<(&str, &T)>
67 if topic.is_empty() {
68 return vec![];
69 }
70
71 let segments: Vec<&str> = topic.split('.').collect();
72 let mut matched_indices = HashSet::new(); // Still collect indices
73
74 // Start the recursive search (logic remains the same)
75 self.find_matches_recursive(&self.root, &segments, 0, &mut matched_indices);
76
77 // Convert indices back to (pattern string, data) references
78 matched_indices
79 .into_iter()
80 .map(|index| {
81 let (pattern_str, data) = &self.patterns_with_data[index];
82 (pattern_str.as_str(), data) // Return refs: (&str, &T)
83 })
84 .collect()
85 }
86
87 // Recursive helper function for matching - signature stays the same
88 // It only populates matched_indices (Vec<usize>)
89 fn find_matches_recursive(
90 &self,
91 node: &TrieNode,
92 segments: &[&str],
93 segment_index: usize,
94 matched_indices: &mut HashSet<usize>,
95 ) {
96 // --- Match patterns involving '**' ---
97 if let Some(ds_child) = &node.double_star_child {
98 // 1. '**' matches everything from current segment_index onwards.
99 self.collect_all_terminal_patterns(ds_child, matched_indices);
100
101 // 2. '**' matches zero or more segments, then the rest of the pattern.
102 if segment_index < segments.len() {
103 self.find_matches_recursive(ds_child, segments, segment_index, matched_indices);
104 }
105 // Case: Pattern like "a.**" matching topic "a"
106 // If the topic ends exactly where '**' begins in the pattern.
107 else if segment_index == segments.len() {
108 self.collect_all_terminal_patterns(ds_child, matched_indices);
109 }
110 }
111
112 // --- Base Case: End of topic reached ---
113 if segment_index == segments.len() {
114 // Add patterns ending exactly at this node
115 matched_indices.extend(node.pattern_indices.iter().cloned());
116
117 // Also, if a pattern ending in '**' led here, that '**' matches zero
118 // remaining segments. Check the double_star_child's patterns.
119 // This case is subtly handled by the collect_all_terminal_patterns call
120 // at the beginning of the function if the '**' node was reached *before*
121 // exhausting the topic segments. If we arrive *at* the end of the topic
122 // and the current node has a '**' child, that '**' child represents patterns
123 // ending in '**' which should match.
124 if let Some(ds_child) = &node.double_star_child {
125 // Add patterns ending *exactly* at the double star node itself.
126 // Patterns deeper within the double_star tree were handled by collect_all_terminal_patterns
127 // at the top if ds_child existed.
128 matched_indices.extend(ds_child.pattern_indices.iter().cloned());
129 }
130 return;
131 }
132
133
134 // --- Recursive Step: Match current segment ---
135 let current_segment = segments[segment_index];
136
137 // 1. Match exact segment
138 if let Some(child) = node.children.get(current_segment) {
139 self.find_matches_recursive(child, segments, segment_index + 1, matched_indices);
140 }
141
142 // 2. Match single-level wildcard '*'
143 if let Some(star_child) = &node.star_child {
144 self.find_matches_recursive(star_child, segments, segment_index + 1, matched_indices);
145 }
146
147 // 3. Match multi-level wildcard '**' (already handled at the start of the function)
148 // The logic at the start covers the '**' matching one or more segments.
149 }
150
151
152 // Helper to collect all pattern indices in the subtree rooted at 'node'
153 // Signature stays the same, works with indices.
154 fn collect_all_terminal_patterns(
155 &self,
156 node: &TrieNode,
157 matched_indices: &mut HashSet<usize>,
158 ) {
159 // Add patterns ending at this node
160 matched_indices.extend(node.pattern_indices.iter().cloned());
161
162 // Recursively explore children
163 for child in node.children.values() {
164 self.collect_all_terminal_patterns(child, matched_indices);
165 }
166 if let Some(star_child) = &node.star_child {
167 self.collect_all_terminal_patterns(star_child, matched_indices);
168 }
169 if let Some(ds_child) = &node.double_star_child {
170 self.collect_all_terminal_patterns(ds_child, matched_indices);
171 }
172 }
173}