1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11pub const MAX_INFERRED_CONFIDENCE: f32 = 0.7;
16
17pub fn word_count_confidence_factor(word_count: usize) -> f32 {
28 const MIN_WORDS: f32 = 10.0; const STABLE_WORDS: f32 = 50.0; if word_count < MIN_WORDS as usize {
32 return 0.5; }
34
35 let factor = (word_count as f32 - MIN_WORDS) / (STABLE_WORDS - MIN_WORDS);
36 0.5 + 0.5 * factor.clamp(0.0, 1.0) }
38
39pub fn max_confidence_for_axis(axis: &str) -> f32 {
50 match axis {
51 "formality" | "emotional_intensity" => 0.7,
53
54 "anxiety_level" | "assertiveness" | "directness_preference" => 0.6,
56
57 "urgency_sensitivity" | "warmth" | "ritual_need" => 0.5,
59
60 "tolerance_for_complexity" | "verbosity_preference" => 0.4,
62
63 _ => 0.5,
65 }
66}
67
68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum InferenceSource {
72 SelfReport,
74
75 Linguistic {
77 features_used: Vec<String>,
79 feature_values: HashMap<String, f32>,
81 },
82
83 Delta {
85 baseline_messages: usize,
87 z_score: f32,
89 metric: String,
91 },
92
93 Combined {
95 sources: Vec<InferenceSource>,
97 weights: Vec<f32>,
99 },
100
101 Decayed {
103 original: Box<InferenceSource>,
105 age_seconds: u64,
107 decay_factor: f32,
109 },
110
111 Prior {
113 reason: String,
115 },
116}
117
118impl InferenceSource {
119 pub fn is_self_report(&self) -> bool {
121 matches!(self, Self::SelfReport)
122 }
123
124 pub fn is_inferred(&self) -> bool {
126 !self.is_self_report()
127 }
128
129 pub fn summary(&self) -> String {
131 match self {
132 Self::SelfReport => "self-report".to_string(),
133 Self::Linguistic { features_used, .. } => {
134 format!("linguistic({})", features_used.join(", "))
135 }
136 Self::Delta {
137 metric, z_score, ..
138 } => {
139 format!("delta({}: z={:.2})", metric, z_score)
140 }
141 Self::Combined { sources, .. } => {
142 format!("combined({})", sources.len())
143 }
144 Self::Decayed {
145 original,
146 decay_factor,
147 ..
148 } => {
149 format!(
150 "decayed({}, factor={:.2})",
151 original.summary(),
152 decay_factor
153 )
154 }
155 Self::Prior { reason } => format!("prior({})", reason),
156 }
157 }
158}
159
160#[derive(Clone, Debug, Serialize, Deserialize)]
162pub struct AxisEstimate {
163 pub axis: String,
165
166 pub value: f32,
168
169 pub confidence: f32,
175
176 pub variance: f32,
181
182 pub source: InferenceSource,
184
185 pub timestamp: DateTime<Utc>,
187}
188
189impl AxisEstimate {
190 pub fn inferred(
194 axis: impl Into<String>,
195 value: f32,
196 confidence: f32,
197 source: InferenceSource,
198 ) -> Self {
199 debug_assert!(
200 source.is_inferred(),
201 "Use self_report() for self-report values"
202 );
203 Self {
204 axis: axis.into(),
205 value: value.clamp(0.0, 1.0),
206 confidence: confidence.min(MAX_INFERRED_CONFIDENCE),
207 variance: Self::confidence_to_variance(confidence.min(MAX_INFERRED_CONFIDENCE)),
208 source,
209 timestamp: Utc::now(),
210 }
211 }
212
213 pub fn self_report(axis: impl Into<String>, value: f32) -> Self {
217 Self {
218 axis: axis.into(),
219 value: value.clamp(0.0, 1.0),
220 confidence: 1.0,
221 variance: 0.001, source: InferenceSource::SelfReport,
223 timestamp: Utc::now(),
224 }
225 }
226
227 pub fn prior(
229 axis: impl Into<String>,
230 value: f32,
231 confidence: f32,
232 reason: impl Into<String>,
233 ) -> Self {
234 Self {
235 axis: axis.into(),
236 value: value.clamp(0.0, 1.0),
237 confidence: confidence.min(MAX_INFERRED_CONFIDENCE),
238 variance: Self::confidence_to_variance(confidence.min(MAX_INFERRED_CONFIDENCE)),
239 source: InferenceSource::Prior {
240 reason: reason.into(),
241 },
242 timestamp: Utc::now(),
243 }
244 }
245
246 pub fn confidence_to_variance(confidence: f32) -> f32 {
250 let conf = confidence.clamp(0.0, 1.0);
253 (1.0 - conf).powi(2) + 0.001
254 }
255
256 pub fn variance_to_confidence(variance: f32) -> f32 {
258 (1.0 - (variance - 0.001).max(0.0).sqrt()).clamp(0.0, 1.0)
259 }
260
261 pub fn decay(&self, half_life_seconds: f64) -> Self {
266 let age = Utc::now()
267 .signed_duration_since(self.timestamp)
268 .num_seconds() as f64;
269
270 if age <= 0.0 || self.source.is_self_report() {
271 return self.clone();
272 }
273
274 let decay_factor = 0.5_f64.powf(age / half_life_seconds) as f32;
276 let new_confidence = (self.confidence * decay_factor).max(0.1); Self {
279 axis: self.axis.clone(),
280 value: self.value,
281 confidence: new_confidence,
282 variance: Self::confidence_to_variance(new_confidence),
283 source: InferenceSource::Decayed {
284 original: Box::new(self.source.clone()),
285 age_seconds: age as u64,
286 decay_factor,
287 },
288 timestamp: self.timestamp,
289 }
290 }
291
292 pub fn is_stale(&self, max_age_seconds: i64) -> bool {
294 let age = Utc::now()
295 .signed_duration_since(self.timestamp)
296 .num_seconds();
297 age > max_age_seconds
298 }
299}
300
301#[derive(Clone, Debug, Default, Serialize, Deserialize)]
303pub struct InferredState {
304 estimates: HashMap<String, AxisEstimate>,
305}
306
307impl InferredState {
308 pub fn new() -> Self {
310 Self::default()
311 }
312
313 pub fn update(&mut self, estimate: AxisEstimate) {
319 let dominated = self.estimates.get(&estimate.axis).is_some_and(|existing| {
320 existing.source.is_self_report() && estimate.source.is_inferred()
321 });
322
323 if !dominated {
324 self.estimates.insert(estimate.axis.clone(), estimate);
325 }
326 }
327
328 pub fn get(&self, axis: &str) -> Option<&AxisEstimate> {
330 self.estimates.get(axis)
331 }
332
333 pub fn all(&self) -> impl Iterator<Item = &AxisEstimate> {
335 self.estimates.values()
336 }
337
338 pub fn axes(&self) -> impl Iterator<Item = &str> {
340 self.estimates.keys().map(|s| s.as_str())
341 }
342
343 pub fn len(&self) -> usize {
345 self.estimates.len()
346 }
347
348 pub fn is_empty(&self) -> bool {
350 self.estimates.is_empty()
351 }
352
353 pub fn override_with_self_report(&mut self, axis: impl Into<String>, value: f32) {
358 let axis = axis.into();
359 self.estimates
360 .insert(axis.clone(), AxisEstimate::self_report(axis, value));
361 }
362
363 pub fn decay_all(&mut self, half_life_seconds: f64) {
365 for estimate in self.estimates.values_mut() {
366 if estimate.source.is_inferred() {
367 *estimate = estimate.decay(half_life_seconds);
368 }
369 }
370 }
371
372 pub fn prune_stale(&mut self, max_age_seconds: i64) {
374 self.estimates.retain(|_, e| !e.is_stale(max_age_seconds));
375 }
376
377 pub fn merge(&mut self, other: InferredState) {
382 for (axis, new_estimate) in other.estimates {
383 match self.estimates.get(&axis) {
384 Some(existing) if existing.source.is_self_report() => {
385 continue;
387 }
388 Some(_existing) if new_estimate.source.is_self_report() => {
389 self.estimates.insert(axis, new_estimate);
391 }
392 Some(existing) if new_estimate.confidence > existing.confidence => {
393 self.estimates.insert(axis, new_estimate);
395 }
396 Some(_) => {
397 continue;
399 }
400 None => {
401 self.estimates.insert(axis, new_estimate);
402 }
403 }
404 }
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_inferred_confidence_cap() {
414 let estimate = AxisEstimate::inferred(
415 "warmth",
416 0.8,
417 0.95, InferenceSource::Linguistic {
419 features_used: vec!["exclamation_ratio".into()],
420 feature_values: HashMap::new(),
421 },
422 );
423
424 assert!(estimate.confidence <= MAX_INFERRED_CONFIDENCE);
425 }
426
427 #[test]
428 fn test_self_report_full_confidence() {
429 let estimate = AxisEstimate::self_report("warmth", 0.8);
430 assert_eq!(estimate.confidence, 1.0);
431 assert!(estimate.variance < 0.01);
432 }
433
434 #[test]
435 fn test_self_report_dominates() {
436 let mut state = InferredState::new();
437
438 state.update(AxisEstimate::inferred(
440 "warmth",
441 0.3,
442 0.6,
443 InferenceSource::Linguistic {
444 features_used: vec![],
445 feature_values: HashMap::new(),
446 },
447 ));
448
449 state.override_with_self_report("warmth", 0.9);
451
452 let estimate = state.get("warmth").unwrap();
453 assert_eq!(estimate.value, 0.9);
454 assert!(estimate.source.is_self_report());
455 }
456
457 #[test]
458 fn test_inference_cannot_override_self_report() {
459 let mut state = InferredState::new();
460
461 state.update(AxisEstimate::self_report("warmth", 0.9));
463
464 state.update(AxisEstimate::inferred(
466 "warmth",
467 0.3,
468 0.7,
469 InferenceSource::Linguistic {
470 features_used: vec![],
471 feature_values: HashMap::new(),
472 },
473 ));
474
475 let estimate = state.get("warmth").unwrap();
477 assert_eq!(estimate.value, 0.9);
478 assert!(estimate.source.is_self_report());
479 }
480
481 #[test]
482 fn test_source_summary() {
483 let source = InferenceSource::Linguistic {
484 features_used: vec!["hedge_words".into(), "sentence_length".into()],
485 feature_values: HashMap::new(),
486 };
487 assert_eq!(source.summary(), "linguistic(hedge_words, sentence_length)");
488 }
489}