use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecisionPattern {
pub pattern_id: String,
pub description: String,
pub feature_weights: Vec<f32>,
pub confidence: f32,
pub category: String,
}
impl DecisionPattern {
#[must_use]
pub fn new(
pattern_id: impl Into<String>,
description: impl Into<String>,
feature_weights: Vec<f32>,
confidence: f32,
category: impl Into<String>,
) -> Self {
Self {
pattern_id: pattern_id.into(),
description: description.into(),
feature_weights,
confidence: confidence.clamp(0.0, 1.0),
category: category.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PatternStore {
patterns: HashMap<String, DecisionPattern>,
}
impl PatternStore {
#[must_use]
pub fn new() -> Self {
Self { patterns: HashMap::new() }
}
pub fn add_pattern(&mut self, pattern: DecisionPattern) {
self.patterns.insert(pattern.pattern_id.clone(), pattern);
}
#[must_use]
pub fn get_pattern(&self, id: &str) -> Option<&DecisionPattern> {
self.patterns.get(id)
}
#[must_use]
pub fn search(&self, query_features: &[f32], top_k: usize) -> Vec<&DecisionPattern> {
let mut scored: Vec<(f32, &DecisionPattern)> = self
.patterns
.values()
.map(|p| {
let sim = cosine_similarity(query_features, &p.feature_weights);
(sim, p)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored.into_iter().map(|(_, p)| p).collect()
}
#[must_use]
pub fn list_patterns(&self) -> Vec<&DecisionPattern> {
self.patterns.values().collect()
}
pub fn remove_pattern(&mut self, id: &str) -> Option<DecisionPattern> {
self.patterns.remove(id)
}
#[must_use]
pub fn len(&self) -> usize {
self.patterns.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot = 0.0_f32;
let mut norm_a = 0.0_f32;
let mut norm_b = 0.0_f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_pattern(id: &str, weights: Vec<f32>, category: &str) -> DecisionPattern {
DecisionPattern::new(id, format!("desc_{id}"), weights, 0.8, category)
}
#[test]
fn test_add_and_get_pattern() {
let mut store = PatternStore::new();
let p = make_pattern("p1", vec![1.0, 0.0], "cat_a");
store.add_pattern(p);
let retrieved = store.get_pattern("p1");
assert!(retrieved.is_some());
assert_eq!(retrieved.expect("operation should succeed").pattern_id, "p1");
assert_eq!(retrieved.expect("operation should succeed").category, "cat_a");
}
#[test]
fn test_get_nonexistent_pattern() {
let store = PatternStore::new();
assert!(store.get_pattern("missing").is_none());
}
#[test]
fn test_add_replaces_existing() {
let mut store = PatternStore::new();
store.add_pattern(make_pattern("p1", vec![1.0], "old"));
store.add_pattern(make_pattern("p1", vec![2.0], "new"));
assert_eq!(store.len(), 1);
assert_eq!(store.get_pattern("p1").expect("operation should succeed").category, "new");
}
#[test]
fn test_remove_pattern() {
let mut store = PatternStore::new();
store.add_pattern(make_pattern("p1", vec![1.0], "a"));
store.add_pattern(make_pattern("p2", vec![2.0], "b"));
let removed = store.remove_pattern("p1");
assert!(removed.is_some());
assert_eq!(removed.expect("operation should succeed").pattern_id, "p1");
assert_eq!(store.len(), 1);
assert!(store.get_pattern("p1").is_none());
}
#[test]
fn test_remove_nonexistent() {
let mut store = PatternStore::new();
assert!(store.remove_pattern("ghost").is_none());
}
#[test]
fn test_list_patterns() {
let mut store = PatternStore::new();
store.add_pattern(make_pattern("p1", vec![1.0], "a"));
store.add_pattern(make_pattern("p2", vec![2.0], "b"));
let list = store.list_patterns();
assert_eq!(list.len(), 2);
let ids: Vec<&str> = list.iter().map(|p| p.pattern_id.as_str()).collect();
assert!(ids.contains(&"p1"));
assert!(ids.contains(&"p2"));
}
#[test]
fn test_len_and_is_empty() {
let mut store = PatternStore::new();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
store.add_pattern(make_pattern("p1", vec![1.0], "a"));
assert!(!store.is_empty());
assert_eq!(store.len(), 1);
}
#[test]
fn test_cosine_similarity_identical() {
let sim = cosine_similarity(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let sim = cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let sim = cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]);
assert!((sim - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let sim = cosine_similarity(&[1.0, 2.0], &[1.0]);
assert_eq!(sim, 0.0);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let sim = cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]);
assert_eq!(sim, 0.0);
}
#[test]
fn test_search_returns_top_k() {
let mut store = PatternStore::new();
store.add_pattern(make_pattern("close", vec![0.9, 0.1, 0.0], "a"));
store.add_pattern(make_pattern("exact", vec![1.0, 0.0, 0.0], "b"));
store.add_pattern(make_pattern("far", vec![0.0, 0.0, 1.0], "c"));
let results = store.search(&[1.0, 0.0, 0.0], 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].pattern_id, "exact");
assert_eq!(results[1].pattern_id, "close");
}
#[test]
fn test_search_top_k_larger_than_store() {
let mut store = PatternStore::new();
store.add_pattern(make_pattern("p1", vec![1.0, 0.0], "a"));
let results = store.search(&[1.0, 0.0], 10);
assert_eq!(results.len(), 1);
}
#[test]
fn test_search_empty_store() {
let store = PatternStore::new();
let results = store.search(&[1.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn test_search_with_mismatched_dimensions() {
let mut store = PatternStore::new();
store.add_pattern(make_pattern("p1", vec![1.0, 0.0, 0.0], "a"));
let results = store.search(&[1.0, 0.0], 5);
assert_eq!(results.len(), 1);
}
#[test]
fn test_confidence_clamped() {
let p = DecisionPattern::new("id", "desc", vec![], 1.5, "cat");
assert_eq!(p.confidence, 1.0);
let p2 = DecisionPattern::new("id2", "desc", vec![], -0.5, "cat");
assert_eq!(p2.confidence, 0.0);
}
#[test]
fn test_default_store() {
let store = PatternStore::default();
assert!(store.is_empty());
}
}