use super::{ChunkId, FixPattern, FixSuggestion, PatternStoreConfig, PatternStoreData};
use std::collections::HashMap;
use std::path::Path;
use trueno_rag::{
chunk::FixedSizeChunker, embed::MockEmbedder, fusion::FusionStrategy,
pipeline::RagPipelineBuilder, rerank::NoOpReranker, Document, RagPipeline,
};
pub struct DecisionPatternStore {
pipeline: RagPipeline<MockEmbedder, NoOpReranker>,
patterns: HashMap<ChunkId, FixPattern>,
error_index: HashMap<String, Vec<ChunkId>>,
config: PatternStoreConfig,
}
impl DecisionPatternStore {
pub fn new() -> Result<Self, crate::Error> {
Self::with_config(PatternStoreConfig::default())
}
pub fn with_config(config: PatternStoreConfig) -> Result<Self, crate::Error> {
let pipeline = RagPipelineBuilder::new()
.chunker(FixedSizeChunker::new(config.chunk_size, config.chunk_size / 8))
.embedder(MockEmbedder::new(config.embedding_dim))
.reranker(NoOpReranker::new())
.fusion(FusionStrategy::RRF { k: config.rrf_k })
.build()
.map_err(|e| crate::Error::ConfigError(format!("RAG pipeline error: {e}")))?;
Ok(Self { pipeline, patterns: HashMap::new(), error_index: HashMap::new(), config })
}
pub fn index_fix(&mut self, pattern: FixPattern) -> Result<(), crate::Error> {
let chunk_id = pattern.id;
let error_code = pattern.error_code.clone();
let doc = Document::new(pattern.to_searchable_text())
.with_title(format!("Fix for {}", pattern.error_code));
self.pipeline
.index_document(&doc)
.map_err(|e| crate::Error::ConfigError(format!("Indexing error: {e}")))?;
self.error_index.entry(error_code).or_default().push(chunk_id);
self.patterns.insert(chunk_id, pattern);
Ok(())
}
pub fn suggest_fix(
&self,
error_code: &str,
decision_context: &[String],
k: usize,
) -> Result<Vec<FixSuggestion>, crate::Error> {
let context_str = decision_context.join(" ");
let query = format!("{error_code} {context_str}");
let results = self
.pipeline
.query(&query, k * 2) .map_err(|e| crate::Error::ConfigError(format!("Query error: {e}")))?;
let relevant_patterns: Vec<_> = if let Some(pattern_ids) = self.error_index.get(error_code)
{
pattern_ids.iter().filter_map(|id| self.patterns.get(id)).collect()
} else {
self.patterns.values().collect()
};
let mut suggestions: Vec<FixSuggestion> = Vec::new();
for (rank, result) in results.iter().enumerate() {
for pattern in &relevant_patterns {
let pattern_text = pattern.to_searchable_text();
if result.chunk.content.contains(&pattern.error_code)
|| pattern_text.contains(&result.chunk.content)
{
suggestions.push(FixSuggestion::new(
(*pattern).clone(),
result.best_score(),
rank,
));
break;
}
}
}
if suggestions.is_empty() && !relevant_patterns.is_empty() {
for (rank, pattern) in relevant_patterns.iter().take(k).enumerate() {
suggestions.push(FixSuggestion::new(
(*pattern).clone(),
1.0 - (rank as f32 * 0.1),
rank,
));
}
}
suggestions.sort_by(|a, b| {
b.weighted_score().partial_cmp(&a.weighted_score()).unwrap_or(std::cmp::Ordering::Equal)
});
suggestions.truncate(k);
for (rank, suggestion) in suggestions.iter_mut().enumerate() {
suggestion.rank = rank;
}
Ok(suggestions)
}
#[must_use]
pub fn len(&self) -> usize {
self.patterns.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
#[must_use]
pub fn get(&self, id: &ChunkId) -> Option<&FixPattern> {
self.patterns.get(id)
}
pub fn get_mut(&mut self, id: &ChunkId) -> Option<&mut FixPattern> {
self.patterns.get_mut(id)
}
pub fn record_outcome(&mut self, id: &ChunkId, success: bool) {
if let Some(pattern) = self.patterns.get_mut(id) {
if success {
pattern.record_success();
} else {
pattern.record_failure();
}
}
}
#[must_use]
pub fn patterns_for_error(&self, error_code: &str) -> Vec<&FixPattern> {
self.error_index
.get(error_code)
.map(|ids| ids.iter().filter_map(|id| self.patterns.get(id)).collect())
.unwrap_or_default()
}
#[must_use]
pub fn config(&self) -> &PatternStoreConfig {
&self.config
}
pub fn export_json(&self) -> Result<String, crate::Error> {
let patterns: Vec<_> = self.patterns.values().collect();
serde_json::to_string_pretty(&patterns)
.map_err(|e| crate::Error::Serialization(format!("JSON export error: {e}")))
}
pub fn import_json(&mut self, json: &str) -> Result<usize, crate::Error> {
let patterns: Vec<FixPattern> = serde_json::from_str(json)
.map_err(|e| crate::Error::Serialization(format!("JSON import error: {e}")))?;
let count = patterns.len();
for pattern in patterns {
self.index_fix(pattern)?;
}
Ok(count)
}
pub fn save_apr(&self, path: impl AsRef<Path>) -> Result<(), crate::Error> {
use aprender::format::{save, Compression, ModelType, SaveOptions};
let patterns: Vec<FixPattern> = self.patterns.values().cloned().collect();
let wrapper = PatternStoreData { version: 1, config: self.config.clone(), patterns };
save(
&wrapper,
ModelType::Custom,
path,
SaveOptions::default().with_compression(Compression::ZstdDefault),
)
.map_err(|e| crate::Error::Serialization(format!("APR save error: {e}")))
}
pub fn load_apr(path: impl AsRef<Path>) -> Result<Self, crate::Error> {
use aprender::format::{load, ModelType};
let wrapper: PatternStoreData = load(path, ModelType::Custom)
.map_err(|e| crate::Error::Serialization(format!("APR load error: {e}")))?;
let mut store = Self::with_config(wrapper.config)?;
for pattern in wrapper.patterns {
store.index_fix(pattern)?;
}
Ok(store)
}
}
impl std::fmt::Debug for DecisionPatternStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DecisionPatternStore")
.field("pattern_count", &self.patterns.len())
.field("error_codes", &self.error_index.keys().collect::<Vec<_>>())
.field("config", &self.config)
.finish_non_exhaustive()
}
}