use super::clustering::TaskCluster;
#[cfg(feature = "knowledge")]
use crate::knowledge::bks_pks::{
BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemperaturePerformance {
pub success_rate: f32,
pub avg_quality: f32,
pub sample_count: u32,
pub last_updated: i64,
}
impl TemperaturePerformance {
pub fn new() -> Self {
Self {
success_rate: 0.5, avg_quality: 0.5,
sample_count: 0,
last_updated: chrono::Utc::now().timestamp(),
}
}
pub fn update(&mut self, success: bool, quality: f32) {
let alpha = 0.3;
self.success_rate =
alpha * (if success { 1.0 } else { 0.0 }) + (1.0 - alpha) * self.success_rate;
self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
self.sample_count += 1;
self.last_updated = chrono::Utc::now().timestamp();
}
pub fn score(&self) -> f32 {
0.6 * self.success_rate + 0.4 * self.avg_quality
}
}
impl Default for TemperaturePerformance {
fn default() -> Self {
Self::new()
}
}
pub struct TemperatureOptimizer {
performance_map: HashMap<(String, i32), TemperaturePerformance>,
bks_cache: Option<Arc<Mutex<BehavioralKnowledgeCache>>>,
candidates: Vec<f32>,
min_samples: u32,
}
impl TemperatureOptimizer {
pub fn new() -> Self {
Self {
performance_map: HashMap::new(),
bks_cache: None,
candidates: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.3],
min_samples: 5,
}
}
fn temp_to_key(temp: f32) -> i32 {
(temp * 10.0).round() as i32
}
pub fn with_bks(mut self, bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
self.bks_cache = Some(bks_cache);
self
}
pub fn with_min_samples(mut self, min_samples: u32) -> Self {
self.min_samples = min_samples;
self
}
pub async fn get_optimal_temperature(&self, cluster: &TaskCluster) -> f32 {
if let Some(bks_temp) = self.query_bks_temperature(&cluster.id).await {
return bks_temp;
}
if let Some(local_temp) = self.get_local_optimal(&cluster.id) {
return local_temp;
}
self.get_default_temperature(cluster)
}
fn get_local_optimal(&self, cluster_id: &str) -> Option<f32> {
let mut best_temp = None;
let mut best_score = f32::NEG_INFINITY;
for &temp in &self.candidates {
let temp_key = Self::temp_to_key(temp);
if let Some(perf) = self
.performance_map
.get(&(cluster_id.to_string(), temp_key))
&& perf.sample_count >= self.min_samples
{
let score = perf.score();
if score > best_score {
best_score = score;
best_temp = Some(temp);
}
}
}
best_temp
}
async fn query_bks_temperature(&self, cluster_id: &str) -> Option<f32> {
if let Some(ref bks_cache) = self.bks_cache {
let bks = bks_cache.lock().await;
let truths = bks.get_matching_truths(cluster_id);
for truth in truths {
if truth.category == TruthCategory::TaskStrategy
&& let Some(temp) = self.parse_temperature_from_truth(truth)
{
return Some(temp);
}
}
}
None
}
fn parse_temperature_from_truth(&self, truth: &BehavioralTruth) -> Option<f32> {
let text = format!("{} {}", truth.rule, truth.rationale);
if let Some(idx) = text.find("temperature") {
let substr = &text[idx..];
let parts: Vec<&str> = substr.split_whitespace().collect();
for part in parts.iter().skip(1) {
if let Ok(temp) = part.parse::<f32>()
&& self.candidates.contains(&temp)
{
return Some(temp);
}
}
}
None
}
fn get_default_temperature(&self, cluster: &TaskCluster) -> f32 {
let desc = cluster.description.to_lowercase();
if desc.contains("logic")
|| desc.contains("boolean")
|| desc.contains("reasoning")
|| desc.contains("puzzle")
|| desc.contains("deduction")
{
return 0.0;
}
if desc.contains("creative")
|| desc.contains("linguistic")
|| desc.contains("story")
|| desc.contains("writing")
|| desc.contains("generation")
{
return 1.3;
}
if desc.contains("numerical")
|| desc.contains("calculation")
|| desc.contains("math")
|| desc.contains("arithmetic")
{
return 0.2;
}
if desc.contains("code")
|| desc.contains("programming")
|| desc.contains("implementation")
|| desc.contains("algorithm")
{
return 0.6;
}
0.7
}
pub fn record_temperature_outcome(
&mut self,
cluster_id: String,
temperature: f32,
success: bool,
quality: f32,
) {
let temp_key = Self::temp_to_key(temperature);
let key = (cluster_id, temp_key);
let perf = self.performance_map.entry(key).or_default();
perf.update(success, quality);
}
pub async fn check_and_promote_temperature(
&self,
cluster_id: &str,
temperature: f32,
min_score: f32,
min_samples: u32,
) -> Result<()> {
let temp_key = Self::temp_to_key(temperature);
let key = (cluster_id.to_string(), temp_key);
if let Some(perf) = self.performance_map.get(&key)
&& perf.sample_count >= min_samples
&& perf.score() >= min_score
{
if let Some(ref bks_cache) = self.bks_cache {
let truth = BehavioralTruth::new(
TruthCategory::TaskStrategy,
cluster_id.to_string(),
format!(
"For {} tasks, use temperature {} for optimal results",
cluster_id, temperature
),
format!(
"Learned from {} executions with {:.1}% success rate and {:.2} avg quality",
perf.sample_count,
perf.success_rate * 100.0,
perf.avg_quality
),
TruthSource::SuccessPattern,
None,
);
let mut bks = bks_cache.lock().await;
bks.queue_submission(truth)?;
}
}
Ok(())
}
pub fn get_all_performance(&self) -> &HashMap<(String, i32), TemperaturePerformance> {
&self.performance_map
}
pub fn get_performance(
&self,
cluster_id: &str,
temperature: f32,
) -> Option<&TemperaturePerformance> {
let temp_key = Self::temp_to_key(temperature);
self.performance_map
.get(&(cluster_id.to_string(), temp_key))
}
}
impl Default for TemperatureOptimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prompting::techniques::PromptingTechnique;
#[test]
fn test_temperature_performance_update() {
let mut perf = TemperaturePerformance::new();
assert_eq!(perf.success_rate, 0.5);
assert_eq!(perf.sample_count, 0);
perf.update(true, 0.9);
assert!(perf.success_rate > 0.5); assert_eq!(perf.sample_count, 1);
perf.update(false, 0.3);
assert_eq!(perf.sample_count, 2);
assert!(perf.avg_quality < 0.9); }
#[test]
fn test_temperature_performance_score() {
let mut perf = TemperaturePerformance::new();
perf.success_rate = 0.8;
perf.avg_quality = 0.7;
let score = perf.score();
assert!((score - 0.76).abs() < 0.01); }
#[test]
fn test_default_temperature_heuristics() {
let optimizer = TemperatureOptimizer::new();
let logic_cluster = TaskCluster::new(
"logic_task".to_string(),
"Boolean logic and reasoning puzzles".to_string(),
vec![0.5; 768],
vec![PromptingTechnique::LogicOfThought],
vec![],
);
assert_eq!(optimizer.get_default_temperature(&logic_cluster), 0.0);
let creative_cluster = TaskCluster::new(
"creative_task".to_string(),
"Creative writing and story generation".to_string(),
vec![0.5; 768],
vec![PromptingTechnique::RolePlaying],
vec![],
);
assert_eq!(optimizer.get_default_temperature(&creative_cluster), 1.3);
let code_cluster = TaskCluster::new(
"code_task".to_string(),
"Code implementation and algorithm design".to_string(),
vec![0.5; 768],
vec![PromptingTechnique::PlanAndSolve],
vec![],
);
assert_eq!(optimizer.get_default_temperature(&code_cluster), 0.6);
}
#[test]
fn test_record_and_retrieve_local_optimal() {
let mut optimizer = TemperatureOptimizer::new();
for _ in 0..10 {
optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.9);
optimizer.record_temperature_outcome("test_cluster".to_string(), 0.6, false, 0.5);
}
let optimal = optimizer.get_local_optimal("test_cluster");
assert_eq!(optimal, Some(0.0));
}
#[test]
fn test_min_samples_requirement() {
let mut optimizer = TemperatureOptimizer::new().with_min_samples(5);
for _ in 0..3 {
optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
}
assert_eq!(optimizer.get_local_optimal("test_cluster"), None);
for _ in 0..2 {
optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
}
assert_eq!(optimizer.get_local_optimal("test_cluster"), Some(0.0));
}
#[tokio::test]
async fn test_get_optimal_temperature_fallback() {
let optimizer = TemperatureOptimizer::new();
let cluster = TaskCluster::new(
"logic_test".to_string(),
"Boolean logic problems".to_string(),
vec![0.5; 768],
vec![PromptingTechnique::LogicOfThought],
vec![],
);
let temp = optimizer.get_optimal_temperature(&cluster).await;
assert_eq!(temp, 0.0); }
}