1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DecisionPattern {
13 pub pattern_id: String,
15 pub description: String,
17 pub feature_weights: Vec<f32>,
19 pub confidence: f32,
21 pub category: String,
23}
24
25impl DecisionPattern {
26 #[must_use]
28 pub fn new(
29 pattern_id: impl Into<String>,
30 description: impl Into<String>,
31 feature_weights: Vec<f32>,
32 confidence: f32,
33 category: impl Into<String>,
34 ) -> Self {
35 Self {
36 pattern_id: pattern_id.into(),
37 description: description.into(),
38 feature_weights,
39 confidence: confidence.clamp(0.0, 1.0),
40 category: category.into(),
41 }
42 }
43}
44
45#[derive(Debug, Clone, Default)]
67pub struct PatternStore {
68 patterns: HashMap<String, DecisionPattern>,
69}
70
71impl PatternStore {
72 #[must_use]
74 pub fn new() -> Self {
75 Self { patterns: HashMap::new() }
76 }
77
78 pub fn add_pattern(&mut self, pattern: DecisionPattern) {
82 self.patterns.insert(pattern.pattern_id.clone(), pattern);
83 }
84
85 #[must_use]
87 pub fn get_pattern(&self, id: &str) -> Option<&DecisionPattern> {
88 self.patterns.get(id)
89 }
90
91 #[must_use]
100 pub fn search(&self, query_features: &[f32], top_k: usize) -> Vec<&DecisionPattern> {
101 let mut scored: Vec<(f32, &DecisionPattern)> = self
102 .patterns
103 .values()
104 .map(|p| {
105 let sim = cosine_similarity(query_features, &p.feature_weights);
106 (sim, p)
107 })
108 .collect();
109
110 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
112 scored.truncate(top_k);
113 scored.into_iter().map(|(_, p)| p).collect()
114 }
115
116 #[must_use]
118 pub fn list_patterns(&self) -> Vec<&DecisionPattern> {
119 self.patterns.values().collect()
120 }
121
122 pub fn remove_pattern(&mut self, id: &str) -> Option<DecisionPattern> {
124 self.patterns.remove(id)
125 }
126
127 #[must_use]
129 pub fn len(&self) -> usize {
130 self.patterns.len()
131 }
132
133 #[must_use]
135 pub fn is_empty(&self) -> bool {
136 self.patterns.is_empty()
137 }
138}
139
140fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
144 if a.len() != b.len() {
145 return 0.0;
146 }
147
148 let mut dot = 0.0_f32;
149 let mut norm_a = 0.0_f32;
150 let mut norm_b = 0.0_f32;
151
152 for i in 0..a.len() {
153 dot += a[i] * b[i];
154 norm_a += a[i] * a[i];
155 norm_b += b[i] * b[i];
156 }
157
158 let denom = norm_a.sqrt() * norm_b.sqrt();
159 if denom == 0.0 {
160 0.0
161 } else {
162 dot / denom
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 fn make_pattern(id: &str, weights: Vec<f32>, category: &str) -> DecisionPattern {
171 DecisionPattern::new(id, format!("desc_{id}"), weights, 0.8, category)
172 }
173
174 #[test]
175 fn test_add_and_get_pattern() {
176 let mut store = PatternStore::new();
177 let p = make_pattern("p1", vec![1.0, 0.0], "cat_a");
178 store.add_pattern(p);
179
180 let retrieved = store.get_pattern("p1");
181 assert!(retrieved.is_some());
182 assert_eq!(retrieved.expect("operation should succeed").pattern_id, "p1");
183 assert_eq!(retrieved.expect("operation should succeed").category, "cat_a");
184 }
185
186 #[test]
187 fn test_get_nonexistent_pattern() {
188 let store = PatternStore::new();
189 assert!(store.get_pattern("missing").is_none());
190 }
191
192 #[test]
193 fn test_add_replaces_existing() {
194 let mut store = PatternStore::new();
195 store.add_pattern(make_pattern("p1", vec![1.0], "old"));
196 store.add_pattern(make_pattern("p1", vec![2.0], "new"));
197
198 assert_eq!(store.len(), 1);
199 assert_eq!(store.get_pattern("p1").expect("operation should succeed").category, "new");
200 }
201
202 #[test]
203 fn test_remove_pattern() {
204 let mut store = PatternStore::new();
205 store.add_pattern(make_pattern("p1", vec![1.0], "a"));
206 store.add_pattern(make_pattern("p2", vec![2.0], "b"));
207
208 let removed = store.remove_pattern("p1");
209 assert!(removed.is_some());
210 assert_eq!(removed.expect("operation should succeed").pattern_id, "p1");
211 assert_eq!(store.len(), 1);
212 assert!(store.get_pattern("p1").is_none());
213 }
214
215 #[test]
216 fn test_remove_nonexistent() {
217 let mut store = PatternStore::new();
218 assert!(store.remove_pattern("ghost").is_none());
219 }
220
221 #[test]
222 fn test_list_patterns() {
223 let mut store = PatternStore::new();
224 store.add_pattern(make_pattern("p1", vec![1.0], "a"));
225 store.add_pattern(make_pattern("p2", vec![2.0], "b"));
226
227 let list = store.list_patterns();
228 assert_eq!(list.len(), 2);
229 let ids: Vec<&str> = list.iter().map(|p| p.pattern_id.as_str()).collect();
230 assert!(ids.contains(&"p1"));
231 assert!(ids.contains(&"p2"));
232 }
233
234 #[test]
235 fn test_len_and_is_empty() {
236 let mut store = PatternStore::new();
237 assert!(store.is_empty());
238 assert_eq!(store.len(), 0);
239
240 store.add_pattern(make_pattern("p1", vec![1.0], "a"));
241 assert!(!store.is_empty());
242 assert_eq!(store.len(), 1);
243 }
244
245 #[test]
246 fn test_cosine_similarity_identical() {
247 let sim = cosine_similarity(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]);
248 assert!((sim - 1.0).abs() < 1e-6);
249 }
250
251 #[test]
252 fn test_cosine_similarity_orthogonal() {
253 let sim = cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]);
254 assert!(sim.abs() < 1e-6);
255 }
256
257 #[test]
258 fn test_cosine_similarity_opposite() {
259 let sim = cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]);
260 assert!((sim - (-1.0)).abs() < 1e-6);
261 }
262
263 #[test]
264 fn test_cosine_similarity_different_lengths() {
265 let sim = cosine_similarity(&[1.0, 2.0], &[1.0]);
266 assert_eq!(sim, 0.0);
267 }
268
269 #[test]
270 fn test_cosine_similarity_zero_vector() {
271 let sim = cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]);
272 assert_eq!(sim, 0.0);
273 }
274
275 #[test]
276 fn test_search_returns_top_k() {
277 let mut store = PatternStore::new();
278 store.add_pattern(make_pattern("close", vec![0.9, 0.1, 0.0], "a"));
279 store.add_pattern(make_pattern("exact", vec![1.0, 0.0, 0.0], "b"));
280 store.add_pattern(make_pattern("far", vec![0.0, 0.0, 1.0], "c"));
281
282 let results = store.search(&[1.0, 0.0, 0.0], 2);
283 assert_eq!(results.len(), 2);
284 assert_eq!(results[0].pattern_id, "exact");
286 assert_eq!(results[1].pattern_id, "close");
287 }
288
289 #[test]
290 fn test_search_top_k_larger_than_store() {
291 let mut store = PatternStore::new();
292 store.add_pattern(make_pattern("p1", vec![1.0, 0.0], "a"));
293
294 let results = store.search(&[1.0, 0.0], 10);
295 assert_eq!(results.len(), 1);
296 }
297
298 #[test]
299 fn test_search_empty_store() {
300 let store = PatternStore::new();
301 let results = store.search(&[1.0, 0.0], 5);
302 assert!(results.is_empty());
303 }
304
305 #[test]
306 fn test_search_with_mismatched_dimensions() {
307 let mut store = PatternStore::new();
308 store.add_pattern(make_pattern("p1", vec![1.0, 0.0, 0.0], "a"));
310
311 let results = store.search(&[1.0, 0.0], 5);
312 assert_eq!(results.len(), 1);
314 }
315
316 #[test]
317 fn test_confidence_clamped() {
318 let p = DecisionPattern::new("id", "desc", vec![], 1.5, "cat");
319 assert_eq!(p.confidence, 1.0);
320
321 let p2 = DecisionPattern::new("id2", "desc", vec![], -0.5, "cat");
322 assert_eq!(p2.confidence, 0.0);
323 }
324
325 #[test]
326 fn test_default_store() {
327 let store = PatternStore::default();
328 assert!(store.is_empty());
329 }
330}