1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use burn::prelude::*;
use burn::module::{Param, ParamId};
use crate::config::SubjectLayersConfig;
/// SubjectLayersModel: per-subject linear + average_subjects inference mode.
/// weights: [num_weight_subjects, C, D], bias: [num_weight_subjects, D]
#[derive(Module, Debug)]
pub struct SubjectLayers<B: Backend> {
pub weights: Param<Tensor<B, 3>>,
pub bias: Option<Param<Tensor<B, 2>>>,
/// Pre-transposed average-subject weight: [1, D, C].
/// Stored ready for `w_avg_t.matmul(x)` so every forward skips the
/// runtime slice → squeeze → transpose → unsqueeze chain on the full
/// [N, C, D] weight tensor.
pub w_avg_t: Param<Tensor<B, 3>>,
pub n_subjects: usize,
pub has_dropout: bool,
}
impl<B: Backend> SubjectLayers<B> {
pub fn new(in_ch: usize, out_ch: usize, config: &SubjectLayersConfig, device: &B::Device) -> Self {
let n = config.num_weight_subjects();
let weights_data = Tensor::zeros([n, in_ch, out_ch], device);
// Pre-compute and store the average-subject weight transposed [1, D, C].
// avg subject is always the last row (index n_subjects).
let idx = config.n_subjects;
let w_avg_t = weights_data
.clone()
.slice([idx..idx + 1]) // [1, C, D]
.squeeze::<2>() // [C, D]
.transpose() // [D, C]
.unsqueeze::<3>(); // [1, D, C]
Self {
weights: Param::initialized(ParamId::new(), weights_data),
bias: if config.bias {
Some(Param::initialized(ParamId::new(), Tensor::zeros([n, out_ch], device)))
} else {
None
},
w_avg_t: Param::initialized(ParamId::new(), w_avg_t),
n_subjects: config.n_subjects,
has_dropout: config.subject_dropout.is_some(),
}
}
/// Average-subjects mode: use the pre-transposed dropout-subject weight.
/// x: [B, C, T] → [B, D, T]
///
/// `w_avg_t` is stored as [1, D, C] and updated in `set_weights_for_inference`.
/// matmul: [1, D, C] @ [B, C, T] = [B, D, T]
pub fn forward_average(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
// [1, D, C] @ [B, C, T] = [B, D, T]
let out = self.w_avg_t.val().matmul(x);
if let Some(ref bias) = self.bias {
let idx = self.n_subjects;
let bv: Tensor<B, 1> = bias.val().slice([idx..idx + 1]).squeeze();
out + bv.unsqueeze_dim::<2>(0).unsqueeze_dim::<3>(2)
} else {
out
}
}
/// Call this after loading pretrained weights to re-sync `w_avg_t`
/// with `weights[n_subjects]`. During benchmarking with random weights
/// this is not needed, but it must be called in the real inference path
/// after `model = model.load_record(record)`.
pub fn rebuild_w_avg_t(mut self) -> Self {
let idx = self.n_subjects;
let w = self.weights.val();
let w_t = w
.slice([idx..idx + 1])
.squeeze::<2>()
.transpose()
.unsqueeze::<3>();
self.w_avg_t = Param::initialized(ParamId::new(), w_t);
self
}
/// Per-subject gather mode.
pub fn forward_subjects(&self, x: Tensor<B, 3>, subject_ids: &[usize]) -> Tensor<B, 3> {
let [b, _c, _t] = x.dims();
let w = self.weights.val();
let device = x.device();
let idx_t = Tensor::<B, 1, Int>::from_data(
TensorData::new(subject_ids.iter().map(|&i| i as i64).collect::<Vec<_>>(), [b]),
&device,
);
// w_sel: [B, C, D] → transpose last two → [B, D, C]
let w_sel = w.select(0, idx_t.clone());
let w_sel_t = w_sel.swap_dims(1, 2); // [B, D, C]
// [B, D, C] @ [B, C, T] = [B, D, T]
let out = w_sel_t.matmul(x);
if let Some(ref bias) = self.bias {
let b_sel = bias.val().select(0, idx_t); // [B, D]
out + b_sel.unsqueeze_dim::<3>(2)
} else {
out
}
}
}