Skip to main content

brainwires_seal/
learning.rs

1//! Self-Evolving Learning Mechanism
2//!
3//! Enables the system to learn from successful interactions without retraining.
4//! Implements both local (per-session) and global (cross-session) memory.
5//!
6//! ## Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────┐
10//! │         Learning Coordinator        │
11//! │                                     │
12//! │  ┌─────────────┐  ┌─────────────┐  │
13//! │  │Local Memory │  │Global Memory│  │
14//! │  │ (Session)   │  │ (Persisted) │  │
15//! │  └─────────────┘  └─────────────┘  │
16//! └─────────────────────────────────────┘
17//! ```
18//!
19//! ## Local Memory (Per-Session)
20//!
21//! - Tracks entities mentioned in the current conversation
22//! - Maintains coreference resolution history
23//! - Stores query patterns used in this session
24//! - Focus stack for active entities
25//!
26//! ## Global Memory (Cross-Session)
27//!
28//! - Template library organized by question type
29//! - Query patterns with success/failure statistics
30//! - Resolution patterns that worked well
31//! - Persisted to LanceDB for retrieval
32//!
33//! ## Learning Flow
34//!
35//! 1. User query is processed
36//! 2. Query core is extracted
37//! 3. Check global memory for similar patterns
38//! 4. Execute query
39//! 5. Record outcome (success/failure, result count)
40//! 6. If successful: generalize pattern, add to global memory
41//! 7. If failed: record failure for pattern avoidance
42
43use super::query_core::{QueryCore, QuestionType};
44use brainwires_core::confidence::ResponseConfidence;
45use brainwires_core::graph::EntityType;
46use brainwires_tool_runtime::{ToolErrorCategory, ToolOutcome};
47use chrono::Utc;
48use std::collections::HashMap;
49
50/// A tracked entity in local memory
51#[derive(Debug, Clone)]
52pub struct TrackedEntity {
53    /// Entity name
54    pub name: String,
55    /// Entity type
56    pub entity_type: EntityType,
57    /// Turn numbers when mentioned
58    pub mention_turns: Vec<u32>,
59    /// Whether this entity was queried about
60    pub was_queried: bool,
61    /// Whether this entity was modified
62    pub was_modified: bool,
63    /// Relationships discovered for this entity
64    pub discovered_relations: Vec<(String, String)>, // (relation_type, target)
65}
66
67impl TrackedEntity {
68    /// Create a new tracked entity
69    pub fn new(name: String, entity_type: EntityType, turn: u32) -> Self {
70        Self {
71            name,
72            entity_type,
73            mention_turns: vec![turn],
74            was_queried: false,
75            was_modified: false,
76            discovered_relations: Vec::new(),
77        }
78    }
79
80    /// Record a mention
81    pub fn record_mention(&mut self, turn: u32) {
82        if !self.mention_turns.contains(&turn) {
83            self.mention_turns.push(turn);
84        }
85    }
86
87    /// Get the frequency of mentions
88    pub fn frequency(&self) -> usize {
89        self.mention_turns.len()
90    }
91}
92
93/// Record of a coreference resolution
94#[derive(Debug, Clone)]
95pub struct CoreferenceRecord {
96    /// The original reference text
97    pub reference: String,
98    /// The resolved entity
99    pub resolved_to: String,
100    /// Confidence of the resolution
101    pub confidence: f32,
102    /// Turn when resolved
103    pub turn: u32,
104    /// Whether the resolution was confirmed correct
105    pub confirmed: Option<bool>,
106}
107
108/// Record of a query execution
109#[derive(Debug, Clone)]
110pub struct QueryRecord {
111    /// Original query text
112    pub original: String,
113    /// Resolved query (after coreference)
114    pub resolved: String,
115    /// Question type
116    pub question_type: QuestionType,
117    /// Query core S-expression
118    pub query_sexp: Option<String>,
119    /// Turn when executed
120    pub turn: u32,
121    /// Whether successful
122    pub success: bool,
123    /// Number of results
124    pub result_count: usize,
125    /// Execution time in ms
126    pub execution_time_ms: u64,
127}
128
129/// Local memory for a single conversation session
130#[derive(Debug)]
131pub struct LocalMemory {
132    /// Conversation ID
133    pub conversation_id: String,
134    /// Tracked entities
135    pub entities: HashMap<String, TrackedEntity>,
136    /// Coreference resolution history
137    pub coreference_log: Vec<CoreferenceRecord>,
138    /// Query history
139    pub query_history: Vec<QueryRecord>,
140    /// Current focus stack (entity names)
141    pub focus_stack: Vec<String>,
142    /// Current turn number
143    pub current_turn: u32,
144}
145
146impl LocalMemory {
147    /// Create new local memory for a conversation
148    pub fn new(conversation_id: String) -> Self {
149        Self {
150            conversation_id,
151            entities: HashMap::new(),
152            coreference_log: Vec::new(),
153            query_history: Vec::new(),
154            focus_stack: Vec::new(),
155            current_turn: 0,
156        }
157    }
158
159    /// Advance to the next turn
160    pub fn next_turn(&mut self) {
161        self.current_turn += 1;
162    }
163
164    /// Track an entity mention
165    pub fn track_entity(&mut self, name: &str, entity_type: EntityType) {
166        if let Some(entity) = self.entities.get_mut(name) {
167            entity.record_mention(self.current_turn);
168        } else {
169            self.entities.insert(
170                name.to_string(),
171                TrackedEntity::new(name.to_string(), entity_type, self.current_turn),
172            );
173        }
174
175        // Update focus stack
176        self.focus_stack.retain(|n| n != name);
177        self.focus_stack.insert(0, name.to_string());
178        if self.focus_stack.len() > 20 {
179            self.focus_stack.truncate(20);
180        }
181    }
182
183    /// Record a coreference resolution
184    pub fn record_coreference(&mut self, reference: &str, resolved_to: &str, confidence: f32) {
185        self.coreference_log.push(CoreferenceRecord {
186            reference: reference.to_string(),
187            resolved_to: resolved_to.to_string(),
188            confidence,
189            turn: self.current_turn,
190            confirmed: None,
191        });
192    }
193
194    /// Record a query execution
195    #[allow(clippy::too_many_arguments)]
196    pub fn record_query(
197        &mut self,
198        original: &str,
199        resolved: &str,
200        question_type: QuestionType,
201        query_sexp: Option<String>,
202        success: bool,
203        result_count: usize,
204        execution_time_ms: u64,
205    ) {
206        self.query_history.push(QueryRecord {
207            original: original.to_string(),
208            resolved: resolved.to_string(),
209            question_type,
210            query_sexp,
211            turn: self.current_turn,
212            success,
213            result_count,
214            execution_time_ms,
215        });
216    }
217
218    /// Get entities by frequency (most frequent first)
219    pub fn get_frequent_entities(&self, limit: usize) -> Vec<&TrackedEntity> {
220        let mut entities: Vec<_> = self.entities.values().collect();
221        entities.sort_by_key(|e| std::cmp::Reverse(e.frequency()));
222        entities.into_iter().take(limit).collect()
223    }
224
225    /// Get recent coreference resolutions
226    pub fn get_recent_coreferences(&self, count: usize) -> Vec<&CoreferenceRecord> {
227        self.coreference_log.iter().rev().take(count).collect()
228    }
229
230    /// Get success rate for a question type
231    pub fn get_success_rate(&self, question_type: &QuestionType) -> f32 {
232        let relevant: Vec<_> = self
233            .query_history
234            .iter()
235            .filter(|q| &q.question_type == question_type)
236            .collect();
237
238        if relevant.is_empty() {
239            return 0.5; // No data, neutral assumption
240        }
241
242        let successes = relevant.iter().filter(|q| q.success).count();
243        successes as f32 / relevant.len() as f32
244    }
245}
246
247/// A learned query pattern
248#[derive(Debug, Clone)]
249pub struct QueryPattern {
250    /// Unique pattern ID
251    pub id: String,
252    /// Question type this pattern applies to
253    pub question_type: QuestionType,
254    /// Template for the query (with placeholders)
255    pub template: String,
256    /// Required entity types
257    pub required_types: Vec<EntityType>,
258    /// Number of successful uses
259    pub success_count: u32,
260    /// Number of failed uses
261    pub failure_count: u32,
262    /// Average number of results
263    pub avg_results: f32,
264    /// When this pattern was created
265    pub created_at: i64,
266    /// When this pattern was last used
267    pub last_used_at: i64,
268}
269
270impl QueryPattern {
271    /// Create a new query pattern
272    pub fn new(
273        question_type: QuestionType,
274        template: String,
275        required_types: Vec<EntityType>,
276    ) -> Self {
277        let now = Utc::now().timestamp();
278        Self {
279            id: uuid::Uuid::new_v4().to_string(),
280            question_type,
281            template,
282            required_types,
283            success_count: 0,
284            failure_count: 0,
285            avg_results: 0.0,
286            created_at: now,
287            last_used_at: now,
288        }
289    }
290
291    /// Compute reliability score (0.0 - 1.0)
292    pub fn reliability(&self) -> f32 {
293        let total = self.success_count + self.failure_count;
294        if total == 0 {
295            return 0.5; // No data, neutral
296        }
297        self.success_count as f32 / total as f32
298    }
299
300    /// Record a successful use
301    pub fn record_success(&mut self, result_count: usize) {
302        self.success_count += 1;
303        self.last_used_at = Utc::now().timestamp();
304
305        // Update average results with exponential moving average
306        let alpha = 0.3;
307        self.avg_results = alpha * result_count as f32 + (1.0 - alpha) * self.avg_results;
308    }
309
310    /// Record a failed use
311    pub fn record_failure(&mut self) {
312        self.failure_count += 1;
313        self.last_used_at = Utc::now().timestamp();
314    }
315
316    /// Check if this pattern matches the given entity types
317    pub fn matches_types(&self, types: &[EntityType]) -> bool {
318        self.required_types.iter().all(|rt| types.contains(rt))
319    }
320}
321
322/// A learned coreference resolution pattern
323#[derive(Debug, Clone)]
324pub struct ResolutionPattern {
325    /// Reference type (e.g., "it", "the file")
326    pub reference_type: String,
327    /// Entity type that was resolved
328    pub entity_type: EntityType,
329    /// Context pattern (what typically precedes this reference)
330    pub context_pattern: Option<String>,
331    /// Success count
332    pub success_count: u32,
333    /// Failure count
334    pub failure_count: u32,
335}
336
337/// A learned tool error pattern for avoiding repeated failures
338#[derive(Debug, Clone)]
339pub struct ToolErrorPattern {
340    /// Tool name this pattern applies to
341    pub tool_name: String,
342    /// Error category (serialized for storage)
343    pub error_category: String,
344    /// Number of times this error occurred
345    pub occurrence_count: u32,
346    /// Last occurrence timestamp
347    pub last_occurred: i64,
348    /// Suggested fix or avoidance strategy
349    pub suggested_fix: Option<String>,
350    /// Input patterns that led to this error (for prevention)
351    pub input_patterns: Vec<String>,
352}
353
354impl ToolErrorPattern {
355    /// Create a new error pattern
356    pub fn new(tool_name: &str, error_category: &ToolErrorCategory) -> Self {
357        Self {
358            tool_name: tool_name.to_string(),
359            error_category: error_category.category_name().to_string(),
360            occurrence_count: 1,
361            last_occurred: Utc::now().timestamp(),
362            suggested_fix: error_category.get_suggestion(),
363            input_patterns: Vec::new(),
364        }
365    }
366
367    /// Record another occurrence
368    pub fn record_occurrence(&mut self) {
369        self.occurrence_count += 1;
370        self.last_occurred = Utc::now().timestamp();
371    }
372
373    /// Check if this pattern is frequent (warrants attention)
374    pub fn is_frequent(&self) -> bool {
375        self.occurrence_count >= 3
376    }
377}
378
379/// Tool execution statistics for learning
380#[derive(Debug, Clone, Default)]
381pub struct ToolStats {
382    /// Number of successful executions
383    pub success_count: u32,
384    /// Number of failed executions
385    pub failure_count: u32,
386    /// Total retries needed
387    pub total_retries: u32,
388    /// Average execution time in ms
389    pub avg_execution_time_ms: f64,
390    /// Last used timestamp
391    pub last_used: i64,
392}
393
394impl ToolStats {
395    /// Record a successful execution
396    pub fn record_success(&mut self, retries: u32, execution_time_ms: u64) {
397        self.success_count += 1;
398        self.total_retries += retries;
399        self.last_used = Utc::now().timestamp();
400
401        // Update average execution time with exponential moving average
402        let alpha = 0.3;
403        self.avg_execution_time_ms =
404            alpha * execution_time_ms as f64 + (1.0 - alpha) * self.avg_execution_time_ms;
405    }
406
407    /// Record a failed execution
408    pub fn record_failure(&mut self, retries: u32, execution_time_ms: u64) {
409        self.failure_count += 1;
410        self.total_retries += retries;
411        self.last_used = Utc::now().timestamp();
412
413        let alpha = 0.3;
414        self.avg_execution_time_ms =
415            alpha * execution_time_ms as f64 + (1.0 - alpha) * self.avg_execution_time_ms;
416    }
417
418    /// Get the success rate
419    pub fn success_rate(&self) -> f64 {
420        let total = self.success_count + self.failure_count;
421        if total == 0 {
422            0.5 // Neutral when no data
423        } else {
424            self.success_count as f64 / total as f64
425        }
426    }
427
428    /// Get average retries per execution
429    pub fn avg_retries(&self) -> f64 {
430        let total = self.success_count + self.failure_count;
431        if total == 0 {
432            0.0
433        } else {
434            self.total_retries as f64 / total as f64
435        }
436    }
437}
438
439/// Response confidence statistics for learning prompt patterns
440#[derive(Debug, Clone, Default)]
441pub struct ConfidenceStats {
442    /// Total samples recorded
443    pub sample_count: u32,
444    /// Sum of confidence scores
445    pub confidence_sum: f64,
446    /// Number of low confidence responses
447    pub low_confidence_count: u32,
448    /// Number of high confidence responses
449    pub high_confidence_count: u32,
450}
451
452impl ConfidenceStats {
453    /// Record a confidence sample
454    pub fn record_sample(&mut self, confidence: &ResponseConfidence) {
455        self.sample_count += 1;
456        self.confidence_sum += confidence.score;
457
458        if confidence.is_low_confidence() {
459            self.low_confidence_count += 1;
460        } else if confidence.is_high_confidence() {
461            self.high_confidence_count += 1;
462        }
463    }
464
465    /// Get average confidence
466    pub fn avg_confidence(&self) -> f64 {
467        if self.sample_count == 0 {
468            0.5
469        } else {
470            self.confidence_sum / self.sample_count as f64
471        }
472    }
473
474    /// Get the ratio of low confidence responses
475    pub fn low_confidence_ratio(&self) -> f64 {
476        if self.sample_count == 0 {
477            0.0
478        } else {
479            self.low_confidence_count as f64 / self.sample_count as f64
480        }
481    }
482}
483
484/// A structured hint derived from behavioral knowledge (BKS)
485#[derive(Debug, Clone)]
486pub struct PatternHint {
487    /// Context pattern describing when this hint applies
488    pub context_pattern: String,
489    /// The learned rule or guideline
490    pub rule: String,
491    /// Confidence of the source truth (0.0-1.0)
492    pub confidence: f64,
493    /// Source system that produced this hint (e.g. "bks", "seal")
494    pub source: String,
495}
496
497/// Global memory for cross-session learning
498#[derive(Debug, Default)]
499pub struct GlobalMemory {
500    /// Query patterns organized by question type
501    pub query_patterns: HashMap<QuestionType, Vec<QueryPattern>>,
502    /// Coreference resolution patterns
503    pub resolution_patterns: Vec<ResolutionPattern>,
504    /// Tool error patterns for learning failure modes
505    pub tool_error_patterns: HashMap<String, ToolErrorPattern>,
506    /// Tool execution statistics
507    pub tool_stats: HashMap<String, ToolStats>,
508    /// Response confidence statistics
509    pub confidence_stats: ConfidenceStats,
510    /// Structured hints from behavioral knowledge
511    pub pattern_hints: Vec<PatternHint>,
512}
513
514impl GlobalMemory {
515    /// Create new global memory
516    pub fn new() -> Self {
517        Self::default()
518    }
519
520    /// Add a structured pattern hint from behavioral knowledge
521    pub fn add_pattern_hint(&mut self, hint: PatternHint) {
522        self.pattern_hints.push(hint);
523    }
524
525    /// Get all stored pattern hints
526    pub fn get_pattern_hints(&self) -> &[PatternHint] {
527        &self.pattern_hints
528    }
529
530    /// Add a query pattern
531    pub fn add_pattern(&mut self, pattern: QueryPattern) {
532        self.query_patterns
533            .entry(pattern.question_type.clone())
534            .or_default()
535            .push(pattern);
536    }
537
538    /// Get patterns for a question type, sorted by reliability
539    pub fn get_patterns(&self, question_type: &QuestionType) -> Vec<&QueryPattern> {
540        if let Some(patterns) = self.query_patterns.get(question_type) {
541            let mut sorted: Vec<_> = patterns.iter().collect();
542            sorted.sort_by(|a, b| {
543                b.reliability()
544                    .partial_cmp(&a.reliability())
545                    .unwrap_or(std::cmp::Ordering::Equal)
546            });
547            sorted
548        } else {
549            Vec::new()
550        }
551    }
552
553    /// Get the best pattern for a question type and entity types
554    pub fn get_best_pattern(
555        &self,
556        question_type: &QuestionType,
557        entity_types: &[EntityType],
558    ) -> Option<&QueryPattern> {
559        self.get_patterns(question_type)
560            .into_iter()
561            .find(|p| p.matches_types(entity_types))
562    }
563
564    /// Get a pattern by ID
565    pub fn get_pattern_mut(&mut self, id: &str) -> Option<&mut QueryPattern> {
566        for patterns in self.query_patterns.values_mut() {
567            if let Some(pattern) = patterns.iter_mut().find(|p| p.id == id) {
568                return Some(pattern);
569            }
570        }
571        None
572    }
573
574    /// Remove low-reliability patterns
575    pub fn prune_patterns(&mut self, min_reliability: f32, min_uses: u32) {
576        for patterns in self.query_patterns.values_mut() {
577            patterns.retain(|p| {
578                let total_uses = p.success_count + p.failure_count;
579                total_uses < min_uses || p.reliability() >= min_reliability
580            });
581        }
582    }
583
584    /// Record a tool outcome for learning
585    pub fn record_tool_outcome(&mut self, outcome: &ToolOutcome) {
586        let stats = self
587            .tool_stats
588            .entry(outcome.tool_name.clone())
589            .or_default();
590
591        if outcome.success {
592            stats.record_success(outcome.retries, outcome.execution_time_ms);
593        } else {
594            stats.record_failure(outcome.retries, outcome.execution_time_ms);
595
596            // Also record error pattern if we have error info
597            if let Some(ref error_category) = outcome.error_category {
598                let key = format!("{}:{}", outcome.tool_name, error_category.category_name());
599
600                if let Some(pattern) = self.tool_error_patterns.get_mut(&key) {
601                    pattern.record_occurrence();
602                } else {
603                    self.tool_error_patterns.insert(
604                        key,
605                        ToolErrorPattern::new(&outcome.tool_name, error_category),
606                    );
607                }
608            }
609        }
610    }
611
612    /// Record a response confidence sample
613    pub fn record_confidence(&mut self, confidence: &ResponseConfidence) {
614        self.confidence_stats.record_sample(confidence);
615    }
616
617    /// Get common errors for a tool
618    pub fn get_common_errors(&self, tool_name: &str) -> Vec<&ToolErrorPattern> {
619        self.tool_error_patterns
620            .values()
621            .filter(|p| p.tool_name == tool_name && p.is_frequent())
622            .collect()
623    }
624
625    /// Get error prevention hints for prompts
626    pub fn get_error_prevention_hints(&self, tool_name: &str) -> Option<String> {
627        let common_errors = self.get_common_errors(tool_name);
628        if common_errors.is_empty() {
629            return None;
630        }
631
632        let hints: Vec<String> = common_errors
633            .iter()
634            .filter_map(|e| e.suggested_fix.clone())
635            .collect();
636
637        if hints.is_empty() {
638            None
639        } else {
640            Some(format!(
641                "Common pitfalls for {}: {}",
642                tool_name,
643                hints.join("; ")
644            ))
645        }
646    }
647
648    /// Get tool reliability score
649    pub fn get_tool_reliability(&self, tool_name: &str) -> Option<f64> {
650        self.tool_stats.get(tool_name).map(|s| s.success_rate())
651    }
652}
653
654/// Learning coordinator that manages both local and global memory
655#[derive(Debug)]
656pub struct LearningCoordinator {
657    /// Local memory for current session
658    pub local: LocalMemory,
659    /// Global memory for cross-session patterns
660    pub global: GlobalMemory,
661    /// Learning rate for pattern updates
662    _learning_rate: f32,
663    /// Minimum successes before pattern is trusted
664    min_successes: u32,
665}
666
667impl LearningCoordinator {
668    /// Create a new learning coordinator
669    pub fn new(conversation_id: String) -> Self {
670        Self {
671            local: LocalMemory::new(conversation_id),
672            global: GlobalMemory::new(),
673            _learning_rate: 0.3,
674            min_successes: 3,
675        }
676    }
677
678    /// Process a query through the learning system
679    pub fn process_query(
680        &mut self,
681        _original: &str,
682        _resolved: &str,
683        core: Option<QueryCore>,
684        turn: u32,
685    ) -> Option<&QueryPattern> {
686        self.local.current_turn = turn;
687
688        if let Some(ref c) = core {
689            // Get entity types from the core
690            let entity_types: Vec<_> = c.entities.iter().map(|(_, t)| t.clone()).collect();
691
692            // Check for matching pattern in global memory
693            if let Some(pattern) = self
694                .global
695                .get_best_pattern(&c.question_type, &entity_types)
696            {
697                return Some(pattern);
698            }
699        }
700
701        None
702    }
703
704    /// Record the outcome of a query execution
705    pub fn record_outcome(
706        &mut self,
707        pattern_id: Option<&str>,
708        success: bool,
709        result_count: usize,
710        query_core: Option<&QueryCore>,
711        execution_time_ms: u64,
712    ) {
713        // Update pattern statistics if we used one
714        if let Some(id) = pattern_id
715            && let Some(pattern) = self.global.get_pattern_mut(id)
716        {
717            if success {
718                pattern.record_success(result_count);
719            } else {
720                pattern.record_failure();
721            }
722        }
723
724        // Record in local memory
725        if let Some(core) = query_core {
726            self.local.record_query(
727                &core.original,
728                core.resolved.as_deref().unwrap_or(&core.original),
729                core.question_type.clone(),
730                Some(core.to_sexp()),
731                success,
732                result_count,
733                execution_time_ms,
734            );
735
736            // If successful and we don't have a pattern, create one
737            if success && pattern_id.is_none() && result_count > 0 {
738                let _ = self.learn_pattern(core, result_count);
739            }
740        }
741    }
742
743    /// Learn a new pattern from a successful query
744    pub fn learn_pattern(&mut self, query: &QueryCore, result_count: usize) -> Option<String> {
745        // Only learn from queries with reasonable results
746        if result_count == 0 || result_count > 100 {
747            return None;
748        }
749
750        // Generalize the query to a template
751        let template = self.generalize_query(query);
752
753        // Get required entity types
754        let required_types: Vec<_> = query.entities.iter().map(|(_, t)| t.clone()).collect();
755
756        // Check if we already have a similar pattern
757        if let Some(existing) = self
758            .global
759            .get_best_pattern(&query.question_type, &required_types)
760            && existing.template == template
761        {
762            return None; // Already have this pattern
763        }
764
765        // Create and add the new pattern
766        let mut pattern = QueryPattern::new(query.question_type.clone(), template, required_types);
767        pattern.record_success(result_count);
768
769        let id = pattern.id.clone();
770        self.global.add_pattern(pattern);
771
772        Some(id)
773    }
774
775    /// Generalize a query to a template (replace specific entities with placeholders)
776    fn generalize_query(&self, query: &QueryCore) -> String {
777        let mut template = query.to_sexp();
778
779        // Replace entity names with type placeholders
780        for (name, entity_type) in &query.entities {
781            let placeholder = format!("${{{}}}", entity_type.as_str().to_uppercase());
782            template = template.replace(&format!("\"{}\"", name), &placeholder);
783        }
784
785        template
786    }
787
788    /// Get context for prompt injection
789    pub fn get_context_for_prompt(&self) -> String {
790        let mut context = String::new();
791
792        // Add frequently used entities
793        let frequent = self.local.get_frequent_entities(5);
794        if !frequent.is_empty() {
795            context.push_str("Frequently referenced entities:\n");
796            for entity in frequent {
797                context.push_str(&format!(
798                    "- {} ({}): {} mentions\n",
799                    entity.name,
800                    entity.entity_type.as_str(),
801                    entity.frequency()
802                ));
803            }
804            context.push('\n');
805        }
806
807        // Add recent successful patterns
808        for question_type in [
809            QuestionType::Definition,
810            QuestionType::Location,
811            QuestionType::Dependency,
812        ] {
813            let patterns = self.global.get_patterns(&question_type);
814            let good_patterns: Vec<_> = patterns
815                .iter()
816                .filter(|p| p.reliability() > 0.7 && p.success_count >= self.min_successes)
817                .take(2)
818                .collect();
819
820            if !good_patterns.is_empty() {
821                context.push_str(&format!("Effective {:?} patterns:\n", question_type));
822                for pattern in good_patterns {
823                    context.push_str(&format!(
824                        "- {} ({}% reliable)\n",
825                        pattern.template,
826                        (pattern.reliability() * 100.0) as u32
827                    ));
828                }
829                context.push('\n');
830            }
831        }
832
833        context
834    }
835
836    /// Get all promotable patterns (high reliability, enough uses)
837    ///
838    /// Returns patterns that meet the criteria for promotion to BKS
839    pub fn get_promotable_patterns(
840        &self,
841        min_reliability: f32,
842        min_uses: u32,
843    ) -> Vec<&QueryPattern> {
844        let mut promotable = Vec::new();
845
846        for patterns in self.global.query_patterns.values() {
847            for pattern in patterns {
848                let total_uses = pattern.success_count + pattern.failure_count;
849                if pattern.reliability() >= min_reliability && total_uses >= min_uses {
850                    promotable.push(pattern);
851                }
852            }
853        }
854
855        // Sort by reliability descending
856        promotable.sort_by(|a, b| {
857            b.reliability()
858                .partial_cmp(&a.reliability())
859                .unwrap_or(std::cmp::Ordering::Equal)
860        });
861
862        promotable
863    }
864
865    /// Get learning statistics
866    pub fn get_stats(&self) -> LearningStats {
867        let total_patterns: usize = self.global.query_patterns.values().map(|v| v.len()).sum();
868
869        let mut total_successes = 0u32;
870        let mut total_failures = 0u32;
871        for patterns in self.global.query_patterns.values() {
872            for pattern in patterns {
873                total_successes += pattern.success_count;
874                total_failures += pattern.failure_count;
875            }
876        }
877
878        LearningStats {
879            session_queries: self.local.query_history.len(),
880            session_entities: self.local.entities.len(),
881            session_coreferences: self.local.coreference_log.len(),
882            global_patterns: total_patterns,
883            global_successes: total_successes,
884            global_failures: total_failures,
885            overall_reliability: if total_successes + total_failures > 0 {
886                total_successes as f32 / (total_successes + total_failures) as f32
887            } else {
888                0.5
889            },
890        }
891    }
892
893    // =====================
894    // Tool Learning Methods
895    // =====================
896
897    /// Record a tool execution outcome (delegates to global memory)
898    pub fn record_tool_outcome(&mut self, outcome: &ToolOutcome) {
899        self.global.record_tool_outcome(outcome);
900    }
901
902    /// Record a response confidence sample (delegates to global memory)
903    pub fn record_confidence(&mut self, confidence: &ResponseConfidence) {
904        self.global.record_confidence(confidence);
905    }
906
907    /// Get error prevention hints for a tool (delegates to global memory)
908    pub fn get_error_prevention_hints(&self, tool_name: &str) -> Option<String> {
909        self.global.get_error_prevention_hints(tool_name)
910    }
911
912    /// Get tool reliability score (delegates to global memory)
913    pub fn get_tool_reliability(&self, tool_name: &str) -> Option<f64> {
914        self.global.get_tool_reliability(tool_name)
915    }
916
917    /// Get common errors for a tool (delegates to global memory)
918    pub fn get_common_errors(&self, tool_name: &str) -> Vec<&ToolErrorPattern> {
919        self.global.get_common_errors(tool_name)
920    }
921
922    /// Get the average response confidence
923    pub fn get_avg_confidence(&self) -> f64 {
924        self.global.confidence_stats.avg_confidence()
925    }
926
927    /// Check if responses are frequently low confidence
928    pub fn has_confidence_issues(&self) -> bool {
929        self.global.confidence_stats.low_confidence_ratio() > 0.3
930    }
931}
932
933/// Learning statistics
934#[derive(Debug, Clone)]
935pub struct LearningStats {
936    /// Number of queries in current session
937    pub session_queries: usize,
938    /// Number of entities tracked in session
939    pub session_entities: usize,
940    /// Number of coreferences resolved in session
941    pub session_coreferences: usize,
942    /// Number of patterns in global memory
943    pub global_patterns: usize,
944    /// Total successful pattern uses
945    pub global_successes: u32,
946    /// Total failed pattern uses
947    pub global_failures: u32,
948    /// Overall reliability score
949    pub overall_reliability: f32,
950}
951
952#[cfg(test)]
953mod tests {
954    use super::*;
955
956    #[test]
957    fn test_tracked_entity() {
958        let mut entity = TrackedEntity::new("main.rs".to_string(), EntityType::File, 1);
959        assert_eq!(entity.frequency(), 1);
960
961        entity.record_mention(2);
962        entity.record_mention(3);
963        assert_eq!(entity.frequency(), 3);
964
965        // Duplicate mention should not increase frequency
966        entity.record_mention(2);
967        assert_eq!(entity.frequency(), 3);
968    }
969
970    #[test]
971    fn test_local_memory() {
972        let mut local = LocalMemory::new("test-conv".to_string());
973
974        local.track_entity("main.rs", EntityType::File);
975        local.next_turn();
976        local.track_entity("config.toml", EntityType::File);
977        local.track_entity("main.rs", EntityType::File); // Mention again
978
979        assert_eq!(local.entities.len(), 2);
980        assert_eq!(local.entities["main.rs"].frequency(), 2);
981
982        // Focus stack should have config.toml first (most recent)
983        assert_eq!(local.focus_stack[0], "main.rs"); // Re-mentioned
984    }
985
986    #[test]
987    fn test_query_pattern_reliability() {
988        let mut pattern =
989            QueryPattern::new(QuestionType::Definition, "template".to_string(), vec![]);
990
991        assert_eq!(pattern.reliability(), 0.5); // No data
992
993        pattern.record_success(5);
994        pattern.record_success(3);
995        pattern.record_failure();
996
997        // 2 successes, 1 failure = 2/3 reliability
998        assert!((pattern.reliability() - 0.666).abs() < 0.01);
999    }
1000
1001    #[test]
1002    fn test_global_memory_patterns() {
1003        let mut global = GlobalMemory::new();
1004
1005        let mut pattern1 =
1006            QueryPattern::new(QuestionType::Definition, "template1".to_string(), vec![]);
1007        pattern1.record_success(5);
1008        pattern1.record_success(5);
1009
1010        let mut pattern2 =
1011            QueryPattern::new(QuestionType::Definition, "template2".to_string(), vec![]);
1012        pattern2.record_failure();
1013
1014        global.add_pattern(pattern1);
1015        global.add_pattern(pattern2);
1016
1017        // Get patterns should return them sorted by reliability
1018        let patterns = global.get_patterns(&QuestionType::Definition);
1019        assert_eq!(patterns.len(), 2);
1020        assert!(patterns[0].reliability() > patterns[1].reliability());
1021    }
1022
1023    #[test]
1024    fn test_learning_coordinator() {
1025        let mut coordinator = LearningCoordinator::new("test-conv".to_string());
1026
1027        // Record some queries
1028        let core = QueryCore::new(
1029            QuestionType::Definition,
1030            crate::query_core::QueryExpr::var("x"),
1031            vec![("main.rs".to_string(), EntityType::File)],
1032            "What is main.rs?".to_string(),
1033        );
1034
1035        coordinator.record_outcome(None, true, 1, Some(&core), 0);
1036
1037        let stats = coordinator.get_stats();
1038        assert_eq!(stats.session_queries, 1);
1039        assert_eq!(stats.global_patterns, 1); // Should have learned a pattern
1040    }
1041
1042    #[test]
1043    fn test_pattern_matching() {
1044        let pattern = QueryPattern::new(
1045            QuestionType::Definition,
1046            "template".to_string(),
1047            vec![EntityType::File],
1048        );
1049
1050        assert!(pattern.matches_types(&[EntityType::File]));
1051        assert!(pattern.matches_types(&[EntityType::File, EntityType::Function]));
1052        assert!(!pattern.matches_types(&[EntityType::Function]));
1053    }
1054
1055    #[test]
1056    fn test_prune_patterns() {
1057        let mut global = GlobalMemory::new();
1058
1059        let mut good_pattern =
1060            QueryPattern::new(QuestionType::Definition, "good".to_string(), vec![]);
1061        for _ in 0..10 {
1062            good_pattern.record_success(5);
1063        }
1064
1065        let mut bad_pattern =
1066            QueryPattern::new(QuestionType::Definition, "bad".to_string(), vec![]);
1067        for _ in 0..10 {
1068            bad_pattern.record_failure();
1069        }
1070
1071        global.add_pattern(good_pattern);
1072        global.add_pattern(bad_pattern);
1073
1074        assert_eq!(global.get_patterns(&QuestionType::Definition).len(), 2);
1075
1076        global.prune_patterns(0.5, 5);
1077
1078        // Bad pattern should be removed
1079        assert_eq!(global.get_patterns(&QuestionType::Definition).len(), 1);
1080    }
1081
1082    #[test]
1083    fn test_get_context_for_prompt() {
1084        let mut coordinator = LearningCoordinator::new("test".to_string());
1085
1086        coordinator.local.track_entity("main.rs", EntityType::File);
1087        coordinator.local.track_entity("main.rs", EntityType::File);
1088        coordinator
1089            .local
1090            .track_entity("config.toml", EntityType::File);
1091
1092        let context = coordinator.get_context_for_prompt();
1093
1094        // Should mention main.rs as frequently referenced
1095        assert!(context.contains("main.rs") || context.contains("Frequently"));
1096    }
1097}