codex_memory/memory/
importance_assessment.rs

1use crate::embedding::EmbeddingService;
2use crate::memory::MemoryError;
3use anyhow::Result;
4use chrono::{DateTime, Utc};
5use prometheus::{Counter, Histogram, IntCounter, IntGauge, Registry};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use thiserror::Error;
12use tokio::sync::RwLock;
13use tokio::time::timeout;
14use tracing::{debug, error, info, warn};
15
16#[derive(Debug, Error)]
17pub enum ImportanceAssessmentError {
18    #[error("Stage 1 pattern matching failed: {0}")]
19    Stage1Failed(String),
20
21    #[error("Stage 2 semantic analysis failed: {0}")]
22    Stage2Failed(String),
23
24    #[error("Stage 3 LLM scoring failed: {0}")]
25    Stage3Failed(String),
26
27    #[error("Circuit breaker is open: {0}")]
28    CircuitBreakerOpen(String),
29
30    #[error("Timeout exceeded: {0}")]
31    Timeout(String),
32
33    #[error("Configuration error: {0}")]
34    Configuration(String),
35
36    #[error("Memory operation failed: {0}")]
37    Memory(#[from] MemoryError),
38
39    #[error("Cache operation failed: {0}")]
40    Cache(String),
41}
42
43/// Configuration for the importance assessment pipeline
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ImportanceAssessmentConfig {
46    /// Stage 1: Pattern matching configuration
47    pub stage1: Stage1Config,
48
49    /// Stage 2: Semantic similarity configuration
50    pub stage2: Stage2Config,
51
52    /// Stage 3: LLM scoring configuration
53    pub stage3: Stage3Config,
54
55    /// Circuit breaker configuration
56    pub circuit_breaker: CircuitBreakerConfig,
57
58    /// Performance thresholds
59    pub performance: PerformanceConfig,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct Stage1Config {
64    /// Confidence threshold to pass to Stage 2 (0.0-1.0)
65    pub confidence_threshold: f64,
66
67    /// Pattern library for keyword/phrase matching
68    pub pattern_library: Vec<ImportancePattern>,
69
70    /// Maximum processing time in milliseconds
71    pub max_processing_time_ms: u64,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct Stage2Config {
76    /// Confidence threshold to pass to Stage 3 (0.0-1.0)
77    pub confidence_threshold: f64,
78
79    /// Maximum processing time in milliseconds
80    pub max_processing_time_ms: u64,
81
82    /// Cache TTL for embeddings in seconds
83    pub embedding_cache_ttl_seconds: u64,
84
85    /// Similarity threshold for semantic matching
86    pub similarity_threshold: f32,
87
88    /// Reference embeddings for importance patterns
89    pub reference_embeddings: Vec<ReferenceEmbedding>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Stage3Config {
94    /// Maximum processing time in milliseconds
95    pub max_processing_time_ms: u64,
96
97    /// LLM endpoint configuration
98    pub llm_endpoint: String,
99
100    /// Maximum concurrent LLM requests
101    pub max_concurrent_requests: usize,
102
103    /// Prompt template for LLM scoring
104    pub prompt_template: String,
105
106    /// Target percentage of evaluations that should reach Stage 3
107    pub target_usage_percentage: f64,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct CircuitBreakerConfig {
112    /// Failure threshold before opening the circuit
113    pub failure_threshold: usize,
114
115    /// Time window for failure counting in seconds
116    pub failure_window_seconds: u64,
117
118    /// Recovery timeout in seconds
119    pub recovery_timeout_seconds: u64,
120
121    /// Minimum requests before evaluating failures
122    pub minimum_requests: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct PerformanceConfig {
127    /// Stage 1 target time in milliseconds
128    pub stage1_target_ms: u64,
129
130    /// Stage 2 target time in milliseconds
131    pub stage2_target_ms: u64,
132
133    /// Stage 3 target time in milliseconds
134    pub stage3_target_ms: u64,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ImportancePattern {
139    /// Pattern name for metrics and debugging
140    pub name: String,
141
142    /// Regular expression pattern
143    pub pattern: String,
144
145    /// Weight/importance score for this pattern (0.0-1.0)
146    pub weight: f64,
147
148    /// Context words that boost the pattern's importance
149    pub context_boosters: Vec<String>,
150
151    /// Category of the pattern (e.g., "command", "preference", "memory")
152    pub category: String,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ReferenceEmbedding {
157    /// Name of the reference pattern
158    pub name: String,
159
160    /// Pre-computed embedding vector
161    pub embedding: Vec<f32>,
162
163    /// Importance weight for this reference
164    pub weight: f64,
165
166    /// Category of the reference
167    pub category: String,
168}
169
170/// Result of the importance assessment pipeline
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct ImportanceAssessmentResult {
173    /// Final importance score (0.0-1.0)
174    pub importance_score: f64,
175
176    /// Which stage provided the final score
177    pub final_stage: AssessmentStage,
178
179    /// Results from each stage
180    pub stage_results: Vec<StageResult>,
181
182    /// Total processing time in milliseconds
183    pub total_processing_time_ms: u64,
184
185    /// Assessment timestamp
186    pub assessed_at: DateTime<Utc>,
187
188    /// Confidence in the assessment (0.0-1.0)
189    pub confidence: f64,
190
191    /// Explanation of the assessment
192    pub explanation: Option<String>,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct StageResult {
197    /// Which stage this result is from
198    pub stage: AssessmentStage,
199
200    /// Score from this stage (0.0-1.0)
201    pub score: f64,
202
203    /// Confidence in this stage's result (0.0-1.0)
204    pub confidence: f64,
205
206    /// Processing time for this stage in milliseconds
207    pub processing_time_ms: u64,
208
209    /// Whether this stage passed its confidence threshold
210    pub passed_threshold: bool,
211
212    /// Stage-specific details
213    pub details: StageDetails,
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub enum AssessmentStage {
218    Stage1PatternMatching,
219    Stage2SemanticSimilarity,
220    Stage3LLMScoring,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub enum StageDetails {
225    Stage1 {
226        matched_patterns: Vec<MatchedPattern>,
227        total_patterns_checked: usize,
228    },
229    Stage2 {
230        similarity_scores: Vec<SimilarityScore>,
231        cache_hit: bool,
232        embedding_generation_time_ms: Option<u64>,
233    },
234    Stage3 {
235        llm_response: String,
236        prompt_tokens: Option<usize>,
237        completion_tokens: Option<usize>,
238        model_used: String,
239    },
240}
241
242#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
243pub struct MatchedPattern {
244    pub pattern_name: String,
245    pub pattern_category: String,
246    pub match_text: String,
247    pub match_position: usize,
248    pub weight: f64,
249    pub context_boost: f64,
250}
251
252#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
253pub struct SimilarityScore {
254    pub reference_name: String,
255    pub reference_category: String,
256    pub similarity: f32,
257    pub weight: f64,
258    pub weighted_score: f64,
259}
260
261/// Circuit breaker states
262#[derive(Debug, Clone, PartialEq)]
263enum CircuitBreakerState {
264    Closed,
265    Open(DateTime<Utc>),
266    HalfOpen,
267}
268
269/// Circuit breaker for LLM calls
270#[derive(Debug)]
271struct CircuitBreaker {
272    state: RwLock<CircuitBreakerState>,
273    config: CircuitBreakerConfig,
274    failure_count: RwLock<usize>,
275    last_failure_time: RwLock<Option<DateTime<Utc>>>,
276    request_count: RwLock<usize>,
277}
278
279impl CircuitBreaker {
280    fn new(config: CircuitBreakerConfig) -> Self {
281        Self {
282            state: RwLock::new(CircuitBreakerState::Closed),
283            config,
284            failure_count: RwLock::new(0),
285            last_failure_time: RwLock::new(None),
286            request_count: RwLock::new(0),
287        }
288    }
289
290    async fn can_execute(&self) -> Result<bool, ImportanceAssessmentError> {
291        let state = self.state.read().await;
292        match *state {
293            CircuitBreakerState::Closed => Ok(true),
294            CircuitBreakerState::Open(opened_at) => {
295                let now = Utc::now();
296                let recovery_time = opened_at
297                    + chrono::Duration::seconds(self.config.recovery_timeout_seconds as i64);
298
299                if now >= recovery_time {
300                    drop(state);
301                    let mut state = self.state.write().await;
302                    *state = CircuitBreakerState::HalfOpen;
303                    Ok(true)
304                } else {
305                    Err(ImportanceAssessmentError::CircuitBreakerOpen(format!(
306                        "Circuit breaker is open until {}",
307                        recovery_time
308                    )))
309                }
310            }
311            CircuitBreakerState::HalfOpen => Ok(true),
312        }
313    }
314
315    async fn record_success(&self) {
316        let mut state = self.state.write().await;
317        *state = CircuitBreakerState::Closed;
318
319        let mut failure_count = self.failure_count.write().await;
320        *failure_count = 0;
321
322        let mut last_failure_time = self.last_failure_time.write().await;
323        *last_failure_time = None;
324    }
325
326    async fn record_failure(&self) {
327        let now = Utc::now();
328
329        {
330            let mut request_count = self.request_count.write().await;
331            *request_count += 1;
332        }
333
334        {
335            let mut failure_count = self.failure_count.write().await;
336            let mut last_failure_time = self.last_failure_time.write().await;
337
338            // Reset failure count if outside the failure window
339            if let Some(last_failure) = *last_failure_time {
340                let window_start =
341                    now - chrono::Duration::seconds(self.config.failure_window_seconds as i64);
342                if last_failure < window_start {
343                    *failure_count = 0;
344                }
345            }
346
347            *failure_count += 1;
348            *last_failure_time = Some(now);
349        }
350
351        // Check if we should open the circuit
352        let failure_count = *self.failure_count.read().await;
353        let request_count = *self.request_count.read().await;
354
355        if request_count >= self.config.minimum_requests
356            && failure_count >= self.config.failure_threshold
357        {
358            let mut state = self.state.write().await;
359            *state = CircuitBreakerState::Open(now);
360            warn!(
361                "Circuit breaker opened due to {} failures out of {} requests",
362                failure_count, request_count
363            );
364        }
365    }
366}
367
368/// Cached embedding with TTL
369#[derive(Debug, Clone)]
370struct CachedEmbedding {
371    embedding: Vec<f32>,
372    cached_at: DateTime<Utc>,
373    ttl_seconds: u64,
374}
375
376impl CachedEmbedding {
377    fn new(embedding: Vec<f32>, ttl_seconds: u64) -> Self {
378        Self {
379            embedding,
380            cached_at: Utc::now(),
381            ttl_seconds,
382        }
383    }
384
385    fn is_expired(&self) -> bool {
386        let now = Utc::now();
387        let expiry = self.cached_at + chrono::Duration::seconds(self.ttl_seconds as i64);
388        now >= expiry
389    }
390}
391
392/// Metrics for the importance assessment pipeline
393#[derive(Debug)]
394pub struct ImportanceAssessmentMetrics {
395    // Stage progression counters
396    pub stage1_executions: IntCounter,
397    pub stage2_executions: IntCounter,
398    pub stage3_executions: IntCounter,
399
400    // Stage timing histograms
401    pub stage1_duration: Histogram,
402    pub stage2_duration: Histogram,
403    pub stage3_duration: Histogram,
404
405    // Pipeline completion counters
406    pub completed_at_stage1: IntCounter,
407    pub completed_at_stage2: IntCounter,
408    pub completed_at_stage3: IntCounter,
409
410    // Performance metrics
411    pub stage1_threshold_violations: IntCounter,
412    pub stage2_threshold_violations: IntCounter,
413    pub stage3_threshold_violations: IntCounter,
414
415    // Cache metrics
416    pub embedding_cache_hits: IntCounter,
417    pub embedding_cache_misses: IntCounter,
418    pub embedding_cache_size: IntGauge,
419
420    // Circuit breaker metrics
421    pub circuit_breaker_opened: Counter,
422    pub circuit_breaker_half_open: Counter,
423    pub circuit_breaker_closed: Counter,
424    pub llm_call_failures: IntCounter,
425    pub llm_call_successes: IntCounter,
426
427    // Quality metrics
428    pub assessment_confidence: Histogram,
429    pub final_importance_scores: Histogram,
430}
431
432impl ImportanceAssessmentMetrics {
433    pub fn new(registry: &Registry) -> Result<Self> {
434        let stage1_executions = IntCounter::new(
435            "importance_assessment_stage1_executions_total",
436            "Total number of Stage 1 executions",
437        )?;
438        registry.register(Box::new(stage1_executions.clone()))?;
439
440        let stage2_executions = IntCounter::new(
441            "importance_assessment_stage2_executions_total",
442            "Total number of Stage 2 executions",
443        )?;
444        registry.register(Box::new(stage2_executions.clone()))?;
445
446        let stage3_executions = IntCounter::new(
447            "importance_assessment_stage3_executions_total",
448            "Total number of Stage 3 executions",
449        )?;
450        registry.register(Box::new(stage3_executions.clone()))?;
451
452        let stage1_duration = Histogram::with_opts(
453            prometheus::HistogramOpts::new(
454                "importance_assessment_stage1_duration_seconds",
455                "Duration of Stage 1 processing",
456            )
457            .buckets(vec![0.001, 0.005, 0.01, 0.02, 0.05, 0.1]),
458        )?;
459        registry.register(Box::new(stage1_duration.clone()))?;
460
461        let stage2_duration = Histogram::with_opts(
462            prometheus::HistogramOpts::new(
463                "importance_assessment_stage2_duration_seconds",
464                "Duration of Stage 2 processing",
465            )
466            .buckets(vec![0.01, 0.05, 0.1, 0.2, 0.5, 1.0]),
467        )?;
468        registry.register(Box::new(stage2_duration.clone()))?;
469
470        let stage3_duration = Histogram::with_opts(
471            prometheus::HistogramOpts::new(
472                "importance_assessment_stage3_duration_seconds",
473                "Duration of Stage 3 processing",
474            )
475            .buckets(vec![0.1, 0.5, 1.0, 2.0, 5.0, 10.0]),
476        )?;
477        registry.register(Box::new(stage3_duration.clone()))?;
478
479        let completed_at_stage1 = IntCounter::new(
480            "importance_assessment_completed_at_stage1_total",
481            "Total assessments completed at Stage 1",
482        )?;
483        registry.register(Box::new(completed_at_stage1.clone()))?;
484
485        let completed_at_stage2 = IntCounter::new(
486            "importance_assessment_completed_at_stage2_total",
487            "Total assessments completed at Stage 2",
488        )?;
489        registry.register(Box::new(completed_at_stage2.clone()))?;
490
491        let completed_at_stage3 = IntCounter::new(
492            "importance_assessment_completed_at_stage3_total",
493            "Total assessments completed at Stage 3",
494        )?;
495        registry.register(Box::new(completed_at_stage3.clone()))?;
496
497        let stage1_threshold_violations = IntCounter::new(
498            "importance_assessment_stage1_threshold_violations_total",
499            "Total Stage 1 performance threshold violations",
500        )?;
501        registry.register(Box::new(stage1_threshold_violations.clone()))?;
502
503        let stage2_threshold_violations = IntCounter::new(
504            "importance_assessment_stage2_threshold_violations_total",
505            "Total Stage 2 performance threshold violations",
506        )?;
507        registry.register(Box::new(stage2_threshold_violations.clone()))?;
508
509        let stage3_threshold_violations = IntCounter::new(
510            "importance_assessment_stage3_threshold_violations_total",
511            "Total Stage 3 performance threshold violations",
512        )?;
513        registry.register(Box::new(stage3_threshold_violations.clone()))?;
514
515        let embedding_cache_hits = IntCounter::new(
516            "importance_assessment_embedding_cache_hits_total",
517            "Total embedding cache hits",
518        )?;
519        registry.register(Box::new(embedding_cache_hits.clone()))?;
520
521        let embedding_cache_misses = IntCounter::new(
522            "importance_assessment_embedding_cache_misses_total",
523            "Total embedding cache misses",
524        )?;
525        registry.register(Box::new(embedding_cache_misses.clone()))?;
526
527        let embedding_cache_size = IntGauge::new(
528            "importance_assessment_embedding_cache_size",
529            "Current size of embedding cache",
530        )?;
531        registry.register(Box::new(embedding_cache_size.clone()))?;
532
533        let circuit_breaker_opened = Counter::new(
534            "importance_assessment_circuit_breaker_opened_total",
535            "Total times circuit breaker opened",
536        )?;
537        registry.register(Box::new(circuit_breaker_opened.clone()))?;
538
539        let circuit_breaker_half_open = Counter::new(
540            "importance_assessment_circuit_breaker_half_open_total",
541            "Total times circuit breaker went half-open",
542        )?;
543        registry.register(Box::new(circuit_breaker_half_open.clone()))?;
544
545        let circuit_breaker_closed = Counter::new(
546            "importance_assessment_circuit_breaker_closed_total",
547            "Total times circuit breaker closed",
548        )?;
549        registry.register(Box::new(circuit_breaker_closed.clone()))?;
550
551        let llm_call_failures = IntCounter::new(
552            "importance_assessment_llm_call_failures_total",
553            "Total LLM call failures",
554        )?;
555        registry.register(Box::new(llm_call_failures.clone()))?;
556
557        let llm_call_successes = IntCounter::new(
558            "importance_assessment_llm_call_successes_total",
559            "Total LLM call successes",
560        )?;
561        registry.register(Box::new(llm_call_successes.clone()))?;
562
563        let assessment_confidence = Histogram::with_opts(
564            prometheus::HistogramOpts::new(
565                "importance_assessment_confidence",
566                "Confidence scores of assessments",
567            )
568            .buckets(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]),
569        )?;
570        registry.register(Box::new(assessment_confidence.clone()))?;
571
572        let final_importance_scores = Histogram::with_opts(
573            prometheus::HistogramOpts::new(
574                "importance_assessment_final_scores",
575                "Final importance scores from assessments",
576            )
577            .buckets(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]),
578        )?;
579        registry.register(Box::new(final_importance_scores.clone()))?;
580
581        Ok(Self {
582            stage1_executions,
583            stage2_executions,
584            stage3_executions,
585            stage1_duration,
586            stage2_duration,
587            stage3_duration,
588            completed_at_stage1,
589            completed_at_stage2,
590            completed_at_stage3,
591            stage1_threshold_violations,
592            stage2_threshold_violations,
593            stage3_threshold_violations,
594            embedding_cache_hits,
595            embedding_cache_misses,
596            embedding_cache_size,
597            circuit_breaker_opened,
598            circuit_breaker_half_open,
599            circuit_breaker_closed,
600            llm_call_failures,
601            llm_call_successes,
602            assessment_confidence,
603            final_importance_scores,
604        })
605    }
606}
607
608/// Main importance assessment pipeline
609pub struct ImportanceAssessmentPipeline {
610    config: ImportanceAssessmentConfig,
611    stage1_patterns: Vec<(Regex, ImportancePattern)>,
612    embedding_service: Arc<dyn EmbeddingService>,
613    embedding_cache: RwLock<HashMap<String, CachedEmbedding>>,
614    circuit_breaker: CircuitBreaker,
615    metrics: ImportanceAssessmentMetrics,
616    http_client: reqwest::Client,
617}
618
619impl ImportanceAssessmentPipeline {
620    pub fn new(
621        config: ImportanceAssessmentConfig,
622        embedding_service: Arc<dyn EmbeddingService>,
623        metrics_registry: &Registry,
624    ) -> Result<Self> {
625        // Compile regex patterns for Stage 1
626        let mut stage1_patterns = Vec::new();
627        for pattern in &config.stage1.pattern_library {
628            match Regex::new(&pattern.pattern) {
629                Ok(regex) => stage1_patterns.push((regex, pattern.clone())),
630                Err(e) => {
631                    error!(
632                        "Failed to compile regex pattern '{}': {}",
633                        pattern.pattern, e
634                    );
635                    return Err(ImportanceAssessmentError::Configuration(format!(
636                        "Invalid regex pattern '{}': {}",
637                        pattern.pattern, e
638                    ))
639                    .into());
640                }
641            }
642        }
643
644        let metrics = ImportanceAssessmentMetrics::new(metrics_registry)?;
645
646        let circuit_breaker = CircuitBreaker::new(config.circuit_breaker.clone());
647
648        let http_client = reqwest::Client::builder()
649            .timeout(Duration::from_millis(config.stage3.max_processing_time_ms))
650            .build()?;
651
652        Ok(Self {
653            config,
654            stage1_patterns,
655            embedding_service,
656            embedding_cache: RwLock::new(HashMap::new()),
657            circuit_breaker,
658            metrics,
659            http_client,
660        })
661    }
662
663    /// Assess the importance of a memory content string
664    pub async fn assess_importance(
665        &self,
666        content: &str,
667    ) -> Result<ImportanceAssessmentResult, ImportanceAssessmentError> {
668        let assessment_start = Instant::now();
669        let mut stage_results = Vec::new();
670
671        info!(
672            "Starting importance assessment for content length: {}",
673            content.len()
674        );
675
676        // Stage 1: Pattern matching
677        let stage1_result = self.execute_stage1(content).await?;
678        let stage1_passed = stage1_result.passed_threshold;
679        stage_results.push(stage1_result.clone());
680
681        if stage1_passed {
682            debug!("Stage 1 passed threshold, proceeding to Stage 2");
683
684            // Stage 2: Semantic similarity
685            let stage2_result = self.execute_stage2(content).await?;
686            let stage2_passed = stage2_result.passed_threshold;
687            stage_results.push(stage2_result.clone());
688
689            if stage2_passed {
690                debug!("Stage 2 passed threshold, proceeding to Stage 3");
691
692                // Stage 3: LLM scoring
693                let stage3_result = self.execute_stage3(content).await?;
694                stage_results.push(stage3_result.clone());
695
696                self.metrics.completed_at_stage3.inc();
697
698                let final_score = stage3_result.score;
699                let confidence = stage3_result.confidence;
700
701                let result = ImportanceAssessmentResult {
702                    importance_score: final_score,
703                    final_stage: AssessmentStage::Stage3LLMScoring,
704                    stage_results,
705                    total_processing_time_ms: assessment_start.elapsed().as_millis() as u64,
706                    assessed_at: Utc::now(),
707                    confidence,
708                    explanation: self.extract_explanation_from_stage3(&stage3_result),
709                };
710
711                self.record_final_metrics(&result);
712                return Ok(result);
713            } else {
714                self.metrics.completed_at_stage2.inc();
715
716                let final_score = stage2_result.score;
717                let confidence = stage2_result.confidence;
718
719                let result = ImportanceAssessmentResult {
720                    importance_score: final_score,
721                    final_stage: AssessmentStage::Stage2SemanticSimilarity,
722                    stage_results,
723                    total_processing_time_ms: assessment_start.elapsed().as_millis() as u64,
724                    assessed_at: Utc::now(),
725                    confidence,
726                    explanation: Some(
727                        "Assessment completed at Stage 2 based on semantic similarity".to_string(),
728                    ),
729                };
730
731                self.record_final_metrics(&result);
732                return Ok(result);
733            }
734        } else {
735            self.metrics.completed_at_stage1.inc();
736
737            let final_score = stage1_result.score;
738            let confidence = stage1_result.confidence;
739
740            let result = ImportanceAssessmentResult {
741                importance_score: final_score,
742                final_stage: AssessmentStage::Stage1PatternMatching,
743                stage_results,
744                total_processing_time_ms: assessment_start.elapsed().as_millis() as u64,
745                assessed_at: Utc::now(),
746                confidence,
747                explanation: Some(
748                    "Assessment completed at Stage 1 based on pattern matching".to_string(),
749                ),
750            };
751
752            self.record_final_metrics(&result);
753            return Ok(result);
754        }
755    }
756
757    async fn execute_stage1(
758        &self,
759        content: &str,
760    ) -> Result<StageResult, ImportanceAssessmentError> {
761        let stage_start = Instant::now();
762        self.metrics.stage1_executions.inc();
763
764        let timeout_duration = Duration::from_millis(self.config.stage1.max_processing_time_ms);
765
766        let result = timeout(timeout_duration, async {
767            let mut matched_patterns = Vec::new();
768            let mut total_score = 0.0;
769            let mut max_weight: f64 = 0.0;
770
771            for (regex, pattern) in &self.stage1_patterns {
772                for mat in regex.find_iter(content) {
773                    let match_text = mat.as_str().to_string();
774                    let match_position = mat.start();
775
776                    // Calculate context boost
777                    let context_boost = self.calculate_context_boost(
778                        content,
779                        match_position,
780                        &pattern.context_boosters,
781                    );
782                    let effective_weight = pattern.weight * (1.0 + context_boost);
783
784                    matched_patterns.push(MatchedPattern {
785                        pattern_name: pattern.name.clone(),
786                        pattern_category: pattern.category.clone(),
787                        match_text,
788                        match_position,
789                        weight: pattern.weight,
790                        context_boost,
791                    });
792
793                    total_score += effective_weight;
794                    max_weight = max_weight.max(effective_weight);
795                }
796            }
797
798            // Normalize score to 0.0-1.0 range
799            let normalized_score = if matched_patterns.is_empty() {
800                0.0
801            } else {
802                (total_score / (matched_patterns.len() as f64)).min(1.0)
803            };
804
805            // Calculate confidence based on pattern diversity and strength
806            let confidence = if matched_patterns.is_empty() {
807                0.1 // Low confidence for no matches
808            } else {
809                let pattern_diversity = matched_patterns
810                    .iter()
811                    .map(|m| m.pattern_category.clone())
812                    .collect::<std::collections::HashSet<_>>()
813                    .len() as f64;
814                let base_confidence =
815                    (pattern_diversity / self.config.stage1.pattern_library.len() as f64).min(1.0);
816                let strength_boost = (max_weight / 1.0).min(0.3); // Max 30% boost from pattern strength
817                (base_confidence + strength_boost).min(1.0)
818            };
819
820            let passed_threshold = confidence >= self.config.stage1.confidence_threshold;
821
822            StageResult {
823                stage: AssessmentStage::Stage1PatternMatching,
824                score: normalized_score,
825                confidence,
826                processing_time_ms: stage_start.elapsed().as_millis() as u64,
827                passed_threshold,
828                details: StageDetails::Stage1 {
829                    matched_patterns,
830                    total_patterns_checked: self.stage1_patterns.len(),
831                },
832            }
833        })
834        .await;
835
836        match result {
837            Ok(stage_result) => {
838                let duration_seconds = stage_start.elapsed().as_secs_f64();
839                self.metrics.stage1_duration.observe(duration_seconds);
840
841                // Check performance threshold
842                if stage_result.processing_time_ms > self.config.performance.stage1_target_ms {
843                    self.metrics.stage1_threshold_violations.inc();
844                    warn!(
845                        "Stage 1 exceeded target time: {}ms > {}ms",
846                        stage_result.processing_time_ms, self.config.performance.stage1_target_ms
847                    );
848                }
849
850                debug!(
851                    "Stage 1 completed in {}ms with score {:.3} and confidence {:.3}",
852                    stage_result.processing_time_ms, stage_result.score, stage_result.confidence
853                );
854
855                Ok(stage_result)
856            }
857            Err(_) => {
858                self.metrics.stage1_threshold_violations.inc();
859                Err(ImportanceAssessmentError::Timeout(format!(
860                    "Stage 1 exceeded maximum processing time of {}ms",
861                    self.config.stage1.max_processing_time_ms
862                )))
863            }
864        }
865    }
866
867    async fn execute_stage2(
868        &self,
869        content: &str,
870    ) -> Result<StageResult, ImportanceAssessmentError> {
871        let stage_start = Instant::now();
872        self.metrics.stage2_executions.inc();
873
874        let timeout_duration = Duration::from_millis(self.config.stage2.max_processing_time_ms);
875
876        let stage2_result = async {
877            // Check cache first
878            let content_hash = format!("{:x}", md5::compute(content.as_bytes()));
879            let cached_embedding = {
880                let cache = self.embedding_cache.read().await;
881                cache.get(&content_hash).cloned()
882            };
883
884            let (content_embedding, cache_hit, embedding_time) = if let Some(cached) =
885                cached_embedding
886            {
887                if !cached.is_expired() {
888                    self.metrics.embedding_cache_hits.inc();
889                    (cached.embedding, true, None)
890                } else {
891                    // Cache expired, remove and generate new
892                    {
893                        let mut cache = self.embedding_cache.write().await;
894                        cache.remove(&content_hash);
895                        self.metrics.embedding_cache_size.set(cache.len() as i64);
896                    }
897                    self.metrics.embedding_cache_misses.inc();
898                    let embed_start = Instant::now();
899                    let embedding = match self.embedding_service.generate_embedding(content).await {
900                        Ok(emb) => emb,
901                        Err(e) => {
902                            return Err(ImportanceAssessmentError::Stage2Failed(format!(
903                                "Embedding generation failed: {}",
904                                e
905                            )))
906                        }
907                    };
908                    let embed_time = embed_start.elapsed().as_millis() as u64;
909
910                    // Cache the new embedding
911                    {
912                        let mut cache = self.embedding_cache.write().await;
913                        cache.insert(
914                            content_hash,
915                            CachedEmbedding::new(
916                                embedding.clone(),
917                                self.config.stage2.embedding_cache_ttl_seconds,
918                            ),
919                        );
920                        self.metrics.embedding_cache_size.set(cache.len() as i64);
921                    }
922
923                    (embedding, false, Some(embed_time))
924                }
925            } else {
926                self.metrics.embedding_cache_misses.inc();
927                let embed_start = Instant::now();
928                let embedding = match self.embedding_service.generate_embedding(content).await {
929                    Ok(emb) => emb,
930                    Err(e) => {
931                        return Err(ImportanceAssessmentError::Stage2Failed(format!(
932                            "Embedding generation failed: {}",
933                            e
934                        )))
935                    }
936                };
937                let embed_time = embed_start.elapsed().as_millis() as u64;
938
939                // Cache the new embedding
940                {
941                    let mut cache = self.embedding_cache.write().await;
942                    cache.insert(
943                        content_hash,
944                        CachedEmbedding::new(
945                            embedding.clone(),
946                            self.config.stage2.embedding_cache_ttl_seconds,
947                        ),
948                    );
949                    self.metrics.embedding_cache_size.set(cache.len() as i64);
950                }
951
952                (embedding, false, Some(embed_time))
953            };
954
955            // Calculate similarity scores with reference embeddings
956            let mut similarity_scores = Vec::new();
957            let mut total_weighted_score = 0.0;
958            let mut total_weight = 0.0;
959
960            for reference in &self.config.stage2.reference_embeddings {
961                let similarity =
962                    self.calculate_cosine_similarity(&content_embedding, &reference.embedding);
963
964                if similarity >= self.config.stage2.similarity_threshold {
965                    let weighted_score = similarity as f64 * reference.weight;
966
967                    similarity_scores.push(SimilarityScore {
968                        reference_name: reference.name.clone(),
969                        reference_category: reference.category.clone(),
970                        similarity,
971                        weight: reference.weight,
972                        weighted_score,
973                    });
974
975                    total_weighted_score += weighted_score;
976                    total_weight += reference.weight;
977                }
978            }
979
980            // Normalize score to 0.0-1.0 range
981            let normalized_score = if total_weight > 0.0 {
982                (total_weighted_score / total_weight).min(1.0)
983            } else {
984                0.0
985            };
986
987            // Calculate confidence based on number of matches and their strength
988            let confidence = if similarity_scores.is_empty() {
989                0.1 // Low confidence for no semantic matches
990            } else {
991                let match_ratio = similarity_scores.len() as f64
992                    / self.config.stage2.reference_embeddings.len() as f64;
993                let avg_similarity = similarity_scores
994                    .iter()
995                    .map(|s| s.similarity as f64)
996                    .sum::<f64>()
997                    / similarity_scores.len() as f64;
998                (match_ratio * 0.5 + avg_similarity * 0.5).min(1.0)
999            };
1000
1001            let passed_threshold = confidence >= self.config.stage2.confidence_threshold;
1002
1003            Ok(StageResult {
1004                stage: AssessmentStage::Stage2SemanticSimilarity,
1005                score: normalized_score,
1006                confidence,
1007                processing_time_ms: stage_start.elapsed().as_millis() as u64,
1008                passed_threshold,
1009                details: StageDetails::Stage2 {
1010                    similarity_scores,
1011                    cache_hit,
1012                    embedding_generation_time_ms: embedding_time,
1013                },
1014            })
1015        };
1016
1017        let result = timeout(timeout_duration, stage2_result).await;
1018
1019        match result {
1020            Ok(Ok(stage_result)) => {
1021                let duration_seconds = stage_start.elapsed().as_secs_f64();
1022                self.metrics.stage2_duration.observe(duration_seconds);
1023
1024                // Check performance threshold
1025                if stage_result.processing_time_ms > self.config.performance.stage2_target_ms {
1026                    self.metrics.stage2_threshold_violations.inc();
1027                    warn!(
1028                        "Stage 2 exceeded target time: {}ms > {}ms",
1029                        stage_result.processing_time_ms, self.config.performance.stage2_target_ms
1030                    );
1031                }
1032
1033                debug!(
1034                    "Stage 2 completed in {}ms with score {:.3} and confidence {:.3}",
1035                    stage_result.processing_time_ms, stage_result.score, stage_result.confidence
1036                );
1037
1038                Ok(stage_result)
1039            }
1040            Ok(Err(e)) => Err(e),
1041            Err(_) => {
1042                self.metrics.stage2_threshold_violations.inc();
1043                Err(ImportanceAssessmentError::Timeout(format!(
1044                    "Stage 2 exceeded maximum processing time of {}ms",
1045                    self.config.stage2.max_processing_time_ms
1046                )))
1047            }
1048        }
1049    }
1050
1051    async fn execute_stage3(
1052        &self,
1053        content: &str,
1054    ) -> Result<StageResult, ImportanceAssessmentError> {
1055        let stage_start = Instant::now();
1056        self.metrics.stage3_executions.inc();
1057
1058        // Check circuit breaker
1059        if !self.circuit_breaker.can_execute().await? {
1060            return Err(ImportanceAssessmentError::CircuitBreakerOpen(
1061                "LLM circuit breaker is open".to_string(),
1062            ));
1063        }
1064
1065        let timeout_duration = Duration::from_millis(self.config.stage3.max_processing_time_ms);
1066
1067        let result = timeout(timeout_duration, async {
1068            // Prepare LLM prompt
1069            let prompt = self
1070                .config
1071                .stage3
1072                .prompt_template
1073                .replace("{content}", content)
1074                .replace("{timestamp}", &Utc::now().to_rfc3339());
1075
1076            // Make LLM request
1077            let llm_response = self.call_llm(&prompt).await?;
1078
1079            // Parse LLM response to extract importance score and confidence
1080            let (importance_score, confidence) = self.parse_llm_response(&llm_response)?;
1081
1082            let passed_threshold = true; // Stage 3 is the final stage
1083
1084            Ok::<StageResult, ImportanceAssessmentError>(StageResult {
1085                stage: AssessmentStage::Stage3LLMScoring,
1086                score: importance_score,
1087                confidence,
1088                processing_time_ms: stage_start.elapsed().as_millis() as u64,
1089                passed_threshold,
1090                details: StageDetails::Stage3 {
1091                    llm_response,
1092                    prompt_tokens: Some(prompt.len() / 4), // Rough token estimate
1093                    completion_tokens: None,               // Would need to be provided by LLM API
1094                    model_used: "configured-model".to_string(),
1095                },
1096            })
1097        })
1098        .await;
1099
1100        match result {
1101            Ok(Ok(stage_result)) => {
1102                let duration_seconds = stage_start.elapsed().as_secs_f64();
1103                self.metrics.stage3_duration.observe(duration_seconds);
1104                self.metrics.llm_call_successes.inc();
1105                self.circuit_breaker.record_success().await;
1106
1107                // Check performance threshold
1108                if stage_result.processing_time_ms > self.config.performance.stage3_target_ms {
1109                    self.metrics.stage3_threshold_violations.inc();
1110                    warn!(
1111                        "Stage 3 exceeded target time: {}ms > {}ms",
1112                        stage_result.processing_time_ms, self.config.performance.stage3_target_ms
1113                    );
1114                }
1115
1116                debug!(
1117                    "Stage 3 completed in {}ms with score {:.3} and confidence {:.3}",
1118                    stage_result.processing_time_ms, stage_result.score, stage_result.confidence
1119                );
1120
1121                Ok(stage_result)
1122            }
1123            Ok(Err(e)) => {
1124                self.metrics.llm_call_failures.inc();
1125                self.circuit_breaker.record_failure().await;
1126                Err(e)
1127            }
1128            Err(_) => {
1129                self.metrics.stage3_threshold_violations.inc();
1130                self.metrics.llm_call_failures.inc();
1131                self.circuit_breaker.record_failure().await;
1132                Err(ImportanceAssessmentError::Timeout(format!(
1133                    "Stage 3 exceeded maximum processing time of {}ms",
1134                    self.config.stage3.max_processing_time_ms
1135                )))
1136            }
1137        }
1138    }
1139
1140    fn calculate_context_boost(
1141        &self,
1142        content: &str,
1143        match_position: usize,
1144        boosters: &[String],
1145    ) -> f64 {
1146        let window_size = 100; // Characters to check around the match
1147        let start = match_position.saturating_sub(window_size);
1148        let end = (match_position + window_size).min(content.len());
1149        let context = &content[start..end].to_lowercase();
1150
1151        let mut boost: f64 = 0.0;
1152        for booster in boosters {
1153            if context.contains(&booster.to_lowercase()) {
1154                boost += 0.1; // 10% boost per context word
1155            }
1156        }
1157
1158        boost.min(0.5) // Maximum 50% boost
1159    }
1160
1161    fn calculate_cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
1162        if a.len() != b.len() {
1163            return 0.0;
1164        }
1165
1166        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1167        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
1168        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1169
1170        if norm_a == 0.0 || norm_b == 0.0 {
1171            return 0.0;
1172        }
1173
1174        dot_product / (norm_a * norm_b)
1175    }
1176
1177    async fn call_llm(&self, prompt: &str) -> Result<String, ImportanceAssessmentError> {
1178        // This is a placeholder implementation. In a real system, this would call
1179        // an actual LLM service like OpenAI, Anthropic, or a local model.
1180
1181        let request_body = serde_json::json!({
1182            "prompt": prompt,
1183            "max_tokens": 100,
1184            "temperature": 0.1
1185        });
1186
1187        let response = self
1188            .http_client
1189            .post(&self.config.stage3.llm_endpoint)
1190            .json(&request_body)
1191            .send()
1192            .await
1193            .map_err(|e| {
1194                ImportanceAssessmentError::Stage3Failed(format!("LLM request failed: {}", e))
1195            })?;
1196
1197        if !response.status().is_success() {
1198            return Err(ImportanceAssessmentError::Stage3Failed(format!(
1199                "LLM service returned status: {}",
1200                response.status()
1201            )));
1202        }
1203
1204        let response_body: serde_json::Value = response.json().await.map_err(|e| {
1205            ImportanceAssessmentError::Stage3Failed(format!("Failed to parse LLM response: {}", e))
1206        })?;
1207
1208        response_body["choices"][0]["text"]
1209            .as_str()
1210            .ok_or_else(|| {
1211                ImportanceAssessmentError::Stage3Failed("Invalid LLM response format".to_string())
1212            })
1213            .map(|s| s.to_string())
1214    }
1215
1216    fn parse_llm_response(&self, response: &str) -> Result<(f64, f64), ImportanceAssessmentError> {
1217        // Parse LLM response to extract importance score and confidence
1218        // This is a simplified parser - in practice, you'd want more robust parsing
1219
1220        let lines: Vec<&str> = response.lines().collect();
1221        let mut importance_score = 0.5; // Default
1222        let mut confidence = 0.7; // Default
1223
1224        for line in lines {
1225            let line = line.trim().to_lowercase();
1226
1227            // Look for importance score
1228            if line.contains("importance:") || line.contains("score:") {
1229                if let Some(score_str) = line.split(':').nth(1) {
1230                    if let Ok(score) = score_str.trim().parse::<f64>() {
1231                        importance_score = score.clamp(0.0, 1.0);
1232                    }
1233                }
1234            }
1235
1236            // Look for confidence
1237            if line.contains("confidence:") {
1238                if let Some(conf_str) = line.split(':').nth(1) {
1239                    if let Ok(conf) = conf_str.trim().parse::<f64>() {
1240                        confidence = conf.clamp(0.0, 1.0);
1241                    }
1242                }
1243            }
1244        }
1245
1246        Ok((importance_score, confidence))
1247    }
1248
1249    fn extract_explanation_from_stage3(&self, stage_result: &StageResult) -> Option<String> {
1250        if let StageDetails::Stage3 { llm_response, .. } = &stage_result.details {
1251            Some(llm_response.clone())
1252        } else {
1253            None
1254        }
1255    }
1256
1257    fn record_final_metrics(&self, result: &ImportanceAssessmentResult) {
1258        self.metrics
1259            .assessment_confidence
1260            .observe(result.confidence);
1261        self.metrics
1262            .final_importance_scores
1263            .observe(result.importance_score);
1264
1265        info!(
1266            "Importance assessment completed: score={:.3}, confidence={:.3}, stage={:?}, time={}ms",
1267            result.importance_score,
1268            result.confidence,
1269            result.final_stage,
1270            result.total_processing_time_ms
1271        );
1272    }
1273
1274    /// Get current pipeline statistics
1275    pub async fn get_statistics(&self) -> PipelineStatistics {
1276        let cache_size = self.embedding_cache.read().await.len();
1277
1278        PipelineStatistics {
1279            cache_size,
1280            stage1_executions: self.metrics.stage1_executions.get(),
1281            stage2_executions: self.metrics.stage2_executions.get(),
1282            stage3_executions: self.metrics.stage3_executions.get(),
1283            completed_at_stage1: self.metrics.completed_at_stage1.get(),
1284            completed_at_stage2: self.metrics.completed_at_stage2.get(),
1285            completed_at_stage3: self.metrics.completed_at_stage3.get(),
1286            cache_hits: self.metrics.embedding_cache_hits.get(),
1287            cache_misses: self.metrics.embedding_cache_misses.get(),
1288            circuit_breaker_state: format!("{:?}", *self.circuit_breaker.state.read().await),
1289            llm_success_rate: {
1290                let successes = self.metrics.llm_call_successes.get() as f64;
1291                let failures = self.metrics.llm_call_failures.get() as f64;
1292                let total = successes + failures;
1293                if total > 0.0 {
1294                    successes / total
1295                } else {
1296                    1.0
1297                }
1298            },
1299        }
1300    }
1301
1302    /// Clear the embedding cache
1303    pub async fn clear_cache(&self) {
1304        let mut cache = self.embedding_cache.write().await;
1305        cache.clear();
1306        self.metrics.embedding_cache_size.set(0);
1307        info!("Embedding cache cleared");
1308    }
1309
1310    /// Get cache hit ratio
1311    pub fn get_cache_hit_ratio(&self) -> f64 {
1312        let hits = self.metrics.embedding_cache_hits.get() as f64;
1313        let misses = self.metrics.embedding_cache_misses.get() as f64;
1314        let total = hits + misses;
1315        if total > 0.0 {
1316            hits / total
1317        } else {
1318            0.0
1319        }
1320    }
1321}
1322
1323#[derive(Debug, Clone, Serialize, Deserialize)]
1324pub struct PipelineStatistics {
1325    pub cache_size: usize,
1326    pub stage1_executions: u64,
1327    pub stage2_executions: u64,
1328    pub stage3_executions: u64,
1329    pub completed_at_stage1: u64,
1330    pub completed_at_stage2: u64,
1331    pub completed_at_stage3: u64,
1332    pub cache_hits: u64,
1333    pub cache_misses: u64,
1334    pub circuit_breaker_state: String,
1335    pub llm_success_rate: f64,
1336}
1337
1338impl Default for ImportanceAssessmentConfig {
1339    fn default() -> Self {
1340        Self {
1341            stage1: Stage1Config {
1342                confidence_threshold: 0.6,
1343                pattern_library: vec![
1344                    ImportancePattern {
1345                        name: "remember_command".to_string(),
1346                        pattern: r"(?i)\b(remember|recall|don't forget)\b".to_string(),
1347                        weight: 0.8,
1348                        context_boosters: vec!["important".to_string(), "critical".to_string()],
1349                        category: "memory".to_string(),
1350                    },
1351                    ImportancePattern {
1352                        name: "preference_statement".to_string(),
1353                        pattern: r"(?i)\b(prefer|like|want|choose)\b".to_string(),
1354                        weight: 0.7,
1355                        context_boosters: vec!["always".to_string(), "usually".to_string()],
1356                        category: "preference".to_string(),
1357                    },
1358                    ImportancePattern {
1359                        name: "decision_making".to_string(),
1360                        pattern: r"(?i)\b(decide|decision|choose|select)\b".to_string(),
1361                        weight: 0.75,
1362                        context_boosters: vec!["final".to_string(), "official".to_string()],
1363                        category: "decision".to_string(),
1364                    },
1365                    ImportancePattern {
1366                        name: "correction".to_string(),
1367                        pattern: r"(?i)\b(correct|fix|wrong|mistake|error)\b".to_string(),
1368                        weight: 0.6,
1369                        context_boosters: vec!["actually".to_string(), "should".to_string()],
1370                        category: "correction".to_string(),
1371                    },
1372                    ImportancePattern {
1373                        name: "importance_marker".to_string(),
1374                        pattern: r"(?i)\b(important|critical|crucial|vital|essential)\b".to_string(),
1375                        weight: 0.9,
1376                        context_boosters: vec!["very".to_string(), "extremely".to_string()],
1377                        category: "importance".to_string(),
1378                    },
1379                ],
1380                max_processing_time_ms: 10,
1381            },
1382            stage2: Stage2Config {
1383                confidence_threshold: 0.7,
1384                max_processing_time_ms: 100,
1385                embedding_cache_ttl_seconds: 3600, // 1 hour
1386                similarity_threshold: 0.7,
1387                reference_embeddings: vec![], // Would be populated with pre-computed embeddings
1388            },
1389            stage3: Stage3Config {
1390                max_processing_time_ms: 1000,
1391                llm_endpoint: "http://localhost:8080/generate".to_string(),
1392                max_concurrent_requests: 5,
1393                prompt_template: "Assess the importance of this content on a scale of 0.0 to 1.0. Consider context, user intent, and actionability.\n\nContent: {content}\n\nProvide your assessment as:\nImportance: [score]\nConfidence: [confidence]\nReasoning: [explanation]".to_string(),
1394                target_usage_percentage: 20.0,
1395            },
1396            circuit_breaker: CircuitBreakerConfig {
1397                failure_threshold: 5,
1398                failure_window_seconds: 60,
1399                recovery_timeout_seconds: 30,
1400                minimum_requests: 3,
1401            },
1402            performance: PerformanceConfig {
1403                stage1_target_ms: 10,
1404                stage2_target_ms: 100,
1405                stage3_target_ms: 1000,
1406            },
1407        }
1408    }
1409}