use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::Result;
#[derive(Debug)]
pub struct NeuralODE<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
parameters: Array2<F>,
time_steps: Array1<F>,
solver_config: ODESolverConfig<F>,
input_dim: usize,
hidden_dim: usize,
}
#[derive(Debug, Clone)]
pub struct ODESolverConfig<F: Float + Debug> {
method: IntegrationMethod,
#[allow(dead_code)]
step_size: F,
#[allow(dead_code)]
tolerance: F,
}
#[derive(Debug, Clone)]
pub enum IntegrationMethod {
Euler,
RungeKutta4,
RKF45,
}
impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> NeuralODE<F> {
pub fn new(
input_dim: usize,
hidden_dim: usize,
time_steps: Array1<F>,
solver_config: ODESolverConfig<F>,
) -> Self {
let total_params = input_dim * hidden_dim + hidden_dim * input_dim + 2 * hidden_dim;
let scale = F::from(2.0).expect("Failed to convert constant to float")
/ F::from(input_dim).expect("Failed to convert to float");
let std_dev = scale.sqrt();
let mut parameters = Array2::zeros((1, total_params));
for i in 0..total_params {
let val = ((i * 23) % 1000) as f64 / 1000.0 - 0.5;
parameters[[0, i]] = F::from(val).expect("Failed to convert to float") * std_dev;
}
Self {
parameters,
time_steps,
solver_config,
input_dim,
hidden_dim,
}
}
pub fn forward(&self, initial_state: &Array1<F>) -> Result<Array2<F>> {
let num_times = self.time_steps.len();
let mut trajectory = Array2::zeros((num_times, self.input_dim));
for i in 0..self.input_dim {
trajectory[[0, i]] = initial_state[i];
}
for t in 1..num_times {
let dt = self.time_steps[t] - self.time_steps[t - 1];
let current_state = trajectory.row(t - 1).to_owned();
let next_state = match self.solver_config.method {
IntegrationMethod::Euler => self.euler_step(¤t_state, dt)?,
IntegrationMethod::RungeKutta4 => self.rk4_step(¤t_state, dt)?,
IntegrationMethod::RKF45 => self.rkf45_step(¤t_state, dt)?,
};
for i in 0..self.input_dim {
trajectory[[t, i]] = next_state[i];
}
}
Ok(trajectory)
}
fn neural_network(&self, state: &Array1<F>) -> Result<Array1<F>> {
let (w1, b1, w2, b2) = self.extract_ode_weights();
let mut hidden = Array1::zeros(self.hidden_dim);
for i in 0..self.hidden_dim {
let mut sum = b1[i];
for j in 0..self.input_dim {
sum = sum + w1[[i, j]] * state[j];
}
hidden[i] = self.tanh(sum);
}
let mut output = Array1::zeros(self.input_dim);
for i in 0..self.input_dim {
let mut sum = b2[i];
for j in 0..self.hidden_dim {
sum = sum + w2[[i, j]] * hidden[j];
}
output[i] = sum;
}
Ok(output)
}
fn extract_ode_weights(&self) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
let param_vec = self.parameters.row(0);
let mut idx = 0;
let mut w1 = Array2::zeros((self.hidden_dim, self.input_dim));
for i in 0..self.hidden_dim {
for j in 0..self.input_dim {
w1[[i, j]] = param_vec[idx];
idx += 1;
}
}
let mut b1 = Array1::zeros(self.hidden_dim);
for i in 0..self.hidden_dim {
b1[i] = param_vec[idx];
idx += 1;
}
let mut w2 = Array2::zeros((self.input_dim, self.hidden_dim));
for i in 0..self.input_dim {
for j in 0..self.hidden_dim {
w2[[i, j]] = param_vec[idx];
idx += 1;
}
}
let mut b2 = Array1::zeros(self.input_dim);
for i in 0..self.input_dim {
b2[i] = param_vec[idx];
idx += 1;
}
(w1, b1, w2, b2)
}
fn euler_step(&self, state: &Array1<F>, dt: F) -> Result<Array1<F>> {
let derivative = self.neural_network(state)?;
let mut next_state = Array1::zeros(self.input_dim);
for i in 0..self.input_dim {
next_state[i] = state[i] + dt * derivative[i];
}
Ok(next_state)
}
fn rk4_step(&self, state: &Array1<F>, dt: F) -> Result<Array1<F>> {
let k1 = self.neural_network(state)?;
let mut temp_state = Array1::zeros(self.input_dim);
for i in 0..self.input_dim {
temp_state[i] =
state[i] + dt * k1[i] / F::from(2.0).expect("Failed to convert constant to float");
}
let k2 = self.neural_network(&temp_state)?;
for i in 0..self.input_dim {
temp_state[i] =
state[i] + dt * k2[i] / F::from(2.0).expect("Failed to convert constant to float");
}
let k3 = self.neural_network(&temp_state)?;
for i in 0..self.input_dim {
temp_state[i] = state[i] + dt * k3[i];
}
let k4 = self.neural_network(&temp_state)?;
let mut next_state = Array1::zeros(self.input_dim);
for i in 0..self.input_dim {
next_state[i] = state[i]
+ dt * (k1[i]
+ F::from(2.0).expect("Failed to convert constant to float") * k2[i]
+ F::from(2.0).expect("Failed to convert constant to float") * k3[i]
+ k4[i])
/ F::from(6.0).expect("Failed to convert constant to float");
}
Ok(next_state)
}
fn rkf45_step(&self, state: &Array1<F>, dt: F) -> Result<Array1<F>> {
self.rk4_step(state, dt)
}
fn tanh(&self, x: F) -> F {
x.tanh()
}
}
impl<F: Float + Debug> ODESolverConfig<F> {
pub fn new(method: IntegrationMethod, step_size: F, tolerance: F) -> Self {
Self {
method,
step_size,
tolerance,
}
}
pub fn euler(step_size: F) -> Self {
Self::new(
IntegrationMethod::Euler,
step_size,
F::from(1e-6).expect("Failed to convert constant to float"),
)
}
pub fn runge_kutta4(step_size: F) -> Self {
Self::new(
IntegrationMethod::RungeKutta4,
step_size,
F::from(1e-6).expect("Failed to convert constant to float"),
)
}
pub fn rkf45(step_size: F, tolerance: F) -> Self {
Self::new(IntegrationMethod::RKF45, step_size, tolerance)
}
}