use crate::tensor::Tensor;
use crate::config::SubjectLayersConfig;
#[derive(Debug, Clone)]
pub struct SubjectLayers {
pub weights: Tensor,
pub bias: Option<Tensor>,
pub config: SubjectLayersConfig,
}
impl SubjectLayers {
pub fn new(in_channels: usize, out_channels: usize, config: &SubjectLayersConfig) -> Self {
let n = config.num_weight_subjects();
Self {
weights: Tensor::zeros(&[n, in_channels, out_channels]),
bias: if config.bias {
Some(Tensor::zeros(&[n, out_channels]))
} else {
None
},
config: config.clone(),
}
}
pub fn forward(&self, x: &Tensor, subjects: Option<&[usize]>) -> Tensor {
let (b, c, t) = (x.shape[0], x.shape[1], x.shape[2]);
let d = self.weights.shape[2];
if self.config.average_subjects {
let idx = self.config.n_subjects;
let w_offset = idx * c * d;
let w_slice = Tensor::from_vec(
self.weights.data[w_offset..w_offset + c * d].to_vec(),
vec![c, d],
);
let out = x.einsum_bct_cd_bdt(&w_slice);
if let Some(ref bias) = self.bias {
let b_offset = idx * d;
let b_data: Vec<f32> = bias.data[b_offset..b_offset + d].to_vec();
return self.add_bias_3d(&out, &b_data);
}
return out;
}
let subj = subjects.unwrap_or(&[0]);
if b == 1 || (subj.len() == b && subj.windows(2).all(|w| w[0] == w[1])) {
let idx = if subj.is_empty() { 0 } else { subj[0] };
let w_offset = idx * c * d;
let w_slice = Tensor::from_vec(
self.weights.data[w_offset..w_offset + c * d].to_vec(),
vec![c, d],
);
let out = x.einsum_bct_cd_bdt(&w_slice);
if let Some(ref bias) = self.bias {
let b_offset = idx * d;
let b_data: Vec<f32> = bias.data[b_offset..b_offset + d].to_vec();
return self.add_bias_3d(&out, &b_data);
}
return out;
}
let mut out_data = vec![0.0f32; b * d * t];
for bi in 0..b {
let idx = if bi < subj.len() { subj[bi] } else { 0 };
let w_offset = idx * c * d;
for di in 0..d {
for ti in 0..t {
let mut sum = 0.0f32;
for ci in 0..c {
sum += x.data[bi * c * t + ci * t + ti]
* self.weights.data[w_offset + ci * d + di];
}
out_data[bi * d * t + di * t + ti] = sum;
}
}
if let Some(ref bias) = self.bias {
let b_off = idx * d;
for di in 0..d {
let bv = bias.data[b_off + di];
for ti in 0..t {
out_data[bi * d * t + di * t + ti] += bv;
}
}
}
}
Tensor::from_vec(out_data, vec![b, d, t])
}
fn add_bias_3d(&self, x: &Tensor, bias_data: &[f32]) -> Tensor {
let (b, d, t) = (x.shape[0], x.shape[1], x.shape[2]);
let mut data = x.data.clone();
for bi in 0..b {
for di in 0..d {
let bv = bias_data[di];
for ti in 0..t {
data[bi * d * t + di * t + ti] += bv;
}
}
}
Tensor::from_vec(data, x.shape.clone())
}
}