1use crate::{AiError, Result};
14use glam::Vec2;
15use jugar_apr::{AprModel, ModelArchitecture, ModelData};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone)]
20pub struct AiComponent {
21 pub model_id: String,
23 pub state: BehaviorState,
25 pub difficulty: u8,
27}
28
29impl AiComponent {
30 #[must_use]
32 pub fn new(model_id: impl Into<String>) -> Self {
33 Self {
34 model_id: model_id.into(),
35 state: BehaviorState::default(),
36 difficulty: 5,
37 }
38 }
39
40 #[must_use]
42 pub const fn with_difficulty(mut self, difficulty: u8) -> Self {
43 self.difficulty = if difficulty < 1 {
45 1
46 } else if difficulty > 10 {
47 10
48 } else {
49 difficulty
50 };
51 self
52 }
53}
54
55#[derive(Debug, Clone, Default)]
57pub struct BehaviorState {
58 pub direction: Vec2,
60 pub state_time: f32,
62 pub waypoint_index: usize,
64 pub internal_state: f32,
66}
67
68#[derive(Debug, Clone, Default)]
70pub struct AiInputs {
71 pub position: Vec2,
73 pub target_position: Vec2,
75 pub distance_to_target: f32,
77 pub direction_to_target: Vec2,
79 pub dt: f32,
81}
82
83impl AiInputs {
84 #[must_use]
86 pub fn from_positions(position: Vec2, target: Vec2, dt: f32) -> Self {
87 let delta = target - position;
88 let distance = delta.length();
89 let direction = if distance > 0.001 {
90 delta / distance
91 } else {
92 Vec2::ZERO
93 };
94
95 Self {
96 position,
97 target_position: target,
98 distance_to_target: distance,
99 direction_to_target: direction,
100 dt,
101 }
102 }
103
104 #[must_use]
106 pub fn to_vector(&self) -> Vec<f32> {
107 vec![
108 self.direction_to_target.x,
109 self.direction_to_target.y,
110 self.distance_to_target / 100.0, self.dt,
112 ]
113 }
114}
115
116#[derive(Debug, Clone, Default)]
118pub struct AiOutputs {
119 pub movement: Vec2,
121 pub speed: f32,
123 pub action: bool,
125}
126
127impl AiOutputs {
128 #[must_use]
130 pub fn from_raw(values: &[f32]) -> Self {
131 let movement = if values.len() >= 2 {
132 Vec2::new(values[0], values[1]).normalize_or_zero()
133 } else {
134 Vec2::ZERO
135 };
136
137 let speed = if values.len() >= 3 {
138 values[2].clamp(0.0, 1.0)
139 } else {
140 1.0
141 };
142
143 let action = values.len() >= 4 && values[3] > 0.5;
144
145 Self {
146 movement,
147 speed,
148 action,
149 }
150 }
151}
152
153#[derive(Debug, Default)]
155pub struct AiSystem {
156 models: HashMap<String, LoadedModel>,
158}
159
160#[derive(Debug, Clone)]
162struct LoadedModel {
163 model: AprModel,
165 layer_weights: Vec<LayerWeights>,
167}
168
169#[derive(Debug, Clone)]
171struct LayerWeights {
172 weights: Vec<f32>,
174 biases: Vec<f32>,
176 input_size: usize,
178 output_size: usize,
180}
181
182impl AiSystem {
183 #[must_use]
185 pub fn new() -> Self {
186 Self::default()
187 }
188
189 pub fn load_model_from_file(&mut self, id: &str, path: &str) -> Result<()> {
195 let bytes = std::fs::read(path).map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
196
197 let apr_file = jugar_apr::AprFile::from_bytes(&bytes)
198 .map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
199
200 self.register_model(id, apr_file.model)
201 }
202
203 pub fn load_builtin(&mut self, id: &str, builtin_name: &str) -> Result<()> {
209 let model = AprModel::builtin(builtin_name)
210 .map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
211
212 self.register_model(id, model)
213 }
214
215 pub fn register_model(&mut self, id: &str, model: AprModel) -> Result<()> {
221 let layer_weights = Self::prepare_weights(&model.data)?;
222 let loaded = LoadedModel {
223 model,
224 layer_weights,
225 };
226 let _ = self.models.insert(id.to_string(), loaded);
227 Ok(())
228 }
229
230 fn prepare_weights(data: &ModelData) -> Result<Vec<LayerWeights>> {
232 match &data.architecture {
233 ModelArchitecture::Mlp { layers } => {
234 if layers.len() < 2 {
235 return Err(AiError::PreconditionsNotMet(
236 "MLP needs at least 2 layers".to_string(),
237 ));
238 }
239
240 let mut result = Vec::new();
241 let mut weight_offset = 0;
242 let mut bias_offset = 0;
243
244 for i in 0..layers.len() - 1 {
245 let input_size = layers[i];
246 let output_size = layers[i + 1];
247 let weight_count = input_size * output_size;
248
249 let weights = if weight_offset + weight_count <= data.weights.len() {
250 data.weights[weight_offset..weight_offset + weight_count].to_vec()
251 } else {
252 vec![0.1; weight_count]
254 };
255
256 let biases = if bias_offset + output_size <= data.biases.len() {
257 data.biases[bias_offset..bias_offset + output_size].to_vec()
258 } else {
259 vec![0.0; output_size]
261 };
262
263 result.push(LayerWeights {
264 weights,
265 biases,
266 input_size,
267 output_size,
268 });
269
270 weight_offset += weight_count;
271 bias_offset += output_size;
272 }
273
274 Ok(result)
275 }
276 ModelArchitecture::BehaviorTree { .. } => {
277 Ok(Vec::new())
279 }
280 }
281 }
282
283 pub fn infer(&self, model_id: &str, inputs: &AiInputs) -> Result<AiOutputs> {
289 let loaded = self
290 .models
291 .get(model_id)
292 .ok_or_else(|| AiError::PreconditionsNotMet(format!("Model not found: {model_id}")))?;
293
294 match &loaded.model.data.architecture {
295 ModelArchitecture::Mlp { .. } => {
296 let raw_outputs =
297 Self::run_mlp_inference(&loaded.layer_weights, &inputs.to_vector());
298 Ok(AiOutputs::from_raw(&raw_outputs))
299 }
300 ModelArchitecture::BehaviorTree { .. } => {
301 Self::run_behavior_inference(&loaded.model.metadata.name, inputs)
303 }
304 }
305 }
306
307 fn run_mlp_inference(layers: &[LayerWeights], input: &[f32]) -> Vec<f32> {
309 let mut current = input.to_vec();
310
311 for layer in layers {
312 let mut output = vec![0.0; layer.output_size];
313
314 for (i, out) in output.iter_mut().enumerate() {
316 let mut sum = layer.biases.get(i).copied().unwrap_or(0.0);
317 for (j, &inp) in current.iter().enumerate() {
318 let weight_idx = i * layer.input_size + j;
319 let weight = layer.weights.get(weight_idx).copied().unwrap_or(0.0);
320 sum += weight * inp;
321 }
322 *out = sum.max(0.0);
324 }
325
326 current = output;
327 }
328
329 current.iter().map(|&x| x.tanh()).collect()
331 }
332
333 fn run_behavior_inference(behavior_name: &str, inputs: &AiInputs) -> Result<AiOutputs> {
335 match behavior_name {
336 "builtin-chase" => Ok(AiOutputs {
337 movement: inputs.direction_to_target,
338 speed: 1.0,
339 action: inputs.distance_to_target < 50.0,
340 }),
341 "builtin-patrol" => {
342 let phase = (inputs.position.x / 100.0).sin();
344 Ok(AiOutputs {
345 movement: Vec2::new(phase.signum(), 0.0),
346 speed: 0.5,
347 action: false,
348 })
349 }
350 "builtin-wander" => {
351 #[allow(clippy::suboptimal_flops)]
353 let angle = (inputs.position.x * 0.1 + inputs.position.y * 0.07).sin()
354 * core::f32::consts::PI;
355 Ok(AiOutputs {
356 movement: Vec2::new(angle.cos(), angle.sin()),
357 speed: 0.3,
358 action: false,
359 })
360 }
361 _ => Err(AiError::PreconditionsNotMet(format!(
362 "Unknown behavior: {behavior_name}"
363 ))),
364 }
365 }
366
367 #[must_use]
369 pub fn has_model(&self, id: &str) -> bool {
370 self.models.contains_key(id)
371 }
372
373 #[must_use]
375 pub fn model_count(&self) -> usize {
376 self.models.len()
377 }
378
379 pub fn unload_model(&mut self, id: &str) -> bool {
381 self.models.remove(id).is_some()
382 }
383}
384
385#[derive(Debug, Default)]
389pub struct YamlAiBridge {
390 custom_models: HashMap<String, String>,
392}
393
394impl YamlAiBridge {
395 #[must_use]
397 pub fn new() -> Self {
398 Self::default()
399 }
400
401 pub fn register_custom(&mut self, yaml_key: &str, path: &str) {
403 let _ = self
404 .custom_models
405 .insert(yaml_key.to_string(), path.to_string());
406 }
407
408 pub fn resolve(&self, yaml_key: &str, system: &mut AiSystem) -> Result<String> {
421 if let Some(builtin) = yaml_key.strip_prefix("builtin:") {
423 let id = format!("builtin-{builtin}");
424 if !system.has_model(&id) {
425 system.load_builtin(&id, builtin)?;
426 }
427 return Ok(id);
428 }
429
430 if std::path::Path::new(yaml_key)
432 .extension()
433 .is_some_and(|ext| ext.eq_ignore_ascii_case("apr"))
434 {
435 let id = yaml_key.replace(['/', '\\', '.'], "_");
436 if !system.has_model(&id) {
437 system.load_model_from_file(&id, yaml_key)?;
438 }
439 return Ok(id);
440 }
441
442 if let Some(path) = self.custom_models.get(yaml_key) {
444 let id = yaml_key.to_string();
445 if !system.has_model(&id) {
446 system.load_model_from_file(&id, path)?;
447 }
448 return Ok(id);
449 }
450
451 if matches!(yaml_key, "chase" | "patrol" | "wander") {
453 let id = format!("builtin-{yaml_key}");
454 if !system.has_model(&id) {
455 system.load_builtin(&id, yaml_key)?;
456 }
457 return Ok(id);
458 }
459
460 Err(AiError::PreconditionsNotMet(format!(
461 "Unknown AI behavior: {yaml_key}"
462 )))
463 }
464}
465
466#[cfg(test)]
467#[allow(clippy::unwrap_used, clippy::expect_used)]
468mod tests {
469 use super::*;
470
471 mod ai_component_tests {
472 use super::*;
473
474 #[test]
475 fn test_ai_component_new() {
476 let component = AiComponent::new("builtin:chase");
477 assert_eq!(component.model_id, "builtin:chase");
478 assert_eq!(component.difficulty, 5);
479 }
480
481 #[test]
482 fn test_ai_component_with_difficulty() {
483 let component = AiComponent::new("chase").with_difficulty(8);
484 assert_eq!(component.difficulty, 8);
485 }
486
487 #[test]
488 fn test_ai_component_difficulty_clamped() {
489 let low = AiComponent::new("chase").with_difficulty(0);
490 assert_eq!(low.difficulty, 1);
491
492 let high = AiComponent::new("chase").with_difficulty(100);
493 assert_eq!(high.difficulty, 10);
494 }
495 }
496
497 mod ai_inputs_tests {
498 use super::*;
499
500 #[test]
501 fn test_from_positions() {
502 let inputs =
503 AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(100.0, 0.0), 0.016);
504
505 assert!((inputs.distance_to_target - 100.0).abs() < 0.01);
506 assert!((inputs.direction_to_target.x - 1.0).abs() < 0.01);
507 assert!(inputs.direction_to_target.y.abs() < 0.01);
508 }
509
510 #[test]
511 fn test_from_positions_same_point() {
512 let inputs =
513 AiInputs::from_positions(Vec2::new(50.0, 50.0), Vec2::new(50.0, 50.0), 0.016);
514
515 assert!(inputs.distance_to_target < 0.001);
516 assert_eq!(inputs.direction_to_target, Vec2::ZERO);
517 }
518
519 #[test]
520 fn test_to_vector() {
521 let inputs = AiInputs::from_positions(Vec2::ZERO, Vec2::new(100.0, 0.0), 0.016);
522
523 let vec = inputs.to_vector();
524 assert_eq!(vec.len(), 4);
525 assert!((vec[0] - 1.0).abs() < 0.01); assert!(vec[1].abs() < 0.01); assert!((vec[2] - 1.0).abs() < 0.01); }
529 }
530
531 mod ai_outputs_tests {
532 use super::*;
533
534 #[test]
535 fn test_from_raw() {
536 let outputs = AiOutputs::from_raw(&[0.5, 0.5, 0.8, 0.9]);
537
538 assert!(outputs.movement.length() > 0.0);
539 assert!((outputs.speed - 0.8).abs() < 0.01);
540 assert!(outputs.action);
541 }
542
543 #[test]
544 fn test_from_raw_empty() {
545 let outputs = AiOutputs::from_raw(&[]);
546
547 assert_eq!(outputs.movement, Vec2::ZERO);
548 assert!((outputs.speed - 1.0).abs() < 0.01);
549 assert!(!outputs.action);
550 }
551
552 #[test]
553 fn test_from_raw_speed_clamped() {
554 let outputs = AiOutputs::from_raw(&[0.0, 0.0, 2.0]);
555 assert!((outputs.speed - 1.0).abs() < 0.01);
556
557 let outputs2 = AiOutputs::from_raw(&[0.0, 0.0, -1.0]);
558 assert!(outputs2.speed.abs() < 0.01);
559 }
560 }
561
562 mod ai_system_tests {
563 use super::*;
564
565 #[test]
566 fn test_new_system() {
567 let system = AiSystem::new();
568 assert_eq!(system.model_count(), 0);
569 }
570
571 #[test]
572 fn test_load_builtin_chase() {
573 let mut system = AiSystem::new();
574 system.load_builtin("chase", "chase").unwrap();
575
576 assert!(system.has_model("chase"));
577 assert_eq!(system.model_count(), 1);
578 }
579
580 #[test]
581 fn test_load_builtin_patrol() {
582 let mut system = AiSystem::new();
583 system.load_builtin("patrol", "patrol").unwrap();
584
585 assert!(system.has_model("patrol"));
586 }
587
588 #[test]
589 fn test_load_builtin_wander() {
590 let mut system = AiSystem::new();
591 system.load_builtin("wander", "wander").unwrap();
592
593 assert!(system.has_model("wander"));
594 }
595
596 #[test]
597 fn test_load_unknown_builtin() {
598 let mut system = AiSystem::new();
599 let result = system.load_builtin("unknown", "unknown");
600
601 assert!(result.is_err());
602 }
603
604 #[test]
605 fn test_register_model() {
606 let mut system = AiSystem::new();
607 let model = AprModel::new_test_model();
608
609 system.register_model("test", model).unwrap();
610 assert!(system.has_model("test"));
611 }
612
613 #[test]
614 fn test_unload_model() {
615 let mut system = AiSystem::new();
616 system.load_builtin("chase", "chase").unwrap();
617
618 assert!(system.unload_model("chase"));
619 assert!(!system.has_model("chase"));
620 }
621
622 #[test]
623 fn test_infer_chase() {
624 let mut system = AiSystem::new();
625 system.load_builtin("chase", "chase").unwrap();
626
627 let inputs =
628 AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(100.0, 0.0), 0.016);
629
630 let outputs = system.infer("chase", &inputs).unwrap();
631
632 assert!(outputs.movement.x > 0.0);
634 assert!((outputs.speed - 1.0).abs() < 0.01);
635 }
636
637 #[test]
638 fn test_infer_patrol() {
639 let mut system = AiSystem::new();
640 system.load_builtin("patrol", "patrol").unwrap();
641
642 let inputs = AiInputs::from_positions(Vec2::new(50.0, 0.0), Vec2::new(0.0, 0.0), 0.016);
643
644 let outputs = system.infer("patrol", &inputs).unwrap();
645
646 assert!(outputs.movement.length() > 0.0);
648 assert!((outputs.speed - 0.5).abs() < 0.01);
649 }
650
651 #[test]
652 fn test_infer_wander() {
653 let mut system = AiSystem::new();
654 system.load_builtin("wander", "wander").unwrap();
655
656 let inputs =
657 AiInputs::from_positions(Vec2::new(25.0, 75.0), Vec2::new(0.0, 0.0), 0.016);
658
659 let outputs = system.infer("wander", &inputs).unwrap();
660
661 assert!(outputs.movement.length() > 0.0);
663 assert!((outputs.speed - 0.3).abs() < 0.01);
664 }
665
666 #[test]
667 fn test_infer_mlp_model() {
668 let mut system = AiSystem::new();
669 let model = AprModel::new_test_model();
670 system.register_model("mlp", model).unwrap();
671
672 let inputs =
673 AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(50.0, 50.0), 0.016);
674
675 let outputs = system.infer("mlp", &inputs).unwrap();
676
677 assert!(outputs.movement.length() >= 0.0);
679 }
680
681 #[test]
682 fn test_infer_unknown_model() {
683 let system = AiSystem::new();
684 let inputs = AiInputs::default();
685
686 let result = system.infer("nonexistent", &inputs);
687 assert!(result.is_err());
688 }
689 }
690
691 mod yaml_bridge_tests {
692 use super::*;
693
694 #[test]
695 fn test_resolve_builtin_prefix() {
696 let bridge = YamlAiBridge::new();
697 let mut system = AiSystem::new();
698
699 let id = bridge.resolve("builtin:chase", &mut system).unwrap();
700
701 assert_eq!(id, "builtin-chase");
702 assert!(system.has_model("builtin-chase"));
703 }
704
705 #[test]
706 fn test_resolve_simple_builtin() {
707 let bridge = YamlAiBridge::new();
708 let mut system = AiSystem::new();
709
710 let id = bridge.resolve("patrol", &mut system).unwrap();
711
712 assert_eq!(id, "builtin-patrol");
713 assert!(system.has_model("builtin-patrol"));
714 }
715
716 #[test]
717 fn test_resolve_all_builtins() {
718 let bridge = YamlAiBridge::new();
719 let mut system = AiSystem::new();
720
721 let _ = bridge.resolve("chase", &mut system).unwrap();
722 let _ = bridge.resolve("patrol", &mut system).unwrap();
723 let _ = bridge.resolve("wander", &mut system).unwrap();
724
725 assert_eq!(system.model_count(), 3);
726 }
727
728 #[test]
729 fn test_resolve_unknown() {
730 let bridge = YamlAiBridge::new();
731 let mut system = AiSystem::new();
732
733 let result = bridge.resolve("unknown_behavior", &mut system);
734 assert!(result.is_err());
735 }
736
737 #[test]
738 fn test_resolve_caches_model() {
739 let bridge = YamlAiBridge::new();
740 let mut system = AiSystem::new();
741
742 let _ = bridge.resolve("builtin:chase", &mut system).unwrap();
744 let _ = bridge.resolve("builtin:chase", &mut system).unwrap();
745
746 assert_eq!(system.model_count(), 1);
748 }
749
750 #[test]
751 fn test_register_custom() {
752 let mut bridge = YamlAiBridge::new();
753 bridge.register_custom("smart-ghost", "models/ghost.apr");
754
755 assert!(!bridge.custom_models.is_empty());
757 }
758 }
759
760 mod mlp_inference_tests {
761 use super::*;
762
763 #[test]
764 fn test_simple_mlp() {
765 let layers = vec![LayerWeights {
767 weights: vec![1.0, 0.0, 0.0, 1.0], biases: vec![0.0, 0.0],
769 input_size: 2,
770 output_size: 2,
771 }];
772
773 let input = vec![0.5, -0.5];
774 let output = AiSystem::run_mlp_inference(&layers, &input);
775
776 assert!(output[0] > 0.0);
778 assert!(output[1].abs() < 0.01);
779 }
780
781 #[test]
782 fn test_multi_layer_mlp() {
783 let layers = vec![
784 LayerWeights {
785 weights: vec![0.5, 0.5, 0.5, 0.5],
786 biases: vec![0.0, 0.0],
787 input_size: 2,
788 output_size: 2,
789 },
790 LayerWeights {
791 weights: vec![1.0, 1.0],
792 biases: vec![0.0],
793 input_size: 2,
794 output_size: 1,
795 },
796 ];
797
798 let input = vec![1.0, 1.0];
799 let output = AiSystem::run_mlp_inference(&layers, &input);
800
801 assert_eq!(output.len(), 1);
802 assert!(output[0] > 0.0);
803 }
804 }
805}