use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const MAX_INFERRED_CONFIDENCE: f32 = 0.7;
pub fn word_count_confidence_factor(word_count: usize) -> f32 {
const MIN_WORDS: f32 = 10.0; const STABLE_WORDS: f32 = 50.0;
if word_count < MIN_WORDS as usize {
return 0.5; }
let factor = (word_count as f32 - MIN_WORDS) / (STABLE_WORDS - MIN_WORDS);
0.5 + 0.5 * factor.clamp(0.0, 1.0) }
pub fn max_confidence_for_axis(axis: &str) -> f32 {
match axis {
"formality" | "emotional_intensity" => 0.7,
"anxiety_level" | "assertiveness" | "directness_preference" => 0.6,
"urgency_sensitivity" | "warmth" | "ritual_need" => 0.5,
"tolerance_for_complexity" | "verbosity_preference" => 0.4,
_ => 0.5,
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InferenceSource {
SelfReport,
Linguistic {
features_used: Vec<String>,
feature_values: HashMap<String, f32>,
},
Delta {
baseline_messages: usize,
z_score: f32,
metric: String,
},
Combined {
sources: Vec<InferenceSource>,
weights: Vec<f32>,
},
Decayed {
original: Box<InferenceSource>,
age_seconds: u64,
decay_factor: f32,
},
Prior {
reason: String,
},
}
impl InferenceSource {
pub fn is_self_report(&self) -> bool {
matches!(self, Self::SelfReport)
}
pub fn is_inferred(&self) -> bool {
!self.is_self_report()
}
pub fn summary(&self) -> String {
match self {
Self::SelfReport => "self-report".to_string(),
Self::Linguistic { features_used, .. } => {
format!("linguistic({})", features_used.join(", "))
}
Self::Delta {
metric, z_score, ..
} => {
format!("delta({}: z={:.2})", metric, z_score)
}
Self::Combined { sources, .. } => {
format!("combined({})", sources.len())
}
Self::Decayed {
original,
decay_factor,
..
} => {
format!(
"decayed({}, factor={:.2})",
original.summary(),
decay_factor
)
}
Self::Prior { reason } => format!("prior({})", reason),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AxisEstimate {
pub axis: String,
pub value: f32,
pub confidence: f32,
pub variance: f32,
pub source: InferenceSource,
pub timestamp: DateTime<Utc>,
}
impl AxisEstimate {
pub fn inferred(
axis: impl Into<String>,
value: f32,
confidence: f32,
source: InferenceSource,
) -> Self {
debug_assert!(
source.is_inferred(),
"Use self_report() for self-report values"
);
Self {
axis: axis.into(),
value: value.clamp(0.0, 1.0),
confidence: confidence.min(MAX_INFERRED_CONFIDENCE),
variance: Self::confidence_to_variance(confidence.min(MAX_INFERRED_CONFIDENCE)),
source,
timestamp: Utc::now(),
}
}
pub fn self_report(axis: impl Into<String>, value: f32) -> Self {
Self {
axis: axis.into(),
value: value.clamp(0.0, 1.0),
confidence: 1.0,
variance: 0.001, source: InferenceSource::SelfReport,
timestamp: Utc::now(),
}
}
pub fn prior(
axis: impl Into<String>,
value: f32,
confidence: f32,
reason: impl Into<String>,
) -> Self {
Self {
axis: axis.into(),
value: value.clamp(0.0, 1.0),
confidence: confidence.min(MAX_INFERRED_CONFIDENCE),
variance: Self::confidence_to_variance(confidence.min(MAX_INFERRED_CONFIDENCE)),
source: InferenceSource::Prior {
reason: reason.into(),
},
timestamp: Utc::now(),
}
}
pub fn confidence_to_variance(confidence: f32) -> f32 {
let conf = confidence.clamp(0.0, 1.0);
(1.0 - conf).powi(2) + 0.001
}
pub fn variance_to_confidence(variance: f32) -> f32 {
(1.0 - (variance - 0.001).max(0.0).sqrt()).clamp(0.0, 1.0)
}
pub fn decay(&self, half_life_seconds: f64) -> Self {
let age = Utc::now()
.signed_duration_since(self.timestamp)
.num_seconds() as f64;
if age <= 0.0 || self.source.is_self_report() {
return self.clone();
}
let decay_factor = 0.5_f64.powf(age / half_life_seconds) as f32;
let new_confidence = (self.confidence * decay_factor).max(0.1);
Self {
axis: self.axis.clone(),
value: self.value,
confidence: new_confidence,
variance: Self::confidence_to_variance(new_confidence),
source: InferenceSource::Decayed {
original: Box::new(self.source.clone()),
age_seconds: age as u64,
decay_factor,
},
timestamp: self.timestamp,
}
}
pub fn is_stale(&self, max_age_seconds: i64) -> bool {
let age = Utc::now()
.signed_duration_since(self.timestamp)
.num_seconds();
age > max_age_seconds
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct InferredState {
estimates: HashMap<String, AxisEstimate>,
}
impl InferredState {
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, estimate: AxisEstimate) {
let dominated = self.estimates.get(&estimate.axis).is_some_and(|existing| {
existing.source.is_self_report() && estimate.source.is_inferred()
});
if !dominated {
self.estimates.insert(estimate.axis.clone(), estimate);
}
}
pub fn get(&self, axis: &str) -> Option<&AxisEstimate> {
self.estimates.get(axis)
}
pub fn all(&self) -> impl Iterator<Item = &AxisEstimate> {
self.estimates.values()
}
pub fn axes(&self) -> impl Iterator<Item = &str> {
self.estimates.keys().map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.estimates.len()
}
pub fn is_empty(&self) -> bool {
self.estimates.is_empty()
}
pub fn override_with_self_report(&mut self, axis: impl Into<String>, value: f32) {
let axis = axis.into();
self.estimates
.insert(axis.clone(), AxisEstimate::self_report(axis, value));
}
pub fn decay_all(&mut self, half_life_seconds: f64) {
for estimate in self.estimates.values_mut() {
if estimate.source.is_inferred() {
*estimate = estimate.decay(half_life_seconds);
}
}
}
pub fn prune_stale(&mut self, max_age_seconds: i64) {
self.estimates.retain(|_, e| !e.is_stale(max_age_seconds));
}
pub fn merge(&mut self, other: InferredState) {
for (axis, new_estimate) in other.estimates {
match self.estimates.get(&axis) {
Some(existing) if existing.source.is_self_report() => {
continue;
}
Some(_existing) if new_estimate.source.is_self_report() => {
self.estimates.insert(axis, new_estimate);
}
Some(existing) if new_estimate.confidence > existing.confidence => {
self.estimates.insert(axis, new_estimate);
}
Some(_) => {
continue;
}
None => {
self.estimates.insert(axis, new_estimate);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inferred_confidence_cap() {
let estimate = AxisEstimate::inferred(
"warmth",
0.8,
0.95, InferenceSource::Linguistic {
features_used: vec!["exclamation_ratio".into()],
feature_values: HashMap::new(),
},
);
assert!(estimate.confidence <= MAX_INFERRED_CONFIDENCE);
}
#[test]
fn test_self_report_full_confidence() {
let estimate = AxisEstimate::self_report("warmth", 0.8);
assert_eq!(estimate.confidence, 1.0);
assert!(estimate.variance < 0.01);
}
#[test]
fn test_self_report_dominates() {
let mut state = InferredState::new();
state.update(AxisEstimate::inferred(
"warmth",
0.3,
0.6,
InferenceSource::Linguistic {
features_used: vec![],
feature_values: HashMap::new(),
},
));
state.override_with_self_report("warmth", 0.9);
let estimate = state.get("warmth").unwrap();
assert_eq!(estimate.value, 0.9);
assert!(estimate.source.is_self_report());
}
#[test]
fn test_inference_cannot_override_self_report() {
let mut state = InferredState::new();
state.update(AxisEstimate::self_report("warmth", 0.9));
state.update(AxisEstimate::inferred(
"warmth",
0.3,
0.7,
InferenceSource::Linguistic {
features_used: vec![],
feature_values: HashMap::new(),
},
));
let estimate = state.get("warmth").unwrap();
assert_eq!(estimate.value, 0.9);
assert!(estimate.source.is_self_report());
}
#[test]
fn test_source_summary() {
let source = InferenceSource::Linguistic {
features_used: vec!["hedge_words".into(), "sentence_length".into()],
feature_values: HashMap::new(),
};
assert_eq!(source.summary(), "linguistic(hedge_words, sentence_length)");
}
}