entrenar/citl/pattern_store/
store.rs1use super::{ChunkId, FixPattern, FixSuggestion, PatternStoreConfig, PatternStoreData};
7use std::collections::HashMap;
8use std::path::Path;
9use trueno_rag::{
10 chunk::FixedSizeChunker, embed::MockEmbedder, fusion::FusionStrategy,
11 pipeline::RagPipelineBuilder, rerank::NoOpReranker, Document, RagPipeline,
12};
13
14pub struct DecisionPatternStore {
35 pipeline: RagPipeline<MockEmbedder, NoOpReranker>,
37 patterns: HashMap<ChunkId, FixPattern>,
39 error_index: HashMap<String, Vec<ChunkId>>,
41 config: PatternStoreConfig,
43}
44
45impl DecisionPatternStore {
46 pub fn new() -> Result<Self, crate::Error> {
48 Self::with_config(PatternStoreConfig::default())
49 }
50
51 pub fn with_config(config: PatternStoreConfig) -> Result<Self, crate::Error> {
53 let pipeline = RagPipelineBuilder::new()
54 .chunker(FixedSizeChunker::new(config.chunk_size, config.chunk_size / 8))
55 .embedder(MockEmbedder::new(config.embedding_dim))
56 .reranker(NoOpReranker::new())
57 .fusion(FusionStrategy::RRF { k: config.rrf_k })
58 .build()
59 .map_err(|e| crate::Error::ConfigError(format!("RAG pipeline error: {e}")))?;
60
61 Ok(Self { pipeline, patterns: HashMap::new(), error_index: HashMap::new(), config })
62 }
63
64 pub fn index_fix(&mut self, pattern: FixPattern) -> Result<(), crate::Error> {
66 let chunk_id = pattern.id;
67 let error_code = pattern.error_code.clone();
68
69 let doc = Document::new(pattern.to_searchable_text())
71 .with_title(format!("Fix for {}", pattern.error_code));
72
73 self.pipeline
75 .index_document(&doc)
76 .map_err(|e| crate::Error::ConfigError(format!("Indexing error: {e}")))?;
77
78 self.error_index.entry(error_code).or_default().push(chunk_id);
80
81 self.patterns.insert(chunk_id, pattern);
83
84 Ok(())
85 }
86
87 pub fn suggest_fix(
99 &self,
100 error_code: &str,
101 decision_context: &[String],
102 k: usize,
103 ) -> Result<Vec<FixSuggestion>, crate::Error> {
104 let context_str = decision_context.join(" ");
106 let query = format!("{error_code} {context_str}");
107
108 let results = self
110 .pipeline
111 .query(&query, k * 2) .map_err(|e| crate::Error::ConfigError(format!("Query error: {e}")))?;
113
114 let relevant_patterns: Vec<_> = if let Some(pattern_ids) = self.error_index.get(error_code)
116 {
117 pattern_ids.iter().filter_map(|id| self.patterns.get(id)).collect()
118 } else {
119 self.patterns.values().collect()
121 };
122
123 let mut suggestions: Vec<FixSuggestion> = Vec::new();
125
126 for (rank, result) in results.iter().enumerate() {
127 for pattern in &relevant_patterns {
129 let pattern_text = pattern.to_searchable_text();
130 if result.chunk.content.contains(&pattern.error_code)
131 || pattern_text.contains(&result.chunk.content)
132 {
133 suggestions.push(FixSuggestion::new(
134 (*pattern).clone(),
135 result.best_score(),
136 rank,
137 ));
138 break;
139 }
140 }
141 }
142
143 if suggestions.is_empty() && !relevant_patterns.is_empty() {
145 for (rank, pattern) in relevant_patterns.iter().take(k).enumerate() {
146 suggestions.push(FixSuggestion::new(
147 (*pattern).clone(),
148 1.0 - (rank as f32 * 0.1),
149 rank,
150 ));
151 }
152 }
153
154 suggestions.sort_by(|a, b| {
156 b.weighted_score().partial_cmp(&a.weighted_score()).unwrap_or(std::cmp::Ordering::Equal)
157 });
158 suggestions.truncate(k);
159
160 for (rank, suggestion) in suggestions.iter_mut().enumerate() {
162 suggestion.rank = rank;
163 }
164
165 Ok(suggestions)
166 }
167
168 #[must_use]
170 pub fn len(&self) -> usize {
171 self.patterns.len()
172 }
173
174 #[must_use]
176 pub fn is_empty(&self) -> bool {
177 self.patterns.is_empty()
178 }
179
180 #[must_use]
182 pub fn get(&self, id: &ChunkId) -> Option<&FixPattern> {
183 self.patterns.get(id)
184 }
185
186 pub fn get_mut(&mut self, id: &ChunkId) -> Option<&mut FixPattern> {
188 self.patterns.get_mut(id)
189 }
190
191 pub fn record_outcome(&mut self, id: &ChunkId, success: bool) {
193 if let Some(pattern) = self.patterns.get_mut(id) {
194 if success {
195 pattern.record_success();
196 } else {
197 pattern.record_failure();
198 }
199 }
200 }
201
202 #[must_use]
204 pub fn patterns_for_error(&self, error_code: &str) -> Vec<&FixPattern> {
205 self.error_index
206 .get(error_code)
207 .map(|ids| ids.iter().filter_map(|id| self.patterns.get(id)).collect())
208 .unwrap_or_default()
209 }
210
211 #[must_use]
213 pub fn config(&self) -> &PatternStoreConfig {
214 &self.config
215 }
216
217 pub fn export_json(&self) -> Result<String, crate::Error> {
219 let patterns: Vec<_> = self.patterns.values().collect();
220 serde_json::to_string_pretty(&patterns)
221 .map_err(|e| crate::Error::Serialization(format!("JSON export error: {e}")))
222 }
223
224 pub fn import_json(&mut self, json: &str) -> Result<usize, crate::Error> {
226 let patterns: Vec<FixPattern> = serde_json::from_str(json)
227 .map_err(|e| crate::Error::Serialization(format!("JSON import error: {e}")))?;
228
229 let count = patterns.len();
230 for pattern in patterns {
231 self.index_fix(pattern)?;
232 }
233
234 Ok(count)
235 }
236
237 pub fn save_apr(&self, path: impl AsRef<Path>) -> Result<(), crate::Error> {
255 use aprender::format::{save, Compression, ModelType, SaveOptions};
256
257 let patterns: Vec<FixPattern> = self.patterns.values().cloned().collect();
259 let wrapper = PatternStoreData { version: 1, config: self.config.clone(), patterns };
260
261 save(
262 &wrapper,
263 ModelType::Custom,
264 path,
265 SaveOptions::default().with_compression(Compression::ZstdDefault),
266 )
267 .map_err(|e| crate::Error::Serialization(format!("APR save error: {e}")))
268 }
269
270 pub fn load_apr(path: impl AsRef<Path>) -> Result<Self, crate::Error> {
283 use aprender::format::{load, ModelType};
284
285 let wrapper: PatternStoreData = load(path, ModelType::Custom)
286 .map_err(|e| crate::Error::Serialization(format!("APR load error: {e}")))?;
287
288 let mut store = Self::with_config(wrapper.config)?;
290
291 for pattern in wrapper.patterns {
293 store.index_fix(pattern)?;
294 }
295
296 Ok(store)
297 }
298}
299
300impl std::fmt::Debug for DecisionPatternStore {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("DecisionPatternStore")
303 .field("pattern_count", &self.patterns.len())
304 .field("error_codes", &self.error_index.keys().collect::<Vec<_>>())
305 .field("config", &self.config)
306 .finish_non_exhaustive()
307 }
308}