brainwires_cognition/prompting/
temperature.rs1use super::clustering::TaskCluster;
11#[cfg(feature = "knowledge")]
12use crate::knowledge::bks_pks::{
13 BehavioralKnowledgeCache, BehavioralTruth, TruthCategory, TruthSource,
14};
15use anyhow::Result;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TemperaturePerformance {
24 pub success_rate: f32,
26 pub avg_quality: f32,
28 pub sample_count: u32,
30 pub last_updated: i64,
32}
33
34impl TemperaturePerformance {
35 pub fn new() -> Self {
37 Self {
38 success_rate: 0.5, avg_quality: 0.5,
40 sample_count: 0,
41 last_updated: chrono::Utc::now().timestamp(),
42 }
43 }
44
45 pub fn update(&mut self, success: bool, quality: f32) {
47 let alpha = 0.3;
48 self.success_rate =
49 alpha * (if success { 1.0 } else { 0.0 }) + (1.0 - alpha) * self.success_rate;
50 self.avg_quality = alpha * quality + (1.0 - alpha) * self.avg_quality;
51 self.sample_count += 1;
52 self.last_updated = chrono::Utc::now().timestamp();
53 }
54
55 pub fn score(&self) -> f32 {
57 0.6 * self.success_rate + 0.4 * self.avg_quality
58 }
59}
60
61impl Default for TemperaturePerformance {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67pub struct TemperatureOptimizer {
69 performance_map: HashMap<(String, i32), TemperaturePerformance>,
72 bks_cache: Option<Arc<Mutex<BehavioralKnowledgeCache>>>,
74 candidates: Vec<f32>,
76 min_samples: u32,
78}
79
80impl TemperatureOptimizer {
81 pub fn new() -> Self {
83 Self {
84 performance_map: HashMap::new(),
85 bks_cache: None,
86 candidates: vec![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.3],
87 min_samples: 5,
88 }
89 }
90
91 fn temp_to_key(temp: f32) -> i32 {
93 (temp * 10.0).round() as i32
94 }
95
96 pub fn with_bks(mut self, bks_cache: Arc<Mutex<BehavioralKnowledgeCache>>) -> Self {
98 self.bks_cache = Some(bks_cache);
99 self
100 }
101
102 pub fn with_min_samples(mut self, min_samples: u32) -> Self {
104 self.min_samples = min_samples;
105 self
106 }
107
108 pub async fn get_optimal_temperature(&self, cluster: &TaskCluster) -> f32 {
115 if let Some(bks_temp) = self.query_bks_temperature(&cluster.id).await {
117 return bks_temp;
118 }
119
120 if let Some(local_temp) = self.get_local_optimal(&cluster.id) {
122 return local_temp;
123 }
124
125 self.get_default_temperature(cluster)
127 }
128
129 fn get_local_optimal(&self, cluster_id: &str) -> Option<f32> {
131 let mut best_temp = None;
132 let mut best_score = f32::NEG_INFINITY;
133
134 for &temp in &self.candidates {
135 let temp_key = Self::temp_to_key(temp);
136 if let Some(perf) = self
137 .performance_map
138 .get(&(cluster_id.to_string(), temp_key))
139 && perf.sample_count >= self.min_samples
140 {
141 let score = perf.score();
142 if score > best_score {
143 best_score = score;
144 best_temp = Some(temp);
145 }
146 }
147 }
148
149 best_temp
150 }
151
152 async fn query_bks_temperature(&self, cluster_id: &str) -> Option<f32> {
154 if let Some(ref bks_cache) = self.bks_cache {
155 let bks = bks_cache.lock().await;
156
157 let truths = bks.get_matching_truths(cluster_id);
160
161 for truth in truths {
164 if truth.category == TruthCategory::TaskStrategy
166 && let Some(temp) = self.parse_temperature_from_truth(truth)
167 {
168 return Some(temp);
169 }
170 }
171 }
172
173 None
174 }
175
176 fn parse_temperature_from_truth(&self, truth: &BehavioralTruth) -> Option<f32> {
178 let text = format!("{} {}", truth.rule, truth.rationale);
180
181 if let Some(idx) = text.find("temperature") {
183 let substr = &text[idx..];
184 let parts: Vec<&str> = substr.split_whitespace().collect();
186 for part in parts.iter().skip(1) {
187 if let Ok(temp) = part.parse::<f32>()
188 && self.candidates.contains(&temp)
189 {
190 return Some(temp);
191 }
192 }
193 }
194
195 None
196 }
197
198 fn get_default_temperature(&self, cluster: &TaskCluster) -> f32 {
200 let desc = cluster.description.to_lowercase();
201
202 if desc.contains("logic")
204 || desc.contains("boolean")
205 || desc.contains("reasoning")
206 || desc.contains("puzzle")
207 || desc.contains("deduction")
208 {
209 return 0.0;
210 }
211
212 if desc.contains("creative")
214 || desc.contains("linguistic")
215 || desc.contains("story")
216 || desc.contains("writing")
217 || desc.contains("generation")
218 {
219 return 1.3;
220 }
221
222 if desc.contains("numerical")
224 || desc.contains("calculation")
225 || desc.contains("math")
226 || desc.contains("arithmetic")
227 {
228 return 0.2;
229 }
230
231 if desc.contains("code")
233 || desc.contains("programming")
234 || desc.contains("implementation")
235 || desc.contains("algorithm")
236 {
237 return 0.6;
238 }
239
240 0.7
242 }
243
244 pub fn record_temperature_outcome(
246 &mut self,
247 cluster_id: String,
248 temperature: f32,
249 success: bool,
250 quality: f32,
251 ) {
252 let temp_key = Self::temp_to_key(temperature);
253 let key = (cluster_id, temp_key);
254 let perf = self.performance_map.entry(key).or_default();
255
256 perf.update(success, quality);
257 }
258
259 pub async fn check_and_promote_temperature(
261 &self,
262 cluster_id: &str,
263 temperature: f32,
264 min_score: f32,
265 min_samples: u32,
266 ) -> Result<()> {
267 let temp_key = Self::temp_to_key(temperature);
268 let key = (cluster_id.to_string(), temp_key);
269
270 if let Some(perf) = self.performance_map.get(&key)
271 && perf.sample_count >= min_samples
272 && perf.score() >= min_score
273 {
274 if let Some(ref bks_cache) = self.bks_cache {
276 let truth = BehavioralTruth::new(
277 TruthCategory::TaskStrategy,
278 cluster_id.to_string(),
279 format!(
280 "For {} tasks, use temperature {} for optimal results",
281 cluster_id, temperature
282 ),
283 format!(
284 "Learned from {} executions with {:.1}% success rate and {:.2} avg quality",
285 perf.sample_count,
286 perf.success_rate * 100.0,
287 perf.avg_quality
288 ),
289 TruthSource::SuccessPattern,
290 None,
291 );
292
293 let mut bks = bks_cache.lock().await;
294 bks.queue_submission(truth)?;
295 }
296 }
297
298 Ok(())
299 }
300
301 pub fn get_all_performance(&self) -> &HashMap<(String, i32), TemperaturePerformance> {
303 &self.performance_map
304 }
305
306 pub fn get_performance(
308 &self,
309 cluster_id: &str,
310 temperature: f32,
311 ) -> Option<&TemperaturePerformance> {
312 let temp_key = Self::temp_to_key(temperature);
313 self.performance_map
314 .get(&(cluster_id.to_string(), temp_key))
315 }
316}
317
318impl Default for TemperatureOptimizer {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::prompting::techniques::PromptingTechnique;
328
329 #[test]
330 fn test_temperature_performance_update() {
331 let mut perf = TemperaturePerformance::new();
332 assert_eq!(perf.success_rate, 0.5);
333 assert_eq!(perf.sample_count, 0);
334
335 perf.update(true, 0.9);
337 assert!(perf.success_rate > 0.5); assert_eq!(perf.sample_count, 1);
339
340 perf.update(false, 0.3);
342 assert_eq!(perf.sample_count, 2);
343 assert!(perf.avg_quality < 0.9); }
345
346 #[test]
347 fn test_temperature_performance_score() {
348 let mut perf = TemperaturePerformance::new();
349 perf.success_rate = 0.8;
350 perf.avg_quality = 0.7;
351
352 let score = perf.score();
353 assert!((score - 0.76).abs() < 0.01); }
355
356 #[test]
357 fn test_default_temperature_heuristics() {
358 let optimizer = TemperatureOptimizer::new();
359
360 let logic_cluster = TaskCluster::new(
362 "logic_task".to_string(),
363 "Boolean logic and reasoning puzzles".to_string(),
364 vec![0.5; 768],
365 vec![PromptingTechnique::LogicOfThought],
366 vec![],
367 );
368 assert_eq!(optimizer.get_default_temperature(&logic_cluster), 0.0);
369
370 let creative_cluster = TaskCluster::new(
372 "creative_task".to_string(),
373 "Creative writing and story generation".to_string(),
374 vec![0.5; 768],
375 vec![PromptingTechnique::RolePlaying],
376 vec![],
377 );
378 assert_eq!(optimizer.get_default_temperature(&creative_cluster), 1.3);
379
380 let code_cluster = TaskCluster::new(
382 "code_task".to_string(),
383 "Code implementation and algorithm design".to_string(),
384 vec![0.5; 768],
385 vec![PromptingTechnique::PlanAndSolve],
386 vec![],
387 );
388 assert_eq!(optimizer.get_default_temperature(&code_cluster), 0.6);
389 }
390
391 #[test]
392 fn test_record_and_retrieve_local_optimal() {
393 let mut optimizer = TemperatureOptimizer::new();
394
395 for _ in 0..10 {
397 optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.9);
398 optimizer.record_temperature_outcome("test_cluster".to_string(), 0.6, false, 0.5);
399 }
400
401 let optimal = optimizer.get_local_optimal("test_cluster");
403 assert_eq!(optimal, Some(0.0));
404 }
405
406 #[test]
407 fn test_min_samples_requirement() {
408 let mut optimizer = TemperatureOptimizer::new().with_min_samples(5);
409
410 for _ in 0..3 {
412 optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
413 }
414
415 assert_eq!(optimizer.get_local_optimal("test_cluster"), None);
417
418 for _ in 0..2 {
420 optimizer.record_temperature_outcome("test_cluster".to_string(), 0.0, true, 0.95);
421 }
422
423 assert_eq!(optimizer.get_local_optimal("test_cluster"), Some(0.0));
425 }
426
427 #[tokio::test]
428 async fn test_get_optimal_temperature_fallback() {
429 let optimizer = TemperatureOptimizer::new();
430
431 let cluster = TaskCluster::new(
433 "logic_test".to_string(),
434 "Boolean logic problems".to_string(),
435 vec![0.5; 768],
436 vec![PromptingTechnique::LogicOfThought],
437 vec![],
438 );
439
440 let temp = optimizer.get_optimal_temperature(&cluster).await;
441 assert_eq!(temp, 0.0); }
443}