Skip to main content

entrenar/citl/pattern_store/
store.rs

1//! Decision pattern store implementation with hybrid retrieval.
2//!
3//! Uses trueno-rag for BM25 lexical search combined with dense embeddings
4//! and Reciprocal Rank Fusion (RRF) for optimal fix suggestions.
5
6use super::{ChunkId, FixPattern, FixSuggestion, PatternStoreConfig, PatternStoreData};
7use std::collections::HashMap;
8use std::path::Path;
9use trueno_rag::{
10    chunk::FixedSizeChunker, embed::MockEmbedder, fusion::FusionStrategy,
11    pipeline::RagPipelineBuilder, rerank::NoOpReranker, Document, RagPipeline,
12};
13
14/// Store for decision patterns with hybrid retrieval
15///
16/// Uses trueno-rag for BM25 + dense embedding retrieval with RRF fusion.
17///
18/// # Example
19///
20/// ```ignore
21/// use entrenar::citl::{DecisionPatternStore, FixPattern};
22///
23/// let mut store = DecisionPatternStore::new()?;
24///
25/// // Index a fix pattern
26/// let pattern = FixPattern::new("E0308", "- let x: i32 = \"hello\";\n+ let x: &str = \"hello\";")
27///     .with_decision("type_mismatch_detected")
28///     .with_decision("infer_correct_type");
29/// store.index_fix(pattern)?;
30///
31/// // Get fix suggestions
32/// let suggestions = store.suggest_fix("E0308", &["type_mismatch"], 5)?;
33/// ```
34pub struct DecisionPatternStore {
35    /// RAG pipeline for hybrid retrieval
36    pipeline: RagPipeline<MockEmbedder, NoOpReranker>,
37    /// Pattern storage indexed by chunk ID
38    patterns: HashMap<ChunkId, FixPattern>,
39    /// Error code index for fast filtering
40    error_index: HashMap<String, Vec<ChunkId>>,
41    /// Configuration
42    config: PatternStoreConfig,
43}
44
45impl DecisionPatternStore {
46    /// Create a new pattern store with default configuration
47    pub fn new() -> Result<Self, crate::Error> {
48        Self::with_config(PatternStoreConfig::default())
49    }
50
51    /// Create a new pattern store with custom configuration
52    pub fn with_config(config: PatternStoreConfig) -> Result<Self, crate::Error> {
53        let pipeline = RagPipelineBuilder::new()
54            .chunker(FixedSizeChunker::new(config.chunk_size, config.chunk_size / 8))
55            .embedder(MockEmbedder::new(config.embedding_dim))
56            .reranker(NoOpReranker::new())
57            .fusion(FusionStrategy::RRF { k: config.rrf_k })
58            .build()
59            .map_err(|e| crate::Error::ConfigError(format!("RAG pipeline error: {e}")))?;
60
61        Ok(Self { pipeline, patterns: HashMap::new(), error_index: HashMap::new(), config })
62    }
63
64    /// Index a fix pattern for later retrieval
65    pub fn index_fix(&mut self, pattern: FixPattern) -> Result<(), crate::Error> {
66        let chunk_id = pattern.id;
67        let error_code = pattern.error_code.clone();
68
69        // Create searchable document
70        let doc = Document::new(pattern.to_searchable_text())
71            .with_title(format!("Fix for {}", pattern.error_code));
72
73        // Index in RAG pipeline
74        self.pipeline
75            .index_document(&doc)
76            .map_err(|e| crate::Error::ConfigError(format!("Indexing error: {e}")))?;
77
78        // Update error index
79        self.error_index.entry(error_code).or_default().push(chunk_id);
80
81        // Store pattern
82        self.patterns.insert(chunk_id, pattern);
83
84        Ok(())
85    }
86
87    /// Suggest fixes for a given error code and decision context
88    ///
89    /// # Arguments
90    ///
91    /// * `error_code` - The error code to find fixes for
92    /// * `decision_context` - Recent decisions that led to the error
93    /// * `k` - Maximum number of suggestions to return
94    ///
95    /// # Returns
96    ///
97    /// Vector of fix suggestions ranked by relevance
98    pub fn suggest_fix(
99        &self,
100        error_code: &str,
101        decision_context: &[String],
102        k: usize,
103    ) -> Result<Vec<FixSuggestion>, crate::Error> {
104        // Build query from error code and decision context
105        let context_str = decision_context.join(" ");
106        let query = format!("{error_code} {context_str}");
107
108        // Retrieve from RAG pipeline
109        let results = self
110            .pipeline
111            .query(&query, k * 2) // Over-fetch for filtering
112            .map_err(|e| crate::Error::ConfigError(format!("Query error: {e}")))?;
113
114        // Filter by error code if we have patterns for it
115        let relevant_patterns: Vec<_> = if let Some(pattern_ids) = self.error_index.get(error_code)
116        {
117            pattern_ids.iter().filter_map(|id| self.patterns.get(id)).collect()
118        } else {
119            // Return any patterns if no exact error code match
120            self.patterns.values().collect()
121        };
122
123        // Match RAG results with our patterns (by content similarity)
124        let mut suggestions: Vec<FixSuggestion> = Vec::new();
125
126        for (rank, result) in results.iter().enumerate() {
127            // Find matching pattern by comparing content
128            for pattern in &relevant_patterns {
129                let pattern_text = pattern.to_searchable_text();
130                if result.chunk.content.contains(&pattern.error_code)
131                    || pattern_text.contains(&result.chunk.content)
132                {
133                    suggestions.push(FixSuggestion::new(
134                        (*pattern).clone(),
135                        result.best_score(),
136                        rank,
137                    ));
138                    break;
139                }
140            }
141        }
142
143        // If no RAG matches, fall back to error index
144        if suggestions.is_empty() && !relevant_patterns.is_empty() {
145            for (rank, pattern) in relevant_patterns.iter().take(k).enumerate() {
146                suggestions.push(FixSuggestion::new(
147                    (*pattern).clone(),
148                    1.0 - (rank as f32 * 0.1),
149                    rank,
150                ));
151            }
152        }
153
154        // Sort by weighted score and limit
155        suggestions.sort_by(|a, b| {
156            b.weighted_score().partial_cmp(&a.weighted_score()).unwrap_or(std::cmp::Ordering::Equal)
157        });
158        suggestions.truncate(k);
159
160        // Re-assign ranks after sorting
161        for (rank, suggestion) in suggestions.iter_mut().enumerate() {
162            suggestion.rank = rank;
163        }
164
165        Ok(suggestions)
166    }
167
168    /// Get the number of indexed patterns
169    #[must_use]
170    pub fn len(&self) -> usize {
171        self.patterns.len()
172    }
173
174    /// Check if the store is empty
175    #[must_use]
176    pub fn is_empty(&self) -> bool {
177        self.patterns.is_empty()
178    }
179
180    /// Get a pattern by ID
181    #[must_use]
182    pub fn get(&self, id: &ChunkId) -> Option<&FixPattern> {
183        self.patterns.get(id)
184    }
185
186    /// Get a mutable pattern by ID
187    pub fn get_mut(&mut self, id: &ChunkId) -> Option<&mut FixPattern> {
188        self.patterns.get_mut(id)
189    }
190
191    /// Update a pattern's success/failure count
192    pub fn record_outcome(&mut self, id: &ChunkId, success: bool) {
193        if let Some(pattern) = self.patterns.get_mut(id) {
194            if success {
195                pattern.record_success();
196            } else {
197                pattern.record_failure();
198            }
199        }
200    }
201
202    /// Get all patterns for an error code
203    #[must_use]
204    pub fn patterns_for_error(&self, error_code: &str) -> Vec<&FixPattern> {
205        self.error_index
206            .get(error_code)
207            .map(|ids| ids.iter().filter_map(|id| self.patterns.get(id)).collect())
208            .unwrap_or_default()
209    }
210
211    /// Get the configuration
212    #[must_use]
213    pub fn config(&self) -> &PatternStoreConfig {
214        &self.config
215    }
216
217    /// Export all patterns to JSON
218    pub fn export_json(&self) -> Result<String, crate::Error> {
219        let patterns: Vec<_> = self.patterns.values().collect();
220        serde_json::to_string_pretty(&patterns)
221            .map_err(|e| crate::Error::Serialization(format!("JSON export error: {e}")))
222    }
223
224    /// Import patterns from JSON
225    pub fn import_json(&mut self, json: &str) -> Result<usize, crate::Error> {
226        let patterns: Vec<FixPattern> = serde_json::from_str(json)
227            .map_err(|e| crate::Error::Serialization(format!("JSON import error: {e}")))?;
228
229        let count = patterns.len();
230        for pattern in patterns {
231            self.index_fix(pattern)?;
232        }
233
234        Ok(count)
235    }
236
237    /// Save patterns to .apr format (aprender model format)
238    ///
239    /// Uses `ModelType::Custom` with compressed MessagePack serialization.
240    /// The .apr format provides:
241    /// - CRC32 checksum (integrity)
242    /// - Optional zstd compression
243    /// - Compatible with aprender ecosystem
244    ///
245    /// # Example
246    ///
247    /// ```ignore
248    /// use entrenar::citl::DecisionPatternStore;
249    ///
250    /// let store = DecisionPatternStore::new()?;
251    /// // ... index patterns ...
252    /// store.save_apr("decision_patterns.apr")?;
253    /// ```
254    pub fn save_apr(&self, path: impl AsRef<Path>) -> Result<(), crate::Error> {
255        use aprender::format::{save, Compression, ModelType, SaveOptions};
256
257        // Collect patterns into serializable wrapper
258        let patterns: Vec<FixPattern> = self.patterns.values().cloned().collect();
259        let wrapper = PatternStoreData { version: 1, config: self.config.clone(), patterns };
260
261        save(
262            &wrapper,
263            ModelType::Custom,
264            path,
265            SaveOptions::default().with_compression(Compression::ZstdDefault),
266        )
267        .map_err(|e| crate::Error::Serialization(format!("APR save error: {e}")))
268    }
269
270    /// Load patterns from .apr format
271    ///
272    /// Restores patterns and rebuilds the RAG index.
273    ///
274    /// # Example
275    ///
276    /// ```ignore
277    /// use entrenar::citl::DecisionPatternStore;
278    ///
279    /// let store = DecisionPatternStore::load_apr("decision_patterns.apr")?;
280    /// let suggestions = store.suggest_fix("E0308", &["type_mismatch".into()], 5)?;
281    /// ```
282    pub fn load_apr(path: impl AsRef<Path>) -> Result<Self, crate::Error> {
283        use aprender::format::{load, ModelType};
284
285        let wrapper: PatternStoreData = load(path, ModelType::Custom)
286            .map_err(|e| crate::Error::Serialization(format!("APR load error: {e}")))?;
287
288        // Rebuild store with loaded config
289        let mut store = Self::with_config(wrapper.config)?;
290
291        // Re-index all patterns
292        for pattern in wrapper.patterns {
293            store.index_fix(pattern)?;
294        }
295
296        Ok(store)
297    }
298}
299
300impl std::fmt::Debug for DecisionPatternStore {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("DecisionPatternStore")
303            .field("pattern_count", &self.patterns.len())
304            .field("error_codes", &self.error_index.keys().collect::<Vec<_>>())
305            .field("config", &self.config)
306            .finish_non_exhaustive()
307    }
308}