Skip to main content

brainwires_prompting/
temperature.rs

1//! Temperature Optimization
2//!
3//! This module provides adaptive temperature selection per task cluster,
4//! based on the paper's findings:
5//! - Low temp (0.0): Best for logical tasks (Zebra Puzzles, Web of Lies, Boolean Expressions)
6//! - High temp (1.3): Best for linguistic tasks (Hyperbaton - adjective order judgment)
7//!
8//! Temperature performance is tracked per cluster and can be shared via BKS/PKS.
9
10use super::clustering::TaskCluster;
11use anyhow::Result;
12use brainwires_knowledge::knowledge::bks_pks::{
13    BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
14};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::Mutex;
19
20/// Tracks performance metrics for a specific temperature setting
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct TemperaturePerformance {
23    /// Success rate (0.0-1.0) using EMA
24    pub success_rate: f32,
25    /// Average quality score (0.0-1.0) using EMA
26    pub avg_quality: f32,
27    /// Number of samples collected
28    pub sample_count: u32,
29    /// Last updated timestamp
30    pub last_updated: i64,
31}
32
33impl TemperaturePerformance {
34    /// Create a new performance record with neutral defaults.
35    pub fn new() -> Self {
36        Self {
37            success_rate: 0.5, // Neutral starting point
38            avg_quality: 0.5,
39            sample_count: 0,
40            last_updated: chrono::Utc::now().timestamp(),
41        }
42    }
43
44    /// Update metrics with new outcome using EMA (alpha = 0.3)
45    pub fn update(&mut self, success: bool, quality: f32) {
46        let alpha = 0.3;
47        self.success_rate =
48            alpha * (if success { 1.0 } else { 0.0 }) + (1.0 - alpha) * self.success_rate;
49        self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
50        self.sample_count += 1;
51        self.last_updated = chrono::Utc::now().timestamp();
52    }
53
54    /// Combined score for ranking (60% success rate, 40% quality)
55    pub fn score(&self) -> f32 {
56        0.6 * self.success_rate + 0.4 * self.avg_quality
57    }
58}
59
60impl Default for TemperaturePerformance {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66/// Manages adaptive temperature selection per task cluster
67pub struct TemperatureOptimizer {
68    /// Maps (cluster_id, temperature_int) → performance stats
69    /// Temperature stored as i32 (multiply by 10: 0.0 → 0, 0.2 → 2, 1.3 → 13)
70    performance_map: HashMap<(String, i32), TemperaturePerformance>,
71    /// BKS cache for querying shared temperature preferences
72    bks_cache: Option<Arc<Mutex<BehavioralKnowledgeCache>>>,
73    /// Candidate temperatures to test (from paper)
74    candidates: Vec<f32>,
75    /// Minimum samples before trusting a temperature setting
76    min_samples: u32,
77}
78
79impl TemperatureOptimizer {
80    /// Create a new temperature optimizer
81    pub fn new() -> Self {
82        Self {
83            performance_map: HashMap::new(),
84            bks_cache: None,
85            candidates: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.3],
86            min_samples: 5,
87        }
88    }
89
90    /// Convert temperature f32 to i32 for HashMap key
91    fn temp_to_key(temp: f32) -> i32 {
92        (temp * 10.0).round() as i32
93    }
94
95    /// Set BKS cache for querying shared temperature knowledge
96    pub fn with_bks(mut self, bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
97        self.bks_cache = Some(bks_cache);
98        self
99    }
100
101    /// Set minimum samples required before trusting a temperature
102    pub fn with_min_samples(mut self, min_samples: u32) -> Self {
103        self.min_samples = min_samples;
104        self
105    }
106
107    /// Get optimal temperature for a cluster
108    ///
109    /// Selection order:
110    /// 1. BKS shared knowledge (if available)
111    /// 2. Local learned temperature (if enough samples)
112    /// 3. Default heuristic based on cluster characteristics
113    pub async fn get_optimal_temperature(&self, cluster: &TaskCluster) -> f32 {
114        // Source 1: BKS shared knowledge
115        if let Some(bks_temp) = self.query_bks_temperature(&cluster.id).await {
116            return bks_temp;
117        }
118
119        // Source 2: Local learned temperature
120        if let Some(local_temp) = self.get_local_optimal(&cluster.id) {
121            return local_temp;
122        }
123
124        // Source 3: Default heuristic based on cluster characteristics
125        self.get_default_temperature(cluster)
126    }
127
128    /// Get locally learned optimal temperature
129    fn get_local_optimal(&self, cluster_id: &str) -> Option<f32> {
130        let mut best_temp = None;
131        let mut best_score = f32::NEG_INFINITY;
132
133        for &temp in &self.candidates {
134            let temp_key = Self::temp_to_key(temp);
135            if let Some(perf) = self
136                .performance_map
137                .get(&(cluster_id.to_string(), temp_key))
138                && perf.sample_count >= self.min_samples
139            {
140                let score = perf.score();
141                if score > best_score {
142                    best_score = score;
143                    best_temp = Some(temp);
144                }
145            }
146        }
147
148        best_temp
149    }
150
151    /// Query BKS for shared temperature knowledge
152    async fn query_bks_temperature(&self, cluster_id: &str) -> Option<f32> {
153        if let Some(ref bks_cache) = self.bks_cache {
154            let bks = bks_cache.lock().await;
155
156            // Query for temperature truths for this cluster
157            // get_matching_truths takes just a context string
158            let truths = bks.get_matching_truths(cluster_id);
159
160            // Parse temperature from truth content
161            // Example: "For numerical_reasoning, use temperature 0.0 for optimal results"
162            for truth in truths {
163                // Filter for TaskStrategy category
164                if truth.category == TruthCategory::TaskStrategy
165                    && let Some(temp) = self.parse_temperature_from_truth(truth)
166                {
167                    return Some(temp);
168                }
169            }
170        }
171
172        None
173    }
174
175    /// Parse temperature value from BKS truth
176    fn parse_temperature_from_truth(&self, truth: &BehavioralTruth) -> Option<f32> {
177        // Look for "temperature X.X" pattern in rule or rationale
178        let text = format!("{} {}", truth.rule, truth.rationale);
179
180        // Simple regex-like parsing
181        if let Some(idx) = text.find("temperature") {
182            let substr = &text[idx..];
183            // Find first number after "temperature"
184            let parts: Vec<&str> = substr.split_whitespace().collect();
185            for part in parts.iter().skip(1) {
186                if let Ok(temp) = part.parse::<f32>()
187                    && self.candidates.contains(&temp)
188                {
189                    return Some(temp);
190                }
191            }
192        }
193
194        None
195    }
196
197    /// Get default temperature based on cluster characteristics (heuristic)
198    fn get_default_temperature(&self, cluster: &TaskCluster) -> f32 {
199        let desc = cluster.description.to_lowercase();
200
201        // Logic/reasoning tasks: Low temperature (0.0)
202        if desc.contains("logic")
203            || desc.contains("boolean")
204            || desc.contains("reasoning")
205            || desc.contains("puzzle")
206            || desc.contains("deduction")
207        {
208            return 0.0;
209        }
210
211        // Creative/linguistic tasks: High temperature (1.3)
212        if desc.contains("creative")
213            || desc.contains("linguistic")
214            || desc.contains("story")
215            || desc.contains("writing")
216            || desc.contains("generation")
217        {
218            return 1.3;
219        }
220
221        // Numerical/calculation tasks: Low temperature (0.2)
222        if desc.contains("numerical")
223            || desc.contains("calculation")
224            || desc.contains("math")
225            || desc.contains("arithmetic")
226        {
227            return 0.2;
228        }
229
230        // Code generation: Moderate temperature (0.6)
231        if desc.contains("code")
232            || desc.contains("programming")
233            || desc.contains("implementation")
234            || desc.contains("algorithm")
235        {
236            return 0.6;
237        }
238
239        // Default: Moderate temperature
240        0.7
241    }
242
243    /// Record outcome for a temperature setting
244    pub fn record_temperature_outcome(
245        &mut self,
246        cluster_id: String,
247        temperature: f32,
248        success: bool,
249        quality: f32,
250    ) {
251        let temp_key = Self::temp_to_key(temperature);
252        let key = (cluster_id, temp_key);
253        let perf = self.performance_map.entry(key).or_default();
254
255        perf.update(success, quality);
256    }
257
258    /// Check if temperature should be promoted to BKS
259    pub async fn check_and_promote_temperature(
260        &self,
261        cluster_id: &str,
262        temperature: f32,
263        min_score: f32,
264        min_samples: u32,
265    ) -> Result<()> {
266        let temp_key = Self::temp_to_key(temperature);
267        let key = (cluster_id.to_string(), temp_key);
268
269        if let Some(perf) = self.performance_map.get(&key)
270            && perf.sample_count >= min_samples
271            && perf.score() >= min_score
272        {
273            // Promote to BKS
274            if let Some(ref bks_cache) = self.bks_cache {
275                let truth = BehavioralTruth::new(
276                    TruthCategory::TaskStrategy,
277                    cluster_id.to_string(),
278                    format!(
279                        "For {} tasks, use temperature {} for optimal results",
280                        cluster_id, temperature
281                    ),
282                    format!(
283                        "Learned from {} executions with {:.1}% success rate and {:.2} avg quality",
284                        perf.sample_count,
285                        perf.success_rate * 100.0,
286                        perf.avg_quality
287                    ),
288                    TruthSource::SuccessPattern,
289                    None,
290                );
291
292                let mut bks = bks_cache.lock().await;
293                bks.queue_submission(truth)?;
294            }
295        }
296
297        Ok(())
298    }
299
300    /// Get all performance data (for debugging/inspection)
301    pub fn get_all_performance(&self) -> &HashMap<(String, i32), TemperaturePerformance> {
302        &self.performance_map
303    }
304
305    /// Get performance for a specific cluster and temperature
306    pub fn get_performance(
307        &self,
308        cluster_id: &str,
309        temperature: f32,
310    ) -> Option<&TemperaturePerformance> {
311        let temp_key = Self::temp_to_key(temperature);
312        self.performance_map
313            .get(&(cluster_id.to_string(), temp_key))
314    }
315}
316
317impl Default for TemperatureOptimizer {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::techniques::PromptingTechnique;
327
328    #[test]
329    fn test_temperature_performance_update() {
330        let mut perf = TemperaturePerformance::new();
331        assert_eq!(perf.success_rate, 0.5);
332        assert_eq!(perf.sample_count, 0);
333
334        // Record success
335        perf.update(true, 0.9);
336        assert!(perf.success_rate > 0.5); // Should increase
337        assert_eq!(perf.sample_count, 1);
338
339        // Record failure
340        perf.update(false, 0.3);
341        assert_eq!(perf.sample_count, 2);
342        assert!(perf.avg_quality < 0.9); // Should decrease
343    }
344
345    #[test]
346    fn test_temperature_performance_score() {
347        let mut perf = TemperaturePerformance::new();
348        perf.success_rate = 0.8;
349        perf.avg_quality = 0.7;
350
351        let score = perf.score();
352        assert!((score - 0.76).abs() < 0.01); // 0.6 * 0.8 + 0.4 * 0.7 = 0.76
353    }
354
355    #[test]
356    fn test_default_temperature_heuristics() {
357        let optimizer = TemperatureOptimizer::new();
358
359        // Logic task: Low temperature
360        let logic_cluster = TaskCluster::new(
361            "logic_task".to_string(),
362            "Boolean logic and reasoning puzzles".to_string(),
363            vec![0.5; 768],
364            vec![PromptingTechnique::LogicOfThought],
365            vec![],
366        );
367        assert_eq!(optimizer.get_default_temperature(&logic_cluster), 0.0);
368
369        // Creative task: High temperature
370        let creative_cluster = TaskCluster::new(
371            "creative_task".to_string(),
372            "Creative writing and story generation".to_string(),
373            vec![0.5; 768],
374            vec![PromptingTechnique::RolePlaying],
375            vec![],
376        );
377        assert_eq!(optimizer.get_default_temperature(&creative_cluster), 1.3);
378
379        // Code task: Moderate temperature
380        let code_cluster = TaskCluster::new(
381            "code_task".to_string(),
382            "Code implementation and algorithm design".to_string(),
383            vec![0.5; 768],
384            vec![PromptingTechnique::PlanAndSolve],
385            vec![],
386        );
387        assert_eq!(optimizer.get_default_temperature(&code_cluster), 0.6);
388    }
389
390    #[test]
391    fn test_record_and_retrieve_local_optimal() {
392        let mut optimizer = TemperatureOptimizer::new();
393
394        // Record outcomes for different temperatures
395        for _ in 0..10 {
396            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.9);
397            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.6, false, 0.5);
398        }
399
400        // Get optimal (should be 0.0 due to high success rate)
401        let optimal = optimizer.get_local_optimal("test_cluster");
402        assert_eq!(optimal, Some(0.0));
403    }
404
405    #[test]
406    fn test_min_samples_requirement() {
407        let mut optimizer = TemperatureOptimizer::new().with_min_samples(5);
408
409        // Record only 3 samples
410        for _ in 0..3 {
411            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
412        }
413
414        // Should not return optimal (not enough samples)
415        assert_eq!(optimizer.get_local_optimal("test_cluster"), None);
416
417        // Add 2 more samples
418        for _ in 0..2 {
419            optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
420        }
421
422        // Now should return optimal
423        assert_eq!(optimizer.get_local_optimal("test_cluster"), Some(0.0));
424    }
425
426    #[tokio::test]
427    async fn test_get_optimal_temperature_fallback() {
428        let optimizer = TemperatureOptimizer::new();
429
430        // No BKS, no local data → should use heuristic
431        let cluster = TaskCluster::new(
432            "logic_test".to_string(),
433            "Boolean logic problems".to_string(),
434            vec![0.5; 768],
435            vec![PromptingTechnique::LogicOfThought],
436            vec![],
437        );
438
439        let temp = optimizer.get_optimal_temperature(&cluster).await;
440        assert_eq!(temp, 0.0); // Heuristic for logic tasks
441    }
442}