1use super::tensor::Tensor;
4use super::model::{Model, Sequential};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Action {
9 Move,
10 Attack,
11 UseAbility,
12 UseItem,
13 Wait,
14}
15
16impl Action {
17 pub const ALL: [Action; 5] = [
18 Action::Move,
19 Action::Attack,
20 Action::UseAbility,
21 Action::UseItem,
22 Action::Wait,
23 ];
24
25 pub fn from_index(idx: usize) -> Self {
26 Self::ALL[idx % Self::ALL.len()]
27 }
28
29 pub fn index(&self) -> usize {
30 match self {
31 Action::Move => 0,
32 Action::Attack => 1,
33 Action::UseAbility => 2,
34 Action::UseItem => 3,
35 Action::Wait => 4,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct GameState {
43 pub features: Tensor,
48}
49
50impl GameState {
51 pub fn new(features: Vec<f32>) -> Self {
53 let n = features.len();
54 Self { features: Tensor::from_vec(features, vec![1, n]) }
55 }
56
57 pub fn default_state() -> Self {
59 Self::new(vec![
60 100.0, 50.0, 5.0, 5.0, 100.0, 50.0, 10.0, 10.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ])
65 }
66
67 pub fn feature_dim(&self) -> usize {
68 self.features.data.len()
69 }
70}
71
72pub struct AIBrain {
74 pub policy_net: Model,
75 pub value_net: Model,
76}
77
78impl AIBrain {
79 pub fn new(state_dim: usize) -> Self {
81 let policy_net = Sequential::new("policy")
82 .dense(state_dim, 64)
83 .relu()
84 .dense(64, 32)
85 .relu()
86 .dense(32, Action::ALL.len())
87 .build();
88
89 let value_net = Sequential::new("value")
90 .dense(state_dim, 64)
91 .relu()
92 .dense(64, 32)
93 .relu()
94 .dense(32, 1)
95 .build();
96
97 Self { policy_net, value_net }
98 }
99
100 pub fn select_action(&self, state: &GameState, temperature: f32) -> Action {
103 let logits = self.policy_net.forward(&state.features);
104 let scaled: Vec<f32> = logits.data.iter().map(|&v| v / temperature.max(0.01)).collect();
106 let scaled_tensor = Tensor::from_vec(scaled, logits.shape.clone());
107 let probs = scaled_tensor.softmax(if scaled_tensor.shape.len() > 1 { 1 } else { 0 });
108
109 let seed: u64 = state.features.data.iter()
111 .map(|v| (v.to_bits() as u64).wrapping_mul(2654435761))
112 .fold(0u64, |a, b| a.wrapping_add(b));
113 let mut rng_state = seed.wrapping_add(1);
114 rng_state ^= rng_state << 13;
115 rng_state ^= rng_state >> 7;
116 rng_state ^= rng_state << 17;
117 let sample = (rng_state as u32 as f32) / (u32::MAX as f32);
118
119 let prob_data = &probs.data;
120 let num_actions = Action::ALL.len();
121 let start = prob_data.len().saturating_sub(num_actions);
123 let action_probs = &prob_data[start..];
124
125 let mut cumulative = 0.0f32;
126 for (i, &p) in action_probs.iter().enumerate() {
127 cumulative += p;
128 if sample <= cumulative {
129 return Action::from_index(i);
130 }
131 }
132 Action::Wait
133 }
134
135 pub fn evaluate_state(&self, state: &GameState) -> f32 {
137 let value = self.value_net.forward(&state.features);
138 value.data.last().copied().unwrap_or(0.0).tanh()
140 }
141}
142
143pub struct AdaptiveAI {
145 pub brain: AIBrain,
146 pub difficulty: f32,
148 score_differential: f32,
150 pub adaptation_rate: f32,
151}
152
153impl AdaptiveAI {
154 pub fn new(state_dim: usize, difficulty: f32) -> Self {
155 Self {
156 brain: AIBrain::new(state_dim),
157 difficulty: difficulty.clamp(0.0, 1.0),
158 score_differential: 0.0,
159 adaptation_rate: 0.05,
160 }
161 }
162
163 pub fn select_action(&self, state: &GameState) -> Action {
165 let temperature = 0.1 + (1.0 - self.difficulty) * 5.0;
167 self.brain.select_action(state, temperature)
168 }
169
170 pub fn update_difficulty(&mut self, player_won: bool) {
173 if player_won {
174 self.score_differential += 1.0;
176 } else {
177 self.score_differential -= 1.0;
179 }
180 self.difficulty += self.adaptation_rate * self.score_differential.signum() * 0.1;
182 self.difficulty = self.difficulty.clamp(0.0, 1.0);
183 self.score_differential *= 0.9;
185 }
186}
187
188pub struct PlaystyleTracker {
190 pub action_counts: [u32; 5],
192 pub total_actions: u32,
194 pub aggression: f32,
196 pub caution: f32,
198 pub ability_usage: f32,
200 pub history: Vec<Action>,
202 pub history_max: usize,
203}
204
205impl PlaystyleTracker {
206 pub fn new() -> Self {
207 Self {
208 action_counts: [0; 5],
209 total_actions: 0,
210 aggression: 0.0,
211 caution: 0.0,
212 ability_usage: 0.0,
213 history: Vec::new(),
214 history_max: 100,
215 }
216 }
217
218 pub fn record(&mut self, action: Action) {
220 self.action_counts[action.index()] += 1;
221 self.total_actions += 1;
222 self.history.push(action);
223 if self.history.len() > self.history_max {
224 self.history.remove(0);
225 }
226 self.update_stats();
227 }
228
229 fn update_stats(&mut self) {
230 let total = self.total_actions as f32;
231 if total == 0.0 { return; }
232 self.aggression = self.action_counts[Action::Attack.index()] as f32 / total;
233 self.caution = (self.action_counts[Action::Wait.index()] as f32
234 + self.action_counts[Action::UseItem.index()] as f32) / total;
235 self.ability_usage = self.action_counts[Action::UseAbility.index()] as f32 / total;
236 }
237
238 pub fn as_features(&self) -> Vec<f32> {
240 vec![
241 self.aggression,
242 self.caution,
243 self.ability_usage,
244 self.action_counts[Action::Move.index()] as f32 / self.total_actions.max(1) as f32,
245 self.total_actions as f32,
246 ]
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct AIParameters {
253 pub aggression_bias: f32,
254 pub defense_bias: f32,
255 pub ability_preference: f32,
256 pub patience: f32,
257}
258
259impl AIParameters {
260 pub fn balanced() -> Self {
261 Self { aggression_bias: 0.0, defense_bias: 0.0, ability_preference: 0.0, patience: 0.5 }
262 }
263}
264
265pub fn counter_strategy(tracker: &PlaystyleTracker) -> AIParameters {
267 let mut params = AIParameters::balanced();
268 if tracker.aggression > 0.4 {
270 params.defense_bias = 0.5;
271 params.patience = 0.8;
272 }
273 if tracker.caution > 0.3 {
275 params.aggression_bias = 0.6;
276 params.patience = 0.2;
277 }
278 if tracker.ability_usage > 0.3 {
280 params.defense_bias = 0.3;
281 params.aggression_bias = 0.2;
282 }
283 params
284}
285
286pub struct ExperienceBuffer {
288 pub states: Vec<GameState>,
289 pub actions: Vec<Action>,
290 pub rewards: Vec<f32>,
291 pub capacity: usize,
292}
293
294impl ExperienceBuffer {
295 pub fn new(capacity: usize) -> Self {
296 Self {
297 states: Vec::new(),
298 actions: Vec::new(),
299 rewards: Vec::new(),
300 capacity,
301 }
302 }
303
304 pub fn push(&mut self, state: GameState, action: Action, reward: f32) {
305 if self.states.len() >= self.capacity {
306 self.states.remove(0);
307 self.actions.remove(0);
308 self.rewards.remove(0);
309 }
310 self.states.push(state);
311 self.actions.push(action);
312 self.rewards.push(reward);
313 }
314
315 pub fn len(&self) -> usize {
316 self.states.len()
317 }
318
319 pub fn is_empty(&self) -> bool {
320 self.states.is_empty()
321 }
322
323 pub fn sample_indices(&self, batch_size: usize, rng_seed: u64) -> Vec<usize> {
325 let n = self.len();
326 if n == 0 { return vec![]; }
327 let batch_size = batch_size.min(n);
328 let mut indices = Vec::with_capacity(batch_size);
329 let mut state = rng_seed.wrapping_add(1);
330 for _ in 0..batch_size {
331 state ^= state << 13;
332 state ^= state >> 7;
333 state ^= state << 17;
334 indices.push((state as usize) % n);
335 }
336 indices
337 }
338
339 pub fn compute_returns(&self, gamma: f32) -> Vec<f32> {
341 let n = self.rewards.len();
342 let mut returns = vec![0.0f32; n];
343 if n == 0 { return returns; }
344 returns[n - 1] = self.rewards[n - 1];
345 for i in (0..n - 1).rev() {
346 returns[i] = self.rewards[i] + gamma * returns[i + 1];
347 }
348 returns
349 }
350
351 pub fn clear(&mut self) {
352 self.states.clear();
353 self.actions.clear();
354 self.rewards.clear();
355 }
356
357 pub fn mean_reward(&self) -> f32 {
359 if self.rewards.is_empty() { return 0.0; }
360 self.rewards.iter().sum::<f32>() / self.rewards.len() as f32
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_action_roundtrip() {
370 for a in Action::ALL {
371 assert_eq!(Action::from_index(a.index()), a);
372 }
373 }
374
375 #[test]
376 fn test_game_state() {
377 let state = GameState::default_state();
378 assert_eq!(state.feature_dim(), 16);
379 }
380
381 #[test]
382 fn test_brain_select_action() {
383 let brain = AIBrain::new(16);
384 let state = GameState::default_state();
385 let action = brain.select_action(&state, 1.0);
386 assert!(Action::ALL.contains(&action));
387 }
388
389 #[test]
390 fn test_brain_evaluate_state() {
391 let brain = AIBrain::new(16);
392 let state = GameState::default_state();
393 let value = brain.evaluate_state(&state);
394 assert!(value >= -1.0 && value <= 1.0);
395 }
396
397 #[test]
398 fn test_adaptive_ai() {
399 let mut ai = AdaptiveAI::new(16, 0.5);
400 let state = GameState::default_state();
401 let _action = ai.select_action(&state);
402
403 let initial_diff = ai.difficulty;
404 ai.update_difficulty(true); assert!(ai.difficulty >= initial_diff || (ai.difficulty - initial_diff).abs() < 0.1);
407 }
408
409 #[test]
410 fn test_playstyle_tracker() {
411 let mut tracker = PlaystyleTracker::new();
412 for _ in 0..10 { tracker.record(Action::Attack); }
413 for _ in 0..5 { tracker.record(Action::Wait); }
414 assert_eq!(tracker.total_actions, 15);
415 assert!((tracker.aggression - 10.0 / 15.0).abs() < 1e-5);
416 assert!((tracker.caution - 5.0 / 15.0).abs() < 1e-5);
417 }
418
419 #[test]
420 fn test_counter_strategy_aggressive() {
421 let mut tracker = PlaystyleTracker::new();
422 for _ in 0..10 { tracker.record(Action::Attack); }
423 let params = counter_strategy(&tracker);
424 assert!(params.defense_bias > 0.0);
425 assert!(params.patience > 0.5);
426 }
427
428 #[test]
429 fn test_counter_strategy_cautious() {
430 let mut tracker = PlaystyleTracker::new();
431 for _ in 0..10 { tracker.record(Action::Wait); }
432 let params = counter_strategy(&tracker);
433 assert!(params.aggression_bias > 0.0);
434 }
435
436 #[test]
437 fn test_experience_buffer() {
438 let mut buf = ExperienceBuffer::new(5);
439 for i in 0..7 {
440 buf.push(GameState::default_state(), Action::Attack, i as f32);
441 }
442 assert_eq!(buf.len(), 5); assert!(!buf.is_empty());
444 }
445
446 #[test]
447 fn test_experience_buffer_returns() {
448 let mut buf = ExperienceBuffer::new(100);
449 buf.push(GameState::default_state(), Action::Move, 1.0);
450 buf.push(GameState::default_state(), Action::Attack, 2.0);
451 buf.push(GameState::default_state(), Action::Wait, 3.0);
452 let returns = buf.compute_returns(0.9);
453 assert!((returns[2] - 3.0).abs() < 1e-5);
457 assert!((returns[1] - 4.7).abs() < 1e-5);
458 assert!((returns[0] - 5.23).abs() < 1e-3);
459 }
460
461 #[test]
462 fn test_sample_indices() {
463 let mut buf = ExperienceBuffer::new(100);
464 for i in 0..20 {
465 buf.push(GameState::default_state(), Action::Move, i as f32);
466 }
467 let indices = buf.sample_indices(5, 42);
468 assert_eq!(indices.len(), 5);
469 for &idx in &indices {
470 assert!(idx < 20);
471 }
472 }
473
474 #[test]
475 fn test_mean_reward() {
476 let mut buf = ExperienceBuffer::new(100);
477 buf.push(GameState::default_state(), Action::Move, 2.0);
478 buf.push(GameState::default_state(), Action::Move, 4.0);
479 assert!((buf.mean_reward() - 3.0).abs() < 1e-5);
480 }
481}