1use super::techniques::PromptingTechnique;
7#[cfg(feature = "knowledge")]
8use crate::knowledge::bks_pks::{
9 BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
10};
11use anyhow::Result;
12use chrono::Utc;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18const DEFAULT_PROMOTION_THRESHOLD: f64 = 0.8;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TechniqueEffectivenessRecord {
24 pub technique: PromptingTechnique,
26 pub cluster_id: String,
28 pub task_description: String,
30 pub success: bool,
32 pub iterations_used: u32,
34 pub quality_score: f32,
36 pub timestamp: i64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct TechniqueStats {
43 pub success_count: u32,
45 pub failure_count: u32,
47 pub avg_iterations: f32,
49 pub avg_quality: f32,
51 pub last_used: i64,
53}
54
55impl TechniqueStats {
56 pub fn new() -> Self {
58 Self {
59 success_count: 0,
60 failure_count: 0,
61 avg_iterations: 0.0,
62 avg_quality: 0.0,
63 last_used: Utc::now().timestamp(),
64 }
65 }
66
67 pub fn reliability(&self) -> f32 {
69 let total = self.success_count + self.failure_count;
70 if total == 0 {
71 0.0
72 } else {
73 self.success_count as f32 / total as f32
74 }
75 }
76
77 pub fn total_uses(&self) -> u32 {
79 self.success_count + self.failure_count
80 }
81
82 pub fn update(&mut self, success: bool, iterations: u32, quality: f32) {
84 if success {
85 self.success_count += 1;
86 } else {
87 self.failure_count += 1;
88 }
89
90 let alpha = 0.3; self.avg_iterations = alpha * iterations as f32 + (1.0 - alpha) * self.avg_iterations;
92 self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
93 self.last_used = Utc::now().timestamp();
94 }
95}
96
97impl Default for TechniqueStats {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103pub struct PromptingLearningCoordinator {
105 records: Vec<TechniqueEffectivenessRecord>,
107
108 bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
110
111 technique_stats: HashMap<(String, PromptingTechnique), TechniqueStats>,
113
114 promotion_threshold: f32,
116
117 min_uses_for_promotion: u32,
119}
120
121impl PromptingLearningCoordinator {
122 pub fn new(bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
124 Self {
125 records: Vec::new(),
126 bks_cache,
127 technique_stats: HashMap::new(),
128 promotion_threshold: DEFAULT_PROMOTION_THRESHOLD as f32,
129 min_uses_for_promotion: 5,
130 }
131 }
132
133 pub fn with_thresholds(
135 bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
136 promotion_threshold: f32,
137 min_uses: u32,
138 ) -> Self {
139 Self {
140 records: Vec::new(),
141 bks_cache,
142 technique_stats: HashMap::new(),
143 promotion_threshold,
144 min_uses_for_promotion: min_uses,
145 }
146 }
147
148 pub fn record_outcome(
160 &mut self,
161 cluster_id: String,
162 techniques: Vec<PromptingTechnique>,
163 task_description: String,
164 success: bool,
165 iterations: u32,
166 quality_score: f32,
167 ) {
168 let timestamp = Utc::now().timestamp();
169
170 for technique in techniques {
171 let record = TechniqueEffectivenessRecord {
173 technique: technique.clone(),
174 cluster_id: cluster_id.clone(),
175 task_description: task_description.clone(),
176 success,
177 iterations_used: iterations,
178 quality_score,
179 timestamp,
180 };
181
182 self.records.push(record);
183
184 self.update_stats(&cluster_id, &technique, success, iterations, quality_score);
186 }
187 }
188
189 fn update_stats(
191 &mut self,
192 cluster_id: &str,
193 technique: &PromptingTechnique,
194 success: bool,
195 iterations: u32,
196 quality: f32,
197 ) {
198 let key = (cluster_id.to_string(), technique.clone());
199 let stats = self.technique_stats.entry(key).or_default();
200 stats.update(success, iterations, quality);
201 }
202
203 pub fn should_promote(&self, cluster_id: &str, technique: &PromptingTechnique) -> bool {
212 if let Some(stats) = self
213 .technique_stats
214 .get(&(cluster_id.to_string(), technique.clone()))
215 {
216 let reliability = stats.reliability();
217 let uses = stats.total_uses();
218
219 reliability >= self.promotion_threshold && uses >= self.min_uses_for_promotion
220 } else {
221 false
222 }
223 }
224
225 pub async fn promote_technique_to_bks(
230 &mut self,
231 cluster_id: &str,
232 technique: &PromptingTechnique,
233 ) -> Result<bool> {
234 if !self.should_promote(cluster_id, technique) {
235 return Ok(false);
236 }
237
238 let stats = self
239 .technique_stats
240 .get(&(cluster_id.to_string(), technique.clone()))
241 .expect("should_promote verified this entry exists");
242
243 let reliability = stats.reliability();
244 let uses = stats.total_uses();
245
246 let truth = BehavioralTruth::new(
248 TruthCategory::PromptingTechnique,
249 cluster_id.to_string(), format!(
251 "Use {:?} for {} tasks (achieves {:.1}% success rate)",
252 technique,
253 cluster_id,
254 reliability * 100.0
255 ), format!(
257 "Learned from {} executions with avg quality {:.2}. \
258 Average iterations: {:.1}. \
259 This technique has proven effective for this task cluster.",
260 uses, stats.avg_quality, stats.avg_iterations
261 ), TruthSource::SuccessPattern,
263 None, );
265
266 let mut bks = self.bks_cache.lock().await;
268 let queued = bks.queue_submission(truth)?;
269
270 if queued {
271 tracing::debug!(
272 ?technique,
273 %cluster_id,
274 reliability_pct = reliability * 100.0,
275 uses,
276 "Adaptive Prompting: Promoted technique for cluster"
277 );
278 }
279
280 Ok(queued)
281 }
282
283 pub async fn check_and_promote_all(&mut self) -> Result<Vec<(String, PromptingTechnique)>> {
288 let mut promoted = Vec::new();
289
290 let eligible: Vec<_> = self
292 .technique_stats
293 .keys()
294 .filter(|(cluster_id, technique)| self.should_promote(cluster_id, technique))
295 .cloned()
296 .collect();
297
298 for (cluster_id, technique) in eligible {
300 if self
301 .promote_technique_to_bks(&cluster_id, &technique)
302 .await?
303 {
304 promoted.push((cluster_id, technique));
305 }
306 }
307
308 Ok(promoted)
309 }
310
311 pub fn get_stats(
313 &self,
314 cluster_id: &str,
315 technique: &PromptingTechnique,
316 ) -> Option<&TechniqueStats> {
317 self.technique_stats
318 .get(&(cluster_id.to_string(), technique.clone()))
319 }
320
321 pub fn get_all_stats(&self) -> &HashMap<(String, PromptingTechnique), TechniqueStats> {
323 &self.technique_stats
324 }
325
326 pub fn get_recent_records(&self, count: usize) -> Vec<&TechniqueEffectivenessRecord> {
328 self.records.iter().rev().take(count).collect()
329 }
330
331 pub fn get_cluster_summary(&self, cluster_id: &str) -> ClusterSummary {
333 let mut summary = ClusterSummary {
334 cluster_id: cluster_id.to_string(),
335 total_executions: 0,
336 techniques: HashMap::new(),
337 };
338
339 for ((cid, technique), stats) in &self.technique_stats {
340 if cid == cluster_id {
341 summary.total_executions += stats.total_uses();
342 summary.techniques.insert(technique.clone(), stats.clone());
343 }
344 }
345
346 summary
347 }
348
349 pub fn prune_old_records(&mut self, keep_count: usize) {
351 if self.records.len() > keep_count {
352 let excess = self.records.len() - keep_count;
353 self.records.drain(0..excess);
354 }
355 }
356
357 pub fn get_thresholds(&self) -> (f32, u32) {
359 (self.promotion_threshold, self.min_uses_for_promotion)
360 }
361}
362
363#[derive(Debug, Clone)]
365pub struct ClusterSummary {
366 pub cluster_id: String,
368 pub total_executions: u32,
370 pub techniques: HashMap<PromptingTechnique, TechniqueStats>,
372}
373
374impl ClusterSummary {
375 pub fn best_technique(&self) -> Option<(&PromptingTechnique, &TechniqueStats)> {
377 self.techniques
378 .iter()
379 .filter(|(_, stats)| stats.total_uses() >= 3) .max_by(|(_, a), (_, b)| {
381 a.reliability()
383 .partial_cmp(&b.reliability())
384 .unwrap_or(std::cmp::Ordering::Equal)
385 .then(
386 a.avg_quality
387 .partial_cmp(&b.avg_quality)
388 .unwrap_or(std::cmp::Ordering::Equal),
389 )
390 })
391 }
392
393 pub fn promotable_techniques(&self, threshold: f32, min_uses: u32) -> Vec<&PromptingTechnique> {
395 self.techniques
396 .iter()
397 .filter(|(_, stats)| stats.reliability() >= threshold && stats.total_uses() >= min_uses)
398 .map(|(technique, _)| technique)
399 .collect()
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_technique_stats_update() {
409 let mut stats = TechniqueStats::new();
410
411 for _ in 0..5 {
413 stats.update(true, 10, 0.9);
414 }
415
416 assert_eq!(stats.success_count, 5);
417 assert_eq!(stats.failure_count, 0);
418 assert_eq!(stats.reliability(), 1.0);
419 assert_eq!(stats.total_uses(), 5);
420 assert!(stats.avg_quality > 0.7); }
422
423 #[test]
424 fn test_should_promote() {
425 let bks_cache = Arc::new(Mutex::new(
426 BehavioralKnowledgeCache::in_memory(100).unwrap(),
427 ));
428 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
429
430 let cluster_id = "test_cluster";
431 let technique = PromptingTechnique::ChainOfThought;
432
433 for _ in 0..6 {
435 coordinator.record_outcome(
436 cluster_id.to_string(),
437 vec![technique.clone()],
438 "test task".to_string(),
439 true,
440 5,
441 0.9,
442 );
443 }
444
445 assert!(coordinator.should_promote(cluster_id, &technique));
446 }
447
448 #[test]
449 fn test_not_enough_uses() {
450 let bks_cache = Arc::new(Mutex::new(
451 BehavioralKnowledgeCache::in_memory(100).unwrap(),
452 ));
453 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
454
455 let cluster_id = "test_cluster";
456 let technique = PromptingTechnique::ChainOfThought;
457
458 for _ in 0..3 {
460 coordinator.record_outcome(
461 cluster_id.to_string(),
462 vec![technique.clone()],
463 "test task".to_string(),
464 true,
465 5,
466 0.9,
467 );
468 }
469
470 assert!(!coordinator.should_promote(cluster_id, &technique));
471 }
472
473 #[test]
474 fn test_reliability_too_low() {
475 let bks_cache = Arc::new(Mutex::new(
476 BehavioralKnowledgeCache::in_memory(100).unwrap(),
477 ));
478 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
479
480 let cluster_id = "test_cluster";
481 let technique = PromptingTechnique::ChainOfThought;
482
483 for _ in 0..3 {
485 coordinator.record_outcome(
486 cluster_id.to_string(),
487 vec![technique.clone()],
488 "test task".to_string(),
489 true,
490 5,
491 0.9,
492 );
493 }
494 for _ in 0..3 {
495 coordinator.record_outcome(
496 cluster_id.to_string(),
497 vec![technique.clone()],
498 "test task".to_string(),
499 false,
500 5,
501 0.5,
502 );
503 }
504
505 assert!(!coordinator.should_promote(cluster_id, &technique));
506 }
507
508 #[test]
509 fn test_cluster_summary() {
510 let bks_cache = Arc::new(Mutex::new(
511 BehavioralKnowledgeCache::in_memory(100).unwrap(),
512 ));
513 let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
514
515 let cluster_id = "test_cluster";
516
517 coordinator.record_outcome(
519 cluster_id.to_string(),
520 vec![PromptingTechnique::ChainOfThought],
521 "task 1".to_string(),
522 true,
523 5,
524 0.9,
525 );
526 coordinator.record_outcome(
527 cluster_id.to_string(),
528 vec![PromptingTechnique::PlanAndSolve],
529 "task 2".to_string(),
530 true,
531 8,
532 0.85,
533 );
534
535 let summary = coordinator.get_cluster_summary(cluster_id);
536 assert_eq!(summary.cluster_id, cluster_id);
537 assert_eq!(summary.total_executions, 2);
538 assert_eq!(summary.techniques.len(), 2);
539 }
540
541 #[tokio::test]
542 async fn test_promotion_to_bks() {
543 let bks_cache = Arc::new(Mutex::new(
544 BehavioralKnowledgeCache::in_memory(100).unwrap(),
545 ));
546 let mut coordinator = PromptingLearningCoordinator::new(bks_cache.clone());
547
548 let cluster_id = "numerical_reasoning";
549 let technique = PromptingTechnique::ChainOfThought;
550
551 for _ in 0..6 {
553 coordinator.record_outcome(
554 cluster_id.to_string(),
555 vec![technique.clone()],
556 "calculate primes".to_string(),
557 true,
558 5,
559 0.9,
560 );
561 }
562
563 let promoted = coordinator
565 .promote_technique_to_bks(cluster_id, &technique)
566 .await
567 .unwrap();
568 assert!(promoted);
569
570 let bks = bks_cache.lock().await;
572 let _truths = bks.all_truths().collect::<Vec<_>>();
573
574 let pending = bks.pending_submissions();
576 assert!(!pending.is_empty());
577
578 let truth = &pending[0].truth;
579 assert_eq!(truth.category, TruthCategory::PromptingTechnique);
580 assert_eq!(truth.context_pattern, cluster_id);
581 assert!(truth.rule.contains("ChainOfThought"));
582 }
583}