use super::StreamingMamba;
pub(super) fn build_readout_features(
model: &StreamingMamba,
gated_output: &[f64],
state: &[f64],
raw_input: &[f64],
) -> Vec<f64> {
use crate::ssm::mamba_config::MambaVersion;
match model.config.version {
MambaVersion::V1 => build_readout_features_v1(model, gated_output, state),
MambaVersion::V3 | MambaVersion::V3Mimo { .. } => {
build_readout_features_v3(model, gated_output, state)
}
MambaVersion::V3Exp { .. } => {
build_readout_features_v3exp(model, gated_output, state, raw_input)
}
MambaVersion::BlockDiagonal { block_size } => {
build_readout_features_bd(model, gated_output, state, block_size)
}
}
}
fn build_readout_features_v1(
model: &StreamingMamba,
gated_output: &[f64],
state: &[f64],
) -> Vec<f64> {
let d_in = model.config.d_in;
let n_state = model.config.n_state;
let mut rf = Vec::with_capacity(d_in * 2);
rf.extend_from_slice(gated_output);
for d in 0..d_in {
let offset = d * n_state;
let energy: f64 = state[offset..offset + n_state]
.iter()
.map(|&s| s * s)
.sum::<f64>();
rf.push(energy.sqrt());
}
rf
}
fn build_readout_features_v3(
model: &StreamingMamba,
gated_output: &[f64],
state: &[f64],
) -> Vec<f64> {
let d_in = model.config.d_in;
let n_groups = model.config.n_groups;
let mut rf = Vec::with_capacity(d_in + n_groups);
rf.extend_from_slice(gated_output);
let per_group = if n_groups > 0 {
state.len() / n_groups
} else {
0
};
for g in 0..n_groups {
let group_start = g * per_group;
let group_end = (group_start + per_group).min(state.len());
let group_slice = if group_start < state.len() {
&state[group_start..group_end]
} else {
&[]
};
let energy: f64 = group_slice.iter().map(|&s| s * s).sum::<f64>();
rf.push(energy.sqrt());
}
rf
}
fn build_readout_features_v3exp(
model: &StreamingMamba,
gated_output: &[f64],
state: &[f64],
raw_input: &[f64],
) -> Vec<f64> {
let d_in = model.config.d_in;
let n_groups = model.config.n_groups;
let n_state = model.config.n_state;
let n_complex = n_groups * n_state;
let n_lift = model.n_lift;
let mut rf = Vec::with_capacity(d_in + n_groups + 4 * n_complex + n_lift);
rf.extend_from_slice(gated_output);
let per_group = if n_groups > 0 {
state.len() / n_groups
} else {
0
};
for g in 0..n_groups {
let group_start = g * per_group;
let group_end = (group_start + per_group).min(state.len());
let group_slice = if group_start < state.len() {
&state[group_start..group_end]
} else {
&[][..]
};
let energy: f64 = group_slice.iter().map(|&s| s * s).sum::<f64>();
rf.push(energy.sqrt());
}
let base_end_before_complex = rf.len();
for g in 0..n_groups {
for n in 0..n_state {
let re_idx = (g * n_state + n) * 2;
let im_idx = re_idx + 1;
let re = state.get(re_idx).copied().unwrap_or(0.0);
let im = state.get(im_idx).copied().unwrap_or(0.0);
let mod_sq = re * re + im * im;
rf.push(re);
rf.push(im);
rf.push(mod_sq.sqrt());
rf.push(mod_sq);
}
}
debug_assert_eq!(
rf.len(),
base_end_before_complex + 4 * n_complex,
"V3Exp tertiary block size mismatch"
);
if let (Some(weights), Some(biases)) =
(model.lift_weights.as_deref(), model.lift_bias.as_deref())
{
let lift_input_dim = model.lift_input_dim;
debug_assert_eq!(
lift_input_dim, d_in,
"V3Exp lift input dim must match d_in for raw-input projection"
);
let mut bipolar: Vec<f64> = Vec::with_capacity(lift_input_dim);
for &x in raw_input.iter().take(d_in) {
bipolar.push(2.0 * x - 1.0);
}
while bipolar.len() < lift_input_dim {
bipolar.push(0.0);
}
for (m, &bias) in biases.iter().take(n_lift).enumerate() {
let row_start = m * lift_input_dim;
let row = &weights[row_start..row_start + lift_input_dim];
let dot: f64 = row.iter().zip(bipolar.iter()).map(|(w, z)| w * z).sum();
rf.push((dot + bias).tanh());
}
}
rf
}
fn build_readout_features_bd(
model: &StreamingMamba,
gated_output: &[f64],
state: &[f64],
block_size: usize,
) -> Vec<f64> {
let d_in = model.config.d_in;
let n_state = model.config.n_state;
let n_blocks = d_in / block_size;
let block_state_size = n_state * block_size;
let mut rf = Vec::with_capacity(d_in + n_blocks);
rf.extend_from_slice(gated_output);
for b in 0..n_blocks {
let start = b * block_state_size;
let end = (start + block_state_size).min(state.len());
let energy: f64 = state[start..end].iter().map(|&s| s * s).sum::<f64>();
rf.push(energy.sqrt());
}
rf
}