Skip to main content

entrenar/decision/
pattern_store.rs

1//! Decision pattern storage with cosine similarity search.
2//!
3//! Stores `DecisionPattern` instances indexed by a unique `pattern_id`,
4//! and retrieves the top-k most similar patterns to a query feature vector
5//! using cosine similarity over the pattern's `feature_weights`.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// A decision pattern with feature weights and metadata.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DecisionPattern {
13    /// Unique identifier for this pattern.
14    pub pattern_id: String,
15    /// Human-readable description of this pattern.
16    pub description: String,
17    /// Feature weight vector used for similarity search.
18    pub feature_weights: Vec<f32>,
19    /// Confidence score in range [0.0, 1.0].
20    pub confidence: f32,
21    /// Category label for this pattern.
22    pub category: String,
23}
24
25impl DecisionPattern {
26    /// Create a new decision pattern.
27    #[must_use]
28    pub fn new(
29        pattern_id: impl Into<String>,
30        description: impl Into<String>,
31        feature_weights: Vec<f32>,
32        confidence: f32,
33        category: impl Into<String>,
34    ) -> Self {
35        Self {
36            pattern_id: pattern_id.into(),
37            description: description.into(),
38            feature_weights,
39            confidence: confidence.clamp(0.0, 1.0),
40            category: category.into(),
41        }
42    }
43}
44
45/// Store for decision patterns with cosine-similarity retrieval.
46///
47/// Patterns are stored in a `HashMap` keyed by `pattern_id`.
48/// The `search` method computes cosine similarity between the query
49/// feature vector and every stored pattern, returning the top-k results
50/// sorted by descending similarity.
51///
52/// # Example
53///
54/// ```
55/// use entrenar::decision::{DecisionPattern, PatternStore};
56///
57/// let mut store = PatternStore::new();
58/// store.add_pattern(DecisionPattern::new(
59///     "p1", "type fix", vec![1.0, 0.0, 0.0], 0.9, "type_error",
60/// ));
61///
62/// let results = store.search(&[1.0, 0.0, 0.0], 5);
63/// assert_eq!(results.len(), 1);
64/// assert_eq!(results[0].pattern_id, "p1");
65/// ```
66#[derive(Debug, Clone, Default)]
67pub struct PatternStore {
68    patterns: HashMap<String, DecisionPattern>,
69}
70
71impl PatternStore {
72    /// Create an empty pattern store.
73    #[must_use]
74    pub fn new() -> Self {
75        Self { patterns: HashMap::new() }
76    }
77
78    /// Add a pattern to the store.
79    ///
80    /// If a pattern with the same `pattern_id` already exists it is replaced.
81    pub fn add_pattern(&mut self, pattern: DecisionPattern) {
82        self.patterns.insert(pattern.pattern_id.clone(), pattern);
83    }
84
85    /// Retrieve a pattern by its id.
86    #[must_use]
87    pub fn get_pattern(&self, id: &str) -> Option<&DecisionPattern> {
88        self.patterns.get(id)
89    }
90
91    /// Search for the top-k patterns most similar to `query_features`.
92    ///
93    /// Similarity is measured by cosine similarity between `query_features`
94    /// and each pattern's `feature_weights`. Patterns whose feature vector
95    /// has a different length than the query, or whose norm is zero, receive
96    /// a similarity of zero.
97    ///
98    /// Returns results sorted by descending similarity, limited to `top_k`.
99    #[must_use]
100    pub fn search(&self, query_features: &[f32], top_k: usize) -> Vec<&DecisionPattern> {
101        let mut scored: Vec<(f32, &DecisionPattern)> = self
102            .patterns
103            .values()
104            .map(|p| {
105                let sim = cosine_similarity(query_features, &p.feature_weights);
106                (sim, p)
107            })
108            .collect();
109
110        // Sort descending by similarity (NaN-safe: treat NaN as less than everything).
111        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
112        scored.truncate(top_k);
113        scored.into_iter().map(|(_, p)| p).collect()
114    }
115
116    /// List all patterns in the store (unordered).
117    #[must_use]
118    pub fn list_patterns(&self) -> Vec<&DecisionPattern> {
119        self.patterns.values().collect()
120    }
121
122    /// Remove a pattern by id. Returns the removed pattern if it existed.
123    pub fn remove_pattern(&mut self, id: &str) -> Option<DecisionPattern> {
124        self.patterns.remove(id)
125    }
126
127    /// Return the number of stored patterns.
128    #[must_use]
129    pub fn len(&self) -> usize {
130        self.patterns.len()
131    }
132
133    /// Return whether the store is empty.
134    #[must_use]
135    pub fn is_empty(&self) -> bool {
136        self.patterns.is_empty()
137    }
138}
139
140/// Compute cosine similarity between two vectors.
141///
142/// Returns 0.0 if the vectors differ in length or either has zero norm.
143fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
144    if a.len() != b.len() {
145        return 0.0;
146    }
147
148    let mut dot = 0.0_f32;
149    let mut norm_a = 0.0_f32;
150    let mut norm_b = 0.0_f32;
151
152    for i in 0..a.len() {
153        dot += a[i] * b[i];
154        norm_a += a[i] * a[i];
155        norm_b += b[i] * b[i];
156    }
157
158    let denom = norm_a.sqrt() * norm_b.sqrt();
159    if denom == 0.0 {
160        0.0
161    } else {
162        dot / denom
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    fn make_pattern(id: &str, weights: Vec<f32>, category: &str) -> DecisionPattern {
171        DecisionPattern::new(id, format!("desc_{id}"), weights, 0.8, category)
172    }
173
174    #[test]
175    fn test_add_and_get_pattern() {
176        let mut store = PatternStore::new();
177        let p = make_pattern("p1", vec![1.0, 0.0], "cat_a");
178        store.add_pattern(p);
179
180        let retrieved = store.get_pattern("p1");
181        assert!(retrieved.is_some());
182        assert_eq!(retrieved.expect("operation should succeed").pattern_id, "p1");
183        assert_eq!(retrieved.expect("operation should succeed").category, "cat_a");
184    }
185
186    #[test]
187    fn test_get_nonexistent_pattern() {
188        let store = PatternStore::new();
189        assert!(store.get_pattern("missing").is_none());
190    }
191
192    #[test]
193    fn test_add_replaces_existing() {
194        let mut store = PatternStore::new();
195        store.add_pattern(make_pattern("p1", vec![1.0], "old"));
196        store.add_pattern(make_pattern("p1", vec![2.0], "new"));
197
198        assert_eq!(store.len(), 1);
199        assert_eq!(store.get_pattern("p1").expect("operation should succeed").category, "new");
200    }
201
202    #[test]
203    fn test_remove_pattern() {
204        let mut store = PatternStore::new();
205        store.add_pattern(make_pattern("p1", vec![1.0], "a"));
206        store.add_pattern(make_pattern("p2", vec![2.0], "b"));
207
208        let removed = store.remove_pattern("p1");
209        assert!(removed.is_some());
210        assert_eq!(removed.expect("operation should succeed").pattern_id, "p1");
211        assert_eq!(store.len(), 1);
212        assert!(store.get_pattern("p1").is_none());
213    }
214
215    #[test]
216    fn test_remove_nonexistent() {
217        let mut store = PatternStore::new();
218        assert!(store.remove_pattern("ghost").is_none());
219    }
220
221    #[test]
222    fn test_list_patterns() {
223        let mut store = PatternStore::new();
224        store.add_pattern(make_pattern("p1", vec![1.0], "a"));
225        store.add_pattern(make_pattern("p2", vec![2.0], "b"));
226
227        let list = store.list_patterns();
228        assert_eq!(list.len(), 2);
229        let ids: Vec<&str> = list.iter().map(|p| p.pattern_id.as_str()).collect();
230        assert!(ids.contains(&"p1"));
231        assert!(ids.contains(&"p2"));
232    }
233
234    #[test]
235    fn test_len_and_is_empty() {
236        let mut store = PatternStore::new();
237        assert!(store.is_empty());
238        assert_eq!(store.len(), 0);
239
240        store.add_pattern(make_pattern("p1", vec![1.0], "a"));
241        assert!(!store.is_empty());
242        assert_eq!(store.len(), 1);
243    }
244
245    #[test]
246    fn test_cosine_similarity_identical() {
247        let sim = cosine_similarity(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]);
248        assert!((sim - 1.0).abs() < 1e-6);
249    }
250
251    #[test]
252    fn test_cosine_similarity_orthogonal() {
253        let sim = cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]);
254        assert!(sim.abs() < 1e-6);
255    }
256
257    #[test]
258    fn test_cosine_similarity_opposite() {
259        let sim = cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]);
260        assert!((sim - (-1.0)).abs() < 1e-6);
261    }
262
263    #[test]
264    fn test_cosine_similarity_different_lengths() {
265        let sim = cosine_similarity(&[1.0, 2.0], &[1.0]);
266        assert_eq!(sim, 0.0);
267    }
268
269    #[test]
270    fn test_cosine_similarity_zero_vector() {
271        let sim = cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]);
272        assert_eq!(sim, 0.0);
273    }
274
275    #[test]
276    fn test_search_returns_top_k() {
277        let mut store = PatternStore::new();
278        store.add_pattern(make_pattern("close", vec![0.9, 0.1, 0.0], "a"));
279        store.add_pattern(make_pattern("exact", vec![1.0, 0.0, 0.0], "b"));
280        store.add_pattern(make_pattern("far", vec![0.0, 0.0, 1.0], "c"));
281
282        let results = store.search(&[1.0, 0.0, 0.0], 2);
283        assert_eq!(results.len(), 2);
284        // Most similar first
285        assert_eq!(results[0].pattern_id, "exact");
286        assert_eq!(results[1].pattern_id, "close");
287    }
288
289    #[test]
290    fn test_search_top_k_larger_than_store() {
291        let mut store = PatternStore::new();
292        store.add_pattern(make_pattern("p1", vec![1.0, 0.0], "a"));
293
294        let results = store.search(&[1.0, 0.0], 10);
295        assert_eq!(results.len(), 1);
296    }
297
298    #[test]
299    fn test_search_empty_store() {
300        let store = PatternStore::new();
301        let results = store.search(&[1.0, 0.0], 5);
302        assert!(results.is_empty());
303    }
304
305    #[test]
306    fn test_search_with_mismatched_dimensions() {
307        let mut store = PatternStore::new();
308        // Pattern with 3D weights, query with 2D
309        store.add_pattern(make_pattern("p1", vec![1.0, 0.0, 0.0], "a"));
310
311        let results = store.search(&[1.0, 0.0], 5);
312        // Should return the pattern but with zero similarity
313        assert_eq!(results.len(), 1);
314    }
315
316    #[test]
317    fn test_confidence_clamped() {
318        let p = DecisionPattern::new("id", "desc", vec![], 1.5, "cat");
319        assert_eq!(p.confidence, 1.0);
320
321        let p2 = DecisionPattern::new("id2", "desc", vec![], -0.5, "cat");
322        assert_eq!(p2.confidence, 0.0);
323    }
324
325    #[test]
326    fn test_default_store() {
327        let store = PatternStore::default();
328        assert!(store.is_empty());
329    }
330}