Skip to main content

proof_engine/ai/
utility.rs

1//! Utility AI — scoring-based decision making.
2//!
3//! Each action has a set of considerations (response curves) that evaluate
4//! world state and produce a [0, 1] score. The action with the highest
5//! combined score is selected each tick.
6
7use std::collections::HashMap;
8
9// ── ResponseCurve ─────────────────────────────────────────────────────────────
10
11/// Maps a raw input [0, 1] to a utility score [0, 1].
12#[derive(Debug, Clone)]
13pub enum ResponseCurve {
14    /// Linear pass-through.
15    Linear { slope: f32, intercept: f32 },
16    /// Quadratic: ax² + b.
17    Quadratic { a: f32, b: f32 },
18    /// Logistic sigmoid.
19    Logistic { k: f32, x0: f32 },
20    /// Step threshold.
21    Step { threshold: f32 },
22    /// Exponential curve.
23    Exponential { base: f32, k: f32 },
24    /// Sine wave mapped to [0, 1].
25    Sine { amplitude: f32, phase: f32 },
26    /// Custom lookup table (piecewise linear).
27    Table(Vec<(f32, f32)>),
28}
29
30impl ResponseCurve {
31    pub fn evaluate(&self, x: f32) -> f32 {
32        let x = x.clamp(0.0, 1.0);
33        match self {
34            ResponseCurve::Linear { slope, intercept } => (slope * x + intercept).clamp(0.0, 1.0),
35            ResponseCurve::Quadratic { a, b } => (a * x * x + b).clamp(0.0, 1.0),
36            ResponseCurve::Logistic { k, x0 } => {
37                1.0 / (1.0 + (-k * (x - x0)).exp())
38            }
39            ResponseCurve::Step { threshold } => if x >= *threshold { 1.0 } else { 0.0 },
40            ResponseCurve::Exponential { base, k } => {
41                (base.powf(k * x) - 1.0) / (base.powf(*k) - 1.0).max(1e-10)
42            }
43            ResponseCurve::Sine { amplitude, phase } => {
44                let v = (x * std::f32::consts::PI + phase).sin() * amplitude;
45                (v * 0.5 + 0.5).clamp(0.0, 1.0)
46            }
47            ResponseCurve::Table(pts) => {
48                if pts.is_empty() { return 0.0; }
49                let idx = pts.partition_point(|(px, _)| *px <= x);
50                if idx == 0 { return pts[0].1; }
51                if idx >= pts.len() { return pts.last().unwrap().1; }
52                let (x0, y0) = pts[idx - 1];
53                let (x1, y1) = pts[idx];
54                if (x1 - x0).abs() < 1e-6 { return y1; }
55                let t = (x - x0) / (x1 - x0);
56                y0 + t * (y1 - y0)
57            }
58        }
59    }
60}
61
62// ── Consideration ─────────────────────────────────────────────────────────────
63
64/// A single input consideration for an action.
65#[derive(Debug, Clone)]
66pub struct Consideration {
67    pub name:   String,
68    /// How to evaluate this consideration's raw input.
69    pub curve:  ResponseCurve,
70    /// Weight in the final multiplication.
71    pub weight: f32,
72}
73
74impl Consideration {
75    pub fn new(name: &str, curve: ResponseCurve) -> Self {
76        Self { name: name.to_string(), curve, weight: 1.0 }
77    }
78
79    pub fn evaluate(&self, raw_input: f32) -> f32 {
80        self.curve.evaluate(raw_input) * self.weight
81    }
82}
83
84// ── UtilityAction ─────────────────────────────────────────────────────────────
85
86/// An action with a set of considerations.
87pub struct UtilityAction<W> {
88    pub name:           String,
89    pub considerations: Vec<Consideration>,
90    /// Minimum score threshold to select this action.
91    pub min_threshold:  f32,
92    /// Bonus score added when this action was selected last tick (momentum).
93    pub momentum:       f32,
94    /// Score normalisation: geometric mean vs product.
95    pub use_geo_mean:   bool,
96    /// The action to execute when selected.
97    pub execute:        Box<dyn Fn(&mut W, f32) -> bool + Send + Sync>,
98    /// Input provider: maps W → [0, 1] per consideration name.
99    pub input_provider: Box<dyn Fn(&W) -> HashMap<String, f32> + Send + Sync>,
100}
101
102impl<W> UtilityAction<W> {
103    pub fn new(
104        name: &str,
105        execute: impl Fn(&mut W, f32) -> bool + Send + Sync + 'static,
106        input_provider: impl Fn(&W) -> HashMap<String, f32> + Send + Sync + 'static,
107    ) -> Self {
108        Self {
109            name: name.to_string(),
110            considerations: Vec::new(),
111            min_threshold: 0.0,
112            momentum: 0.0,
113            use_geo_mean: true,
114            execute: Box::new(execute),
115            input_provider: Box::new(input_provider),
116        }
117    }
118
119    pub fn with_consideration(mut self, c: Consideration) -> Self {
120        self.considerations.push(c);
121        self
122    }
123
124    /// Evaluate this action's combined utility score.
125    pub fn score(&self, world: &W, is_current: bool) -> f32 {
126        if self.considerations.is_empty() { return 0.5; }
127        let inputs = (self.input_provider)(world);
128        let n = self.considerations.len() as f32;
129
130        let product: f32 = self.considerations.iter().map(|c| {
131            let raw = inputs.get(&c.name).copied().unwrap_or(0.0);
132            c.evaluate(raw)
133        }).product();
134
135        let score = if self.use_geo_mean && n > 1.0 {
136            product.powf(1.0 / n)
137        } else { product };
138
139        let momentum_bonus = if is_current { self.momentum } else { 0.0 };
140        (score + momentum_bonus).clamp(0.0, 1.0)
141    }
142}
143
144// ── UtilitySelector ───────────────────────────────────────────────────────────
145
146/// Selects the highest-scoring action each tick.
147pub struct UtilitySelector<W> {
148    pub name:        String,
149    actions:         Vec<UtilityAction<W>>,
150    current_action:  Option<usize>,
151    pub last_scores: Vec<f32>,
152}
153
154impl<W> UtilitySelector<W> {
155    pub fn new(name: &str) -> Self {
156        Self { name: name.to_string(), actions: Vec::new(), current_action: None, last_scores: Vec::new() }
157    }
158
159    pub fn add_action(mut self, action: UtilityAction<W>) -> Self {
160        self.actions.push(action);
161        self
162    }
163
164    /// Evaluate all actions and execute the best.
165    /// Returns true if an action was executed.
166    pub fn tick(&mut self, world: &mut W, dt: f32) -> bool {
167        if self.actions.is_empty() { return false; }
168
169        // Score all actions
170        self.last_scores = self.actions.iter().enumerate().map(|(i, a)| {
171            a.score(world, self.current_action == Some(i))
172        }).collect();
173
174        // Find best above threshold
175        let best = self.last_scores.iter().enumerate()
176            .filter(|(i, &s)| s >= self.actions[*i].min_threshold)
177            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
178
179        if let Some((idx, _score)) = best {
180            self.current_action = Some(idx);
181            (self.actions[idx].execute)(world, dt)
182        } else {
183            self.current_action = None;
184            false
185        }
186    }
187
188    pub fn current_action_name(&self) -> Option<&str> {
189        self.current_action.map(|i| self.actions[i].name.as_str())
190    }
191
192    pub fn action_count(&self) -> usize { self.actions.len() }
193}
194
195// ── Tests ─────────────────────────────────────────────────────────────────────
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_linear_curve() {
203        let c = ResponseCurve::Linear { slope: 1.0, intercept: 0.0 };
204        assert!((c.evaluate(0.5) - 0.5).abs() < 0.001);
205        assert!((c.evaluate(1.5) - 1.0).abs() < 0.001); // clamped
206    }
207
208    #[test]
209    fn test_logistic_curve() {
210        let c = ResponseCurve::Logistic { k: 10.0, x0: 0.5 };
211        assert!(c.evaluate(0.5) > 0.45 && c.evaluate(0.5) < 0.55);
212        assert!(c.evaluate(0.9) > 0.9);
213        assert!(c.evaluate(0.1) < 0.1);
214    }
215
216    #[test]
217    fn test_step_curve() {
218        let c = ResponseCurve::Step { threshold: 0.5 };
219        assert_eq!(c.evaluate(0.3), 0.0);
220        assert_eq!(c.evaluate(0.7), 1.0);
221    }
222
223    #[test]
224    fn test_table_curve() {
225        let c = ResponseCurve::Table(vec![(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)]);
226        assert!((c.evaluate(0.25) - 0.4).abs() < 0.001); // midpoint between (0,0) and (0.5,0.8)
227    }
228
229    #[test]
230    fn test_utility_selector_picks_highest() {
231        struct World { health: f32, threat: f32, last_action: String }
232
233        let flee = UtilityAction::new(
234            "flee",
235            |w: &mut World, _| { w.last_action = "flee".to_string(); true },
236            |w: &World| { let mut m = HashMap::new(); m.insert("health".to_string(), 1.0 - w.health); m },
237        ).with_consideration(Consideration::new("health", ResponseCurve::Linear { slope: 1.0, intercept: 0.0 }));
238
239        let attack = UtilityAction::new(
240            "attack",
241            |w: &mut World, _| { w.last_action = "attack".to_string(); true },
242            |w: &World| { let mut m = HashMap::new(); m.insert("threat".to_string(), w.threat); m },
243        ).with_consideration(Consideration::new("threat", ResponseCurve::Linear { slope: 1.0, intercept: 0.0 }));
244
245        let mut selector = UtilitySelector::new("combat")
246            .add_action(flee)
247            .add_action(attack);
248
249        let mut world = World { health: 0.1, threat: 0.5, last_action: String::new() };
250        selector.tick(&mut world, 0.016);
251        assert_eq!(world.last_action, "flee", "low health should prefer fleeing");
252
253        world.health = 0.9;
254        world.threat = 0.9;
255        selector.tick(&mut world, 0.016);
256        assert_eq!(world.last_action, "attack", "high health + high threat → attack");
257    }
258}