lmm 0.1.2

A language agnostic framework for emulating reality.
Documentation
use crate::equation::Expression;
use crate::error::Result;
use crate::simulation::Simulator;
use crate::tensor::Tensor;
use crate::traits::Simulatable;
use std::collections::HashMap;

pub struct WorldModel {
    pub current_state: Tensor,
    pub prediction_errors: Vec<f64>,
}

impl WorldModel {
    pub fn new(initial_state: Tensor) -> Self {
        Self {
            current_state: initial_state,
            prediction_errors: Vec::new(),
        }
    }

    pub fn step(&mut self, action: &Tensor) -> Result<Tensor> {
        let next_state = (&self.current_state + action)?;
        self.current_state = next_state.clone();
        Ok(next_state)
    }

    pub fn physics_step<M: Simulatable>(
        &mut self,
        physics: &M,
        action: &Tensor,
        step_size: f64,
    ) -> Result<Tensor> {
        let sim = Simulator { step_size };
        let physics_next = sim.rk4_step(physics, &self.current_state)?;
        let next_state = (&physics_next + action)?;
        self.current_state = next_state.clone();
        Ok(next_state)
    }

    pub fn equation_step(
        &mut self,
        equation: &Expression,
        var_names: &[&str],
        step_size: f64,
    ) -> Result<Tensor> {
        let bindings: HashMap<String, f64> = var_names
            .iter()
            .zip(self.current_state.data.iter())
            .map(|(k, &v)| ((*k).to_string(), v))
            .collect();
        let derivative = equation.evaluate(&bindings)?;
        let delta = Tensor::fill(self.current_state.shape.clone(), derivative * step_size);
        let next_state = (&self.current_state + &delta)?;
        self.current_state = next_state.clone();
        Ok(next_state)
    }

    pub fn predict_horizon(&self, actions: &[Tensor]) -> Result<Vec<Tensor>> {
        let mut predictions = Vec::with_capacity(actions.len());
        let mut temp_state = self.current_state.clone();
        for action in actions {
            temp_state = (&temp_state + action)?;
            predictions.push(temp_state.clone());
        }
        Ok(predictions)
    }

    pub fn predict_horizon_physics<M: Simulatable>(
        &self,
        physics: &M,
        n_steps: usize,
        step_size: f64,
    ) -> Result<Vec<Tensor>> {
        let sim = Simulator { step_size };
        sim.simulate_trajectory(physics, &self.current_state, n_steps)
    }

    pub fn record_error(&mut self, predicted: &Tensor, actual: &Tensor) -> Result<f64> {
        let diff = (predicted - actual)?;
        let mse = diff.data.iter().map(|x| x * x).sum::<f64>() / diff.data.len() as f64;
        self.prediction_errors.push(mse);
        Ok(mse)
    }

    pub fn mean_prediction_error(&self) -> f64 {
        if self.prediction_errors.is_empty() {
            return 0.0;
        }
        self.prediction_errors.iter().sum::<f64>() / self.prediction_errors.len() as f64
    }
}