Skip to main content

brainwires_prompting/
learning.rs

1//! Learning & Optimization
2//!
3//! This module tracks technique effectiveness and learns from outcomes,
4//! promoting successful patterns to BKS for collective learning.
5
6use super::techniques::PromptingTechnique;
7use anyhow::Result;
8use brainwires_knowledge::knowledge::bks_pks::{
9    BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
10};
11use chrono::Utc;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::Mutex;
16
17/// Default minimum reliability threshold for promoting a technique to BKS.
18const DEFAULT_PROMOTION_THRESHOLD: f64 = 0.8;
19
20/// Record of technique effectiveness for a specific task execution
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct TechniqueEffectivenessRecord {
23    /// The prompting technique that was used.
24    pub technique: PromptingTechnique,
25    /// The cluster this task belongs to.
26    pub cluster_id: String,
27    /// Description of the task that was executed.
28    pub task_description: String,
29    /// Whether the task completed successfully.
30    pub success: bool,
31    /// Number of iterations consumed.
32    pub iterations_used: u32,
33    /// Quality score from 0.0 to 1.0.
34    pub quality_score: f32,
35    /// Unix timestamp of the execution.
36    pub timestamp: i64,
37}
38
39/// Statistics for a technique in a specific cluster
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TechniqueStats {
42    /// Number of successful executions.
43    pub success_count: u32,
44    /// Number of failed executions.
45    pub failure_count: u32,
46    /// Average iterations used across executions.
47    pub avg_iterations: f32,
48    /// Average quality score across executions.
49    pub avg_quality: f32,
50    /// Unix timestamp of the last execution.
51    pub last_used: i64,
52}
53
54impl TechniqueStats {
55    /// Create new stats with initial values
56    pub fn new() -> Self {
57        Self {
58            success_count: 0,
59            failure_count: 0,
60            avg_iterations: 0.0,
61            avg_quality: 0.0,
62            last_used: Utc::now().timestamp(),
63        }
64    }
65
66    /// Calculate reliability (success rate)
67    pub fn reliability(&self) -> f32 {
68        let total = self.success_count + self.failure_count;
69        if total == 0 {
70            0.0
71        } else {
72            self.success_count as f32 / total as f32
73        }
74    }
75
76    /// Total uses
77    pub fn total_uses(&self) -> u32 {
78        self.success_count + self.failure_count
79    }
80
81    /// Update stats with new outcome (using EMA for averages)
82    pub fn update(&mut self, success: bool, iterations: u32, quality: f32) {
83        if success {
84            self.success_count += 1;
85        } else {
86            self.failure_count += 1;
87        }
88
89        let alpha = 0.3; // EMA weight
90        self.avg_iterations = alpha * iterations as f32 + (1.0 - alpha) * self.avg_iterations;
91        self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
92        self.last_used = Utc::now().timestamp();
93    }
94}
95
96impl Default for TechniqueStats {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102/// Coordinates learning and promotion of technique effectiveness
103pub struct PromptingLearningCoordinator {
104    /// Historical records of technique effectiveness
105    records: Vec<TechniqueEffectivenessRecord>,
106
107    /// BKS cache for promoting effective techniques
108    bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
109
110    /// Aggregated statistics per (cluster_id, technique)
111    technique_stats: HashMap<(String, PromptingTechnique), TechniqueStats>,
112
113    /// Minimum reliability for BKS promotion (default: 0.8)
114    promotion_threshold: f32,
115
116    /// Minimum uses before promotion (default: 5)
117    min_uses_for_promotion: u32,
118}
119
120impl PromptingLearningCoordinator {
121    /// Create a new learning coordinator
122    pub fn new(bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
123        Self {
124            records: Vec::new(),
125            bks_cache,
126            technique_stats: HashMap::new(),
127            promotion_threshold: DEFAULT_PROMOTION_THRESHOLD as f32,
128            min_uses_for_promotion: 5,
129        }
130    }
131
132    /// Create with custom promotion thresholds
133    pub fn with_thresholds(
134        bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
135        promotion_threshold: f32,
136        min_uses: u32,
137    ) -> Self {
138        Self {
139            records: Vec::new(),
140            bks_cache,
141            technique_stats: HashMap::new(),
142            promotion_threshold,
143            min_uses_for_promotion: min_uses,
144        }
145    }
146
147    /// Record outcome of using specific techniques
148    ///
149    /// This is called after task completion to track which techniques worked.
150    ///
151    /// # Arguments
152    /// * `cluster_id` - The cluster that was matched
153    /// * `techniques` - The techniques that were used
154    /// * `task_description` - Description of the task
155    /// * `success` - Whether the task completed successfully
156    /// * `iterations` - Number of iterations used
157    /// * `quality_score` - Quality score (0.0-1.0)
158    pub fn record_outcome(
159        &mut self,
160        cluster_id: String,
161        techniques: Vec<PromptingTechnique>,
162        task_description: String,
163        success: bool,
164        iterations: u32,
165        quality_score: f32,
166    ) {
167        let timestamp = Utc::now().timestamp();
168
169        for technique in techniques {
170            // Create record
171            let record = TechniqueEffectivenessRecord {
172                technique: technique.clone(),
173                cluster_id: cluster_id.clone(),
174                task_description: task_description.clone(),
175                success,
176                iterations_used: iterations,
177                quality_score,
178                timestamp,
179            };
180
181            self.records.push(record);
182
183            // Update aggregated stats
184            self.update_stats(&cluster_id, &technique, success, iterations, quality_score);
185        }
186    }
187
188    /// Update aggregated statistics for a technique
189    fn update_stats(
190        &mut self,
191        cluster_id: &str,
192        technique: &PromptingTechnique,
193        success: bool,
194        iterations: u32,
195        quality: f32,
196    ) {
197        let key = (cluster_id.to_string(), technique.clone());
198        let stats = self.technique_stats.entry(key).or_default();
199        stats.update(success, iterations, quality);
200    }
201
202    /// Check if technique should be promoted to BKS
203    ///
204    /// Promotion criteria (same as SEAL patterns):
205    /// - Reliability > threshold (default: 0.8 / 80%)
206    /// - Total uses > min_uses (default: 5)
207    ///
208    /// # Returns
209    /// * `true` if technique qualifies for promotion
210    pub fn should_promote(&self, cluster_id: &str, technique: &PromptingTechnique) -> bool {
211        if let Some(stats) = self
212            .technique_stats
213            .get(&(cluster_id.to_string(), technique.clone()))
214        {
215            let reliability = stats.reliability();
216            let uses = stats.total_uses();
217
218            reliability >= self.promotion_threshold && uses >= self.min_uses_for_promotion
219        } else {
220            false
221        }
222    }
223
224    /// Promote technique to BKS
225    ///
226    /// Creates a BehavioralTruth with effectiveness information and submits to BKS.
227    /// This allows other users to benefit from the learned effectiveness.
228    pub async fn promote_technique_to_bks(
229        &mut self,
230        cluster_id: &str,
231        technique: &PromptingTechnique,
232    ) -> Result<bool> {
233        if !self.should_promote(cluster_id, technique) {
234            return Ok(false);
235        }
236
237        let stats = self
238            .technique_stats
239            .get(&(cluster_id.to_string(), technique.clone()))
240            .expect("should_promote verified this entry exists");
241
242        let reliability = stats.reliability();
243        let uses = stats.total_uses();
244
245        // Create BehavioralTruth
246        let truth = BehavioralTruth::new(
247            TruthCategory::PromptingTechnique,
248            cluster_id.to_string(), // context_pattern
249            format!(
250                "Use {:?} for {} tasks (achieves {:.1}% success rate)",
251                technique,
252                cluster_id,
253                reliability * 100.0
254            ), // rule
255            format!(
256                "Learned from {} executions with avg quality {:.2}. \
257                Average iterations: {:.1}. \
258                This technique has proven effective for this task cluster.",
259                uses, stats.avg_quality, stats.avg_iterations
260            ), // rationale
261            TruthSource::SuccessPattern,
262            None, // No specific user attribution
263        );
264
265        // Submit to BKS
266        let mut bks = self.bks_cache.lock().await;
267        let queued = bks.queue_submission(truth)?;
268
269        if queued {
270            tracing::debug!(
271                ?technique,
272                %cluster_id,
273                reliability_pct = reliability * 100.0,
274                uses,
275                "Adaptive Prompting: Promoted technique for cluster"
276            );
277        }
278
279        Ok(queued)
280    }
281
282    /// Check and promote all eligible techniques
283    ///
284    /// This should be called periodically (e.g., after each task completion)
285    /// to promote techniques that have reached the threshold.
286    pub async fn check_and_promote_all(&mut self) -> Result<Vec<(String, PromptingTechnique)>> {
287        let mut promoted = Vec::new();
288
289        // Collect eligible techniques (to avoid borrowing issues)
290        let eligible: Vec<_> = self
291            .technique_stats
292            .keys()
293            .filter(|(cluster_id, technique)| self.should_promote(cluster_id, technique))
294            .cloned()
295            .collect();
296
297        // Promote each eligible technique
298        for (cluster_id, technique) in eligible {
299            if self
300                .promote_technique_to_bks(&cluster_id, &technique)
301                .await?
302            {
303                promoted.push((cluster_id, technique));
304            }
305        }
306
307        Ok(promoted)
308    }
309
310    /// Get statistics for a specific technique in a cluster
311    pub fn get_stats(
312        &self,
313        cluster_id: &str,
314        technique: &PromptingTechnique,
315    ) -> Option<&TechniqueStats> {
316        self.technique_stats
317            .get(&(cluster_id.to_string(), technique.clone()))
318    }
319
320    /// Get all statistics
321    pub fn get_all_stats(&self) -> &HashMap<(String, PromptingTechnique), TechniqueStats> {
322        &self.technique_stats
323    }
324
325    /// Get recent records (last N)
326    pub fn get_recent_records(&self, count: usize) -> Vec<&TechniqueEffectivenessRecord> {
327        self.records.iter().rev().take(count).collect()
328    }
329
330    /// Get statistics summary for a cluster
331    pub fn get_cluster_summary(&self, cluster_id: &str) -> ClusterSummary {
332        let mut summary = ClusterSummary {
333            cluster_id: cluster_id.to_string(),
334            total_executions: 0,
335            techniques: HashMap::new(),
336        };
337
338        for ((cid, technique), stats) in &self.technique_stats {
339            if cid == cluster_id {
340                summary.total_executions += stats.total_uses();
341                summary.techniques.insert(technique.clone(), stats.clone());
342            }
343        }
344
345        summary
346    }
347
348    /// Clear old records (keep only recent N records)
349    pub fn prune_old_records(&mut self, keep_count: usize) {
350        if self.records.len() > keep_count {
351            let excess = self.records.len() - keep_count;
352            self.records.drain(0..excess);
353        }
354    }
355
356    /// Get promotion thresholds
357    pub fn get_thresholds(&self) -> (f32, u32) {
358        (self.promotion_threshold, self.min_uses_for_promotion)
359    }
360}
361
362/// Summary of technique performance for a cluster
363#[derive(Debug, Clone)]
364pub struct ClusterSummary {
365    /// The cluster identifier.
366    pub cluster_id: String,
367    /// Total number of task executions in this cluster.
368    pub total_executions: u32,
369    /// Per-technique performance statistics.
370    pub techniques: HashMap<PromptingTechnique, TechniqueStats>,
371}
372
373impl ClusterSummary {
374    /// Get the most effective technique in this cluster
375    pub fn best_technique(&self) -> Option<(&PromptingTechnique, &TechniqueStats)> {
376        self.techniques
377            .iter()
378            .filter(|(_, stats)| stats.total_uses() >= 3) // Minimum sample size
379            .max_by(|(_, a), (_, b)| {
380                // Compare by reliability, then by quality
381                a.reliability()
382                    .partial_cmp(&b.reliability())
383                    .unwrap_or(std::cmp::Ordering::Equal)
384                    .then(
385                        a.avg_quality
386                            .partial_cmp(&b.avg_quality)
387                            .unwrap_or(std::cmp::Ordering::Equal),
388                    )
389            })
390    }
391
392    /// Get techniques eligible for promotion
393    pub fn promotable_techniques(&self, threshold: f32, min_uses: u32) -> Vec<&PromptingTechnique> {
394        self.techniques
395            .iter()
396            .filter(|(_, stats)| stats.reliability() >= threshold && stats.total_uses() >= min_uses)
397            .map(|(technique, _)| technique)
398            .collect()
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_technique_stats_update() {
408        let mut stats = TechniqueStats::new();
409
410        // Record 5 successes
411        for _ in 0..5 {
412            stats.update(true, 10, 0.9);
413        }
414
415        assert_eq!(stats.success_count, 5);
416        assert_eq!(stats.failure_count, 0);
417        assert_eq!(stats.reliability(), 1.0);
418        assert_eq!(stats.total_uses(), 5);
419        assert!(stats.avg_quality > 0.7); // EMA with alpha=0.3 from 0.0, 5 updates → ~0.75
420    }
421
422    #[test]
423    fn test_should_promote() {
424        let bks_cache = Arc::new(Mutex::new(
425            BehavioralKnowledgeCache::in_memory(100).unwrap(),
426        ));
427        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
428
429        let cluster_id = "test_cluster";
430        let technique = PromptingTechnique::ChainOfThought;
431
432        // Record 6 successes (meets threshold)
433        for _ in 0..6 {
434            coordinator.record_outcome(
435                cluster_id.to_string(),
436                vec![technique.clone()],
437                "test task".to_string(),
438                true,
439                5,
440                0.9,
441            );
442        }
443
444        assert!(coordinator.should_promote(cluster_id, &technique));
445    }
446
447    #[test]
448    fn test_not_enough_uses() {
449        let bks_cache = Arc::new(Mutex::new(
450            BehavioralKnowledgeCache::in_memory(100).unwrap(),
451        ));
452        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
453
454        let cluster_id = "test_cluster";
455        let technique = PromptingTechnique::ChainOfThought;
456
457        // Only 3 uses (below threshold of 5)
458        for _ in 0..3 {
459            coordinator.record_outcome(
460                cluster_id.to_string(),
461                vec![technique.clone()],
462                "test task".to_string(),
463                true,
464                5,
465                0.9,
466            );
467        }
468
469        assert!(!coordinator.should_promote(cluster_id, &technique));
470    }
471
472    #[test]
473    fn test_reliability_too_low() {
474        let bks_cache = Arc::new(Mutex::new(
475            BehavioralKnowledgeCache::in_memory(100).unwrap(),
476        ));
477        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
478
479        let cluster_id = "test_cluster";
480        let technique = PromptingTechnique::ChainOfThought;
481
482        // 3 successes, 3 failures = 50% reliability (below 80% threshold)
483        for _ in 0..3 {
484            coordinator.record_outcome(
485                cluster_id.to_string(),
486                vec![technique.clone()],
487                "test task".to_string(),
488                true,
489                5,
490                0.9,
491            );
492        }
493        for _ in 0..3 {
494            coordinator.record_outcome(
495                cluster_id.to_string(),
496                vec![technique.clone()],
497                "test task".to_string(),
498                false,
499                5,
500                0.5,
501            );
502        }
503
504        assert!(!coordinator.should_promote(cluster_id, &technique));
505    }
506
507    #[test]
508    fn test_cluster_summary() {
509        let bks_cache = Arc::new(Mutex::new(
510            BehavioralKnowledgeCache::in_memory(100).unwrap(),
511        ));
512        let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
513
514        let cluster_id = "test_cluster";
515
516        // Record outcomes for multiple techniques
517        coordinator.record_outcome(
518            cluster_id.to_string(),
519            vec![PromptingTechnique::ChainOfThought],
520            "task 1".to_string(),
521            true,
522            5,
523            0.9,
524        );
525        coordinator.record_outcome(
526            cluster_id.to_string(),
527            vec![PromptingTechnique::PlanAndSolve],
528            "task 2".to_string(),
529            true,
530            8,
531            0.85,
532        );
533
534        let summary = coordinator.get_cluster_summary(cluster_id);
535        assert_eq!(summary.cluster_id, cluster_id);
536        assert_eq!(summary.total_executions, 2);
537        assert_eq!(summary.techniques.len(), 2);
538    }
539
540    #[tokio::test]
541    async fn test_promotion_to_bks() {
542        let bks_cache = Arc::new(Mutex::new(
543            BehavioralKnowledgeCache::in_memory(100).unwrap(),
544        ));
545        let mut coordinator = PromptingLearningCoordinator::new(bks_cache.clone());
546
547        let cluster_id = "numerical_reasoning";
548        let technique = PromptingTechnique::ChainOfThought;
549
550        // Record 6 successful uses
551        for _ in 0..6 {
552            coordinator.record_outcome(
553                cluster_id.to_string(),
554                vec![technique.clone()],
555                "calculate primes".to_string(),
556                true,
557                5,
558                0.9,
559            );
560        }
561
562        // Promote to BKS
563        let promoted = coordinator
564            .promote_technique_to_bks(cluster_id, &technique)
565            .await
566            .unwrap();
567        assert!(promoted);
568
569        // Verify BKS contains the truth
570        let bks = bks_cache.lock().await;
571        let _truths = bks.all_truths().collect::<Vec<_>>();
572
573        // Check that at least one truth was queued
574        let pending = bks.pending_submissions();
575        assert!(!pending.is_empty());
576
577        let truth = &pending[0].truth;
578        assert_eq!(truth.category, TruthCategory::PromptingTechnique);
579        assert_eq!(truth.context_pattern, cluster_id);
580        assert!(truth.rule.contains("ChainOfThought"));
581    }
582}