Skip to main content

brainwires_knowledge/prompting/
learning.rs

1//! Learning & Optimization
2//!
3//! This module tracks technique effectiveness and learns from outcomes,
4//! promoting successful patterns to BKS for collective learning.
5
6use super::techniques::PromptingTechnique;
7#[cfg(feature = "knowledge")]
8use crate::knowledge::bks_pks::{
9    BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
10};
11use anyhow::Result;
12use chrono::Utc;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18/// Default minimum reliability threshold for promoting a technique to BKS.
19const DEFAULT_PROMOTION_THRESHOLD: f64 = 0.8;
20
21/// Record of technique effectiveness for a specific task execution
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TechniqueEffectivenessRecord {
24    /// The prompting technique that was used.
25    pub technique: PromptingTechnique,
26    /// The cluster this task belongs to.
27    pub cluster_id: String,
28    /// Description of the task that was executed.
29    pub task_description: String,
30    /// Whether the task completed successfully.
31    pub success: bool,
32    /// Number of iterations consumed.
33    pub iterations_used: u32,
34    /// Quality score from 0.0 to 1.0.
35    pub quality_score: f32,
36    /// Unix timestamp of the execution.
37    pub timestamp: i64,
38}
39
40/// Statistics for a technique in a specific cluster
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct TechniqueStats {
43    /// Number of successful executions.
44    pub success_count: u32,
45    /// Number of failed executions.
46    pub failure_count: u32,
47    /// Average iterations used across executions.
48    pub avg_iterations: f32,
49    /// Average quality score across executions.
50    pub avg_quality: f32,
51    /// Unix timestamp of the last execution.
52    pub last_used: i64,
53}
54
55impl TechniqueStats {
56    /// Create new stats with initial values
57    pub fn new() -> Self {
58        Self {
59            success_count: 0,
60            failure_count: 0,
61            avg_iterations: 0.0,
62            avg_quality: 0.0,
63            last_used: Utc::now().timestamp(),
64        }
65    }
66
67    /// Calculate reliability (success rate)
68    pub fn reliability(&self) -> f32 {
69        let total = self.success_count + self.failure_count;
70        if total == 0 {
71            0.0
72        } else {
73            self.success_count as f32 / total as f32
74        }
75    }
76
77    /// Total uses
78    pub fn total_uses(&self) -> u32 {
79        self.success_count + self.failure_count
80    }
81
82    /// Update stats with new outcome (using EMA for averages)
83    pub fn update(&mut self, success: bool, iterations: u32, quality: f32) {
84        if success {
85            self.success_count += 1;
86        } else {
87            self.failure_count += 1;
88        }
89
90        let alpha = 0.3; // EMA weight
91        self.avg_iterations = alpha * iterations as f32 + (1.0 - alpha) * self.avg_iterations;
92        self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
93        self.last_used = Utc::now().timestamp();
94    }
95}
96
97impl Default for TechniqueStats {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103/// Coordinates learning and promotion of technique effectiveness
104pub struct PromptingLearningCoordinator {
105    /// Historical records of technique effectiveness
106    records: Vec<TechniqueEffectivenessRecord>,
107
108    /// BKS cache for promoting effective techniques
109    bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
110
111    /// Aggregated statistics per (cluster_id, technique)
112    technique_stats: HashMap<(String, PromptingTechnique), TechniqueStats>,
113
114    /// Minimum reliability for BKS promotion (default: 0.8)
115    promotion_threshold: f32,
116
117    /// Minimum uses before promotion (default: 5)
118    min_uses_for_promotion: u32,
119}
120
121impl PromptingLearningCoordinator {
122    /// Create a new learning coordinator
123    pub fn new(bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
124        Self {
125            records: Vec::new(),
126            bks_cache,
127            technique_stats: HashMap::new(),
128            promotion_threshold: DEFAULT_PROMOTION_THRESHOLD as f32,
129            min_uses_for_promotion: 5,
130        }
131    }
132
133    /// Create with custom promotion thresholds
134    pub fn with_thresholds(
135        bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
136        promotion_threshold: f32,
137        min_uses: u32,
138    ) -> Self {
139        Self {
140            records: Vec::new(),
141            bks_cache,
142            technique_stats: HashMap::new(),
143            promotion_threshold,
144            min_uses_for_promotion: min_uses,
145        }
146    }
147
148    /// Record outcome of using specific techniques
149    ///
150    /// This is called after task completion to track which techniques worked.
151    ///
152    /// # Arguments
153    /// * `cluster_id` - The cluster that was matched
154    /// * `techniques` - The techniques that were used
155    /// * `task_description` - Description of the task
156    /// * `success` - Whether the task completed successfully
157    /// * `iterations` - Number of iterations used
158    /// * `quality_score` - Quality score (0.0-1.0)
159    pub fn record_outcome(
160        &mut self,
161        cluster_id: String,
162        techniques: Vec<PromptingTechnique>,
163        task_description: String,
164        success: bool,
165        iterations: u32,
166        quality_score: f32,
167    ) {
168        let timestamp = Utc::now().timestamp();
169
170        for technique in techniques {
171            // Create record
172            let record = TechniqueEffectivenessRecord {
173                technique: technique.clone(),
174                cluster_id: cluster_id.clone(),
175                task_description: task_description.clone(),
176                success,
177                iterations_used: iterations,
178                quality_score,
179                timestamp,
180            };
181
182            self.records.push(record);
183
184            // Update aggregated stats
185            self.update_stats(&cluster_id, &technique, success, iterations, quality_score);
186        }
187    }
188
189    /// Update aggregated statistics for a technique
190    fn update_stats(
191        &mut self,
192        cluster_id: &str,
193        technique: &PromptingTechnique,
194        success: bool,
195        iterations: u32,
196        quality: f32,
197    ) {
198        let key = (cluster_id.to_string(), technique.clone());
199        let stats = self.technique_stats.entry(key).or_default();
200        stats.update(success, iterations, quality);
201    }
202
203    /// Check if technique should be promoted to BKS
204    ///
205    /// Promotion criteria (same as SEAL patterns):
206    /// - Reliability > threshold (default: 0.8 / 80%)
207    /// - Total uses > min_uses (default: 5)
208    ///
209    /// # Returns
210    /// * `true` if technique qualifies for promotion
211    pub fn should_promote(&self, cluster_id: &str, technique: &PromptingTechnique) -> bool {
212        if let Some(stats) = self
213            .technique_stats
214            .get(&(cluster_id.to_string(), technique.clone()))
215        {
216            let reliability = stats.reliability();
217            let uses = stats.total_uses();
218
219            reliability >= self.promotion_threshold && uses >= self.min_uses_for_promotion
220        } else {
221            false
222        }
223    }
224
225    /// Promote technique to BKS
226    ///
227    /// Creates a BehavioralTruth with effectiveness information and submits to BKS.
228    /// This allows other users to benefit from the learned effectiveness.
229    pub async fn promote_technique_to_bks(
230        &mut self,
231        cluster_id: &str,
232        technique: &PromptingTechnique,
233    ) -> Result<bool> {
234        if !self.should_promote(cluster_id, technique) {
235            return Ok(false);
236        }
237
238        let stats = self
239            .technique_stats
240            .get(&(cluster_id.to_string(), technique.clone()))
241            .expect("should_promote verified this entry exists");
242
243        let reliability = stats.reliability();
244        let uses = stats.total_uses();
245
246        // Create BehavioralTruth
247        let truth = BehavioralTruth::new(
248            TruthCategory::PromptingTechnique,
249            cluster_id.to_string(), // context_pattern
250            format!(
251                "Use {:?} for {} tasks (achieves {:.1}% success rate)",
252                technique,
253                cluster_id,
254                reliability * 100.0
255            ), // rule
256            format!(
257                "Learned from {} executions with avg quality {:.2}. \
258                Average iterations: {:.1}. \
259                This technique has proven effective for this task cluster.",
260                uses, stats.avg_quality, stats.avg_iterations
261            ), // rationale
262            TruthSource::SuccessPattern,
263            None, // No specific user attribution
264        );
265
266        // Submit to BKS
267        let mut bks = self.bks_cache.lock().await;
268        let queued = bks.queue_submission(truth)?;
269
270        if queued {
271            tracing::debug!(
272                ?technique,
273                %cluster_id,
274                reliability_pct = reliability * 100.0,
275                uses,
276                "Adaptive Prompting: Promoted technique for cluster"
277            );
278        }
279
280        Ok(queued)
281    }
282
283    /// Check and promote all eligible techniques
284    ///
285    /// This should be called periodically (e.g., after each task completion)
286    /// to promote techniques that have reached the threshold.
287    pub async fn check_and_promote_all(&mut self) -> Result<Vec<(String, PromptingTechnique)>> {
288        let mut promoted = Vec::new();
289
290        // Collect eligible techniques (to avoid borrowing issues)
291        let eligible: Vec<_> = self
292            .technique_stats
293            .keys()
294            .filter(|(cluster_id, technique)| self.should_promote(cluster_id, technique))
295            .cloned()
296            .collect();
297
298        // Promote each eligible technique
299        for (cluster_id, technique) in eligible {
300            if self
301                .promote_technique_to_bks(&cluster_id, &technique)
302                .await?
303            {
304                promoted.push((cluster_id, technique));
305            }
306        }
307
308        Ok(promoted)
309    }
310
311    /// Get statistics for a specific technique in a cluster
312    pub fn get_stats(
313        &self,
314        cluster_id: &str,
315        technique: &PromptingTechnique,
316    ) -> Option<&TechniqueStats> {
317        self.technique_stats
318            .get(&(cluster_id.to_string(), technique.clone()))
319    }
320
321    /// Get all statistics
322    pub fn get_all_stats(&self) -> &HashMap<(String, PromptingTechnique), TechniqueStats> {
323        &self.technique_stats
324    }
325
326    /// Get recent records (last N)
327    pub fn get_recent_records(&self, count: usize) -> Vec<&TechniqueEffectivenessRecord> {
328        self.records.iter().rev().take(count).collect()
329    }
330
331    /// Get statistics summary for a cluster
332    pub fn get_cluster_summary(&self, cluster_id: &str) -> ClusterSummary {
333        let mut summary = ClusterSummary {
334            cluster_id: cluster_id.to_string(),
335            total_executions: 0,
336            techniques: HashMap::new(),
337        };
338
339        for ((cid, technique), stats) in &self.technique_stats {
340            if cid == cluster_id {
341                summary.total_executions += stats.total_uses();
342                summary.techniques.insert(technique.clone(), stats.clone());
343            }
344        }
345
346        summary
347    }
348
349    /// Clear old records (keep only recent N records)
350    pub fn prune_old_records(&mut self, keep_count: usize) {
351        if self.records.len() > keep_count {
352            let excess = self.records.len() - keep_count;
353            self.records.drain(0..excess);
354        }
355    }
356
357    /// Get promotion thresholds
358    pub fn get_thresholds(&self) -> (f32, u32) {
359        (self.promotion_threshold, self.min_uses_for_promotion)
360    }
361}
362
363/// Summary of technique performance for a cluster
364#[derive(Debug, Clone)]
365pub struct ClusterSummary {
366    /// The cluster identifier.
367    pub cluster_id: String,
368    /// Total number of task executions in this cluster.
369    pub total_executions: u32,
370    /// Per-technique performance statistics.
371    pub techniques: HashMap<PromptingTechnique, TechniqueStats>,
372}
373
374impl ClusterSummary {
375    /// Get the most effective technique in this cluster
376    pub fn best_technique(&self) -> Option<(&PromptingTechnique, &TechniqueStats)> {
377        self.techniques
378            .iter()
379            .filter(|(_, stats)| stats.total_uses() >= 3) // Minimum sample size
380            .max_by(|(_, a), (_, b)| {
381                // Compare by reliability, then by quality
382                a.reliability()
383                    .partial_cmp(&b.reliability())
384                    .unwrap_or(std::cmp::Ordering::Equal)
385                    .then(
386                        a.avg_quality
387                            .partial_cmp(&b.avg_quality)
388                            .unwrap_or(std::cmp::Ordering::Equal),
389                    )
390            })
391    }
392
393    /// Get techniques eligible for promotion
394    pub fn promotable_techniques(&self, threshold: f32, min_uses: u32) -> Vec<&PromptingTechnique> {
395        self.techniques
396            .iter()
397            .filter(|(_, stats)| stats.reliability() >= threshold && stats.total_uses() >= min_uses)
398            .map(|(technique, _)| technique)
399            .collect()
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_technique_stats_update() {
409        let mut stats = TechniqueStats::new();
410
411        // Record 5 successes
412        for _ in 0..5 {
413            stats.update(true, 10, 0.9);
414        }
415
416        assert_eq!(stats.success_count, 5);
417        assert_eq!(stats.failure_count, 0);
418        assert_eq!(stats.reliability(), 1.0);
419        assert_eq!(stats.total_uses(), 5);
420        assert!(stats.avg_quality > 0.7); // EMA with alpha=0.3 from 0.0, 5 updates → ~0.75
421    }
422
423    #[test]
424    fn test_should_promote() {
425        let bks_cache = Arc::new(Mutex::new(
426            BehavioralKnowledgeCache::in_memory(100).unwrap(),
427        ));
428        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
429
430        let cluster_id = "test_cluster";
431        let technique = PromptingTechnique::ChainOfThought;
432
433        // Record 6 successes (meets threshold)
434        for _ in 0..6 {
435            coordinator.record_outcome(
436                cluster_id.to_string(),
437                vec![technique.clone()],
438                "test task".to_string(),
439                true,
440                5,
441                0.9,
442            );
443        }
444
445        assert!(coordinator.should_promote(cluster_id, &technique));
446    }
447
448    #[test]
449    fn test_not_enough_uses() {
450        let bks_cache = Arc::new(Mutex::new(
451            BehavioralKnowledgeCache::in_memory(100).unwrap(),
452        ));
453        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
454
455        let cluster_id = "test_cluster";
456        let technique = PromptingTechnique::ChainOfThought;
457
458        // Only 3 uses (below threshold of 5)
459        for _ in 0..3 {
460            coordinator.record_outcome(
461                cluster_id.to_string(),
462                vec![technique.clone()],
463                "test task".to_string(),
464                true,
465                5,
466                0.9,
467            );
468        }
469
470        assert!(!coordinator.should_promote(cluster_id, &technique));
471    }
472
473    #[test]
474    fn test_reliability_too_low() {
475        let bks_cache = Arc::new(Mutex::new(
476            BehavioralKnowledgeCache::in_memory(100).unwrap(),
477        ));
478        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
479
480        let cluster_id = "test_cluster";
481        let technique = PromptingTechnique::ChainOfThought;
482
483        // 3 successes, 3 failures = 50% reliability (below 80% threshold)
484        for _ in 0..3 {
485            coordinator.record_outcome(
486                cluster_id.to_string(),
487                vec![technique.clone()],
488                "test task".to_string(),
489                true,
490                5,
491                0.9,
492            );
493        }
494        for _ in 0..3 {
495            coordinator.record_outcome(
496                cluster_id.to_string(),
497                vec![technique.clone()],
498                "test task".to_string(),
499                false,
500                5,
501                0.5,
502            );
503        }
504
505        assert!(!coordinator.should_promote(cluster_id, &technique));
506    }
507
508    #[test]
509    fn test_cluster_summary() {
510        let bks_cache = Arc::new(Mutex::new(
511            BehavioralKnowledgeCache::in_memory(100).unwrap(),
512        ));
513        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
514
515        let cluster_id = "test_cluster";
516
517        // Record outcomes for multiple techniques
518        coordinator.record_outcome(
519            cluster_id.to_string(),
520            vec![PromptingTechnique::ChainOfThought],
521            "task 1".to_string(),
522            true,
523            5,
524            0.9,
525        );
526        coordinator.record_outcome(
527            cluster_id.to_string(),
528            vec![PromptingTechnique::PlanAndSolve],
529            "task 2".to_string(),
530            true,
531            8,
532            0.85,
533        );
534
535        let summary = coordinator.get_cluster_summary(cluster_id);
536        assert_eq!(summary.cluster_id, cluster_id);
537        assert_eq!(summary.total_executions, 2);
538        assert_eq!(summary.techniques.len(), 2);
539    }
540
541    #[tokio::test]
542    async fn test_promotion_to_bks() {
543        let bks_cache = Arc::new(Mutex::new(
544            BehavioralKnowledgeCache::in_memory(100).unwrap(),
545        ));
546        let mut coordinator = PromptingLearningCoordinator::new(bks_cache.clone());
547
548        let cluster_id = "numerical_reasoning";
549        let technique = PromptingTechnique::ChainOfThought;
550
551        // Record 6 successful uses
552        for _ in 0..6 {
553            coordinator.record_outcome(
554                cluster_id.to_string(),
555                vec![technique.clone()],
556                "calculate primes".to_string(),
557                true,
558                5,
559                0.9,
560            );
561        }
562
563        // Promote to BKS
564        let promoted = coordinator
565            .promote_technique_to_bks(cluster_id, &technique)
566            .await
567            .unwrap();
568        assert!(promoted);
569
570        // Verify BKS contains the truth
571        let bks = bks_cache.lock().await;
572        let _truths = bks.all_truths().collect::<Vec<_>>();
573
574        // Check that at least one truth was queued
575        let pending = bks.pending_submissions();
576        assert!(!pending.is_empty());
577
578        let truth = &pending[0].truth;
579        assert_eq!(truth.category, TruthCategory::PromptingTechnique);
580        assert_eq!(truth.context_pattern, cluster_id);
581        assert!(truth.rule.contains("ChainOfThought"));
582    }
583}