use crate::estimate::{AxisEstimate, InferenceSource, MAX_INFERRED_CONFIDENCE};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Prior {
pub mean: f32,
pub variance: f32,
pub reason: String,
}
impl Prior {
pub fn neutral() -> Self {
Self {
mean: 0.5,
variance: 0.25, reason: "neutral default".to_string(),
}
}
pub fn from_value(value: f32, confidence: f32, reason: impl Into<String>) -> Self {
let variance = (1.0 - confidence).powi(2) * 0.25 + 0.01;
Self {
mean: value.clamp(0.0, 1.0),
variance,
reason: reason.into(),
}
}
pub fn biased_low(reason: impl Into<String>) -> Self {
Self {
mean: 0.3,
variance: 0.15,
reason: reason.into(),
}
}
pub fn biased_high(reason: impl Into<String>) -> Self {
Self {
mean: 0.7,
variance: 0.15,
reason: reason.into(),
}
}
}
impl Default for Prior {
fn default() -> Self {
Self::neutral()
}
}
#[derive(Clone, Debug)]
pub struct Observation {
pub value: f32,
pub noise_variance: f32,
pub source: InferenceSource,
pub timestamp: DateTime<Utc>,
}
impl Observation {
pub fn new(value: f32, noise_variance: f32, source: InferenceSource) -> Self {
Self {
value: value.clamp(0.0, 1.0),
noise_variance: noise_variance.max(0.001),
source,
timestamp: Utc::now(),
}
}
pub fn from_linguistic(value: f32, features_used: Vec<String>) -> Self {
Self::new(
value,
0.04, InferenceSource::Linguistic {
features_used,
feature_values: std::collections::HashMap::new(),
},
)
}
pub fn from_delta(value: f32, z_score: f32, metric: String, baseline_messages: usize) -> Self {
let noise = (0.1 / (1.0 + z_score.abs())).max(0.02);
Self::new(
value,
noise,
InferenceSource::Delta {
baseline_messages,
z_score,
metric,
},
)
}
pub fn from_self_report(value: f32) -> Self {
Self::new(value, 0.001, InferenceSource::SelfReport)
}
}
#[derive(Clone, Debug)]
pub struct BayesianConfig {
pub max_update: f32,
pub min_variance: f32,
pub variance_growth_rate: f32,
pub max_inferred_confidence: f32,
}
impl Default for BayesianConfig {
fn default() -> Self {
Self {
max_update: 0.3, min_variance: 0.001, variance_growth_rate: 0.0001, max_inferred_confidence: MAX_INFERRED_CONFIDENCE,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct BayesianUpdater {
config: BayesianConfig,
}
impl BayesianUpdater {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: BayesianConfig) -> Self {
Self { config }
}
pub fn update(&self, axis: &str, prior: &Prior, observation: &Observation) -> AxisEstimate {
if observation.source.is_self_report() {
return AxisEstimate::self_report(axis, observation.value);
}
let prior_precision = 1.0 / prior.variance;
let obs_precision = 1.0 / observation.noise_variance;
let posterior_precision = prior_precision + obs_precision;
let posterior_variance = (1.0 / posterior_precision).max(self.config.min_variance);
let posterior_mean =
posterior_variance * (prior.mean * prior_precision + observation.value * obs_precision);
let clamped_mean = if (posterior_mean - prior.mean).abs() > self.config.max_update {
if posterior_mean > prior.mean {
prior.mean + self.config.max_update
} else {
prior.mean - self.config.max_update
}
} else {
posterior_mean
};
let final_mean = clamped_mean.clamp(0.0, 1.0);
let confidence = AxisEstimate::variance_to_confidence(posterior_variance)
.min(self.config.max_inferred_confidence);
AxisEstimate {
axis: axis.to_string(),
value: final_mean,
confidence,
variance: posterior_variance,
source: observation.source.clone(),
timestamp: observation.timestamp,
}
}
pub fn update_estimate(
&self,
existing: &AxisEstimate,
observation: &Observation,
) -> AxisEstimate {
if observation.source.is_self_report() {
return AxisEstimate::self_report(&existing.axis, observation.value);
}
if existing.source.is_self_report() {
return existing.clone();
}
let prior = Prior {
mean: existing.value,
variance: existing.variance,
reason: "previous estimate".to_string(),
};
self.update(&existing.axis, &prior, observation)
}
pub fn grow_uncertainty(&self, estimate: &AxisEstimate, elapsed_seconds: f64) -> AxisEstimate {
if estimate.source.is_self_report() {
return estimate.clone();
}
let growth = self.config.variance_growth_rate * elapsed_seconds as f32;
let new_variance = (estimate.variance + growth).min(0.25);
let new_confidence = AxisEstimate::variance_to_confidence(new_variance)
.min(self.config.max_inferred_confidence);
AxisEstimate {
axis: estimate.axis.clone(),
value: estimate.value,
confidence: new_confidence,
variance: new_variance,
source: InferenceSource::Decayed {
original: Box::new(estimate.source.clone()),
age_seconds: elapsed_seconds as u64,
decay_factor: estimate.variance / new_variance,
},
timestamp: estimate.timestamp,
}
}
pub fn combine_observations(
&self,
axis: &str,
prior: &Prior,
observations: &[Observation],
) -> AxisEstimate {
if observations.is_empty() {
return AxisEstimate::prior(axis, prior.mean, 0.5, &prior.reason);
}
if let Some(sr) = observations.iter().find(|o| o.source.is_self_report()) {
return AxisEstimate::self_report(axis, sr.value);
}
let mut current = AxisEstimate {
axis: axis.to_string(),
value: prior.mean,
confidence: AxisEstimate::variance_to_confidence(prior.variance),
variance: prior.variance,
source: InferenceSource::Prior {
reason: prior.reason.clone(),
},
timestamp: Utc::now(),
};
let sources: Vec<InferenceSource> = observations.iter().map(|o| o.source.clone()).collect();
let weights: Vec<f32> = observations
.iter()
.map(|o| 1.0 / o.noise_variance)
.collect();
for obs in observations {
current = self.update_estimate(¤t, obs);
}
AxisEstimate {
source: InferenceSource::Combined { sources, weights },
..current
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neutral_prior() {
let prior = Prior::neutral();
assert_eq!(prior.mean, 0.5);
assert!(prior.variance > 0.1); }
#[test]
fn test_basic_update() {
let updater = BayesianUpdater::new();
let prior = Prior::neutral();
let obs = Observation::from_linguistic(0.8, vec!["warmth".into()]);
let posterior = updater.update("warmth", &prior, &obs);
assert!(posterior.value > prior.mean);
assert!(posterior.value < 0.8); assert!(posterior.variance < prior.variance);
}
#[test]
fn test_self_report_dominates() {
let updater = BayesianUpdater::new();
let prior = Prior::from_value(0.2, 0.8, "strong belief in low value");
let obs = Observation::from_self_report(0.9);
let posterior = updater.update("warmth", &prior, &obs);
assert_eq!(posterior.value, 0.9);
assert_eq!(posterior.confidence, 1.0);
}
#[test]
fn test_max_update_constraint() {
let updater = BayesianUpdater::with_config(BayesianConfig {
max_update: 0.1,
..Default::default()
});
let prior = Prior::from_value(0.2, 0.5, "prior");
let obs = Observation::from_linguistic(0.9, vec![]);
let posterior = updater.update("warmth", &prior, &obs);
assert!(posterior.value <= 0.3 + 0.01); }
#[test]
fn test_cannot_override_self_report() {
let updater = BayesianUpdater::new();
let self_report = AxisEstimate::self_report("warmth", 0.9);
let obs = Observation::from_linguistic(0.2, vec![]);
let result = updater.update_estimate(&self_report, &obs);
assert_eq!(result.value, 0.9);
assert!(result.source.is_self_report());
}
#[test]
fn test_confidence_capping() {
let updater = BayesianUpdater::new();
let prior = Prior::neutral();
let obs = Observation::new(
0.8,
0.001,
InferenceSource::Linguistic {
features_used: vec![],
feature_values: std::collections::HashMap::new(),
},
);
let posterior = updater.update("warmth", &prior, &obs);
assert!(posterior.confidence <= MAX_INFERRED_CONFIDENCE);
}
#[test]
fn test_uncertainty_growth() {
let updater = BayesianUpdater::new();
let estimate = AxisEstimate::inferred(
"warmth",
0.7,
0.6,
InferenceSource::Linguistic {
features_used: vec![],
feature_values: std::collections::HashMap::new(),
},
);
let aged = updater.grow_uncertainty(&estimate, 3600.0);
assert!(aged.variance > estimate.variance);
assert!(aged.confidence < estimate.confidence);
}
#[test]
fn test_combine_observations() {
let updater = BayesianUpdater::new();
let prior = Prior::neutral();
let observations = vec![
Observation::from_linguistic(0.7, vec!["feat1".into()]),
Observation::from_linguistic(0.8, vec!["feat2".into()]),
];
let combined = updater.combine_observations("warmth", &prior, &observations);
assert!(combined.value > 0.5);
assert!(matches!(combined.source, InferenceSource::Combined { .. }));
}
}