Skip to main content

attuned_infer/
bayesian.rs

1//! Bayesian state estimation for axis values.
2//!
3//! This module implements principled uncertainty tracking using
4//! Bayesian updating. Each axis is modeled as a latent variable
5//! with a posterior distribution that updates with each observation.
6//!
7//! Key properties:
8//! - Single observations can't swing estimates too wildly
9//! - Uncertainty is explicit and quantified
10//! - Self-report sets variance to near-zero (nuclear override)
11//! - Old inferences decay naturally
12
13use crate::estimate::{AxisEstimate, InferenceSource, MAX_INFERRED_CONFIDENCE};
14use chrono::{DateTime, Utc};
15use serde::{Deserialize, Serialize};
16
17/// Prior distribution for an axis.
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct Prior {
20    /// Prior mean (expected value before observation).
21    pub mean: f32,
22    /// Prior variance (uncertainty before observation).
23    pub variance: f32,
24    /// Human-readable reason for this prior.
25    pub reason: String,
26}
27
28impl Prior {
29    /// Create a neutral prior (0.5 with high uncertainty).
30    pub fn neutral() -> Self {
31        Self {
32            mean: 0.5,
33            variance: 0.25, // High uncertainty
34            reason: "neutral default".to_string(),
35        }
36    }
37
38    /// Create a prior from a specific value with given confidence.
39    pub fn from_value(value: f32, confidence: f32, reason: impl Into<String>) -> Self {
40        // Map confidence to variance: high confidence → low variance
41        let variance = (1.0 - confidence).powi(2) * 0.25 + 0.01;
42        Self {
43            mean: value.clamp(0.0, 1.0),
44            variance,
45            reason: reason.into(),
46        }
47    }
48
49    /// Create a prior biased toward low values.
50    pub fn biased_low(reason: impl Into<String>) -> Self {
51        Self {
52            mean: 0.3,
53            variance: 0.15,
54            reason: reason.into(),
55        }
56    }
57
58    /// Create a prior biased toward high values.
59    pub fn biased_high(reason: impl Into<String>) -> Self {
60        Self {
61            mean: 0.7,
62            variance: 0.15,
63            reason: reason.into(),
64        }
65    }
66}
67
68impl Default for Prior {
69    fn default() -> Self {
70        Self::neutral()
71    }
72}
73
74/// An observation of an axis value.
75#[derive(Clone, Debug)]
76pub struct Observation {
77    /// Observed value.
78    pub value: f32,
79    /// Observation noise (measurement uncertainty).
80    pub noise_variance: f32,
81    /// Source of this observation.
82    pub source: InferenceSource,
83    /// When this observation was made.
84    pub timestamp: DateTime<Utc>,
85}
86
87impl Observation {
88    /// Create a new observation.
89    pub fn new(value: f32, noise_variance: f32, source: InferenceSource) -> Self {
90        Self {
91            value: value.clamp(0.0, 1.0),
92            noise_variance: noise_variance.max(0.001),
93            source,
94            timestamp: Utc::now(),
95        }
96    }
97
98    /// Create an observation from linguistic inference.
99    ///
100    /// Linguistic observations have moderate noise.
101    pub fn from_linguistic(value: f32, features_used: Vec<String>) -> Self {
102        Self::new(
103            value,
104            0.04, // Moderate noise
105            InferenceSource::Linguistic {
106                features_used,
107                feature_values: std::collections::HashMap::new(),
108            },
109        )
110    }
111
112    /// Create an observation from delta analysis.
113    pub fn from_delta(value: f32, z_score: f32, metric: String, baseline_messages: usize) -> Self {
114        // Higher |z_score| = more confidence = less noise
115        let noise = (0.1 / (1.0 + z_score.abs())).max(0.02);
116        Self::new(
117            value,
118            noise,
119            InferenceSource::Delta {
120                baseline_messages,
121                z_score,
122                metric,
123            },
124        )
125    }
126
127    /// Create a self-report observation (very low noise).
128    pub fn from_self_report(value: f32) -> Self {
129        Self::new(value, 0.001, InferenceSource::SelfReport)
130    }
131}
132
133/// Configuration for the Bayesian updater.
134#[derive(Clone, Debug)]
135pub struct BayesianConfig {
136    /// Maximum update per observation (prevents wild swings).
137    pub max_update: f32,
138    /// Minimum variance (prevents overconfidence).
139    pub min_variance: f32,
140    /// Variance added per second without observation (uncertainty grows).
141    pub variance_growth_rate: f32,
142    /// Maximum confidence for inferred values.
143    pub max_inferred_confidence: f32,
144}
145
146impl Default for BayesianConfig {
147    fn default() -> Self {
148        Self {
149            max_update: 0.3,              // Max 0.3 shift per observation
150            min_variance: 0.001,          // Never fully certain
151            variance_growth_rate: 0.0001, // Slow uncertainty growth
152            max_inferred_confidence: MAX_INFERRED_CONFIDENCE,
153        }
154    }
155}
156
157/// Bayesian state updater for a single axis.
158///
159/// Maintains posterior distribution (mean, variance) and updates
160/// with each observation using standard Bayesian updating.
161#[derive(Clone, Debug, Default)]
162pub struct BayesianUpdater {
163    config: BayesianConfig,
164}
165
166impl BayesianUpdater {
167    /// Create a new updater with default configuration.
168    pub fn new() -> Self {
169        Self::default()
170    }
171
172    /// Create an updater with custom configuration.
173    pub fn with_config(config: BayesianConfig) -> Self {
174        Self { config }
175    }
176
177    /// Perform a Bayesian update given prior and observation.
178    ///
179    /// Returns the posterior distribution as an AxisEstimate.
180    pub fn update(&self, axis: &str, prior: &Prior, observation: &Observation) -> AxisEstimate {
181        // Special case: self-report is authoritative
182        if observation.source.is_self_report() {
183            return AxisEstimate::self_report(axis, observation.value);
184        }
185
186        // Standard Bayesian update for Gaussian:
187        // posterior_var = 1 / (1/prior_var + 1/obs_var)
188        // posterior_mean = posterior_var * (prior_mean/prior_var + obs_value/obs_var)
189
190        let prior_precision = 1.0 / prior.variance;
191        let obs_precision = 1.0 / observation.noise_variance;
192
193        let posterior_precision = prior_precision + obs_precision;
194        let posterior_variance = (1.0 / posterior_precision).max(self.config.min_variance);
195
196        let posterior_mean =
197            posterior_variance * (prior.mean * prior_precision + observation.value * obs_precision);
198
199        // Apply max update constraint
200        let clamped_mean = if (posterior_mean - prior.mean).abs() > self.config.max_update {
201            if posterior_mean > prior.mean {
202                prior.mean + self.config.max_update
203            } else {
204                prior.mean - self.config.max_update
205            }
206        } else {
207            posterior_mean
208        };
209
210        // Clamp to valid range
211        let final_mean = clamped_mean.clamp(0.0, 1.0);
212
213        // Convert variance to confidence
214        let confidence = AxisEstimate::variance_to_confidence(posterior_variance)
215            .min(self.config.max_inferred_confidence);
216
217        AxisEstimate {
218            axis: axis.to_string(),
219            value: final_mean,
220            confidence,
221            variance: posterior_variance,
222            source: observation.source.clone(),
223            timestamp: observation.timestamp,
224        }
225    }
226
227    /// Update an existing estimate with a new observation.
228    ///
229    /// The existing estimate serves as the prior.
230    pub fn update_estimate(
231        &self,
232        existing: &AxisEstimate,
233        observation: &Observation,
234    ) -> AxisEstimate {
235        // Special case: self-report always wins
236        if observation.source.is_self_report() {
237            return AxisEstimate::self_report(&existing.axis, observation.value);
238        }
239
240        // Can't override self-report with inference
241        if existing.source.is_self_report() {
242            return existing.clone();
243        }
244
245        // Use existing estimate as prior
246        let prior = Prior {
247            mean: existing.value,
248            variance: existing.variance,
249            reason: "previous estimate".to_string(),
250        };
251
252        self.update(&existing.axis, &prior, observation)
253    }
254
255    /// Grow uncertainty over time without observations.
256    ///
257    /// Variance increases linearly, representing growing uncertainty
258    /// about stale estimates.
259    pub fn grow_uncertainty(&self, estimate: &AxisEstimate, elapsed_seconds: f64) -> AxisEstimate {
260        if estimate.source.is_self_report() {
261            // Self-report doesn't decay (user's stated preference persists)
262            return estimate.clone();
263        }
264
265        let growth = self.config.variance_growth_rate * elapsed_seconds as f32;
266        let new_variance = (estimate.variance + growth).min(0.25); // Cap at neutral uncertainty
267
268        let new_confidence = AxisEstimate::variance_to_confidence(new_variance)
269            .min(self.config.max_inferred_confidence);
270
271        AxisEstimate {
272            axis: estimate.axis.clone(),
273            value: estimate.value,
274            confidence: new_confidence,
275            variance: new_variance,
276            source: InferenceSource::Decayed {
277                original: Box::new(estimate.source.clone()),
278                age_seconds: elapsed_seconds as u64,
279                decay_factor: estimate.variance / new_variance,
280            },
281            timestamp: estimate.timestamp,
282        }
283    }
284
285    /// Combine multiple observations into a single posterior.
286    ///
287    /// Useful when multiple signals about the same axis arrive at once.
288    pub fn combine_observations(
289        &self,
290        axis: &str,
291        prior: &Prior,
292        observations: &[Observation],
293    ) -> AxisEstimate {
294        if observations.is_empty() {
295            return AxisEstimate::prior(axis, prior.mean, 0.5, &prior.reason);
296        }
297
298        // Check for self-report (dominates everything)
299        if let Some(sr) = observations.iter().find(|o| o.source.is_self_report()) {
300            return AxisEstimate::self_report(axis, sr.value);
301        }
302
303        // Iteratively update with each observation
304        let mut current = AxisEstimate {
305            axis: axis.to_string(),
306            value: prior.mean,
307            confidence: AxisEstimate::variance_to_confidence(prior.variance),
308            variance: prior.variance,
309            source: InferenceSource::Prior {
310                reason: prior.reason.clone(),
311            },
312            timestamp: Utc::now(),
313        };
314
315        let sources: Vec<InferenceSource> = observations.iter().map(|o| o.source.clone()).collect();
316        let weights: Vec<f32> = observations
317            .iter()
318            .map(|o| 1.0 / o.noise_variance)
319            .collect();
320
321        for obs in observations {
322            current = self.update_estimate(&current, obs);
323        }
324
325        // Update source to reflect combination
326        AxisEstimate {
327            source: InferenceSource::Combined { sources, weights },
328            ..current
329        }
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_neutral_prior() {
339        let prior = Prior::neutral();
340        assert_eq!(prior.mean, 0.5);
341        assert!(prior.variance > 0.1); // High uncertainty
342    }
343
344    #[test]
345    fn test_basic_update() {
346        let updater = BayesianUpdater::new();
347        let prior = Prior::neutral();
348        let obs = Observation::from_linguistic(0.8, vec!["warmth".into()]);
349
350        let posterior = updater.update("warmth", &prior, &obs);
351
352        // Should move toward observation
353        assert!(posterior.value > prior.mean);
354        assert!(posterior.value < 0.8); // But not all the way
355                                        // Variance should decrease
356        assert!(posterior.variance < prior.variance);
357    }
358
359    #[test]
360    fn test_self_report_dominates() {
361        let updater = BayesianUpdater::new();
362        let prior = Prior::from_value(0.2, 0.8, "strong belief in low value");
363        let obs = Observation::from_self_report(0.9);
364
365        let posterior = updater.update("warmth", &prior, &obs);
366
367        assert_eq!(posterior.value, 0.9);
368        assert_eq!(posterior.confidence, 1.0);
369    }
370
371    #[test]
372    fn test_max_update_constraint() {
373        let updater = BayesianUpdater::with_config(BayesianConfig {
374            max_update: 0.1,
375            ..Default::default()
376        });
377
378        let prior = Prior::from_value(0.2, 0.5, "prior");
379        let obs = Observation::from_linguistic(0.9, vec![]); // Big jump
380
381        let posterior = updater.update("warmth", &prior, &obs);
382
383        // Should be constrained to max_update
384        assert!(posterior.value <= 0.3 + 0.01); // 0.2 + 0.1 with epsilon
385    }
386
387    #[test]
388    fn test_cannot_override_self_report() {
389        let updater = BayesianUpdater::new();
390
391        let self_report = AxisEstimate::self_report("warmth", 0.9);
392        let obs = Observation::from_linguistic(0.2, vec![]);
393
394        let result = updater.update_estimate(&self_report, &obs);
395
396        // Self-report should persist
397        assert_eq!(result.value, 0.9);
398        assert!(result.source.is_self_report());
399    }
400
401    #[test]
402    fn test_confidence_capping() {
403        let updater = BayesianUpdater::new();
404        let prior = Prior::neutral();
405
406        // Many confident observations
407        let obs = Observation::new(
408            0.8,
409            0.001,
410            InferenceSource::Linguistic {
411                features_used: vec![],
412                feature_values: std::collections::HashMap::new(),
413            },
414        );
415
416        let posterior = updater.update("warmth", &prior, &obs);
417
418        // Should still be capped
419        assert!(posterior.confidence <= MAX_INFERRED_CONFIDENCE);
420    }
421
422    #[test]
423    fn test_uncertainty_growth() {
424        let updater = BayesianUpdater::new();
425
426        let estimate = AxisEstimate::inferred(
427            "warmth",
428            0.7,
429            0.6,
430            InferenceSource::Linguistic {
431                features_used: vec![],
432                feature_values: std::collections::HashMap::new(),
433            },
434        );
435
436        let aged = updater.grow_uncertainty(&estimate, 3600.0); // 1 hour
437
438        assert!(aged.variance > estimate.variance);
439        assert!(aged.confidence < estimate.confidence);
440    }
441
442    #[test]
443    fn test_combine_observations() {
444        let updater = BayesianUpdater::new();
445        let prior = Prior::neutral();
446
447        let observations = vec![
448            Observation::from_linguistic(0.7, vec!["feat1".into()]),
449            Observation::from_linguistic(0.8, vec!["feat2".into()]),
450        ];
451
452        let combined = updater.combine_observations("warmth", &prior, &observations);
453
454        // Should be somewhere between prior and observations
455        assert!(combined.value > 0.5);
456        // Should have combined source
457        assert!(matches!(combined.source, InferenceSource::Combined { .. }));
458    }
459}