1use crate::estimate::{AxisEstimate, InferenceSource, MAX_INFERRED_CONFIDENCE};
14use chrono::{DateTime, Utc};
15use serde::{Deserialize, Serialize};
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct Prior {
20 pub mean: f32,
22 pub variance: f32,
24 pub reason: String,
26}
27
28impl Prior {
29 pub fn neutral() -> Self {
31 Self {
32 mean: 0.5,
33 variance: 0.25, reason: "neutral default".to_string(),
35 }
36 }
37
38 pub fn from_value(value: f32, confidence: f32, reason: impl Into<String>) -> Self {
40 let variance = (1.0 - confidence).powi(2) * 0.25 + 0.01;
42 Self {
43 mean: value.clamp(0.0, 1.0),
44 variance,
45 reason: reason.into(),
46 }
47 }
48
49 pub fn biased_low(reason: impl Into<String>) -> Self {
51 Self {
52 mean: 0.3,
53 variance: 0.15,
54 reason: reason.into(),
55 }
56 }
57
58 pub fn biased_high(reason: impl Into<String>) -> Self {
60 Self {
61 mean: 0.7,
62 variance: 0.15,
63 reason: reason.into(),
64 }
65 }
66}
67
68impl Default for Prior {
69 fn default() -> Self {
70 Self::neutral()
71 }
72}
73
74#[derive(Clone, Debug)]
76pub struct Observation {
77 pub value: f32,
79 pub noise_variance: f32,
81 pub source: InferenceSource,
83 pub timestamp: DateTime<Utc>,
85}
86
87impl Observation {
88 pub fn new(value: f32, noise_variance: f32, source: InferenceSource) -> Self {
90 Self {
91 value: value.clamp(0.0, 1.0),
92 noise_variance: noise_variance.max(0.001),
93 source,
94 timestamp: Utc::now(),
95 }
96 }
97
98 pub fn from_linguistic(value: f32, features_used: Vec<String>) -> Self {
102 Self::new(
103 value,
104 0.04, InferenceSource::Linguistic {
106 features_used,
107 feature_values: std::collections::HashMap::new(),
108 },
109 )
110 }
111
112 pub fn from_delta(value: f32, z_score: f32, metric: String, baseline_messages: usize) -> Self {
114 let noise = (0.1 / (1.0 + z_score.abs())).max(0.02);
116 Self::new(
117 value,
118 noise,
119 InferenceSource::Delta {
120 baseline_messages,
121 z_score,
122 metric,
123 },
124 )
125 }
126
127 pub fn from_self_report(value: f32) -> Self {
129 Self::new(value, 0.001, InferenceSource::SelfReport)
130 }
131}
132
133#[derive(Clone, Debug)]
135pub struct BayesianConfig {
136 pub max_update: f32,
138 pub min_variance: f32,
140 pub variance_growth_rate: f32,
142 pub max_inferred_confidence: f32,
144}
145
146impl Default for BayesianConfig {
147 fn default() -> Self {
148 Self {
149 max_update: 0.3, min_variance: 0.001, variance_growth_rate: 0.0001, max_inferred_confidence: MAX_INFERRED_CONFIDENCE,
153 }
154 }
155}
156
157#[derive(Clone, Debug, Default)]
162pub struct BayesianUpdater {
163 config: BayesianConfig,
164}
165
166impl BayesianUpdater {
167 pub fn new() -> Self {
169 Self::default()
170 }
171
172 pub fn with_config(config: BayesianConfig) -> Self {
174 Self { config }
175 }
176
177 pub fn update(&self, axis: &str, prior: &Prior, observation: &Observation) -> AxisEstimate {
181 if observation.source.is_self_report() {
183 return AxisEstimate::self_report(axis, observation.value);
184 }
185
186 let prior_precision = 1.0 / prior.variance;
191 let obs_precision = 1.0 / observation.noise_variance;
192
193 let posterior_precision = prior_precision + obs_precision;
194 let posterior_variance = (1.0 / posterior_precision).max(self.config.min_variance);
195
196 let posterior_mean =
197 posterior_variance * (prior.mean * prior_precision + observation.value * obs_precision);
198
199 let clamped_mean = if (posterior_mean - prior.mean).abs() > self.config.max_update {
201 if posterior_mean > prior.mean {
202 prior.mean + self.config.max_update
203 } else {
204 prior.mean - self.config.max_update
205 }
206 } else {
207 posterior_mean
208 };
209
210 let final_mean = clamped_mean.clamp(0.0, 1.0);
212
213 let confidence = AxisEstimate::variance_to_confidence(posterior_variance)
215 .min(self.config.max_inferred_confidence);
216
217 AxisEstimate {
218 axis: axis.to_string(),
219 value: final_mean,
220 confidence,
221 variance: posterior_variance,
222 source: observation.source.clone(),
223 timestamp: observation.timestamp,
224 }
225 }
226
227 pub fn update_estimate(
231 &self,
232 existing: &AxisEstimate,
233 observation: &Observation,
234 ) -> AxisEstimate {
235 if observation.source.is_self_report() {
237 return AxisEstimate::self_report(&existing.axis, observation.value);
238 }
239
240 if existing.source.is_self_report() {
242 return existing.clone();
243 }
244
245 let prior = Prior {
247 mean: existing.value,
248 variance: existing.variance,
249 reason: "previous estimate".to_string(),
250 };
251
252 self.update(&existing.axis, &prior, observation)
253 }
254
255 pub fn grow_uncertainty(&self, estimate: &AxisEstimate, elapsed_seconds: f64) -> AxisEstimate {
260 if estimate.source.is_self_report() {
261 return estimate.clone();
263 }
264
265 let growth = self.config.variance_growth_rate * elapsed_seconds as f32;
266 let new_variance = (estimate.variance + growth).min(0.25); let new_confidence = AxisEstimate::variance_to_confidence(new_variance)
269 .min(self.config.max_inferred_confidence);
270
271 AxisEstimate {
272 axis: estimate.axis.clone(),
273 value: estimate.value,
274 confidence: new_confidence,
275 variance: new_variance,
276 source: InferenceSource::Decayed {
277 original: Box::new(estimate.source.clone()),
278 age_seconds: elapsed_seconds as u64,
279 decay_factor: estimate.variance / new_variance,
280 },
281 timestamp: estimate.timestamp,
282 }
283 }
284
285 pub fn combine_observations(
289 &self,
290 axis: &str,
291 prior: &Prior,
292 observations: &[Observation],
293 ) -> AxisEstimate {
294 if observations.is_empty() {
295 return AxisEstimate::prior(axis, prior.mean, 0.5, &prior.reason);
296 }
297
298 if let Some(sr) = observations.iter().find(|o| o.source.is_self_report()) {
300 return AxisEstimate::self_report(axis, sr.value);
301 }
302
303 let mut current = AxisEstimate {
305 axis: axis.to_string(),
306 value: prior.mean,
307 confidence: AxisEstimate::variance_to_confidence(prior.variance),
308 variance: prior.variance,
309 source: InferenceSource::Prior {
310 reason: prior.reason.clone(),
311 },
312 timestamp: Utc::now(),
313 };
314
315 let sources: Vec<InferenceSource> = observations.iter().map(|o| o.source.clone()).collect();
316 let weights: Vec<f32> = observations
317 .iter()
318 .map(|o| 1.0 / o.noise_variance)
319 .collect();
320
321 for obs in observations {
322 current = self.update_estimate(¤t, obs);
323 }
324
325 AxisEstimate {
327 source: InferenceSource::Combined { sources, weights },
328 ..current
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_neutral_prior() {
339 let prior = Prior::neutral();
340 assert_eq!(prior.mean, 0.5);
341 assert!(prior.variance > 0.1); }
343
344 #[test]
345 fn test_basic_update() {
346 let updater = BayesianUpdater::new();
347 let prior = Prior::neutral();
348 let obs = Observation::from_linguistic(0.8, vec!["warmth".into()]);
349
350 let posterior = updater.update("warmth", &prior, &obs);
351
352 assert!(posterior.value > prior.mean);
354 assert!(posterior.value < 0.8); assert!(posterior.variance < prior.variance);
357 }
358
359 #[test]
360 fn test_self_report_dominates() {
361 let updater = BayesianUpdater::new();
362 let prior = Prior::from_value(0.2, 0.8, "strong belief in low value");
363 let obs = Observation::from_self_report(0.9);
364
365 let posterior = updater.update("warmth", &prior, &obs);
366
367 assert_eq!(posterior.value, 0.9);
368 assert_eq!(posterior.confidence, 1.0);
369 }
370
371 #[test]
372 fn test_max_update_constraint() {
373 let updater = BayesianUpdater::with_config(BayesianConfig {
374 max_update: 0.1,
375 ..Default::default()
376 });
377
378 let prior = Prior::from_value(0.2, 0.5, "prior");
379 let obs = Observation::from_linguistic(0.9, vec![]); let posterior = updater.update("warmth", &prior, &obs);
382
383 assert!(posterior.value <= 0.3 + 0.01); }
386
387 #[test]
388 fn test_cannot_override_self_report() {
389 let updater = BayesianUpdater::new();
390
391 let self_report = AxisEstimate::self_report("warmth", 0.9);
392 let obs = Observation::from_linguistic(0.2, vec![]);
393
394 let result = updater.update_estimate(&self_report, &obs);
395
396 assert_eq!(result.value, 0.9);
398 assert!(result.source.is_self_report());
399 }
400
401 #[test]
402 fn test_confidence_capping() {
403 let updater = BayesianUpdater::new();
404 let prior = Prior::neutral();
405
406 let obs = Observation::new(
408 0.8,
409 0.001,
410 InferenceSource::Linguistic {
411 features_used: vec![],
412 feature_values: std::collections::HashMap::new(),
413 },
414 );
415
416 let posterior = updater.update("warmth", &prior, &obs);
417
418 assert!(posterior.confidence <= MAX_INFERRED_CONFIDENCE);
420 }
421
422 #[test]
423 fn test_uncertainty_growth() {
424 let updater = BayesianUpdater::new();
425
426 let estimate = AxisEstimate::inferred(
427 "warmth",
428 0.7,
429 0.6,
430 InferenceSource::Linguistic {
431 features_used: vec![],
432 feature_values: std::collections::HashMap::new(),
433 },
434 );
435
436 let aged = updater.grow_uncertainty(&estimate, 3600.0); assert!(aged.variance > estimate.variance);
439 assert!(aged.confidence < estimate.confidence);
440 }
441
442 #[test]
443 fn test_combine_observations() {
444 let updater = BayesianUpdater::new();
445 let prior = Prior::neutral();
446
447 let observations = vec![
448 Observation::from_linguistic(0.7, vec!["feat1".into()]),
449 Observation::from_linguistic(0.8, vec!["feat2".into()]),
450 ];
451
452 let combined = updater.combine_observations("warmth", &prior, &observations);
453
454 assert!(combined.value > 0.5);
456 assert!(matches!(combined.source, InferenceSource::Combined { .. }));
458 }
459}