use super::techniques::PromptingTechnique;
#[cfg(feature = "knowledge")]
use crate::knowledge::bks_pks::{
BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
};
use anyhow::Result;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
const DEFAULT_PROMOTION_THRESHOLD: f64 = 0.8;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TechniqueEffectivenessRecord {
pub technique: PromptingTechnique,
pub cluster_id: String,
pub task_description: String,
pub success: bool,
pub iterations_used: u32,
pub quality_score: f32,
pub timestamp: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TechniqueStats {
pub success_count: u32,
pub failure_count: u32,
pub avg_iterations: f32,
pub avg_quality: f32,
pub last_used: i64,
}
impl TechniqueStats {
pub fn new() -> Self {
Self {
success_count: 0,
failure_count: 0,
avg_iterations: 0.0,
avg_quality: 0.0,
last_used: Utc::now().timestamp(),
}
}
pub fn reliability(&self) -> f32 {
let total = self.success_count + self.failure_count;
if total == 0 {
0.0
} else {
self.success_count as f32 / total as f32
}
}
pub fn total_uses(&self) -> u32 {
self.success_count + self.failure_count
}
pub fn update(&mut self, success: bool, iterations: u32, quality: f32) {
if success {
self.success_count += 1;
} else {
self.failure_count += 1;
}
let alpha = 0.3; self.avg_iterations = alpha * iterations as f32 + (1.0 - alpha) * self.avg_iterations;
self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
self.last_used = Utc::now().timestamp();
}
}
impl Default for TechniqueStats {
fn default() -> Self {
Self::new()
}
}
pub struct PromptingLearningCoordinator {
records: Vec<TechniqueEffectivenessRecord>,
bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
technique_stats: HashMap<(String, PromptingTechnique), TechniqueStats>,
promotion_threshold: f32,
min_uses_for_promotion: u32,
}
impl PromptingLearningCoordinator {
pub fn new(bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
Self {
records: Vec::new(),
bks_cache,
technique_stats: HashMap::new(),
promotion_threshold: DEFAULT_PROMOTION_THRESHOLD as f32,
min_uses_for_promotion: 5,
}
}
pub fn with_thresholds(
bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>,
promotion_threshold: f32,
min_uses: u32,
) -> Self {
Self {
records: Vec::new(),
bks_cache,
technique_stats: HashMap::new(),
promotion_threshold,
min_uses_for_promotion: min_uses,
}
}
pub fn record_outcome(
&mut self,
cluster_id: String,
techniques: Vec<PromptingTechnique>,
task_description: String,
success: bool,
iterations: u32,
quality_score: f32,
) {
let timestamp = Utc::now().timestamp();
for technique in techniques {
let record = TechniqueEffectivenessRecord {
technique: technique.clone(),
cluster_id: cluster_id.clone(),
task_description: task_description.clone(),
success,
iterations_used: iterations,
quality_score,
timestamp,
};
self.records.push(record);
self.update_stats(&cluster_id, &technique, success, iterations, quality_score);
}
}
fn update_stats(
&mut self,
cluster_id: &str,
technique: &PromptingTechnique,
success: bool,
iterations: u32,
quality: f32,
) {
let key = (cluster_id.to_string(), technique.clone());
let stats = self.technique_stats.entry(key).or_default();
stats.update(success, iterations, quality);
}
pub fn should_promote(&self, cluster_id: &str, technique: &PromptingTechnique) -> bool {
if let Some(stats) = self
.technique_stats
.get(&(cluster_id.to_string(), technique.clone()))
{
let reliability = stats.reliability();
let uses = stats.total_uses();
reliability >= self.promotion_threshold && uses >= self.min_uses_for_promotion
} else {
false
}
}
pub async fn promote_technique_to_bks(
&mut self,
cluster_id: &str,
technique: &PromptingTechnique,
) -> Result<bool> {
if !self.should_promote(cluster_id, technique) {
return Ok(false);
}
let stats = self
.technique_stats
.get(&(cluster_id.to_string(), technique.clone()))
.expect("should_promote verified this entry exists");
let reliability = stats.reliability();
let uses = stats.total_uses();
let truth = BehavioralTruth::new(
TruthCategory::PromptingTechnique,
cluster_id.to_string(), format!(
"Use {:?} for {} tasks (achieves {:.1}% success rate)",
technique,
cluster_id,
reliability * 100.0
), format!(
"Learned from {} executions with avg quality {:.2}. \
Average iterations: {:.1}. \
This technique has proven effective for this task cluster.",
uses, stats.avg_quality, stats.avg_iterations
), TruthSource::SuccessPattern,
None, );
let mut bks = self.bks_cache.lock().await;
let queued = bks.queue_submission(truth)?;
if queued {
tracing::debug!(
?technique,
%cluster_id,
reliability_pct = reliability * 100.0,
uses,
"Adaptive Prompting: Promoted technique for cluster"
);
}
Ok(queued)
}
pub async fn check_and_promote_all(&mut self) -> Result<Vec<(String, PromptingTechnique)>> {
let mut promoted = Vec::new();
let eligible: Vec<_> = self
.technique_stats
.keys()
.filter(|(cluster_id, technique)| self.should_promote(cluster_id, technique))
.cloned()
.collect();
for (cluster_id, technique) in eligible {
if self
.promote_technique_to_bks(&cluster_id, &technique)
.await?
{
promoted.push((cluster_id, technique));
}
}
Ok(promoted)
}
pub fn get_stats(
&self,
cluster_id: &str,
technique: &PromptingTechnique,
) -> Option<&TechniqueStats> {
self.technique_stats
.get(&(cluster_id.to_string(), technique.clone()))
}
pub fn get_all_stats(&self) -> &HashMap<(String, PromptingTechnique), TechniqueStats> {
&self.technique_stats
}
pub fn get_recent_records(&self, count: usize) -> Vec<&TechniqueEffectivenessRecord> {
self.records.iter().rev().take(count).collect()
}
pub fn get_cluster_summary(&self, cluster_id: &str) -> ClusterSummary {
let mut summary = ClusterSummary {
cluster_id: cluster_id.to_string(),
total_executions: 0,
techniques: HashMap::new(),
};
for ((cid, technique), stats) in &self.technique_stats {
if cid == cluster_id {
summary.total_executions += stats.total_uses();
summary.techniques.insert(technique.clone(), stats.clone());
}
}
summary
}
pub fn prune_old_records(&mut self, keep_count: usize) {
if self.records.len() > keep_count {
let excess = self.records.len() - keep_count;
self.records.drain(0..excess);
}
}
pub fn get_thresholds(&self) -> (f32, u32) {
(self.promotion_threshold, self.min_uses_for_promotion)
}
}
#[derive(Debug, Clone)]
pub struct ClusterSummary {
pub cluster_id: String,
pub total_executions: u32,
pub techniques: HashMap<PromptingTechnique, TechniqueStats>,
}
impl ClusterSummary {
pub fn best_technique(&self) -> Option<(&PromptingTechnique, &TechniqueStats)> {
self.techniques
.iter()
.filter(|(_, stats)| stats.total_uses() >= 3) .max_by(|(_, a), (_, b)| {
a.reliability()
.partial_cmp(&b.reliability())
.unwrap_or(std::cmp::Ordering::Equal)
.then(
a.avg_quality
.partial_cmp(&b.avg_quality)
.unwrap_or(std::cmp::Ordering::Equal),
)
})
}
pub fn promotable_techniques(&self, threshold: f32, min_uses: u32) -> Vec<&PromptingTechnique> {
self.techniques
.iter()
.filter(|(_, stats)| stats.reliability() >= threshold && stats.total_uses() >= min_uses)
.map(|(technique, _)| technique)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_technique_stats_update() {
let mut stats = TechniqueStats::new();
for _ in 0..5 {
stats.update(true, 10, 0.9);
}
assert_eq!(stats.success_count, 5);
assert_eq!(stats.failure_count, 0);
assert_eq!(stats.reliability(), 1.0);
assert_eq!(stats.total_uses(), 5);
assert!(stats.avg_quality > 0.7); }
#[test]
fn test_should_promote() {
let bks_cache = Arc::new(Mutex::new(
BehavioralKnowledgeCache::in_memory(100).unwrap(),
));
let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
let cluster_id = "test_cluster";
let technique = PromptingTechnique::ChainOfThought;
for _ in 0..6 {
coordinator.record_outcome(
cluster_id.to_string(),
vec![technique.clone()],
"test task".to_string(),
true,
5,
0.9,
);
}
assert!(coordinator.should_promote(cluster_id, &technique));
}
#[test]
fn test_not_enough_uses() {
let bks_cache = Arc::new(Mutex::new(
BehavioralKnowledgeCache::in_memory(100).unwrap(),
));
let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
let cluster_id = "test_cluster";
let technique = PromptingTechnique::ChainOfThought;
for _ in 0..3 {
coordinator.record_outcome(
cluster_id.to_string(),
vec![technique.clone()],
"test task".to_string(),
true,
5,
0.9,
);
}
assert!(!coordinator.should_promote(cluster_id, &technique));
}
#[test]
fn test_reliability_too_low() {
let bks_cache = Arc::new(Mutex::new(
BehavioralKnowledgeCache::in_memory(100).unwrap(),
));
let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
let cluster_id = "test_cluster";
let technique = PromptingTechnique::ChainOfThought;
for _ in 0..3 {
coordinator.record_outcome(
cluster_id.to_string(),
vec![technique.clone()],
"test task".to_string(),
true,
5,
0.9,
);
}
for _ in 0..3 {
coordinator.record_outcome(
cluster_id.to_string(),
vec![technique.clone()],
"test task".to_string(),
false,
5,
0.5,
);
}
assert!(!coordinator.should_promote(cluster_id, &technique));
}
#[test]
fn test_cluster_summary() {
let bks_cache = Arc::new(Mutex::new(
BehavioralKnowledgeCache::in_memory(100).unwrap(),
));
let mut coordinator = PromptingLearningCoordinator::new(bks_cache);
let cluster_id = "test_cluster";
coordinator.record_outcome(
cluster_id.to_string(),
vec![PromptingTechnique::ChainOfThought],
"task 1".to_string(),
true,
5,
0.9,
);
coordinator.record_outcome(
cluster_id.to_string(),
vec![PromptingTechnique::PlanAndSolve],
"task 2".to_string(),
true,
8,
0.85,
);
let summary = coordinator.get_cluster_summary(cluster_id);
assert_eq!(summary.cluster_id, cluster_id);
assert_eq!(summary.total_executions, 2);
assert_eq!(summary.techniques.len(), 2);
}
#[tokio::test]
async fn test_promotion_to_bks() {
let bks_cache = Arc::new(Mutex::new(
BehavioralKnowledgeCache::in_memory(100).unwrap(),
));
let mut coordinator = PromptingLearningCoordinator::new(bks_cache.clone());
let cluster_id = "numerical_reasoning";
let technique = PromptingTechnique::ChainOfThought;
for _ in 0..6 {
coordinator.record_outcome(
cluster_id.to_string(),
vec![technique.clone()],
"calculate primes".to_string(),
true,
5,
0.9,
);
}
let promoted = coordinator
.promote_technique_to_bks(cluster_id, &technique)
.await
.unwrap();
assert!(promoted);
let bks = bks_cache.lock().await;
let _truths = bks.all_truths().collect::<Vec<_>>();
let pending = bks.pending_submissions();
assert!(!pending.is_empty());
let truth = &pending[0].truth;
assert_eq!(truth.category, TruthCategory::PromptingTechnique);
assert_eq!(truth.context_pattern, cluster_id);
assert!(truth.rule.contains("ChainOfThought"));
}
}