Skip to main content

proof_engine/ml/
ai_opponent.rs

1//! AI opponent with policy/value networks, adaptive difficulty, and playstyle tracking.
2
3use super::tensor::Tensor;
4use super::model::{Model, Sequential};
5
6/// Possible actions for the AI opponent.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Action {
9    Move,
10    Attack,
11    UseAbility,
12    UseItem,
13    Wait,
14}
15
16impl Action {
17    pub const ALL: [Action; 5] = [
18        Action::Move,
19        Action::Attack,
20        Action::UseAbility,
21        Action::UseItem,
22        Action::Wait,
23    ];
24
25    pub fn from_index(idx: usize) -> Self {
26        Self::ALL[idx % Self::ALL.len()]
27    }
28
29    pub fn index(&self) -> usize {
30        match self {
31            Action::Move => 0,
32            Action::Attack => 1,
33            Action::UseAbility => 2,
34            Action::UseItem => 3,
35            Action::Wait => 4,
36        }
37    }
38}
39
40/// Compact game state encoded as a tensor.
41#[derive(Debug, Clone)]
42pub struct GameState {
43    /// Flattened state vector. Layout:
44    /// [player_hp, player_mp, player_x, player_y,
45    ///  enemy_hp, enemy_mp, enemy_x, enemy_y,
46    ///  ...additional features...]
47    pub features: Tensor,
48}
49
50impl GameState {
51    /// Create a game state from raw feature values.
52    pub fn new(features: Vec<f32>) -> Self {
53        let n = features.len();
54        Self { features: Tensor::from_vec(features, vec![1, n]) }
55    }
56
57    /// Create a default 16-feature game state.
58    pub fn default_state() -> Self {
59        Self::new(vec![
60            100.0, 50.0, 5.0, 5.0,   // player: hp, mp, x, y
61            100.0, 50.0, 10.0, 10.0,  // enemy: hp, mp, x, y
62            0.0, 0.0, 0.0, 0.0,       // ability cooldowns
63            1.0, 0.0, 0.0, 0.0,       // flags (phase, items, etc.)
64        ])
65    }
66
67    pub fn feature_dim(&self) -> usize {
68        self.features.data.len()
69    }
70}
71
72/// AI brain with policy and value networks.
73pub struct AIBrain {
74    pub policy_net: Model,
75    pub value_net: Model,
76}
77
78impl AIBrain {
79    /// Create an AI brain for a given state dimension.
80    pub fn new(state_dim: usize) -> Self {
81        let policy_net = Sequential::new("policy")
82            .dense(state_dim, 64)
83            .relu()
84            .dense(64, 32)
85            .relu()
86            .dense(32, Action::ALL.len())
87            .build();
88
89        let value_net = Sequential::new("value")
90            .dense(state_dim, 64)
91            .relu()
92            .dense(64, 32)
93            .relu()
94            .dense(32, 1)
95            .build();
96
97        Self { policy_net, value_net }
98    }
99
100    /// Select an action using softmax sampling with temperature.
101    /// Higher temperature = more random; lower = more greedy.
102    pub fn select_action(&self, state: &GameState, temperature: f32) -> Action {
103        let logits = self.policy_net.forward(&state.features);
104        // Apply temperature
105        let scaled: Vec<f32> = logits.data.iter().map(|&v| v / temperature.max(0.01)).collect();
106        let scaled_tensor = Tensor::from_vec(scaled, logits.shape.clone());
107        let probs = scaled_tensor.softmax(if scaled_tensor.shape.len() > 1 { 1 } else { 0 });
108
109        // Sample from the distribution using a simple RNG based on state data
110        let seed: u64 = state.features.data.iter()
111            .map(|v| (v.to_bits() as u64).wrapping_mul(2654435761))
112            .fold(0u64, |a, b| a.wrapping_add(b));
113        let mut rng_state = seed.wrapping_add(1);
114        rng_state ^= rng_state << 13;
115        rng_state ^= rng_state >> 7;
116        rng_state ^= rng_state << 17;
117        let sample = (rng_state as u32 as f32) / (u32::MAX as f32);
118
119        let prob_data = &probs.data;
120        let num_actions = Action::ALL.len();
121        // Find the probabilities for the last num_actions elements
122        let start = prob_data.len().saturating_sub(num_actions);
123        let action_probs = &prob_data[start..];
124
125        let mut cumulative = 0.0f32;
126        for (i, &p) in action_probs.iter().enumerate() {
127            cumulative += p;
128            if sample <= cumulative {
129                return Action::from_index(i);
130            }
131        }
132        Action::Wait
133    }
134
135    /// Evaluate how favorable a state is (higher = better for AI).
136    pub fn evaluate_state(&self, state: &GameState) -> f32 {
137        let value = self.value_net.forward(&state.features);
138        // Return the scalar output (use tanh to bound in [-1, 1])
139        value.data.last().copied().unwrap_or(0.0).tanh()
140    }
141}
142
143/// Adaptive AI that blends optimal and random actions based on difficulty.
144pub struct AdaptiveAI {
145    pub brain: AIBrain,
146    /// Difficulty in [0, 1]. 0 = fully random, 1 = fully optimal.
147    pub difficulty: f32,
148    /// Running score differential to auto-adjust difficulty.
149    score_differential: f32,
150    pub adaptation_rate: f32,
151}
152
153impl AdaptiveAI {
154    pub fn new(state_dim: usize, difficulty: f32) -> Self {
155        Self {
156            brain: AIBrain::new(state_dim),
157            difficulty: difficulty.clamp(0.0, 1.0),
158            score_differential: 0.0,
159            adaptation_rate: 0.05,
160        }
161    }
162
163    /// Select an action, blending optimal with random based on difficulty.
164    pub fn select_action(&self, state: &GameState) -> Action {
165        // Use high temperature (random) for low difficulty, low temperature (greedy) for high
166        let temperature = 0.1 + (1.0 - self.difficulty) * 5.0;
167        self.brain.select_action(state, temperature)
168    }
169
170    /// Update difficulty based on whether the player won or lost the last encounter.
171    /// `player_won`: true if player won.
172    pub fn update_difficulty(&mut self, player_won: bool) {
173        if player_won {
174            // Player is winning: increase difficulty
175            self.score_differential += 1.0;
176        } else {
177            // AI is winning: decrease difficulty
178            self.score_differential -= 1.0;
179        }
180        // Adjust difficulty towards balancing the score differential
181        self.difficulty += self.adaptation_rate * self.score_differential.signum() * 0.1;
182        self.difficulty = self.difficulty.clamp(0.0, 1.0);
183        // Decay differential
184        self.score_differential *= 0.9;
185    }
186}
187
188/// Tracks player behavior patterns over time.
189pub struct PlaystyleTracker {
190    /// Counts of each action type observed from the player.
191    pub action_counts: [u32; 5],
192    /// Total actions observed.
193    pub total_actions: u32,
194    /// Running aggression score (attacks / total).
195    pub aggression: f32,
196    /// Running caution score (waits and items / total).
197    pub caution: f32,
198    /// Ability usage rate.
199    pub ability_usage: f32,
200    /// History window for recent actions.
201    pub history: Vec<Action>,
202    pub history_max: usize,
203}
204
205impl PlaystyleTracker {
206    pub fn new() -> Self {
207        Self {
208            action_counts: [0; 5],
209            total_actions: 0,
210            aggression: 0.0,
211            caution: 0.0,
212            ability_usage: 0.0,
213            history: Vec::new(),
214            history_max: 100,
215        }
216    }
217
218    /// Record a player action and update statistics.
219    pub fn record(&mut self, action: Action) {
220        self.action_counts[action.index()] += 1;
221        self.total_actions += 1;
222        self.history.push(action);
223        if self.history.len() > self.history_max {
224            self.history.remove(0);
225        }
226        self.update_stats();
227    }
228
229    fn update_stats(&mut self) {
230        let total = self.total_actions as f32;
231        if total == 0.0 { return; }
232        self.aggression = self.action_counts[Action::Attack.index()] as f32 / total;
233        self.caution = (self.action_counts[Action::Wait.index()] as f32
234            + self.action_counts[Action::UseItem.index()] as f32) / total;
235        self.ability_usage = self.action_counts[Action::UseAbility.index()] as f32 / total;
236    }
237
238    /// Get a feature vector summarizing the playstyle.
239    pub fn as_features(&self) -> Vec<f32> {
240        vec![
241            self.aggression,
242            self.caution,
243            self.ability_usage,
244            self.action_counts[Action::Move.index()] as f32 / self.total_actions.max(1) as f32,
245            self.total_actions as f32,
246        ]
247    }
248}
249
250/// Parameters controlling AI behavior.
251#[derive(Debug, Clone)]
252pub struct AIParameters {
253    pub aggression_bias: f32,
254    pub defense_bias: f32,
255    pub ability_preference: f32,
256    pub patience: f32,
257}
258
259impl AIParameters {
260    pub fn balanced() -> Self {
261        Self { aggression_bias: 0.0, defense_bias: 0.0, ability_preference: 0.0, patience: 0.5 }
262    }
263}
264
265/// Determine counter-strategy parameters based on observed playstyle.
266pub fn counter_strategy(tracker: &PlaystyleTracker) -> AIParameters {
267    let mut params = AIParameters::balanced();
268    // Counter aggressive players with defense
269    if tracker.aggression > 0.4 {
270        params.defense_bias = 0.5;
271        params.patience = 0.8;
272    }
273    // Counter cautious players with aggression
274    if tracker.caution > 0.3 {
275        params.aggression_bias = 0.6;
276        params.patience = 0.2;
277    }
278    // Counter ability-heavy players with items and positioning
279    if tracker.ability_usage > 0.3 {
280        params.defense_bias = 0.3;
281        params.aggression_bias = 0.2;
282    }
283    params
284}
285
286/// Buffer for storing experience tuples for learning.
287pub struct ExperienceBuffer {
288    pub states: Vec<GameState>,
289    pub actions: Vec<Action>,
290    pub rewards: Vec<f32>,
291    pub capacity: usize,
292}
293
294impl ExperienceBuffer {
295    pub fn new(capacity: usize) -> Self {
296        Self {
297            states: Vec::new(),
298            actions: Vec::new(),
299            rewards: Vec::new(),
300            capacity,
301        }
302    }
303
304    pub fn push(&mut self, state: GameState, action: Action, reward: f32) {
305        if self.states.len() >= self.capacity {
306            self.states.remove(0);
307            self.actions.remove(0);
308            self.rewards.remove(0);
309        }
310        self.states.push(state);
311        self.actions.push(action);
312        self.rewards.push(reward);
313    }
314
315    pub fn len(&self) -> usize {
316        self.states.len()
317    }
318
319    pub fn is_empty(&self) -> bool {
320        self.states.is_empty()
321    }
322
323    /// Sample a random mini-batch of indices.
324    pub fn sample_indices(&self, batch_size: usize, rng_seed: u64) -> Vec<usize> {
325        let n = self.len();
326        if n == 0 { return vec![]; }
327        let batch_size = batch_size.min(n);
328        let mut indices = Vec::with_capacity(batch_size);
329        let mut state = rng_seed.wrapping_add(1);
330        for _ in 0..batch_size {
331            state ^= state << 13;
332            state ^= state >> 7;
333            state ^= state << 17;
334            indices.push((state as usize) % n);
335        }
336        indices
337    }
338
339    /// Compute discounted returns from the reward sequence.
340    pub fn compute_returns(&self, gamma: f32) -> Vec<f32> {
341        let n = self.rewards.len();
342        let mut returns = vec![0.0f32; n];
343        if n == 0 { return returns; }
344        returns[n - 1] = self.rewards[n - 1];
345        for i in (0..n - 1).rev() {
346            returns[i] = self.rewards[i] + gamma * returns[i + 1];
347        }
348        returns
349    }
350
351    pub fn clear(&mut self) {
352        self.states.clear();
353        self.actions.clear();
354        self.rewards.clear();
355    }
356
357    /// Mean reward across all stored experiences.
358    pub fn mean_reward(&self) -> f32 {
359        if self.rewards.is_empty() { return 0.0; }
360        self.rewards.iter().sum::<f32>() / self.rewards.len() as f32
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_action_roundtrip() {
370        for a in Action::ALL {
371            assert_eq!(Action::from_index(a.index()), a);
372        }
373    }
374
375    #[test]
376    fn test_game_state() {
377        let state = GameState::default_state();
378        assert_eq!(state.feature_dim(), 16);
379    }
380
381    #[test]
382    fn test_brain_select_action() {
383        let brain = AIBrain::new(16);
384        let state = GameState::default_state();
385        let action = brain.select_action(&state, 1.0);
386        assert!(Action::ALL.contains(&action));
387    }
388
389    #[test]
390    fn test_brain_evaluate_state() {
391        let brain = AIBrain::new(16);
392        let state = GameState::default_state();
393        let value = brain.evaluate_state(&state);
394        assert!(value >= -1.0 && value <= 1.0);
395    }
396
397    #[test]
398    fn test_adaptive_ai() {
399        let mut ai = AdaptiveAI::new(16, 0.5);
400        let state = GameState::default_state();
401        let _action = ai.select_action(&state);
402
403        let initial_diff = ai.difficulty;
404        ai.update_difficulty(true); // player won
405        // Difficulty should increase
406        assert!(ai.difficulty >= initial_diff || (ai.difficulty - initial_diff).abs() < 0.1);
407    }
408
409    #[test]
410    fn test_playstyle_tracker() {
411        let mut tracker = PlaystyleTracker::new();
412        for _ in 0..10 { tracker.record(Action::Attack); }
413        for _ in 0..5 { tracker.record(Action::Wait); }
414        assert_eq!(tracker.total_actions, 15);
415        assert!((tracker.aggression - 10.0 / 15.0).abs() < 1e-5);
416        assert!((tracker.caution - 5.0 / 15.0).abs() < 1e-5);
417    }
418
419    #[test]
420    fn test_counter_strategy_aggressive() {
421        let mut tracker = PlaystyleTracker::new();
422        for _ in 0..10 { tracker.record(Action::Attack); }
423        let params = counter_strategy(&tracker);
424        assert!(params.defense_bias > 0.0);
425        assert!(params.patience > 0.5);
426    }
427
428    #[test]
429    fn test_counter_strategy_cautious() {
430        let mut tracker = PlaystyleTracker::new();
431        for _ in 0..10 { tracker.record(Action::Wait); }
432        let params = counter_strategy(&tracker);
433        assert!(params.aggression_bias > 0.0);
434    }
435
436    #[test]
437    fn test_experience_buffer() {
438        let mut buf = ExperienceBuffer::new(5);
439        for i in 0..7 {
440            buf.push(GameState::default_state(), Action::Attack, i as f32);
441        }
442        assert_eq!(buf.len(), 5); // capped at capacity
443        assert!(!buf.is_empty());
444    }
445
446    #[test]
447    fn test_experience_buffer_returns() {
448        let mut buf = ExperienceBuffer::new(100);
449        buf.push(GameState::default_state(), Action::Move, 1.0);
450        buf.push(GameState::default_state(), Action::Attack, 2.0);
451        buf.push(GameState::default_state(), Action::Wait, 3.0);
452        let returns = buf.compute_returns(0.9);
453        // returns[2] = 3.0
454        // returns[1] = 2.0 + 0.9*3.0 = 4.7
455        // returns[0] = 1.0 + 0.9*4.7 = 5.23
456        assert!((returns[2] - 3.0).abs() < 1e-5);
457        assert!((returns[1] - 4.7).abs() < 1e-5);
458        assert!((returns[0] - 5.23).abs() < 1e-3);
459    }
460
461    #[test]
462    fn test_sample_indices() {
463        let mut buf = ExperienceBuffer::new(100);
464        for i in 0..20 {
465            buf.push(GameState::default_state(), Action::Move, i as f32);
466        }
467        let indices = buf.sample_indices(5, 42);
468        assert_eq!(indices.len(), 5);
469        for &idx in &indices {
470            assert!(idx < 20);
471        }
472    }
473
474    #[test]
475    fn test_mean_reward() {
476        let mut buf = ExperienceBuffer::new(100);
477        buf.push(GameState::default_state(), Action::Move, 2.0);
478        buf.push(GameState::default_state(), Action::Move, 4.0);
479        assert!((buf.mean_reward() - 3.0).abs() < 1e-5);
480    }
481}