Skip to main content

attuned_infer/
estimate.rs

1//! Axis estimates with confidence and provenance tracking.
2//!
3//! Every inferred value carries metadata about where it came from,
4//! how confident we are, and when it was computed. This enables
5//! full auditability and transparent override by self-report.
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Maximum confidence for any inferred value.
12///
13/// Self-report can have confidence 1.0, but inference is capped
14/// to ensure self-report always dominates when present.
15pub const MAX_INFERRED_CONFIDENCE: f32 = 0.7;
16
17/// Scale confidence based on text length (word count).
18///
19/// Research suggests ~100 words for stable style inference in formal settings,
20/// but chat messages are typically 20-50 words. We use a gentler curve.
21///
22/// # Arguments
23/// * `word_count` - Number of words in the analyzed text
24///
25/// # Returns
26/// A multiplier in [0.5, 1.0] to apply to base confidence
27pub fn word_count_confidence_factor(word_count: usize) -> f32 {
28    const MIN_WORDS: f32 = 10.0; // Below this: reduced confidence
29    const STABLE_WORDS: f32 = 50.0; // At this point: full confidence
30
31    if word_count < MIN_WORDS as usize {
32        return 0.5; // Reduced confidence for very short texts, but not too harsh
33    }
34
35    let factor = (word_count as f32 - MIN_WORDS) / (STABLE_WORDS - MIN_WORDS);
36    0.5 + 0.5 * factor.clamp(0.0, 1.0) // Range [0.5, 1.0]
37}
38
39/// Get maximum confidence for a specific axis based on research evidence strength.
40///
41/// Different axes have different evidence backing in the literature.
42/// This caps confidence based on how well-validated each axis is.
43///
44/// Based on DEEP_RESEARCH.md validation evidence:
45/// - Strong: formality, emotional_intensity (r > 0.3, multiple studies)
46/// - Moderate-Strong: anxiety_level, assertiveness (r = 0.2-0.3)
47/// - Moderate: urgency, warmth (context-dependent)
48/// - Weak: cognitive load proxies (no direct text signal)
49pub fn max_confidence_for_axis(axis: &str) -> f32 {
50    match axis {
51        // Strong evidence (validated across multiple studies)
52        "formality" | "emotional_intensity" => 0.7,
53
54        // Moderate-strong evidence (Dreaddit validated)
55        "anxiety_level" | "assertiveness" | "directness_preference" => 0.6,
56
57        // Moderate/context-dependent evidence
58        "urgency_sensitivity" | "warmth" | "ritual_need" => 0.5,
59
60        // Weak evidence - style-dependent, hard to infer from text alone
61        "tolerance_for_complexity" | "verbosity_preference" => 0.4,
62
63        // Unknown axis - use conservative default
64        _ => 0.5,
65    }
66}
67
68/// Source of an axis inference with full provenance.
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum InferenceSource {
72    /// User explicitly provided this value.
73    SelfReport,
74
75    /// Inferred from linguistic features of text.
76    Linguistic {
77        /// Which features contributed to this inference.
78        features_used: Vec<String>,
79        /// The raw feature values that drove the inference.
80        feature_values: HashMap<String, f32>,
81    },
82
83    /// Inferred from deviation from user's baseline behavior.
84    Delta {
85        /// How many messages in the baseline window.
86        baseline_messages: usize,
87        /// The z-score (standard deviations from baseline).
88        z_score: f32,
89        /// Which metric showed the deviation.
90        metric: String,
91    },
92
93    /// Combined from multiple inference sources.
94    Combined {
95        /// The sources that were combined.
96        sources: Vec<InferenceSource>,
97        /// Weights given to each source.
98        weights: Vec<f32>,
99    },
100
101    /// Confidence has decayed over time from original inference.
102    Decayed {
103        /// The original inference source.
104        original: Box<InferenceSource>,
105        /// How much time has passed.
106        age_seconds: u64,
107        /// The decay factor applied.
108        decay_factor: f32,
109    },
110
111    /// Default/prior value (no observation).
112    Prior {
113        /// Description of why this prior was chosen.
114        reason: String,
115    },
116}
117
118impl InferenceSource {
119    /// Returns true if this is a self-report (highest authority).
120    pub fn is_self_report(&self) -> bool {
121        matches!(self, Self::SelfReport)
122    }
123
124    /// Returns true if this is any form of inference (not self-report).
125    pub fn is_inferred(&self) -> bool {
126        !self.is_self_report()
127    }
128
129    /// Get a human-readable summary of this source.
130    pub fn summary(&self) -> String {
131        match self {
132            Self::SelfReport => "self-report".to_string(),
133            Self::Linguistic { features_used, .. } => {
134                format!("linguistic({})", features_used.join(", "))
135            }
136            Self::Delta {
137                metric, z_score, ..
138            } => {
139                format!("delta({}: z={:.2})", metric, z_score)
140            }
141            Self::Combined { sources, .. } => {
142                format!("combined({})", sources.len())
143            }
144            Self::Decayed {
145                original,
146                decay_factor,
147                ..
148            } => {
149                format!(
150                    "decayed({}, factor={:.2})",
151                    original.summary(),
152                    decay_factor
153                )
154            }
155            Self::Prior { reason } => format!("prior({})", reason),
156        }
157    }
158}
159
160/// A single axis estimate with full metadata.
161#[derive(Clone, Debug, Serialize, Deserialize)]
162pub struct AxisEstimate {
163    /// The axis name (must be a canonical axis).
164    pub axis: String,
165
166    /// Estimated value in [0.0, 1.0].
167    pub value: f32,
168
169    /// Confidence in this estimate.
170    ///
171    /// - 1.0 for self-report
172    /// - ≤0.7 for inference (capped by MAX_INFERRED_CONFIDENCE)
173    /// - Decays over time without new observations
174    pub confidence: f32,
175
176    /// Variance of the estimate (for Bayesian updates).
177    ///
178    /// Lower variance = more certain.
179    /// Self-report sets variance to near-zero.
180    pub variance: f32,
181
182    /// How this estimate was derived.
183    pub source: InferenceSource,
184
185    /// When this estimate was computed.
186    pub timestamp: DateTime<Utc>,
187}
188
189impl AxisEstimate {
190    /// Create a new estimate from inference.
191    ///
192    /// Confidence is automatically capped at MAX_INFERRED_CONFIDENCE.
193    pub fn inferred(
194        axis: impl Into<String>,
195        value: f32,
196        confidence: f32,
197        source: InferenceSource,
198    ) -> Self {
199        debug_assert!(
200            source.is_inferred(),
201            "Use self_report() for self-report values"
202        );
203        Self {
204            axis: axis.into(),
205            value: value.clamp(0.0, 1.0),
206            confidence: confidence.min(MAX_INFERRED_CONFIDENCE),
207            variance: Self::confidence_to_variance(confidence.min(MAX_INFERRED_CONFIDENCE)),
208            source,
209            timestamp: Utc::now(),
210        }
211    }
212
213    /// Create a new estimate from self-report.
214    ///
215    /// Self-report has confidence 1.0 and near-zero variance.
216    pub fn self_report(axis: impl Into<String>, value: f32) -> Self {
217        Self {
218            axis: axis.into(),
219            value: value.clamp(0.0, 1.0),
220            confidence: 1.0,
221            variance: 0.001, // Near-zero but not exactly zero for numerical stability
222            source: InferenceSource::SelfReport,
223            timestamp: Utc::now(),
224        }
225    }
226
227    /// Create a prior estimate (default before any observation).
228    pub fn prior(
229        axis: impl Into<String>,
230        value: f32,
231        confidence: f32,
232        reason: impl Into<String>,
233    ) -> Self {
234        Self {
235            axis: axis.into(),
236            value: value.clamp(0.0, 1.0),
237            confidence: confidence.min(MAX_INFERRED_CONFIDENCE),
238            variance: Self::confidence_to_variance(confidence.min(MAX_INFERRED_CONFIDENCE)),
239            source: InferenceSource::Prior {
240                reason: reason.into(),
241            },
242            timestamp: Utc::now(),
243        }
244    }
245
246    /// Convert confidence to variance for Bayesian math.
247    ///
248    /// High confidence → low variance, low confidence → high variance.
249    pub fn confidence_to_variance(confidence: f32) -> f32 {
250        // Map confidence [0,1] to variance [1, 0.001]
251        // Using exponential mapping for better numerical properties
252        let conf = confidence.clamp(0.0, 1.0);
253        (1.0 - conf).powi(2) + 0.001
254    }
255
256    /// Convert variance back to confidence.
257    pub fn variance_to_confidence(variance: f32) -> f32 {
258        (1.0 - (variance - 0.001).max(0.0).sqrt()).clamp(0.0, 1.0)
259    }
260
261    /// Apply time-based decay to this estimate.
262    ///
263    /// Confidence decreases over time, representing increasing uncertainty
264    /// about stale inferences.
265    pub fn decay(&self, half_life_seconds: f64) -> Self {
266        let age = Utc::now()
267            .signed_duration_since(self.timestamp)
268            .num_seconds() as f64;
269
270        if age <= 0.0 || self.source.is_self_report() {
271            return self.clone();
272        }
273
274        // Exponential decay: confidence * 0.5^(age/half_life)
275        let decay_factor = 0.5_f64.powf(age / half_life_seconds) as f32;
276        let new_confidence = (self.confidence * decay_factor).max(0.1); // Floor at 0.1
277
278        Self {
279            axis: self.axis.clone(),
280            value: self.value,
281            confidence: new_confidence,
282            variance: Self::confidence_to_variance(new_confidence),
283            source: InferenceSource::Decayed {
284                original: Box::new(self.source.clone()),
285                age_seconds: age as u64,
286                decay_factor,
287            },
288            timestamp: self.timestamp,
289        }
290    }
291
292    /// Check if this estimate should be considered stale.
293    pub fn is_stale(&self, max_age_seconds: i64) -> bool {
294        let age = Utc::now()
295            .signed_duration_since(self.timestamp)
296            .num_seconds();
297        age > max_age_seconds
298    }
299}
300
301/// Complete inferred state across multiple axes.
302#[derive(Clone, Debug, Default, Serialize, Deserialize)]
303pub struct InferredState {
304    estimates: HashMap<String, AxisEstimate>,
305}
306
307impl InferredState {
308    /// Create empty state.
309    pub fn new() -> Self {
310        Self::default()
311    }
312
313    /// Add or update an axis estimate.
314    ///
315    /// If an estimate already exists, the new one wins if:
316    /// - It's self-report (always wins), or
317    /// - It has higher confidence than existing non-self-report
318    pub fn update(&mut self, estimate: AxisEstimate) {
319        let dominated = self.estimates.get(&estimate.axis).is_some_and(|existing| {
320            existing.source.is_self_report() && estimate.source.is_inferred()
321        });
322
323        if !dominated {
324            self.estimates.insert(estimate.axis.clone(), estimate);
325        }
326    }
327
328    /// Get estimate for an axis.
329    pub fn get(&self, axis: &str) -> Option<&AxisEstimate> {
330        self.estimates.get(axis)
331    }
332
333    /// Get all estimates.
334    pub fn all(&self) -> impl Iterator<Item = &AxisEstimate> {
335        self.estimates.values()
336    }
337
338    /// Get all axis names with estimates.
339    pub fn axes(&self) -> impl Iterator<Item = &str> {
340        self.estimates.keys().map(|s| s.as_str())
341    }
342
343    /// Number of axes with estimates.
344    pub fn len(&self) -> usize {
345        self.estimates.len()
346    }
347
348    /// Returns true if no estimates.
349    pub fn is_empty(&self) -> bool {
350        self.estimates.is_empty()
351    }
352
353    /// Apply an override from self-report.
354    ///
355    /// This sets the axis to the self-reported value with confidence 1.0,
356    /// regardless of any existing inference.
357    pub fn override_with_self_report(&mut self, axis: impl Into<String>, value: f32) {
358        let axis = axis.into();
359        self.estimates
360            .insert(axis.clone(), AxisEstimate::self_report(axis, value));
361    }
362
363    /// Decay all inferred estimates.
364    pub fn decay_all(&mut self, half_life_seconds: f64) {
365        for estimate in self.estimates.values_mut() {
366            if estimate.source.is_inferred() {
367                *estimate = estimate.decay(half_life_seconds);
368            }
369        }
370    }
371
372    /// Remove stale estimates.
373    pub fn prune_stale(&mut self, max_age_seconds: i64) {
374        self.estimates.retain(|_, e| !e.is_stale(max_age_seconds));
375    }
376
377    /// Merge another state into this one.
378    ///
379    /// Self-report always wins. For inference vs inference,
380    /// higher confidence wins.
381    pub fn merge(&mut self, other: InferredState) {
382        for (axis, new_estimate) in other.estimates {
383            match self.estimates.get(&axis) {
384                Some(existing) if existing.source.is_self_report() => {
385                    // Existing self-report dominates
386                    continue;
387                }
388                Some(_existing) if new_estimate.source.is_self_report() => {
389                    // New self-report dominates
390                    self.estimates.insert(axis, new_estimate);
391                }
392                Some(existing) if new_estimate.confidence > existing.confidence => {
393                    // Higher confidence wins
394                    self.estimates.insert(axis, new_estimate);
395                }
396                Some(_) => {
397                    // Existing has higher or equal confidence
398                    continue;
399                }
400                None => {
401                    self.estimates.insert(axis, new_estimate);
402                }
403            }
404        }
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_inferred_confidence_cap() {
414        let estimate = AxisEstimate::inferred(
415            "warmth",
416            0.8,
417            0.95, // Above cap
418            InferenceSource::Linguistic {
419                features_used: vec!["exclamation_ratio".into()],
420                feature_values: HashMap::new(),
421            },
422        );
423
424        assert!(estimate.confidence <= MAX_INFERRED_CONFIDENCE);
425    }
426
427    #[test]
428    fn test_self_report_full_confidence() {
429        let estimate = AxisEstimate::self_report("warmth", 0.8);
430        assert_eq!(estimate.confidence, 1.0);
431        assert!(estimate.variance < 0.01);
432    }
433
434    #[test]
435    fn test_self_report_dominates() {
436        let mut state = InferredState::new();
437
438        // Add inference
439        state.update(AxisEstimate::inferred(
440            "warmth",
441            0.3,
442            0.6,
443            InferenceSource::Linguistic {
444                features_used: vec![],
445                feature_values: HashMap::new(),
446            },
447        ));
448
449        // Override with self-report
450        state.override_with_self_report("warmth", 0.9);
451
452        let estimate = state.get("warmth").unwrap();
453        assert_eq!(estimate.value, 0.9);
454        assert!(estimate.source.is_self_report());
455    }
456
457    #[test]
458    fn test_inference_cannot_override_self_report() {
459        let mut state = InferredState::new();
460
461        // Add self-report first
462        state.update(AxisEstimate::self_report("warmth", 0.9));
463
464        // Try to update with inference
465        state.update(AxisEstimate::inferred(
466            "warmth",
467            0.3,
468            0.7,
469            InferenceSource::Linguistic {
470                features_used: vec![],
471                feature_values: HashMap::new(),
472            },
473        ));
474
475        // Self-report should still be there
476        let estimate = state.get("warmth").unwrap();
477        assert_eq!(estimate.value, 0.9);
478        assert!(estimate.source.is_self_report());
479    }
480
481    #[test]
482    fn test_source_summary() {
483        let source = InferenceSource::Linguistic {
484            features_used: vec!["hedge_words".into(), "sentence_length".into()],
485            feature_values: HashMap::new(),
486        };
487        assert_eq!(source.summary(), "linguistic(hedge_words, sentence_length)");
488    }
489}