use crate::error::{AnomalyError, AnomalyResult};
use crate::handle::LcgRng;
const GRAD_CLIP: f64 = 1.0;
fn xavier_init_f64(fan_in: usize, fan_out: usize, rng: &mut LcgRng) -> Vec<f64> {
let limit = (6.0_f64 / (fan_in + fan_out) as f64).sqrt();
(0..fan_in * fan_out)
.map(|_| {
let u = rng.next_f32() as f64;
u * 2.0 * limit - limit
})
.collect()
}
#[derive(Debug, Clone)]
pub struct LstmAeConfig {
pub window_size: usize,
pub input_dim: usize,
pub hidden_dim: usize,
pub lr: f64,
pub n_epochs: usize,
}
impl Default for LstmAeConfig {
fn default() -> Self {
Self {
window_size: 10,
input_dim: 1,
hidden_dim: 16,
lr: 1e-3,
n_epochs: 20,
}
}
}
#[derive(Debug, Clone)]
pub struct LstmAeFit {
pub enc_wx: Vec<f64>,
pub enc_wh: Vec<f64>,
pub enc_b: Vec<f64>,
pub dec_wx: Vec<f64>,
pub dec_wh: Vec<f64>,
pub dec_b: Vec<f64>,
pub dec_out_w: Vec<f64>,
pub dec_out_b: Vec<f64>,
pub window_size: usize,
pub input_dim: usize,
pub hidden_dim: usize,
}
fn rnn_step(
x: &[f64],
h_prev: &[f64],
wx: &[f64],
wh: &[f64],
b: &[f64],
input_dim: usize,
hidden_dim: usize,
) -> Vec<f64> {
let mut pre = vec![0.0_f64; hidden_dim];
for o in 0..hidden_dim {
let mut acc = b[o];
for i in 0..input_dim {
acc += wx[o * input_dim + i] * x[i];
}
for j in 0..hidden_dim {
acc += wh[o * hidden_dim + j] * h_prev[j];
}
pre[o] = acc;
}
pre.iter().map(|&v| v.tanh()).collect()
}
fn linear_proj(h: &[f64], w: &[f64], b: &[f64], hidden_dim: usize, out_dim: usize) -> Vec<f64> {
let mut out = vec![0.0_f64; out_dim];
for o in 0..out_dim {
let mut acc = b[o];
for j in 0..hidden_dim {
acc += w[o * hidden_dim + j] * h[j];
}
out[o] = acc;
}
out
}
fn encode_window(window: &[f64], fit: &LstmAeFit) -> (Vec<Vec<f64>>, Vec<f64>) {
let t = fit.window_size;
let d = fit.input_dim;
let h = fit.hidden_dim;
let mut hidden_states: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut h_prev = vec![0.0_f64; h];
for step in 0..t {
let x_t = &window[step * d..(step + 1) * d];
let h_new = rnn_step(x_t, &h_prev, &fit.enc_wx, &fit.enc_wh, &fit.enc_b, d, h);
hidden_states.push(h_new.clone());
h_prev = h_new;
}
let context = hidden_states[t - 1].clone();
(hidden_states, context)
}
fn decode_window(
context: &[f64],
fit: &LstmAeFit,
teacher_inputs: Option<&[f64]>,
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
let t = fit.window_size;
let d = fit.input_dim;
let h = fit.hidden_dim;
let mut dec_hidden: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut recons: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut s_prev = context.to_vec();
let mut prev_out = vec![0.0_f64; d];
for step in 0..t {
let x_in = if let Some(inputs) = teacher_inputs {
let rev_t = t - 1 - step;
inputs[rev_t * d..(rev_t + 1) * d].to_vec()
} else {
prev_out.clone()
};
let s_new = rnn_step(&x_in, &s_prev, &fit.dec_wx, &fit.dec_wh, &fit.dec_b, d, h);
let x_hat = linear_proj(&s_new, &fit.dec_out_w, &fit.dec_out_b, h, d);
prev_out = x_hat.clone();
dec_hidden.push(s_new.clone());
recons.push(x_hat);
s_prev = s_new;
}
(dec_hidden, recons)
}
#[inline]
fn clip_grad(g: f64) -> f64 {
g.clamp(-GRAD_CLIP, GRAD_CLIP)
}
#[allow(clippy::needless_range_loop)]
pub fn lstm_ae_fit(
series: &[f64],
n: usize,
d: usize,
cfg: &LstmAeConfig,
seed: u64,
) -> AnomalyResult<LstmAeFit> {
if d == 0 {
return Err(AnomalyError::InvalidFeatureCount { n: 0 });
}
if cfg.window_size == 0 {
return Err(AnomalyError::InvalidLayerDims {
msg: "window_size must be > 0".into(),
});
}
if n < cfg.window_size {
return Err(AnomalyError::InsufficientSamples {
need: cfg.window_size,
got: n,
});
}
if series.len() != n * d {
return Err(AnomalyError::DimensionMismatch {
expected: n * d,
got: series.len(),
});
}
let t = cfg.window_size;
let h = cfg.hidden_dim;
let lr = cfg.lr;
let mut rng = LcgRng::new(seed);
let enc_wx = xavier_init_f64(d, h, &mut rng);
let enc_wh = xavier_init_f64(h, h, &mut rng);
let enc_b = vec![0.0_f64; h];
let dec_wx = xavier_init_f64(d, h, &mut rng);
let dec_wh = xavier_init_f64(h, h, &mut rng);
let dec_b = vec![0.0_f64; h];
let dec_out_w = xavier_init_f64(h, d, &mut rng);
let dec_out_b = vec![0.0_f64; d];
let mut fit = LstmAeFit {
enc_wx,
enc_wh,
enc_b,
dec_wx,
dec_wh,
dec_b,
dec_out_w,
dec_out_b,
window_size: t,
input_dim: d,
hidden_dim: h,
};
let n_windows = n - t + 1;
for _epoch in 0..cfg.n_epochs {
for w_start in 0..n_windows {
let window = &series[w_start * d..(w_start + t) * d];
let mut enc_pre: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut enc_h: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut h_prev = vec![0.0_f64; h];
for step in 0..t {
let x_t = &window[step * d..(step + 1) * d];
let mut pre = vec![0.0_f64; h];
for o in 0..h {
let mut acc = fit.enc_b[o];
for i in 0..d {
acc += fit.enc_wx[o * d + i] * x_t[i];
}
for j in 0..h {
acc += fit.enc_wh[o * h + j] * h_prev[j];
}
pre[o] = acc;
}
let h_new: Vec<f64> = pre.iter().map(|&v| v.tanh()).collect();
enc_pre.push(pre);
enc_h.push(h_new.clone());
h_prev = h_new;
}
let context = enc_h[t - 1].clone();
let mut dec_pre: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut dec_h: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut dec_out: Vec<Vec<f64>> = Vec::with_capacity(t);
let mut s_prev = context.clone();
let mut x_in = vec![0.0_f64; d];
for step in 0..t {
let mut pre = vec![0.0_f64; h];
for o in 0..h {
let mut acc = fit.dec_b[o];
for i in 0..d {
acc += fit.dec_wx[o * d + i] * x_in[i];
}
for j in 0..h {
acc += fit.dec_wh[o * h + j] * s_prev[j];
}
pre[o] = acc;
}
let s_new: Vec<f64> = pre.iter().map(|&v| v.tanh()).collect();
let x_hat = linear_proj(&s_new, &fit.dec_out_w, &fit.dec_out_b, h, d);
let rev_t = t - 1 - step;
x_in = window[rev_t * d..(rev_t + 1) * d].to_vec();
dec_pre.push(pre);
dec_h.push(s_new.clone());
dec_out.push(x_hat);
s_prev = s_new;
}
let inv_td = 1.0 / (t * d) as f64;
let mut d_enc_wx = vec![0.0_f64; h * d];
let mut d_enc_wh = vec![0.0_f64; h * h];
let mut d_enc_b = vec![0.0_f64; h];
let mut d_dec_wx = vec![0.0_f64; h * d];
let mut d_dec_wh = vec![0.0_f64; h * h];
let mut d_dec_b = vec![0.0_f64; h];
let mut d_dec_out_w = vec![0.0_f64; d * h];
let mut d_dec_out_b = vec![0.0_f64; d];
let mut d_s_next = vec![0.0_f64; h];
let mut dec_inputs: Vec<Vec<f64>> = Vec::with_capacity(t);
dec_inputs.push(vec![0.0_f64; d]); for step in 0..t - 1 {
let rev_t = t - 1 - step;
dec_inputs.push(window[rev_t * d..(rev_t + 1) * d].to_vec());
}
let mut d_context = vec![0.0_f64; h];
for step in (0..t).rev() {
let rev_t = t - 1 - step; let target = &window[rev_t * d..(rev_t + 1) * d];
let x_hat = &dec_out[step];
let d_xhat: Vec<f64> = x_hat
.iter()
.zip(target.iter())
.map(|(&xh, &xt)| 2.0 * inv_td * clip_grad(xh - xt))
.collect();
for o in 0..d {
d_dec_out_b[o] += d_xhat[o];
for j in 0..h {
d_dec_out_w[o * h + j] += d_xhat[o] * dec_h[step][j];
}
}
let mut d_s_new = vec![0.0_f64; h];
for j in 0..h {
for o in 0..d {
d_s_new[j] += d_xhat[o] * fit.dec_out_w[o * h + j];
}
}
for j in 0..h {
d_s_new[j] += d_s_next[j];
}
let mut d_pre = vec![0.0_f64; h];
for j in 0..h {
let tanh_val = dec_pre[step][j].tanh();
d_pre[j] = clip_grad(d_s_new[j] * (1.0 - tanh_val * tanh_val));
}
let x_in_step = &dec_inputs[step];
let s_prev_step: &[f64] = if step == 0 {
&context
} else {
&dec_h[step - 1]
};
for o in 0..h {
d_dec_b[o] += d_pre[o];
for i in 0..d {
d_dec_wx[o * d + i] += d_pre[o] * x_in_step[i];
}
for j in 0..h {
d_dec_wh[o * h + j] += d_pre[o] * s_prev_step[j];
}
}
let mut d_s_prev = vec![0.0_f64; h];
for j in 0..h {
for o in 0..h {
d_s_prev[j] += d_pre[o] * fit.dec_wh[o * h + j];
}
}
if step == 0 {
for j in 0..h {
d_context[j] += d_s_prev[j];
}
} else {
d_s_next = d_s_prev;
}
}
let mut d_h_next = d_context;
for step in (0..t).rev() {
let mut d_pre = vec![0.0_f64; h];
for j in 0..h {
let tanh_val = enc_pre[step][j].tanh();
d_pre[j] = clip_grad(d_h_next[j] * (1.0 - tanh_val * tanh_val));
}
let x_t = &window[step * d..(step + 1) * d];
let h_prev_step: &[f64] = if step == 0 {
&[]
} else {
&enc_h[step - 1]
};
for o in 0..h {
d_enc_b[o] += d_pre[o];
for i in 0..d {
d_enc_wx[o * d + i] += d_pre[o] * x_t[i];
}
if step > 0 {
for j in 0..h {
d_enc_wh[o * h + j] += d_pre[o] * h_prev_step[j];
}
}
}
let mut d_h_prev = vec![0.0_f64; h];
if step > 0 {
for j in 0..h {
for o in 0..h {
d_h_prev[j] += d_pre[o] * fit.enc_wh[o * h + j];
}
}
}
d_h_next = d_h_prev;
}
for (w, &g) in fit.enc_wx.iter_mut().zip(d_enc_wx.iter()) {
*w -= lr * g;
}
for (w, &g) in fit.enc_wh.iter_mut().zip(d_enc_wh.iter()) {
*w -= lr * g;
}
for (w, &g) in fit.enc_b.iter_mut().zip(d_enc_b.iter()) {
*w -= lr * g;
}
for (w, &g) in fit.dec_wx.iter_mut().zip(d_dec_wx.iter()) {
*w -= lr * g;
}
for (w, &g) in fit.dec_wh.iter_mut().zip(d_dec_wh.iter()) {
*w -= lr * g;
}
for (w, &g) in fit.dec_b.iter_mut().zip(d_dec_b.iter()) {
*w -= lr * g;
}
for (w, &g) in fit.dec_out_w.iter_mut().zip(d_dec_out_w.iter()) {
*w -= lr * g;
}
for (w, &g) in fit.dec_out_b.iter_mut().zip(d_dec_out_b.iter()) {
*w -= lr * g;
}
}
}
Ok(fit)
}
pub fn lstm_ae_score(fit: &LstmAeFit, series: &[f64], n: usize) -> AnomalyResult<Vec<f64>> {
let d = fit.input_dim;
let t = fit.window_size;
if n < t {
return Err(AnomalyError::InsufficientSamples { need: t, got: n });
}
if series.len() != n * d {
return Err(AnomalyError::DimensionMismatch {
expected: n * d,
got: series.len(),
});
}
let n_windows = n - t + 1;
let mut score_sum = vec![0.0_f64; n];
let mut score_cnt = vec![0_usize; n];
for w_start in 0..n_windows {
let window = &series[w_start * d..(w_start + t) * d];
let (_, context) = encode_window(window, fit);
let (_, recons) = decode_window(&context, fit, None);
for step in 0..t {
let actual_t = w_start + (t - 1 - step);
let target = &window[(t - 1 - step) * d..(t - step) * d];
let x_hat = &recons[step];
let mse: f64 = target
.iter()
.zip(x_hat.iter())
.map(|(&xt, &xh)| (xt - xh).powi(2))
.sum::<f64>()
/ d as f64;
score_sum[actual_t] += mse;
score_cnt[actual_t] += 1;
}
}
let scores: Vec<f64> = score_sum
.iter()
.zip(score_cnt.iter())
.map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
.collect();
Ok(scores)
}
pub fn lstm_ae_predict(
fit: &LstmAeFit,
series: &[f64],
n: usize,
threshold: f64,
) -> AnomalyResult<Vec<bool>> {
let scores = lstm_ae_score(fit, series, n)?;
Ok(scores.iter().map(|&s| s >= threshold).collect())
}
#[cfg(test)]
mod tests {
use super::*;
fn default_cfg() -> LstmAeConfig {
LstmAeConfig {
window_size: 5,
input_dim: 1,
hidden_dim: 4,
lr: 1e-2,
n_epochs: 5,
}
}
fn constant_series(n: usize, val: f64) -> Vec<f64> {
vec![val; n]
}
#[test]
fn scores_have_correct_length() {
let cfg = default_cfg();
let n = 20_usize;
let series = constant_series(n, 0.5);
let fit =
lstm_ae_fit(&series, n, cfg.input_dim, &cfg, 1).expect("lstm_ae_fit should succeed");
let scores = lstm_ae_score(&fit, &series, n).expect("lstm_ae_score should succeed");
assert_eq!(scores.len(), n, "expected {n} scores, got {}", scores.len());
}
#[test]
fn scores_finite_nonneg() {
let cfg = default_cfg();
let n = 20_usize;
let series: Vec<f64> = (0..n).map(|i| (i as f64 * 0.1).sin()).collect();
let fit =
lstm_ae_fit(&series, n, cfg.input_dim, &cfg, 2).expect("lstm_ae_fit should succeed");
let scores = lstm_ae_score(&fit, &series, n).expect("lstm_ae_score should succeed");
for (i, &s) in scores.iter().enumerate() {
assert!(s.is_finite(), "score[{i}] = {s} not finite");
assert!(s >= 0.0, "score[{i}] = {s} is negative");
}
}
#[test]
fn spike_anomaly_scores_higher() {
let n = 30_usize;
let mut cfg = default_cfg();
cfg.n_epochs = 20;
cfg.hidden_dim = 8;
cfg.lr = 5e-3;
let spike_t = 15_usize;
let mut series = constant_series(n, 0.1);
series[spike_t] = 5.0;
let fit =
lstm_ae_fit(&series, n, cfg.input_dim, &cfg, 3).expect("lstm_ae_fit should succeed");
let scores = lstm_ae_score(&fit, &series, n).expect("lstm_ae_score should succeed");
let spike_score = scores[spike_t];
let normal_mean: f64 = scores
.iter()
.enumerate()
.filter(|&(i, _)| i != spike_t)
.map(|(_, &s)| s)
.sum::<f64>()
/ (n - 1) as f64;
assert!(
spike_score.is_finite(),
"spike score not finite: {spike_score}"
);
assert!(
spike_score > normal_mean,
"spike score {spike_score} should be > normal mean {normal_mean}"
);
}
#[test]
fn predict_correct_length() {
let cfg = default_cfg();
let n = 15_usize;
let series = constant_series(n, 0.3);
let fit =
lstm_ae_fit(&series, n, cfg.input_dim, &cfg, 4).expect("lstm_ae_fit should succeed");
let preds = lstm_ae_predict(&fit, &series, n, 0.1).expect("lstm_ae_predict should succeed");
assert_eq!(preds.len(), n);
}
#[test]
fn error_series_shorter_than_window() {
let cfg = LstmAeConfig {
window_size: 10,
..default_cfg()
};
let series = constant_series(5, 0.0); let result = lstm_ae_fit(&series, 5, cfg.input_dim, &cfg, 5);
assert!(
matches!(
result,
Err(AnomalyError::InsufficientSamples { need: 10, got: 5 })
),
"expected InsufficientSamples, got: {result:?}"
);
}
#[test]
fn error_input_dim_zero() {
let cfg = LstmAeConfig {
input_dim: 0,
..default_cfg()
};
let series: Vec<f64> = vec![];
let result = lstm_ae_fit(&series, 0, 0, &cfg, 6);
assert!(
matches!(result, Err(AnomalyError::InvalidFeatureCount { .. })),
"expected InvalidFeatureCount, got: {result:?}"
);
}
#[test]
fn window_size_one_works() {
let cfg = LstmAeConfig {
window_size: 1,
input_dim: 2,
hidden_dim: 4,
lr: 1e-2,
n_epochs: 3,
};
let n = 10_usize;
let series: Vec<f64> = (0..n * 2).map(|i| i as f64 * 0.1).collect();
let fit = lstm_ae_fit(&series, n, 2, &cfg, 7).expect("lstm_ae_fit should succeed");
let scores = lstm_ae_score(&fit, &series, n).expect("lstm_ae_score should succeed");
assert_eq!(scores.len(), n);
for &s in &scores {
assert!(s.is_finite() && s >= 0.0, "score = {s}");
}
}
#[test]
fn multivariate_series_works() {
let d = 3_usize;
let n = 20_usize;
let cfg = LstmAeConfig {
window_size: 4,
input_dim: d,
hidden_dim: 6,
lr: 1e-2,
n_epochs: 5,
};
let series: Vec<f64> = (0..n * d).map(|i| (i as f64 * 0.05).cos()).collect();
let fit = lstm_ae_fit(&series, n, d, &cfg, 8).expect("lstm_ae_fit should succeed");
let scores = lstm_ae_score(&fit, &series, n).expect("lstm_ae_score should succeed");
assert_eq!(scores.len(), n);
for (i, &s) in scores.iter().enumerate() {
assert!(s.is_finite(), "score[{i}] = {s} not finite");
}
}
#[test]
fn predict_threshold_zero() {
let cfg = default_cfg();
let n = 15_usize;
let series = constant_series(n, 0.5);
let fit =
lstm_ae_fit(&series, n, cfg.input_dim, &cfg, 9).expect("lstm_ae_fit should succeed");
let scores = lstm_ae_score(&fit, &series, n).expect("lstm_ae_score should succeed");
let preds = lstm_ae_predict(&fit, &series, n, 0.0).expect("lstm_ae_predict should succeed");
for (i, (&s, &p)) in scores.iter().zip(preds.iter()).enumerate() {
if s > 0.0 {
assert!(p, "timestep {i} score={s} should be flagged at threshold 0");
}
}
}
#[test]
fn error_on_score_too_short() {
let cfg = default_cfg();
let n = 20_usize;
let series = constant_series(n, 0.5);
let fit =
lstm_ae_fit(&series, n, cfg.input_dim, &cfg, 10).expect("lstm_ae_fit should succeed");
let short = constant_series(3, 0.5);
let result = lstm_ae_score(&fit, &short, 3);
assert!(
matches!(result, Err(AnomalyError::InsufficientSamples { .. })),
"expected InsufficientSamples, got: {result:?}"
);
}
#[test]
fn reconstruction_improves_over_training() {
let d = 1_usize;
let n = 25_usize;
let series: Vec<f64> = (0..n).map(|i| (i as f64 * 0.2).sin() * 0.5 + 0.5).collect();
let cfg_few = LstmAeConfig {
window_size: 5,
input_dim: d,
hidden_dim: 8,
lr: 5e-3,
n_epochs: 1,
};
let cfg_many = LstmAeConfig {
n_epochs: 50,
..cfg_few.clone()
};
let fit_few =
lstm_ae_fit(&series, n, d, &cfg_few, 200).expect("lstm_ae_fit should succeed");
let fit_many =
lstm_ae_fit(&series, n, d, &cfg_many, 200).expect("lstm_ae_fit should succeed");
let score_few: f64 = lstm_ae_score(&fit_few, &series, n)
.expect("lstm_ae_score should succeed")
.iter()
.sum();
let score_many: f64 = lstm_ae_score(&fit_many, &series, n)
.expect("lstm_ae_score should succeed")
.iter()
.sum();
assert!(
score_few.is_finite() && score_many.is_finite(),
"scores not finite: few={score_few}, many={score_many}"
);
assert!(
score_many <= score_few * 1.1,
"expected more epochs to not increase score beyond 10%: few={score_few}, many={score_many}"
);
}
#[test]
fn window_size_equals_n() {
let n = 8_usize;
let d = 1_usize;
let cfg = LstmAeConfig {
window_size: n,
input_dim: d,
hidden_dim: 4,
lr: 1e-2,
n_epochs: 3,
};
let series: Vec<f64> = (0..n).map(|i| i as f64 * 0.1).collect();
let fit = lstm_ae_fit(&series, n, d, &cfg, 12).expect("lstm_ae_fit should succeed");
let scores = lstm_ae_score(&fit, &series, n).expect("lstm_ae_score should succeed");
assert_eq!(scores.len(), n);
for &s in &scores {
assert!(s.is_finite() && s >= 0.0, "score = {s}");
}
}
#[test]
fn error_on_series_dim_mismatch() {
let d = 2_usize;
let n = 10_usize;
let cfg = LstmAeConfig {
window_size: 3,
input_dim: d,
hidden_dim: 4,
lr: 1e-2,
n_epochs: 2,
};
let series = vec![0.0_f64; n * d + 1];
let result = lstm_ae_fit(&series, n, d, &cfg, 13);
assert!(
matches!(result, Err(AnomalyError::DimensionMismatch { .. })),
"expected DimensionMismatch, got: {result:?}"
);
}
}