Skip to main content

jugar_ai/
system.rs

1//! AI system for running .apr models in game entities.
2//!
3//! Per spec Section 5.3: Aprender AI Integration.
4//!
5//! # Example
6//!
7//! ```ignore
8//! let mut ai_system = AiSystem::new();
9//! ai_system.load_model("ghost", "models/smart-ghost.apr")?;
10//! ai_system.update(&mut world, 0.016);
11//! ```
12
13use crate::{AiError, Result};
14use glam::Vec2;
15use jugar_apr::{AprModel, ModelArchitecture, ModelData};
16use std::collections::HashMap;
17
18/// AI component attached to entities
19#[derive(Debug, Clone)]
20pub struct AiComponent {
21    /// Model identifier (either path or builtin name)
22    pub model_id: String,
23    /// Current behavior state
24    pub state: BehaviorState,
25    /// Difficulty level (1-10, affects model parameters)
26    pub difficulty: u8,
27}
28
29impl AiComponent {
30    /// Create a new AI component with the given model
31    #[must_use]
32    pub fn new(model_id: impl Into<String>) -> Self {
33        Self {
34            model_id: model_id.into(),
35            state: BehaviorState::default(),
36            difficulty: 5,
37        }
38    }
39
40    /// Set difficulty level (1-10)
41    #[must_use]
42    pub const fn with_difficulty(mut self, difficulty: u8) -> Self {
43        // Manual clamp for const fn (clamp is not const in stable Rust)
44        self.difficulty = if difficulty < 1 {
45            1
46        } else if difficulty > 10 {
47            10
48        } else {
49            difficulty
50        };
51        self
52    }
53}
54
55/// Current state of a behavior
56#[derive(Debug, Clone, Default)]
57pub struct BehaviorState {
58    /// Current direction of movement
59    pub direction: Vec2,
60    /// Time in current state
61    pub state_time: f32,
62    /// Patrol waypoint index (for patrol behavior)
63    pub waypoint_index: usize,
64    /// Internal state value for deterministic behaviors
65    pub internal_state: f32,
66}
67
68/// Input data for AI inference
69#[derive(Debug, Clone, Default)]
70pub struct AiInputs {
71    /// Entity's current position
72    pub position: Vec2,
73    /// Target position (usually player)
74    pub target_position: Vec2,
75    /// Distance to target
76    pub distance_to_target: f32,
77    /// Normalized direction to target
78    pub direction_to_target: Vec2,
79    /// Delta time
80    pub dt: f32,
81}
82
83impl AiInputs {
84    /// Create inputs from positions
85    #[must_use]
86    pub fn from_positions(position: Vec2, target: Vec2, dt: f32) -> Self {
87        let delta = target - position;
88        let distance = delta.length();
89        let direction = if distance > 0.001 {
90            delta / distance
91        } else {
92            Vec2::ZERO
93        };
94
95        Self {
96            position,
97            target_position: target,
98            distance_to_target: distance,
99            direction_to_target: direction,
100            dt,
101        }
102    }
103
104    /// Convert to input vector for MLP inference
105    #[must_use]
106    pub fn to_vector(&self) -> Vec<f32> {
107        vec![
108            self.direction_to_target.x,
109            self.direction_to_target.y,
110            self.distance_to_target / 100.0, // Normalize distance
111            self.dt,
112        ]
113    }
114}
115
116/// Output from AI inference
117#[derive(Debug, Clone, Default)]
118pub struct AiOutputs {
119    /// Desired movement direction (normalized)
120    pub movement: Vec2,
121    /// Speed multiplier (0.0-1.0)
122    pub speed: f32,
123    /// Should trigger action (e.g., attack)
124    pub action: bool,
125}
126
127impl AiOutputs {
128    /// Create from raw output values
129    #[must_use]
130    pub fn from_raw(values: &[f32]) -> Self {
131        let movement = if values.len() >= 2 {
132            Vec2::new(values[0], values[1]).normalize_or_zero()
133        } else {
134            Vec2::ZERO
135        };
136
137        let speed = if values.len() >= 3 {
138            values[2].clamp(0.0, 1.0)
139        } else {
140            1.0
141        };
142
143        let action = values.len() >= 4 && values[3] > 0.5;
144
145        Self {
146            movement,
147            speed,
148            action,
149        }
150    }
151}
152
153/// The AI system that manages and runs AI models
154#[derive(Debug, Default)]
155pub struct AiSystem {
156    /// Loaded models by ID
157    models: HashMap<String, LoadedModel>,
158}
159
160/// A loaded model ready for inference
161#[derive(Debug, Clone)]
162struct LoadedModel {
163    /// The underlying APR model
164    model: AprModel,
165    /// Cached layer weights for fast inference
166    layer_weights: Vec<LayerWeights>,
167}
168
169/// Weights for a single layer
170#[derive(Debug, Clone)]
171struct LayerWeights {
172    /// Weight matrix (flattened, row-major)
173    weights: Vec<f32>,
174    /// Bias vector
175    biases: Vec<f32>,
176    /// Input size
177    input_size: usize,
178    /// Output size
179    output_size: usize,
180}
181
182impl AiSystem {
183    /// Create a new AI system
184    #[must_use]
185    pub fn new() -> Self {
186        Self::default()
187    }
188
189    /// Load a model from an APR file path
190    ///
191    /// # Errors
192    ///
193    /// Returns error if file cannot be read or model is invalid
194    pub fn load_model_from_file(&mut self, id: &str, path: &str) -> Result<()> {
195        let bytes = std::fs::read(path).map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
196
197        let apr_file = jugar_apr::AprFile::from_bytes(&bytes)
198            .map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
199
200        self.register_model(id, apr_file.model)
201    }
202
203    /// Load a builtin model
204    ///
205    /// # Errors
206    ///
207    /// Returns error if builtin name is unknown
208    pub fn load_builtin(&mut self, id: &str, builtin_name: &str) -> Result<()> {
209        let model = AprModel::builtin(builtin_name)
210            .map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
211
212        self.register_model(id, model)
213    }
214
215    /// Register a model directly
216    ///
217    /// # Errors
218    ///
219    /// Returns error if model architecture is invalid
220    pub fn register_model(&mut self, id: &str, model: AprModel) -> Result<()> {
221        let layer_weights = Self::prepare_weights(&model.data)?;
222        let loaded = LoadedModel {
223            model,
224            layer_weights,
225        };
226        let _ = self.models.insert(id.to_string(), loaded);
227        Ok(())
228    }
229
230    /// Prepare layer weights from model data
231    fn prepare_weights(data: &ModelData) -> Result<Vec<LayerWeights>> {
232        match &data.architecture {
233            ModelArchitecture::Mlp { layers } => {
234                if layers.len() < 2 {
235                    return Err(AiError::PreconditionsNotMet(
236                        "MLP needs at least 2 layers".to_string(),
237                    ));
238                }
239
240                let mut result = Vec::new();
241                let mut weight_offset = 0;
242                let mut bias_offset = 0;
243
244                for i in 0..layers.len() - 1 {
245                    let input_size = layers[i];
246                    let output_size = layers[i + 1];
247                    let weight_count = input_size * output_size;
248
249                    let weights = if weight_offset + weight_count <= data.weights.len() {
250                        data.weights[weight_offset..weight_offset + weight_count].to_vec()
251                    } else {
252                        // Use default weights if not enough in model
253                        vec![0.1; weight_count]
254                    };
255
256                    let biases = if bias_offset + output_size <= data.biases.len() {
257                        data.biases[bias_offset..bias_offset + output_size].to_vec()
258                    } else {
259                        // Use default biases
260                        vec![0.0; output_size]
261                    };
262
263                    result.push(LayerWeights {
264                        weights,
265                        biases,
266                        input_size,
267                        output_size,
268                    });
269
270                    weight_offset += weight_count;
271                    bias_offset += output_size;
272                }
273
274                Ok(result)
275            }
276            ModelArchitecture::BehaviorTree { .. } => {
277                // Behavior trees don't need weight preparation
278                Ok(Vec::new())
279            }
280        }
281    }
282
283    /// Run inference on a model
284    ///
285    /// # Errors
286    ///
287    /// Returns error if model is not found
288    pub fn infer(&self, model_id: &str, inputs: &AiInputs) -> Result<AiOutputs> {
289        let loaded = self
290            .models
291            .get(model_id)
292            .ok_or_else(|| AiError::PreconditionsNotMet(format!("Model not found: {model_id}")))?;
293
294        match &loaded.model.data.architecture {
295            ModelArchitecture::Mlp { .. } => {
296                let raw_outputs =
297                    Self::run_mlp_inference(&loaded.layer_weights, &inputs.to_vector());
298                Ok(AiOutputs::from_raw(&raw_outputs))
299            }
300            ModelArchitecture::BehaviorTree { .. } => {
301                // Behavior trees use special inference based on model name
302                Self::run_behavior_inference(&loaded.model.metadata.name, inputs)
303            }
304        }
305    }
306
307    /// Run MLP forward pass
308    fn run_mlp_inference(layers: &[LayerWeights], input: &[f32]) -> Vec<f32> {
309        let mut current = input.to_vec();
310
311        for layer in layers {
312            let mut output = vec![0.0; layer.output_size];
313
314            // Matrix multiplication: output = weights * input + bias
315            for (i, out) in output.iter_mut().enumerate() {
316                let mut sum = layer.biases.get(i).copied().unwrap_or(0.0);
317                for (j, &inp) in current.iter().enumerate() {
318                    let weight_idx = i * layer.input_size + j;
319                    let weight = layer.weights.get(weight_idx).copied().unwrap_or(0.0);
320                    sum += weight * inp;
321                }
322                // ReLU activation
323                *out = sum.max(0.0);
324            }
325
326            current = output;
327        }
328
329        // Final output uses tanh for bounded output
330        current.iter().map(|&x| x.tanh()).collect()
331    }
332
333    /// Run behavior tree based inference
334    fn run_behavior_inference(behavior_name: &str, inputs: &AiInputs) -> Result<AiOutputs> {
335        match behavior_name {
336            "builtin-chase" => Ok(AiOutputs {
337                movement: inputs.direction_to_target,
338                speed: 1.0,
339                action: inputs.distance_to_target < 50.0,
340            }),
341            "builtin-patrol" => {
342                // Simple left-right patrol
343                let phase = (inputs.position.x / 100.0).sin();
344                Ok(AiOutputs {
345                    movement: Vec2::new(phase.signum(), 0.0),
346                    speed: 0.5,
347                    action: false,
348                })
349            }
350            "builtin-wander" => {
351                // Pseudo-random wander using position as seed
352                #[allow(clippy::suboptimal_flops)]
353                let angle = (inputs.position.x * 0.1 + inputs.position.y * 0.07).sin()
354                    * core::f32::consts::PI;
355                Ok(AiOutputs {
356                    movement: Vec2::new(angle.cos(), angle.sin()),
357                    speed: 0.3,
358                    action: false,
359                })
360            }
361            _ => Err(AiError::PreconditionsNotMet(format!(
362                "Unknown behavior: {behavior_name}"
363            ))),
364        }
365    }
366
367    /// Check if a model is loaded
368    #[must_use]
369    pub fn has_model(&self, id: &str) -> bool {
370        self.models.contains_key(id)
371    }
372
373    /// Get model count
374    #[must_use]
375    pub fn model_count(&self) -> usize {
376        self.models.len()
377    }
378
379    /// Remove a model
380    pub fn unload_model(&mut self, id: &str) -> bool {
381        self.models.remove(id).is_some()
382    }
383}
384
385/// Bridge between YAML keywords and AI behaviors
386///
387/// Per spec Section 5.3: Maps YAML keywords to AI behaviors.
388#[derive(Debug, Default)]
389pub struct YamlAiBridge {
390    /// Custom model paths mapped to IDs
391    custom_models: HashMap<String, String>,
392}
393
394impl YamlAiBridge {
395    /// Create a new YAML-AI bridge
396    #[must_use]
397    pub fn new() -> Self {
398        Self::default()
399    }
400
401    /// Register a custom model path
402    pub fn register_custom(&mut self, yaml_key: &str, path: &str) {
403        let _ = self
404            .custom_models
405            .insert(yaml_key.to_string(), path.to_string());
406    }
407
408    /// Resolve a YAML AI keyword to a behavior
409    ///
410    /// # Examples
411    ///
412    /// - `"builtin:chase"` -> Builtin chase behavior
413    /// - `"builtin:patrol"` -> Builtin patrol behavior
414    /// - `"builtin:wander"` -> Builtin wander behavior
415    /// - `"models/ghost.apr"` -> Custom .apr model
416    ///
417    /// # Errors
418    ///
419    /// Returns error if keyword cannot be resolved
420    pub fn resolve(&self, yaml_key: &str, system: &mut AiSystem) -> Result<String> {
421        // Check for builtin prefix
422        if let Some(builtin) = yaml_key.strip_prefix("builtin:") {
423            let id = format!("builtin-{builtin}");
424            if !system.has_model(&id) {
425                system.load_builtin(&id, builtin)?;
426            }
427            return Ok(id);
428        }
429
430        // Check for .apr file path (case-insensitive)
431        if std::path::Path::new(yaml_key)
432            .extension()
433            .is_some_and(|ext| ext.eq_ignore_ascii_case("apr"))
434        {
435            let id = yaml_key.replace(['/', '\\', '.'], "_");
436            if !system.has_model(&id) {
437                system.load_model_from_file(&id, yaml_key)?;
438            }
439            return Ok(id);
440        }
441
442        // Check custom mappings
443        if let Some(path) = self.custom_models.get(yaml_key) {
444            let id = yaml_key.to_string();
445            if !system.has_model(&id) {
446                system.load_model_from_file(&id, path)?;
447            }
448            return Ok(id);
449        }
450
451        // Try as a direct builtin name
452        if matches!(yaml_key, "chase" | "patrol" | "wander") {
453            let id = format!("builtin-{yaml_key}");
454            if !system.has_model(&id) {
455                system.load_builtin(&id, yaml_key)?;
456            }
457            return Ok(id);
458        }
459
460        Err(AiError::PreconditionsNotMet(format!(
461            "Unknown AI behavior: {yaml_key}"
462        )))
463    }
464}
465
466#[cfg(test)]
467#[allow(clippy::unwrap_used, clippy::expect_used)]
468mod tests {
469    use super::*;
470
471    mod ai_component_tests {
472        use super::*;
473
474        #[test]
475        fn test_ai_component_new() {
476            let component = AiComponent::new("builtin:chase");
477            assert_eq!(component.model_id, "builtin:chase");
478            assert_eq!(component.difficulty, 5);
479        }
480
481        #[test]
482        fn test_ai_component_with_difficulty() {
483            let component = AiComponent::new("chase").with_difficulty(8);
484            assert_eq!(component.difficulty, 8);
485        }
486
487        #[test]
488        fn test_ai_component_difficulty_clamped() {
489            let low = AiComponent::new("chase").with_difficulty(0);
490            assert_eq!(low.difficulty, 1);
491
492            let high = AiComponent::new("chase").with_difficulty(100);
493            assert_eq!(high.difficulty, 10);
494        }
495    }
496
497    mod ai_inputs_tests {
498        use super::*;
499
500        #[test]
501        fn test_from_positions() {
502            let inputs =
503                AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(100.0, 0.0), 0.016);
504
505            assert!((inputs.distance_to_target - 100.0).abs() < 0.01);
506            assert!((inputs.direction_to_target.x - 1.0).abs() < 0.01);
507            assert!(inputs.direction_to_target.y.abs() < 0.01);
508        }
509
510        #[test]
511        fn test_from_positions_same_point() {
512            let inputs =
513                AiInputs::from_positions(Vec2::new(50.0, 50.0), Vec2::new(50.0, 50.0), 0.016);
514
515            assert!(inputs.distance_to_target < 0.001);
516            assert_eq!(inputs.direction_to_target, Vec2::ZERO);
517        }
518
519        #[test]
520        fn test_to_vector() {
521            let inputs = AiInputs::from_positions(Vec2::ZERO, Vec2::new(100.0, 0.0), 0.016);
522
523            let vec = inputs.to_vector();
524            assert_eq!(vec.len(), 4);
525            assert!((vec[0] - 1.0).abs() < 0.01); // direction x
526            assert!(vec[1].abs() < 0.01); // direction y
527            assert!((vec[2] - 1.0).abs() < 0.01); // normalized distance
528        }
529    }
530
531    mod ai_outputs_tests {
532        use super::*;
533
534        #[test]
535        fn test_from_raw() {
536            let outputs = AiOutputs::from_raw(&[0.5, 0.5, 0.8, 0.9]);
537
538            assert!(outputs.movement.length() > 0.0);
539            assert!((outputs.speed - 0.8).abs() < 0.01);
540            assert!(outputs.action);
541        }
542
543        #[test]
544        fn test_from_raw_empty() {
545            let outputs = AiOutputs::from_raw(&[]);
546
547            assert_eq!(outputs.movement, Vec2::ZERO);
548            assert!((outputs.speed - 1.0).abs() < 0.01);
549            assert!(!outputs.action);
550        }
551
552        #[test]
553        fn test_from_raw_speed_clamped() {
554            let outputs = AiOutputs::from_raw(&[0.0, 0.0, 2.0]);
555            assert!((outputs.speed - 1.0).abs() < 0.01);
556
557            let outputs2 = AiOutputs::from_raw(&[0.0, 0.0, -1.0]);
558            assert!(outputs2.speed.abs() < 0.01);
559        }
560    }
561
562    mod ai_system_tests {
563        use super::*;
564
565        #[test]
566        fn test_new_system() {
567            let system = AiSystem::new();
568            assert_eq!(system.model_count(), 0);
569        }
570
571        #[test]
572        fn test_load_builtin_chase() {
573            let mut system = AiSystem::new();
574            system.load_builtin("chase", "chase").unwrap();
575
576            assert!(system.has_model("chase"));
577            assert_eq!(system.model_count(), 1);
578        }
579
580        #[test]
581        fn test_load_builtin_patrol() {
582            let mut system = AiSystem::new();
583            system.load_builtin("patrol", "patrol").unwrap();
584
585            assert!(system.has_model("patrol"));
586        }
587
588        #[test]
589        fn test_load_builtin_wander() {
590            let mut system = AiSystem::new();
591            system.load_builtin("wander", "wander").unwrap();
592
593            assert!(system.has_model("wander"));
594        }
595
596        #[test]
597        fn test_load_unknown_builtin() {
598            let mut system = AiSystem::new();
599            let result = system.load_builtin("unknown", "unknown");
600
601            assert!(result.is_err());
602        }
603
604        #[test]
605        fn test_register_model() {
606            let mut system = AiSystem::new();
607            let model = AprModel::new_test_model();
608
609            system.register_model("test", model).unwrap();
610            assert!(system.has_model("test"));
611        }
612
613        #[test]
614        fn test_unload_model() {
615            let mut system = AiSystem::new();
616            system.load_builtin("chase", "chase").unwrap();
617
618            assert!(system.unload_model("chase"));
619            assert!(!system.has_model("chase"));
620        }
621
622        #[test]
623        fn test_infer_chase() {
624            let mut system = AiSystem::new();
625            system.load_builtin("chase", "chase").unwrap();
626
627            let inputs =
628                AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(100.0, 0.0), 0.016);
629
630            let outputs = system.infer("chase", &inputs).unwrap();
631
632            // Chase should move toward target
633            assert!(outputs.movement.x > 0.0);
634            assert!((outputs.speed - 1.0).abs() < 0.01);
635        }
636
637        #[test]
638        fn test_infer_patrol() {
639            let mut system = AiSystem::new();
640            system.load_builtin("patrol", "patrol").unwrap();
641
642            let inputs = AiInputs::from_positions(Vec2::new(50.0, 0.0), Vec2::new(0.0, 0.0), 0.016);
643
644            let outputs = system.infer("patrol", &inputs).unwrap();
645
646            // Patrol should have some movement
647            assert!(outputs.movement.length() > 0.0);
648            assert!((outputs.speed - 0.5).abs() < 0.01);
649        }
650
651        #[test]
652        fn test_infer_wander() {
653            let mut system = AiSystem::new();
654            system.load_builtin("wander", "wander").unwrap();
655
656            let inputs =
657                AiInputs::from_positions(Vec2::new(25.0, 75.0), Vec2::new(0.0, 0.0), 0.016);
658
659            let outputs = system.infer("wander", &inputs).unwrap();
660
661            // Wander should have some movement
662            assert!(outputs.movement.length() > 0.0);
663            assert!((outputs.speed - 0.3).abs() < 0.01);
664        }
665
666        #[test]
667        fn test_infer_mlp_model() {
668            let mut system = AiSystem::new();
669            let model = AprModel::new_test_model();
670            system.register_model("mlp", model).unwrap();
671
672            let inputs =
673                AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(50.0, 50.0), 0.016);
674
675            let outputs = system.infer("mlp", &inputs).unwrap();
676
677            // MLP should produce some outputs
678            assert!(outputs.movement.length() >= 0.0);
679        }
680
681        #[test]
682        fn test_infer_unknown_model() {
683            let system = AiSystem::new();
684            let inputs = AiInputs::default();
685
686            let result = system.infer("nonexistent", &inputs);
687            assert!(result.is_err());
688        }
689    }
690
691    mod yaml_bridge_tests {
692        use super::*;
693
694        #[test]
695        fn test_resolve_builtin_prefix() {
696            let bridge = YamlAiBridge::new();
697            let mut system = AiSystem::new();
698
699            let id = bridge.resolve("builtin:chase", &mut system).unwrap();
700
701            assert_eq!(id, "builtin-chase");
702            assert!(system.has_model("builtin-chase"));
703        }
704
705        #[test]
706        fn test_resolve_simple_builtin() {
707            let bridge = YamlAiBridge::new();
708            let mut system = AiSystem::new();
709
710            let id = bridge.resolve("patrol", &mut system).unwrap();
711
712            assert_eq!(id, "builtin-patrol");
713            assert!(system.has_model("builtin-patrol"));
714        }
715
716        #[test]
717        fn test_resolve_all_builtins() {
718            let bridge = YamlAiBridge::new();
719            let mut system = AiSystem::new();
720
721            let _ = bridge.resolve("chase", &mut system).unwrap();
722            let _ = bridge.resolve("patrol", &mut system).unwrap();
723            let _ = bridge.resolve("wander", &mut system).unwrap();
724
725            assert_eq!(system.model_count(), 3);
726        }
727
728        #[test]
729        fn test_resolve_unknown() {
730            let bridge = YamlAiBridge::new();
731            let mut system = AiSystem::new();
732
733            let result = bridge.resolve("unknown_behavior", &mut system);
734            assert!(result.is_err());
735        }
736
737        #[test]
738        fn test_resolve_caches_model() {
739            let bridge = YamlAiBridge::new();
740            let mut system = AiSystem::new();
741
742            // Resolve twice
743            let _ = bridge.resolve("builtin:chase", &mut system).unwrap();
744            let _ = bridge.resolve("builtin:chase", &mut system).unwrap();
745
746            // Should only have one model loaded
747            assert_eq!(system.model_count(), 1);
748        }
749
750        #[test]
751        fn test_register_custom() {
752            let mut bridge = YamlAiBridge::new();
753            bridge.register_custom("smart-ghost", "models/ghost.apr");
754
755            // Can't test file loading without a file, but registration works
756            assert!(!bridge.custom_models.is_empty());
757        }
758    }
759
760    mod mlp_inference_tests {
761        use super::*;
762
763        #[test]
764        fn test_simple_mlp() {
765            // Simple 2->2 identity-like network
766            let layers = vec![LayerWeights {
767                weights: vec![1.0, 0.0, 0.0, 1.0], // Identity matrix
768                biases: vec![0.0, 0.0],
769                input_size: 2,
770                output_size: 2,
771            }];
772
773            let input = vec![0.5, -0.5];
774            let output = AiSystem::run_mlp_inference(&layers, &input);
775
776            // With ReLU, negative becomes 0
777            assert!(output[0] > 0.0);
778            assert!(output[1].abs() < 0.01);
779        }
780
781        #[test]
782        fn test_multi_layer_mlp() {
783            let layers = vec![
784                LayerWeights {
785                    weights: vec![0.5, 0.5, 0.5, 0.5],
786                    biases: vec![0.0, 0.0],
787                    input_size: 2,
788                    output_size: 2,
789                },
790                LayerWeights {
791                    weights: vec![1.0, 1.0],
792                    biases: vec![0.0],
793                    input_size: 2,
794                    output_size: 1,
795                },
796            ];
797
798            let input = vec![1.0, 1.0];
799            let output = AiSystem::run_mlp_inference(&layers, &input);
800
801            assert_eq!(output.len(), 1);
802            assert!(output[0] > 0.0);
803        }
804    }
805}