use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::random::{Rng, RngExt};
#[derive(Debug, Clone)]
pub struct LiquidStateMachine {
pub reservoir_weights: Array2<f64>,
pub input_weights: Array2<f64>,
pub output_weights: Array2<f64>,
pub reservoir_state: Array1<f64>,
pub reservoir_size: usize,
pub input_size: usize,
pub output_size: usize,
}
impl LiquidStateMachine {
pub fn new(input_size: usize, reservoir_size: usize, output_size: usize) -> Self {
let mut reservoir_weights = Array2::zeros((reservoir_size, reservoir_size));
let mut input_weights = Array2::zeros((reservoir_size, input_size));
let output_weights = Array2::zeros((output_size, reservoir_size));
for i in 0..reservoir_size {
for j in 0..reservoir_size {
if i != j && scirs2_core::random::rng().random::<f64>() < 0.1 {
reservoir_weights[[i, j]] =
(scirs2_core::random::rng().random::<f64>() - 0.5) * 0.1;
}
}
}
for i in 0..reservoir_size {
for j in 0..input_size {
input_weights[[i, j]] = (scirs2_core::random::rng().random::<f64>() - 0.5) * 0.5;
}
}
Self {
reservoir_weights,
input_weights,
output_weights,
reservoir_state: Array1::zeros(reservoir_size),
reservoir_size,
input_size,
output_size,
}
}
pub fn update_reservoir(&mut self, input: &ArrayView1<f64>) {
let mut new_state: Array1<f64> = Array1::zeros(self.reservoir_size);
for i in 0..self.reservoir_size {
for j in 0..self.input_size.min(input.len()) {
new_state[i] += self.input_weights[[i, j]] * input[j];
}
}
for i in 0..self.reservoir_size {
for j in 0..self.reservoir_size {
new_state[i] += self.reservoir_weights[[i, j]] * self.reservoir_state[j];
}
}
for i in 0..self.reservoir_size {
new_state[i] = new_state[i].tanh();
}
self.reservoir_state = new_state;
}
pub fn compute_output(&self) -> Array1<f64> {
let mut output = Array1::zeros(self.output_size);
for i in 0..self.output_size {
for j in 0..self.reservoir_size {
output[i] += self.output_weights[[i, j]] * self.reservoir_state[j];
}
}
output
}
pub fn train_output_weights(
&mut self,
targets: &Array2<f64>,
states: &Array2<f64>,
) -> Result<(), OptimizeError> {
let state_dims = states.ncols();
let target_dims = targets.ncols();
self.output_weights = Array2::eye(state_dims.min(target_dims));
Ok(())
}
}
#[allow(dead_code)]
pub fn lsm_optimize<F>(
objective: F,
initial_params: &ArrayView1<f64>,
num_nit: usize,
) -> OptimizeResult<Array1<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let input_size = initial_params.len();
let reservoir_size = 100;
let output_size = input_size;
let mut lsm = LiquidStateMachine::new(input_size, reservoir_size, output_size);
let mut params = initial_params.to_owned();
for _iter in 0..num_nit {
lsm.update_reservoir(¶ms.view());
let updates = lsm.compute_output();
for i in 0..params.len() {
if i < updates.len() {
params[i] += 0.01 * updates[i];
}
}
let _obj_val = objective(¶ms.view());
}
Ok(params)
}
#[allow(dead_code)]
pub fn placeholder() {
}