1use super::techniques::PromptingTechnique;
7use anyhow::Result;
8use brainwires_knowledge::knowledge::bks_pks::{
9 BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
10};
11use chrono::Utc;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::Mutex;
16
17const DEFAULT_PROMOTION_THRESHOLD: f64 = 0.8;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct TechniqueEffectivenessRecord {
23 pub technique: PromptingTechnique,
25 pub cluster_id: String,
27 pub task_description: String,
29 pub success: bool,
31 pub iterations_used: u32,
33 pub quality_score: f32,
35 pub timestamp: i64,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TechniqueStats {
42 pub success_count: u32,
44 pub failure_count: u32,
46 pub avg_iterations: f32,
48 pub avg_quality: f32,
50 pub last_used: i64,
52}
53
54impl TechniqueStats {
55 pub fn new() -> Self {
57 Self {
58 success_count: 0,
59 failure_count: 0,
60 avg_iterations: 0.0,
61 avg_quality: 0.0,
62 last_used: Utc::now().timestamp(),
63 }
64 }
65
66 pub fn reliability(&self) -> f32 {
68 let total = self.success_count + self.failure_count;
69 if total == 0 {
70 0.0
71 } else {
72 self.success_count as f32 / total as f32
73 }
74 }
75
76 pub fn total_uses(&self) -> u32 {
78 self.success_count + self.failure_count
79 }
80
81 pub fn update(&mut self, success: bool, iterations: u32, quality: f32) {
83 if success {
84 self.success_count += 1;
85 } else {
86 self.failure_count += 1;
87 }
88
89 let alpha = 0.3; self.avg_iterations = alpha * iterations as f32 + (1.0 - alpha) * self.avg_iterations;
91 self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
92 self.last_used = Utc::now().timestamp();
93 }
94}
95
96impl Default for TechniqueStats {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102pub struct PromptingLearningCoordinator {
104 records: Vec<TechniqueEffectivenessRecord>,
106
107 bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
109
110 technique_stats: HashMap<(String, PromptingTechnique), TechniqueStats>,
112
113 promotion_threshold: f32,
115
116 min_uses_for_promotion: u32,
118}
119
120impl PromptingLearningCoordinator {
121 pub fn new(bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
123 Self {
124 records: Vec::new(),
125 bks_cache,
126 technique_stats: HashMap::new(),
127 promotion_threshold: DEFAULT_PROMOTION_THRESHOLD as f32,
128 min_uses_for_promotion: 5,
129 }
130 }
131
132 pub fn with_thresholds(
134 bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
135 promotion_threshold: f32,
136 min_uses: u32,
137 ) -> Self {
138 Self {
139 records: Vec::new(),
140 bks_cache,
141 technique_stats: HashMap::new(),
142 promotion_threshold,
143 min_uses_for_promotion: min_uses,
144 }
145 }
146
147 pub fn record_outcome(
159 &mut self,
160 cluster_id: String,
161 techniques: Vec<PromptingTechnique>,
162 task_description: String,
163 success: bool,
164 iterations: u32,
165 quality_score: f32,
166 ) {
167 let timestamp = Utc::now().timestamp();
168
169 for technique in techniques {
170 let record = TechniqueEffectivenessRecord {
172 technique: technique.clone(),
173 cluster_id: cluster_id.clone(),
174 task_description: task_description.clone(),
175 success,
176 iterations_used: iterations,
177 quality_score,
178 timestamp,
179 };
180
181 self.records.push(record);
182
183 self.update_stats(&cluster_id, &technique, success, iterations, quality_score);
185 }
186 }
187
188 fn update_stats(
190 &mut self,
191 cluster_id: &str,
192 technique: &PromptingTechnique,
193 success: bool,
194 iterations: u32,
195 quality: f32,
196 ) {
197 let key = (cluster_id.to_string(), technique.clone());
198 let stats = self.technique_stats.entry(key).or_default();
199 stats.update(success, iterations, quality);
200 }
201
202 pub fn should_promote(&self, cluster_id: &str, technique: &PromptingTechnique) -> bool {
211 if let Some(stats) = self
212 .technique_stats
213 .get(&(cluster_id.to_string(), technique.clone()))
214 {
215 let reliability = stats.reliability();
216 let uses = stats.total_uses();
217
218 reliability >= self.promotion_threshold && uses >= self.min_uses_for_promotion
219 } else {
220 false
221 }
222 }
223
224 pub async fn promote_technique_to_bks(
229 &mut self,
230 cluster_id: &str,
231 technique: &PromptingTechnique,
232 ) -> Result<bool> {
233 if !self.should_promote(cluster_id, technique) {
234 return Ok(false);
235 }
236
237 let stats = self
238 .technique_stats
239 .get(&(cluster_id.to_string(), technique.clone()))
240 .expect("should_promote verified this entry exists");
241
242 let reliability = stats.reliability();
243 let uses = stats.total_uses();
244
245 let truth = BehavioralTruth::new(
247 TruthCategory::PromptingTechnique,
248 cluster_id.to_string(), format!(
250 "Use {:?} for {} tasks (achieves {:.1}% success rate)",
251 technique,
252 cluster_id,
253 reliability * 100.0
254 ), format!(
256 "Learned from {} executions with avg quality {:.2}. \
257 Average iterations: {:.1}. \
258 This technique has proven effective for this task cluster.",
259 uses, stats.avg_quality, stats.avg_iterations
260 ), TruthSource::SuccessPattern,
262 None, );
264
265 let mut bks = self.bks_cache.lock().await;
267 let queued = bks.queue_submission(truth)?;
268
269 if queued {
270 tracing::debug!(
271 ?technique,
272 %cluster_id,
273 reliability_pct = reliability * 100.0,
274 uses,
275 "Adaptive Prompting: Promoted technique for cluster"
276 );
277 }
278
279 Ok(queued)
280 }
281
282 pub async fn check_and_promote_all(&mut self) -> Result<Vec<(String, PromptingTechnique)>> {
287 let mut promoted = Vec::new();
288
289 let eligible: Vec<_> = self
291 .technique_stats
292 .keys()
293 .filter(|(cluster_id, technique)| self.should_promote(cluster_id, technique))
294 .cloned()
295 .collect();
296
297 for (cluster_id, technique) in eligible {
299 if self
300 .promote_technique_to_bks(&cluster_id, &technique)
301 .await?
302 {
303 promoted.push((cluster_id, technique));
304 }
305 }
306
307 Ok(promoted)
308 }
309
310 pub fn get_stats(
312 &self,
313 cluster_id: &str,
314 technique: &PromptingTechnique,
315 ) -> Option<&TechniqueStats> {
316 self.technique_stats
317 .get(&(cluster_id.to_string(), technique.clone()))
318 }
319
320 pub fn get_all_stats(&self) -> &HashMap<(String, PromptingTechnique), TechniqueStats> {
322 &self.technique_stats
323 }
324
325 pub fn get_recent_records(&self, count: usize) -> Vec<&TechniqueEffectivenessRecord> {
327 self.records.iter().rev().take(count).collect()
328 }
329
330 pub fn get_cluster_summary(&self, cluster_id: &str) -> ClusterSummary {
332 let mut summary = ClusterSummary {
333 cluster_id: cluster_id.to_string(),
334 total_executions: 0,
335 techniques: HashMap::new(),
336 };
337
338 for ((cid, technique), stats) in &self.technique_stats {
339 if cid == cluster_id {
340 summary.total_executions += stats.total_uses();
341 summary.techniques.insert(technique.clone(), stats.clone());
342 }
343 }
344
345 summary
346 }
347
348 pub fn prune_old_records(&mut self, keep_count: usize) {
350 if self.records.len() > keep_count {
351 let excess = self.records.len() - keep_count;
352 self.records.drain(0..excess);
353 }
354 }
355
356 pub fn get_thresholds(&self) -> (f32, u32) {
358 (self.promotion_threshold, self.min_uses_for_promotion)
359 }
360}
361
362#[derive(Debug, Clone)]
364pub struct ClusterSummary {
365 pub cluster_id: String,
367 pub total_executions: u32,
369 pub techniques: HashMap<PromptingTechnique, TechniqueStats>,
371}
372
373impl ClusterSummary {
374 pub fn best_technique(&self) -> Option<(&PromptingTechnique, &TechniqueStats)> {
376 self.techniques
377 .iter()
378 .filter(|(_, stats)| stats.total_uses() >= 3) .max_by(|(_, a), (_, b)| {
380 a.reliability()
382 .partial_cmp(&b.reliability())
383 .unwrap_or(std::cmp::Ordering::Equal)
384 .then(
385 a.avg_quality
386 .partial_cmp(&b.avg_quality)
387 .unwrap_or(std::cmp::Ordering::Equal),
388 )
389 })
390 }
391
392 pub fn promotable_techniques(&self, threshold: f32, min_uses: u32) -> Vec<&PromptingTechnique> {
394 self.techniques
395 .iter()
396 .filter(|(_, stats)| stats.reliability() >= threshold && stats.total_uses() >= min_uses)
397 .map(|(technique, _)| technique)
398 .collect()
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_technique_stats_update() {
408 let mut stats = TechniqueStats::new();
409
410 for _ in 0..5 {
412 stats.update(true, 10, 0.9);
413 }
414
415 assert_eq!(stats.success_count, 5);
416 assert_eq!(stats.failure_count, 0);
417 assert_eq!(stats.reliability(), 1.0);
418 assert_eq!(stats.total_uses(), 5);
419 assert!(stats.avg_quality > 0.7); }
421
422 #[test]
423 fn test_should_promote() {
424 let bks_cache = Arc::new(Mutex::new(
425 BehavioralKnowledgeCache::in_memory(100).unwrap(),
426 ));
427 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
428
429 let cluster_id = "test_cluster";
430 let technique = PromptingTechnique::ChainOfThought;
431
432 for _ in 0..6 {
434 coordinator.record_outcome(
435 cluster_id.to_string(),
436 vec![technique.clone()],
437 "test task".to_string(),
438 true,
439 5,
440 0.9,
441 );
442 }
443
444 assert!(coordinator.should_promote(cluster_id, &technique));
445 }
446
447 #[test]
448 fn test_not_enough_uses() {
449 let bks_cache = Arc::new(Mutex::new(
450 BehavioralKnowledgeCache::in_memory(100).unwrap(),
451 ));
452 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
453
454 let cluster_id = "test_cluster";
455 let technique = PromptingTechnique::ChainOfThought;
456
457 for _ in 0..3 {
459 coordinator.record_outcome(
460 cluster_id.to_string(),
461 vec![technique.clone()],
462 "test task".to_string(),
463 true,
464 5,
465 0.9,
466 );
467 }
468
469 assert!(!coordinator.should_promote(cluster_id, &technique));
470 }
471
472 #[test]
473 fn test_reliability_too_low() {
474 let bks_cache = Arc::new(Mutex::new(
475 BehavioralKnowledgeCache::in_memory(100).unwrap(),
476 ));
477 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
478
479 let cluster_id = "test_cluster";
480 let technique = PromptingTechnique::ChainOfThought;
481
482 for _ in 0..3 {
484 coordinator.record_outcome(
485 cluster_id.to_string(),
486 vec![technique.clone()],
487 "test task".to_string(),
488 true,
489 5,
490 0.9,
491 );
492 }
493 for _ in 0..3 {
494 coordinator.record_outcome(
495 cluster_id.to_string(),
496 vec![technique.clone()],
497 "test task".to_string(),
498 false,
499 5,
500 0.5,
501 );
502 }
503
504 assert!(!coordinator.should_promote(cluster_id, &technique));
505 }
506
507 #[test]
508 fn test_cluster_summary() {
509 let bks_cache = Arc::new(Mutex::new(
510 BehavioralKnowledgeCache::in_memory(100).unwrap(),
511 ));
512 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
513
514 let cluster_id = "test_cluster";
515
516 coordinator.record_outcome(
518 cluster_id.to_string(),
519 vec![PromptingTechnique::ChainOfThought],
520 "task 1".to_string(),
521 true,
522 5,
523 0.9,
524 );
525 coordinator.record_outcome(
526 cluster_id.to_string(),
527 vec![PromptingTechnique::PlanAndSolve],
528 "task 2".to_string(),
529 true,
530 8,
531 0.85,
532 );
533
534 let summary = coordinator.get_cluster_summary(cluster_id);
535 assert_eq!(summary.cluster_id, cluster_id);
536 assert_eq!(summary.total_executions, 2);
537 assert_eq!(summary.techniques.len(), 2);
538 }
539
540 #[tokio::test]
541 async fn test_promotion_to_bks() {
542 let bks_cache = Arc::new(Mutex::new(
543 BehavioralKnowledgeCache::in_memory(100).unwrap(),
544 ));
545 let mut coordinator = PromptingLearningCoordinator::new(bks_cache.clone());
546
547 let cluster_id = "numerical_reasoning";
548 let technique = PromptingTechnique::ChainOfThought;
549
550 for _ in 0..6 {
552 coordinator.record_outcome(
553 cluster_id.to_string(),
554 vec![technique.clone()],
555 "calculate primes".to_string(),
556 true,
557 5,
558 0.9,
559 );
560 }
561
562 let promoted = coordinator
564 .promote_technique_to_bks(cluster_id, &technique)
565 .await
566 .unwrap();
567 assert!(promoted);
568
569 let bks = bks_cache.lock().await;
571 let _truths = bks.all_truths().collect::<Vec<_>>();
572
573 let pending = bks.pending_submissions();
575 assert!(!pending.is_empty());
576
577 let truth = &pending[0].truth;
578 assert_eq!(truth.category, TruthCategory::PromptingTechnique);
579 assert_eq!(truth.context_pattern, cluster_id);
580 assert!(truth.rule.contains("ChainOfThought"));
581 }
582}