1use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
13pub enum ResponseCurve {
14 Linear { slope: f32, intercept: f32 },
16 Quadratic { a: f32, b: f32 },
18 Logistic { k: f32, x0: f32 },
20 Step { threshold: f32 },
22 Exponential { base: f32, k: f32 },
24 Sine { amplitude: f32, phase: f32 },
26 Table(Vec<(f32, f32)>),
28}
29
30impl ResponseCurve {
31 pub fn evaluate(&self, x: f32) -> f32 {
32 let x = x.clamp(0.0, 1.0);
33 match self {
34 ResponseCurve::Linear { slope, intercept } => (slope * x + intercept).clamp(0.0, 1.0),
35 ResponseCurve::Quadratic { a, b } => (a * x * x + b).clamp(0.0, 1.0),
36 ResponseCurve::Logistic { k, x0 } => {
37 1.0 / (1.0 + (-k * (x - x0)).exp())
38 }
39 ResponseCurve::Step { threshold } => if x >= *threshold { 1.0 } else { 0.0 },
40 ResponseCurve::Exponential { base, k } => {
41 (base.powf(k * x) - 1.0) / (base.powf(*k) - 1.0).max(1e-10)
42 }
43 ResponseCurve::Sine { amplitude, phase } => {
44 let v = (x * std::f32::consts::PI + phase).sin() * amplitude;
45 (v * 0.5 + 0.5).clamp(0.0, 1.0)
46 }
47 ResponseCurve::Table(pts) => {
48 if pts.is_empty() { return 0.0; }
49 let idx = pts.partition_point(|(px, _)| *px <= x);
50 if idx == 0 { return pts[0].1; }
51 if idx >= pts.len() { return pts.last().unwrap().1; }
52 let (x0, y0) = pts[idx - 1];
53 let (x1, y1) = pts[idx];
54 if (x1 - x0).abs() < 1e-6 { return y1; }
55 let t = (x - x0) / (x1 - x0);
56 y0 + t * (y1 - y0)
57 }
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
66pub struct Consideration {
67 pub name: String,
68 pub curve: ResponseCurve,
70 pub weight: f32,
72}
73
74impl Consideration {
75 pub fn new(name: &str, curve: ResponseCurve) -> Self {
76 Self { name: name.to_string(), curve, weight: 1.0 }
77 }
78
79 pub fn evaluate(&self, raw_input: f32) -> f32 {
80 self.curve.evaluate(raw_input) * self.weight
81 }
82}
83
84pub struct UtilityAction<W> {
88 pub name: String,
89 pub considerations: Vec<Consideration>,
90 pub min_threshold: f32,
92 pub momentum: f32,
94 pub use_geo_mean: bool,
96 pub execute: Box<dyn Fn(&mut W, f32) -> bool + Send + Sync>,
98 pub input_provider: Box<dyn Fn(&W) -> HashMap<String, f32> + Send + Sync>,
100}
101
102impl<W> UtilityAction<W> {
103 pub fn new(
104 name: &str,
105 execute: impl Fn(&mut W, f32) -> bool + Send + Sync + 'static,
106 input_provider: impl Fn(&W) -> HashMap<String, f32> + Send + Sync + 'static,
107 ) -> Self {
108 Self {
109 name: name.to_string(),
110 considerations: Vec::new(),
111 min_threshold: 0.0,
112 momentum: 0.0,
113 use_geo_mean: true,
114 execute: Box::new(execute),
115 input_provider: Box::new(input_provider),
116 }
117 }
118
119 pub fn with_consideration(mut self, c: Consideration) -> Self {
120 self.considerations.push(c);
121 self
122 }
123
124 pub fn score(&self, world: &W, is_current: bool) -> f32 {
126 if self.considerations.is_empty() { return 0.5; }
127 let inputs = (self.input_provider)(world);
128 let n = self.considerations.len() as f32;
129
130 let product: f32 = self.considerations.iter().map(|c| {
131 let raw = inputs.get(&c.name).copied().unwrap_or(0.0);
132 c.evaluate(raw)
133 }).product();
134
135 let score = if self.use_geo_mean && n > 1.0 {
136 product.powf(1.0 / n)
137 } else { product };
138
139 let momentum_bonus = if is_current { self.momentum } else { 0.0 };
140 (score + momentum_bonus).clamp(0.0, 1.0)
141 }
142}
143
144pub struct UtilitySelector<W> {
148 pub name: String,
149 actions: Vec<UtilityAction<W>>,
150 current_action: Option<usize>,
151 pub last_scores: Vec<f32>,
152}
153
154impl<W> UtilitySelector<W> {
155 pub fn new(name: &str) -> Self {
156 Self { name: name.to_string(), actions: Vec::new(), current_action: None, last_scores: Vec::new() }
157 }
158
159 pub fn add_action(mut self, action: UtilityAction<W>) -> Self {
160 self.actions.push(action);
161 self
162 }
163
164 pub fn tick(&mut self, world: &mut W, dt: f32) -> bool {
167 if self.actions.is_empty() { return false; }
168
169 self.last_scores = self.actions.iter().enumerate().map(|(i, a)| {
171 a.score(world, self.current_action == Some(i))
172 }).collect();
173
174 let best = self.last_scores.iter().enumerate()
176 .filter(|(i, &s)| s >= self.actions[*i].min_threshold)
177 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
178
179 if let Some((idx, _score)) = best {
180 self.current_action = Some(idx);
181 (self.actions[idx].execute)(world, dt)
182 } else {
183 self.current_action = None;
184 false
185 }
186 }
187
188 pub fn current_action_name(&self) -> Option<&str> {
189 self.current_action.map(|i| self.actions[i].name.as_str())
190 }
191
192 pub fn action_count(&self) -> usize { self.actions.len() }
193}
194
195#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_linear_curve() {
203 let c = ResponseCurve::Linear { slope: 1.0, intercept: 0.0 };
204 assert!((c.evaluate(0.5) - 0.5).abs() < 0.001);
205 assert!((c.evaluate(1.5) - 1.0).abs() < 0.001); }
207
208 #[test]
209 fn test_logistic_curve() {
210 let c = ResponseCurve::Logistic { k: 10.0, x0: 0.5 };
211 assert!(c.evaluate(0.5) > 0.45 && c.evaluate(0.5) < 0.55);
212 assert!(c.evaluate(0.9) > 0.9);
213 assert!(c.evaluate(0.1) < 0.1);
214 }
215
216 #[test]
217 fn test_step_curve() {
218 let c = ResponseCurve::Step { threshold: 0.5 };
219 assert_eq!(c.evaluate(0.3), 0.0);
220 assert_eq!(c.evaluate(0.7), 1.0);
221 }
222
223 #[test]
224 fn test_table_curve() {
225 let c = ResponseCurve::Table(vec![(0.0, 0.0), (0.5, 0.8), (1.0, 1.0)]);
226 assert!((c.evaluate(0.25) - 0.4).abs() < 0.001); }
228
229 #[test]
230 fn test_utility_selector_picks_highest() {
231 struct World { health: f32, threat: f32, last_action: String }
232
233 let flee = UtilityAction::new(
234 "flee",
235 |w: &mut World, _| { w.last_action = "flee".to_string(); true },
236 |w: &World| { let mut m = HashMap::new(); m.insert("health".to_string(), 1.0 - w.health); m },
237 ).with_consideration(Consideration::new("health", ResponseCurve::Linear { slope: 1.0, intercept: 0.0 }));
238
239 let attack = UtilityAction::new(
240 "attack",
241 |w: &mut World, _| { w.last_action = "attack".to_string(); true },
242 |w: &World| { let mut m = HashMap::new(); m.insert("threat".to_string(), w.threat); m },
243 ).with_consideration(Consideration::new("threat", ResponseCurve::Linear { slope: 1.0, intercept: 0.0 }));
244
245 let mut selector = UtilitySelector::new("combat")
246 .add_action(flee)
247 .add_action(attack);
248
249 let mut world = World { health: 0.1, threat: 0.5, last_action: String::new() };
250 selector.tick(&mut world, 0.016);
251 assert_eq!(world.last_action, "flee", "low health should prefer fleeing");
252
253 world.health = 0.9;
254 world.threat = 0.9;
255 selector.tick(&mut world, 0.016);
256 assert_eq!(world.last_action, "attack", "high health + high threat → attack");
257 }
258}