Skip to main content

attuned_infer/
engine.rs

1//! Main inference engine that combines all signals.
2//!
3//! The engine orchestrates:
4//! 1. Linguistic feature extraction
5//! 2. Delta analysis (deviation from baseline)
6//! 3. Bayesian state estimation
7//!
8//! All inferences are declared, bounded, and subordinate to self-report.
9
10use std::collections::HashMap;
11
12use crate::bayesian::{BayesianConfig, BayesianUpdater, Observation, Prior};
13use crate::delta::{Baseline, DeltaAnalyzer};
14use crate::estimate::{
15    max_confidence_for_axis, word_count_confidence_factor, InferenceSource, InferredState,
16};
17use crate::features::{LinguisticExtractor, LinguisticFeatures};
18
19/// Configuration for the inference engine.
20#[derive(Clone, Debug)]
21pub struct InferenceConfig {
22    /// Weight for linguistic signals vs delta signals.
23    pub linguistic_weight: f32,
24    /// Weight for delta (deviation) signals.
25    pub delta_weight: f32,
26    /// Minimum confidence to include an axis estimate.
27    pub min_confidence: f32,
28    /// Whether to track baselines per user.
29    pub enable_delta_analysis: bool,
30    /// Baseline window size.
31    pub baseline_window: usize,
32    /// Default priors for each axis.
33    pub default_priors: HashMap<String, Prior>,
34}
35
36impl Default for InferenceConfig {
37    fn default() -> Self {
38        Self {
39            linguistic_weight: 0.6,
40            delta_weight: 0.4,
41            min_confidence: 0.3,
42            enable_delta_analysis: true,
43            baseline_window: 50,
44            default_priors: Self::standard_priors(),
45        }
46    }
47}
48
49impl InferenceConfig {
50    /// Standard neutral priors for all canonical axes.
51    fn standard_priors() -> HashMap<String, Prior> {
52        // Most axes start at neutral 0.5
53        // Some have biased priors based on typical user behavior
54        let mut priors = HashMap::new();
55
56        // Cognitive - default to moderate load
57        priors.insert(
58            "cognitive_load".into(),
59            Prior::from_value(0.4, 0.3, "typical user"),
60        );
61        priors.insert(
62            "decision_fatigue".into(),
63            Prior::from_value(0.3, 0.3, "typical user"),
64        );
65        priors.insert("tolerance_for_complexity".into(), Prior::neutral());
66        priors.insert(
67            "urgency_sensitivity".into(),
68            Prior::from_value(0.3, 0.3, "most queries not urgent"),
69        );
70
71        // Emotional - default to stable
72        priors.insert("emotional_intensity".into(), Prior::neutral());
73        priors.insert(
74            "emotional_stability".into(),
75            Prior::from_value(0.6, 0.3, "assume stable"),
76        );
77        priors.insert(
78            "anxiety_level".into(),
79            Prior::from_value(0.3, 0.3, "assume calm"),
80        );
81        priors.insert("need_for_reassurance".into(), Prior::neutral());
82
83        // Social - slightly warm/casual for digital interactions
84        priors.insert("warmth".into(), Prior::from_value(0.5, 0.3, "neutral"));
85        priors.insert(
86            "formality".into(),
87            Prior::from_value(0.4, 0.3, "digital tends casual"),
88        );
89        priors.insert("boundary_strength".into(), Prior::neutral());
90        priors.insert("assertiveness".into(), Prior::neutral());
91        priors.insert("reciprocity_expectation".into(), Prior::neutral());
92
93        // Preferences - neutral
94        priors.insert("ritual_need".into(), Prior::neutral());
95        priors.insert("transactional_preference".into(), Prior::neutral());
96        priors.insert("verbosity_preference".into(), Prior::neutral());
97        priors.insert("directness_preference".into(), Prior::neutral());
98
99        // Control - slightly prefer autonomy
100        priors.insert(
101            "autonomy_preference".into(),
102            Prior::from_value(0.6, 0.3, "users prefer control"),
103        );
104        priors.insert("suggestion_tolerance".into(), Prior::neutral());
105        priors.insert(
106            "interruption_tolerance".into(),
107            Prior::from_value(0.4, 0.3, "low by default"),
108        );
109        priors.insert("reflection_vs_action_bias".into(), Prior::neutral());
110
111        // Safety - moderate defaults
112        priors.insert("stakes_awareness".into(), Prior::neutral());
113        priors.insert(
114            "privacy_sensitivity".into(),
115            Prior::from_value(0.6, 0.3, "assume privacy-conscious"),
116        );
117
118        priors
119    }
120}
121
122/// Main inference engine.
123///
124/// Combines linguistic features, delta analysis, and Bayesian updating
125/// to produce axis estimates with full provenance.
126#[derive(Clone, Debug)]
127pub struct InferenceEngine {
128    config: InferenceConfig,
129    extractor: LinguisticExtractor,
130    bayesian: BayesianUpdater,
131    delta_analyzer: DeltaAnalyzer,
132}
133
134impl InferenceEngine {
135    /// Create a new engine with default configuration.
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    /// Create an engine with custom configuration.
141    pub fn with_config(config: InferenceConfig) -> Self {
142        Self {
143            config,
144            extractor: LinguisticExtractor::new(),
145            bayesian: BayesianUpdater::new(),
146            delta_analyzer: DeltaAnalyzer::new(),
147        }
148    }
149
150    /// Create an engine with custom Bayesian configuration.
151    pub fn with_bayesian_config(bayesian_config: BayesianConfig) -> Self {
152        Self {
153            config: InferenceConfig::default(),
154            extractor: LinguisticExtractor::new(),
155            bayesian: BayesianUpdater::with_config(bayesian_config),
156            delta_analyzer: DeltaAnalyzer::new(),
157        }
158    }
159
160    /// Infer state from a single message (no baseline).
161    ///
162    /// Uses linguistic features only.
163    pub fn infer(&self, text: &str) -> InferredState {
164        let features = self.extractor.extract(text);
165        self.infer_from_features(&features)
166    }
167
168    /// Infer from pre-extracted features.
169    ///
170    /// Applies research-validated confidence scaling (TASK-016):
171    /// - Word count scaling (shorter texts = lower confidence)
172    /// - Axis-specific caps (based on evidence strength)
173    pub fn infer_from_features(&self, features: &LinguisticFeatures) -> InferredState {
174        let mut state = InferredState::new();
175
176        // Apply word count confidence scaling (TASK-016)
177        // Research: ~100 words needed for stable inference
178        let word_count_factor = word_count_confidence_factor(features.word_count);
179
180        // Map linguistic features to axis estimates
181        let mappings = self.linguistic_to_axes(features);
182
183        for (axis, value, features_used) in mappings {
184            let prior = self.get_prior(&axis);
185            let obs = Observation::from_linguistic(value, features_used.clone());
186            let mut estimate = self.bayesian.update(&axis, &prior, &obs);
187
188            // Apply word count scaling to confidence
189            estimate.confidence *= word_count_factor;
190
191            // Apply axis-specific confidence cap (TASK-016)
192            let axis_cap = max_confidence_for_axis(&axis);
193            estimate.confidence = estimate.confidence.min(axis_cap);
194
195            // Recalculate variance from adjusted confidence
196            estimate.variance =
197                crate::estimate::AxisEstimate::confidence_to_variance(estimate.confidence);
198
199            if estimate.confidence >= self.config.min_confidence {
200                state.update(estimate);
201            }
202        }
203
204        state
205    }
206
207    /// Infer with baseline context (enables delta analysis).
208    ///
209    /// This is the full pipeline: linguistic + delta + Bayesian.
210    pub fn infer_with_baseline(
211        &self,
212        text: &str,
213        baseline: &mut Baseline,
214        current_state: Option<&InferredState>,
215    ) -> InferredState {
216        let features = self.extractor.extract(text);
217        self.infer_with_features_and_baseline(&features, baseline, current_state)
218    }
219
220    /// Full inference from features with baseline.
221    pub fn infer_with_features_and_baseline(
222        &self,
223        features: &LinguisticFeatures,
224        baseline: &mut Baseline,
225        current_state: Option<&InferredState>,
226    ) -> InferredState {
227        let mut state = InferredState::new();
228
229        // Get linguistic mappings
230        let linguistic_mappings = self.linguistic_to_axes(features);
231
232        // Get delta signals if baseline is ready
233        let delta_signals = if self.config.enable_delta_analysis && baseline.is_ready() {
234            Some(self.delta_analyzer.analyze_and_update(baseline, features))
235        } else {
236            if self.config.enable_delta_analysis {
237                baseline.add(features);
238            }
239            None
240        };
241
242        // Process each axis we have signal for
243        let mut axis_observations: HashMap<String, Vec<Observation>> = HashMap::new();
244
245        // Add linguistic observations
246        for (axis, value, features_used) in linguistic_mappings {
247            let obs = Observation::new(
248                value,
249                0.04, // Linguistic observation noise
250                InferenceSource::Linguistic {
251                    features_used,
252                    feature_values: HashMap::new(),
253                },
254            );
255            axis_observations.entry(axis).or_default().push(obs);
256        }
257
258        // Add delta observations
259        if let Some(ref signals) = delta_signals {
260            let adjustments = self.delta_analyzer.to_axis_adjustments(signals);
261            for (axis, adjustment) in adjustments {
262                // Delta gives relative adjustment, need to convert to absolute
263                let base_value = current_state
264                    .and_then(|s| s.get(axis))
265                    .map(|e| e.value)
266                    .unwrap_or(0.5);
267
268                let delta_value = (base_value + adjustment).clamp(0.0, 1.0);
269
270                let (metric, z) = signals.max_deviation();
271                let obs = Observation::from_delta(
272                    delta_value,
273                    z,
274                    metric.to_string(),
275                    signals.baseline_size,
276                );
277                axis_observations
278                    .entry(axis.to_string())
279                    .or_default()
280                    .push(obs);
281            }
282        }
283
284        // Combine observations through Bayesian updater
285        for (axis, observations) in axis_observations {
286            // Use current estimate as prior if available
287            let prior = current_state
288                .and_then(|s| s.get(&axis))
289                .map(|e| Prior {
290                    mean: e.value,
291                    variance: e.variance,
292                    reason: "previous estimate".to_string(),
293                })
294                .unwrap_or_else(|| self.get_prior(&axis));
295
296            let estimate = self
297                .bayesian
298                .combine_observations(&axis, &prior, &observations);
299
300            if estimate.confidence >= self.config.min_confidence {
301                state.update(estimate);
302            }
303        }
304
305        state
306    }
307
308    /// Map linguistic features to axis values.
309    ///
310    /// Returns (axis_name, value, features_used) tuples.
311    fn linguistic_to_axes(&self, f: &LinguisticFeatures) -> Vec<(String, f32, Vec<String>)> {
312        let mut mappings = Vec::new();
313
314        // Cognitive axes
315        // Higher complexity = higher tolerance for complexity
316        mappings.push((
317            "tolerance_for_complexity".into(),
318            f.complexity_score(),
319            vec![
320                "reading_grade_level".into(),
321                "avg_sentence_length".into(),
322                "long_word_ratio".into(),
323            ],
324        ));
325
326        // Urgency
327        mappings.push((
328            "urgency_sensitivity".into(),
329            f.urgency_score(),
330            vec![
331                "urgency_word_count".into(),
332                "imperative_count".into(),
333                "exclamation_ratio".into(),
334            ],
335        ));
336
337        // Emotional axes
338        // High emotional intensity (exclamation, caps)
339        // Renamed from emotional_openness - we measure intensity, not openness (TASK-016)
340        if f.emotional_intensity() > 0.3 {
341            mappings.push((
342                "emotional_intensity".into(),
343                f.emotional_intensity(),
344                vec!["exclamation_ratio".into(), "caps_ratio".into()],
345            ));
346        }
347
348        // Anxiety/stress - using research-validated score (TASK-016)
349        // Now incorporates negative emotions + first-person + uncertainty + absolutist
350        // Validated on Dreaddit: F1 improved 16.7% over uncertainty_score alone
351        let anxiety = f.anxiety_score();
352        if anxiety > 0.2 {
353            mappings.push((
354                "anxiety_level".into(),
355                anxiety,
356                vec![
357                    "negative_emotion_density".into(),
358                    "first_person_ratio".into(),
359                    "hedge_density".into(),
360                    "absolutist_density".into(),
361                ],
362            ));
363        }
364
365        // Social axes
366        // Formality
367        mappings.push((
368            "formality".into(),
369            f.formality_score(),
370            vec!["contraction_ratio".into(), "complexity_score".into()],
371        ));
372
373        // Warmth - informal + positive emotional signals = warmth
374        let warmth = (1.0 - f.formality_score()) * 0.5
375            + f.emotional_intensity() * 0.3
376            + (f.politeness_count as f32 / 3.0).clamp(0.0, 1.0) * 0.2;
377        if f.politeness_count > 0 || f.emotional_intensity() > 0.2 {
378            mappings.push((
379                "warmth".into(),
380                warmth.clamp(0.0, 1.0),
381                vec![
382                    "politeness_count".into(),
383                    "emotional_intensity".into(),
384                    "formality".into(),
385                ],
386            ));
387        }
388
389        // Assertiveness - certainty markers, imperatives, low hedging
390        let assertiveness = f.certainty_count as f32 / 2.0 * 0.4
391            + f.imperative_count as f32 / 2.0 * 0.3
392            + (1.0 - f.uncertainty_score()) * 0.3;
393        if f.certainty_count > 0 || f.imperative_count > 0 {
394            mappings.push((
395                "assertiveness".into(),
396                assertiveness.clamp(0.0, 1.0),
397                vec![
398                    "certainty_count".into(),
399                    "imperative_count".into(),
400                    "hedge_count".into(),
401                ],
402            ));
403        }
404
405        // Preferences
406        // Verbosity - based on message length relative to typical
407        let verbosity = (f.word_count as f32 / 50.0).clamp(0.0, 1.0);
408        mappings.push((
409            "verbosity_preference".into(),
410            verbosity,
411            vec!["word_count".into()],
412        ));
413
414        // Directness - low hedging, high certainty, imperative usage
415        let directness = (1.0 - f.uncertainty_score()) * 0.5
416            + (f.certainty_count as f32 / 2.0).clamp(0.0, 1.0) * 0.3
417            + (f.imperative_count as f32 / 2.0).clamp(0.0, 1.0) * 0.2;
418        mappings.push((
419            "directness_preference".into(),
420            directness.clamp(0.0, 1.0),
421            vec!["hedge_density".into(), "certainty_count".into()],
422        ));
423
424        // Ritual need - politeness markers, formal greeting patterns
425        if f.politeness_count > 0 {
426            mappings.push((
427                "ritual_need".into(),
428                (f.politeness_count as f32 / 3.0).clamp(0.0, 1.0),
429                vec!["politeness_count".into()],
430            ));
431        }
432
433        mappings
434    }
435
436    /// Get prior for an axis.
437    fn get_prior(&self, axis: &str) -> Prior {
438        self.config
439            .default_priors
440            .get(axis)
441            .cloned()
442            .unwrap_or_else(Prior::neutral)
443    }
444
445    /// Extract features without inference (useful for external analysis).
446    pub fn extract_features(&self, text: &str) -> LinguisticFeatures {
447        self.extractor.extract(text)
448    }
449
450    /// Create a new baseline tracker.
451    pub fn new_baseline(&self) -> Baseline {
452        Baseline::new(self.config.baseline_window)
453    }
454}
455
456impl Default for InferenceEngine {
457    fn default() -> Self {
458        Self {
459            config: InferenceConfig::default(),
460            extractor: LinguisticExtractor::new(),
461            bayesian: BayesianUpdater::new(),
462            delta_analyzer: DeltaAnalyzer::new(),
463        }
464    }
465}
466
467/// Quick inference function for simple use cases.
468///
469/// Creates a default engine and infers state from text.
470pub fn infer(text: &str) -> InferredState {
471    InferenceEngine::new().infer(text)
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    #[test]
479    fn test_basic_inference() {
480        let engine = InferenceEngine::new();
481        // Need sufficient text for confidence (word count scaling)
482        let state = engine.infer(
483            "Hello, how are you doing today? I hope everything is going well with your project. \
484             I wanted to reach out and see if you have any updates on the proposal we discussed \
485             last week. Please let me know when you have a moment to chat about it.",
486        );
487
488        // Should have some estimates
489        assert!(!state.is_empty());
490    }
491
492    #[test]
493    fn test_urgent_message() {
494        let engine = InferenceEngine::new();
495        // Longer urgent message for better confidence
496        let state = engine.infer(
497            "URGENT! I need help immediately! This is absolutely critical and cannot wait! \
498             The system is down and customers are affected. Please respond ASAP! We need \
499             to fix this right now before it gets worse! This is an emergency situation!",
500        );
501
502        let urgency = state.get("urgency_sensitivity");
503        assert!(urgency.is_some());
504        assert!(urgency.unwrap().value > 0.5);
505    }
506
507    #[test]
508    fn test_formal_message() {
509        let engine = InferenceEngine::new();
510        let state = engine.infer(
511            "Dear Sir or Madam, I am writing to formally inquire about the current status \
512             of my application for the senior developer position. I submitted my application \
513             materials on the first of the month and would greatly appreciate any update \
514             you could provide regarding the review process. I would be most grateful for \
515             your prompt response to this matter. Yours sincerely and with respect.",
516        );
517
518        let formality = state.get("formality");
519        assert!(formality.is_some());
520        assert!(formality.unwrap().value > 0.5);
521    }
522
523    #[test]
524    fn test_anxious_hedging() {
525        let engine = InferenceEngine::new();
526        let state = engine.infer(
527            "I think maybe this might be a problem? I'm not really sure but perhaps \
528             we should probably look into it, if that's okay? I'm worried this could \
529             cause issues later. I feel anxious about the whole situation and I'm \
530             struggling to figure out what to do. Maybe I'm overthinking it though?",
531        );
532
533        // Should detect anxiety/uncertainty
534        let anxiety = state.get("anxiety_level");
535        assert!(anxiety.is_some());
536        assert!(anxiety.unwrap().value > 0.3); // Lowered threshold due to confidence scaling
537    }
538
539    #[test]
540    fn test_baseline_integration() {
541        let engine = InferenceEngine::new();
542        let mut baseline = engine.new_baseline();
543
544        // Build baseline
545        for _ in 0..10 {
546            let state = engine.infer_with_baseline(
547                "Here is a normal question about your service.",
548                &mut baseline,
549                None,
550            );
551            // Early iterations won't have delta signals
552            let _ = state;
553        }
554
555        // Now test with different message
556        let state =
557            engine.infer_with_baseline("HELP! Everything is broken!!!", &mut baseline, None);
558
559        // Should detect the deviation
560        assert!(!state.is_empty());
561    }
562
563    #[test]
564    fn test_all_estimates_have_source() {
565        let engine = InferenceEngine::new();
566        let state = engine.infer("Please help me understand this complex topic in detail.");
567
568        for estimate in state.all() {
569            assert!(estimate.source.is_inferred());
570            let summary = estimate.source.summary();
571            assert!(!summary.is_empty());
572        }
573    }
574
575    #[test]
576    fn test_confidence_bounded() {
577        let engine = InferenceEngine::new();
578        let state = engine.infer("URGENT URGENT URGENT!!! HELP NOW!!!");
579
580        for estimate in state.all() {
581            assert!(estimate.confidence <= crate::estimate::MAX_INFERRED_CONFIDENCE);
582        }
583    }
584
585    #[test]
586    fn test_quick_inference_function() {
587        // Need sufficient text for word count confidence scaling
588        let state = infer(
589            "Hello world, this is a test message to verify the inference function works \
590             correctly with enough words to pass the confidence threshold for analysis.",
591        );
592        assert!(!state.is_empty());
593    }
594}