Skip to main content

brainwires_cognition/prompting/
temperature.rs

1//! Temperature Optimization
2//!
3//! This module provides adaptive temperature selection per task cluster,
4//! based on the paper's findings:
5//! - Low temp (0.0): Best for logical tasks (Zebra Puzzles, Web of Lies, Boolean Expressions)
6//! - High temp (1.3): Best for linguistic tasks (Hyperbaton - adjective order judgment)
7//!
8//! Temperature performance is tracked per cluster and can be shared via BKS/PKS.
9
10use super::clustering::TaskCluster;
11#[cfg(feature = "knowledge")]
12use crate::knowledge::bks_pks::{
13    BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
14};
15use anyhow::Result;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20
21/// Tracks performance metrics for a specific temperature setting
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TemperaturePerformance {
24    /// Success rate (0.0-1.0) using EMA
25    pub success_rate: f32,
26    /// Average quality score (0.0-1.0) using EMA
27    pub avg_quality: f32,
28    /// Number of samples collected
29    pub sample_count: u32,
30    /// Last updated timestamp
31    pub last_updated: i64,
32}
33
34impl TemperaturePerformance {
35    /// Create a new performance record with neutral defaults.
36    pub fn new() -> Self {
37        Self {
38            success_rate: 0.5, // Neutral starting point
39            avg_quality: 0.5,
40            sample_count: 0,
41            last_updated: chrono::Utc::now().timestamp(),
42        }
43    }
44
45    /// Update metrics with new outcome using EMA (alpha = 0.3)
46    pub fn update(&mut self, success: bool, quality: f32) {
47        let alpha = 0.3;
48        self.success_rate =
49            alpha * (if success { 1.0 } else { 0.0 }) + (1.0 - alpha) * self.success_rate;
50        self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
51        self.sample_count += 1;
52        self.last_updated = chrono::Utc::now().timestamp();
53    }
54
55    /// Combined score for ranking (60% success rate, 40% quality)
56    pub fn score(&self) -> f32 {
57        0.6 * self.success_rate + 0.4 * self.avg_quality
58    }
59}
60
61impl Default for TemperaturePerformance {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67/// Manages adaptive temperature selection per task cluster
68pub struct TemperatureOptimizer {
69    /// Maps (cluster_id, temperature_int) → performance stats
70    /// Temperature stored as i32 (multiply by 10: 0.0 → 0, 0.2 → 2, 1.3 → 13)
71    performance_map: HashMap<(String, i32), TemperaturePerformance>,
72    /// BKS cache for querying shared temperature preferences
73    bks_cache: Option<Arc<Mutex<BehavioralKnowledgeCache>>>,
74    /// Candidate temperatures to test (from paper)
75    candidates: Vec<f32>,
76    /// Minimum samples before trusting a temperature setting
77    min_samples: u32,
78}
79
80impl TemperatureOptimizer {
81    /// Create a new temperature optimizer
82    pub fn new() -> Self {
83        Self {
84            performance_map: HashMap::new(),
85            bks_cache: None,
86            candidates: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.3],
87            min_samples: 5,
88        }
89    }
90
91    /// Convert temperature f32 to i32 for HashMap key
92    fn temp_to_key(temp: f32) -> i32 {
93        (temp * 10.0).round() as i32
94    }
95
96    /// Set BKS cache for querying shared temperature knowledge
97    pub fn with_bks(mut self, bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
98        self.bks_cache = Some(bks_cache);
99        self
100    }
101
102    /// Set minimum samples required before trusting a temperature
103    pub fn with_min_samples(mut self, min_samples: u32) -> Self {
104        self.min_samples = min_samples;
105        self
106    }
107
108    /// Get optimal temperature for a cluster
109    ///
110    /// Selection order:
111    /// 1. BKS shared knowledge (if available)
112    /// 2. Local learned temperature (if enough samples)
113    /// 3. Default heuristic based on cluster characteristics
114    pub async fn get_optimal_temperature(&self, cluster: &TaskCluster) -> f32 {
115        // Source 1: BKS shared knowledge
116        if let Some(bks_temp) = self.query_bks_temperature(&cluster.id).await {
117            return bks_temp;
118        }
119
120        // Source 2: Local learned temperature
121        if let Some(local_temp) = self.get_local_optimal(&cluster.id) {
122            return local_temp;
123        }
124
125        // Source 3: Default heuristic based on cluster characteristics
126        self.get_default_temperature(cluster)
127    }
128
129    /// Get locally learned optimal temperature
130    fn get_local_optimal(&self, cluster_id: &str) -> Option<f32> {
131        let mut best_temp = None;
132        let mut best_score = f32::NEG_INFINITY;
133
134        for &temp in &self.candidates {
135            let temp_key = Self::temp_to_key(temp);
136            if let Some(perf) = self
137                .performance_map
138                .get(&(cluster_id.to_string(), temp_key))
139                && perf.sample_count >= self.min_samples
140            {
141                let score = perf.score();
142                if score > best_score {
143                    best_score = score;
144                    best_temp = Some(temp);
145                }
146            }
147        }
148
149        best_temp
150    }
151
152    /// Query BKS for shared temperature knowledge
153    async fn query_bks_temperature(&self, cluster_id: &str) -> Option<f32> {
154        if let Some(ref bks_cache) = self.bks_cache {
155            let bks = bks_cache.lock().await;
156
157            // Query for temperature truths for this cluster
158            // get_matching_truths takes just a context string
159            let truths = bks.get_matching_truths(cluster_id);
160
161            // Parse temperature from truth content
162            // Example: "For numerical_reasoning, use temperature 0.0 for optimal results"
163            for truth in truths {
164                // Filter for TaskStrategy category
165                if truth.category == TruthCategory::TaskStrategy
166                    && let Some(temp) = self.parse_temperature_from_truth(truth)
167                {
168                    return Some(temp);
169                }
170            }
171        }
172
173        None
174    }
175
176    /// Parse temperature value from BKS truth
177    fn parse_temperature_from_truth(&self, truth: &BehavioralTruth) -> Option<f32> {
178        // Look for "temperature X.X" pattern in rule or rationale
179        let text = format!("{} {}", truth.rule, truth.rationale);
180
181        // Simple regex-like parsing
182        if let Some(idx) = text.find("temperature") {
183            let substr = &text[idx..];
184            // Find first number after "temperature"
185            let parts: Vec<&str> = substr.split_whitespace().collect();
186            for part in parts.iter().skip(1) {
187                if let Ok(temp) = part.parse::<f32>()
188                    && self.candidates.contains(&temp)
189                {
190                    return Some(temp);
191                }
192            }
193        }
194
195        None
196    }
197
198    /// Get default temperature based on cluster characteristics (heuristic)
199    fn get_default_temperature(&self, cluster: &TaskCluster) -> f32 {
200        let desc = cluster.description.to_lowercase();
201
202        // Logic/reasoning tasks: Low temperature (0.0)
203        if desc.contains("logic")
204            || desc.contains("boolean")
205            || desc.contains("reasoning")
206            || desc.contains("puzzle")
207            || desc.contains("deduction")
208        {
209            return 0.0;
210        }
211
212        // Creative/linguistic tasks: High temperature (1.3)
213        if desc.contains("creative")
214            || desc.contains("linguistic")
215            || desc.contains("story")
216            || desc.contains("writing")
217            || desc.contains("generation")
218        {
219            return 1.3;
220        }
221
222        // Numerical/calculation tasks: Low temperature (0.2)
223        if desc.contains("numerical")
224            || desc.contains("calculation")
225            || desc.contains("math")
226            || desc.contains("arithmetic")
227        {
228            return 0.2;
229        }
230
231        // Code generation: Moderate temperature (0.6)
232        if desc.contains("code")
233            || desc.contains("programming")
234            || desc.contains("implementation")
235            || desc.contains("algorithm")
236        {
237            return 0.6;
238        }
239
240        // Default: Moderate temperature
241        0.7
242    }
243
244    /// Record outcome for a temperature setting
245    pub fn record_temperature_outcome(
246        &mut self,
247        cluster_id: String,
248        temperature: f32,
249        success: bool,
250        quality: f32,
251    ) {
252        let temp_key = Self::temp_to_key(temperature);
253        let key = (cluster_id, temp_key);
254        let perf = self.performance_map.entry(key).or_default();
255
256        perf.update(success, quality);
257    }
258
259    /// Check if temperature should be promoted to BKS
260    pub async fn check_and_promote_temperature(
261        &self,
262        cluster_id: &str,
263        temperature: f32,
264        min_score: f32,
265        min_samples: u32,
266    ) -> Result<()> {
267        let temp_key = Self::temp_to_key(temperature);
268        let key = (cluster_id.to_string(), temp_key);
269
270        if let Some(perf) = self.performance_map.get(&key)
271            && perf.sample_count >= min_samples
272            && perf.score() >= min_score
273        {
274            // Promote to BKS
275            if let Some(ref bks_cache) = self.bks_cache {
276                let truth = BehavioralTruth::new(
277                    TruthCategory::TaskStrategy,
278                    cluster_id.to_string(),
279                    format!(
280                        "For {} tasks, use temperature {} for optimal results",
281                        cluster_id, temperature
282                    ),
283                    format!(
284                        "Learned from {} executions with {:.1}% success rate and {:.2} avg quality",
285                        perf.sample_count,
286                        perf.success_rate * 100.0,
287                        perf.avg_quality
288                    ),
289                    TruthSource::SuccessPattern,
290                    None,
291                );
292
293                let mut bks = bks_cache.lock().await;
294                bks.queue_submission(truth)?;
295            }
296        }
297
298        Ok(())
299    }
300
301    /// Get all performance data (for debugging/inspection)
302    pub fn get_all_performance(&self) -> &HashMap<(String, i32), TemperaturePerformance> {
303        &self.performance_map
304    }
305
306    /// Get performance for a specific cluster and temperature
307    pub fn get_performance(
308        &self,
309        cluster_id: &str,
310        temperature: f32,
311    ) -> Option<&TemperaturePerformance> {
312        let temp_key = Self::temp_to_key(temperature);
313        self.performance_map
314            .get(&(cluster_id.to_string(), temp_key))
315    }
316}
317
318impl Default for TemperatureOptimizer {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::prompting::techniques::PromptingTechnique;
328
329    #[test]
330    fn test_temperature_performance_update() {
331        let mut perf = TemperaturePerformance::new();
332        assert_eq!(perf.success_rate, 0.5);
333        assert_eq!(perf.sample_count, 0);
334
335        // Record success
336        perf.update(true, 0.9);
337        assert!(perf.success_rate > 0.5); // Should increase
338        assert_eq!(perf.sample_count, 1);
339
340        // Record failure
341        perf.update(false, 0.3);
342        assert_eq!(perf.sample_count, 2);
343        assert!(perf.avg_quality < 0.9); // Should decrease
344    }
345
346    #[test]
347    fn test_temperature_performance_score() {
348        let mut perf = TemperaturePerformance::new();
349        perf.success_rate = 0.8;
350        perf.avg_quality = 0.7;
351
352        let score = perf.score();
353        assert!((score - 0.76).abs() < 0.01); // 0.6 * 0.8 + 0.4 * 0.7 = 0.76
354    }
355
356    #[test]
357    fn test_default_temperature_heuristics() {
358        let optimizer = TemperatureOptimizer::new();
359
360        // Logic task: Low temperature
361        let logic_cluster = TaskCluster::new(
362            "logic_task".to_string(),
363            "Boolean logic and reasoning puzzles".to_string(),
364            vec![0.5; 768],
365            vec![PromptingTechnique::LogicOfThought],
366            vec![],
367        );
368        assert_eq!(optimizer.get_default_temperature(&logic_cluster), 0.0);
369
370        // Creative task: High temperature
371        let creative_cluster = TaskCluster::new(
372            "creative_task".to_string(),
373            "Creative writing and story generation".to_string(),
374            vec![0.5; 768],
375            vec![PromptingTechnique::RolePlaying],
376            vec![],
377        );
378        assert_eq!(optimizer.get_default_temperature(&creative_cluster), 1.3);
379
380        // Code task: Moderate temperature
381        let code_cluster = TaskCluster::new(
382            "code_task".to_string(),
383            "Code implementation and algorithm design".to_string(),
384            vec![0.5; 768],
385            vec![PromptingTechnique::PlanAndSolve],
386            vec![],
387        );
388        assert_eq!(optimizer.get_default_temperature(&code_cluster), 0.6);
389    }
390
391    #[test]
392    fn test_record_and_retrieve_local_optimal() {
393        let mut optimizer = TemperatureOptimizer::new();
394
395        // Record outcomes for different temperatures
396        for _ in 0..10 {
397            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.9);
398            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.6, false, 0.5);
399        }
400
401        // Get optimal (should be 0.0 due to high success rate)
402        let optimal = optimizer.get_local_optimal("test_cluster");
403        assert_eq!(optimal, Some(0.0));
404    }
405
406    #[test]
407    fn test_min_samples_requirement() {
408        let mut optimizer = TemperatureOptimizer::new().with_min_samples(5);
409
410        // Record only 3 samples
411        for _ in 0..3 {
412            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
413        }
414
415        // Should not return optimal (not enough samples)
416        assert_eq!(optimizer.get_local_optimal("test_cluster"), None);
417
418        // Add 2 more samples
419        for _ in 0..2 {
420            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
421        }
422
423        // Now should return optimal
424        assert_eq!(optimizer.get_local_optimal("test_cluster"), Some(0.0));
425    }
426
427    #[tokio::test]
428    async fn test_get_optimal_temperature_fallback() {
429        let optimizer = TemperatureOptimizer::new();
430
431        // No BKS, no local data → should use heuristic
432        let cluster = TaskCluster::new(
433            "logic_test".to_string(),
434            "Boolean logic problems".to_string(),
435            vec![0.5; 768],
436            vec![PromptingTechnique::LogicOfThought],
437            vec![],
438        );
439
440        let temp = optimizer.get_optimal_temperature(&cluster).await;
441        assert_eq!(temp, 0.0); // Heuristic for logic tasks
442    }
443}