Skip to main content

oxicuda_ssl/ssl/
jem.rs

1//! JEM — Joint Energy Model (Grathwohl et al. 2019).
2//!
3//! Interprets a classifier `f(x)` as an energy model:
4//! ```text
5//!   p(x) ∝ exp(logsumexp(f(x)))
6//!   p(y|x) = softmax(f(x))[y]
7//!   E(x) = -logsumexp(f(x))
8//! ```
9//!
10//! Training uses contrastive divergence: maximise `log p(x_data) - log p(x_mcmc)`
11//! where `x_mcmc` samples are generated via SGLD (Stochastic Gradient Langevin
12//! Dynamics).
13//!
14//! ## SGLD update rule
15//! ```text
16//!   x_{t+1} = x_t - (step_size/2) * ∇_x E(x_t) + noise * N(0,1)
17//! ```
18//! The energy gradient is approximated with central finite differences:
19//! ```text
20//!   ∂E/∂x_i ≈ (E(x + ε·eᵢ) − E(x − ε·eᵢ)) / (2ε)
21//! ```
22//!
23//! A persistent replay buffer stores MCMC chains between training steps so that
24//! the distribution slowly converges to `p_θ(x)`.
25//!
26//! Reference: "Your Classifier is Secretly an Energy Based Model and You Should
27//! Treat it Like One", Grathwohl et al., ICLR 2020.
28
29use crate::error::{SslError, SslResult};
30use crate::handle::LcgRng;
31
32// ─── Configuration ────────────────────────────────────────────────────────────
33
34/// Hyper-parameters for the [`Jem`] joint energy model.
35#[derive(Debug, Clone)]
36pub struct JemConfig {
37    /// Input dimension (feature / latent vector length).
38    pub d_input: usize,
39    /// Number of output classes.
40    pub n_classes: usize,
41    /// Hidden layer dimension (single hidden layer MLP).
42    pub n_hidden: usize,
43    /// Number of SGLD steps per [`Jem::sgld_step`] call.
44    pub sgld_steps: usize,
45    /// SGLD gradient step size `α`.
46    pub sgld_step_size: f32,
47    /// SGLD noise scale `σ` applied to the Langevin noise term.
48    pub sgld_noise: f32,
49    /// Number of persistent MCMC chains stored in the replay buffer.
50    pub buffer_size: usize,
51}
52
53impl Default for JemConfig {
54    fn default() -> Self {
55        Self {
56            d_input: 4,
57            n_classes: 2,
58            n_hidden: 16,
59            sgld_steps: 20,
60            sgld_step_size: 0.01,
61            sgld_noise: 0.005,
62            buffer_size: 64,
63        }
64    }
65}
66
67// ─── JEM model ────────────────────────────────────────────────────────────────
68
69/// Joint Energy Model that combines a discriminative classifier with an
70/// energy-based generative model via a shared 2-layer MLP.
71///
72/// Architecture: `d_input → (Linear + ReLU) → n_hidden → Linear → n_classes`
73#[derive(Debug, Clone)]
74pub struct Jem {
75    /// First layer weights `[n_hidden × d_input]`.
76    w1: Vec<f32>,
77    /// First layer bias `[n_hidden]`.
78    b1: Vec<f32>,
79    /// Second layer weights `[n_classes × n_hidden]`.
80    w2: Vec<f32>,
81    /// Second layer bias `[n_classes]`.
82    b2: Vec<f32>,
83    /// Persistent MCMC replay buffer; each entry is a `[d_input]` vector.
84    replay_buffer: Vec<Vec<f32>>,
85    /// Model configuration.
86    config: JemConfig,
87}
88
89impl Jem {
90    /// Create a new [`Jem`] model with Kaiming-initialised weights and a
91    /// small-noise replay buffer.
92    ///
93    /// # Errors
94    /// [`SslError::InvalidParameter`] if `d_input`, `n_classes`, or `n_hidden` is zero.
95    pub fn new(config: JemConfig, rng: &mut LcgRng) -> SslResult<Self> {
96        if config.d_input == 0 {
97            return Err(SslError::InvalidParameter {
98                name: "d_input".into(),
99                reason: "must be > 0".into(),
100            });
101        }
102        if config.n_classes == 0 {
103            return Err(SslError::InvalidParameter {
104                name: "n_classes".into(),
105                reason: "must be > 0".into(),
106            });
107        }
108        if config.n_hidden == 0 {
109            return Err(SslError::InvalidParameter {
110                name: "n_hidden".into(),
111                reason: "must be > 0".into(),
112            });
113        }
114
115        let w1 = kaiming_init(config.n_hidden, config.d_input, rng);
116        let b1 = vec![0.0_f32; config.n_hidden];
117        let w2 = kaiming_init(config.n_classes, config.n_hidden, rng);
118        let b2 = vec![0.0_f32; config.n_classes];
119
120        // Initialise replay buffer with small normal noise in [-0.01, 0.01] range.
121        let buf_size = config.buffer_size.max(1);
122        let mut replay_buffer = Vec::with_capacity(buf_size);
123        for _ in 0..buf_size {
124            let mut entry = vec![0.0_f32; config.d_input];
125            rng.fill_normal(&mut entry);
126            for v in entry.iter_mut() {
127                *v *= 0.01;
128            }
129            replay_buffer.push(entry);
130        }
131
132        Ok(Self {
133            w1,
134            b1,
135            w2,
136            b2,
137            replay_buffer,
138            config,
139        })
140    }
141
142    /// Compute the logit vector `f(x)` of shape `[n_classes]`.
143    ///
144    /// Performs `x → W1·x + b1 → ReLU → W2·h + b2 → logits`.
145    ///
146    /// # Errors
147    /// [`SslError::DimensionMismatch`] when `x.len() != d_input`.
148    pub fn logits(&self, x: &[f32]) -> SslResult<Vec<f32>> {
149        let d = self.config.d_input;
150        if x.len() != d {
151            return Err(SslError::DimensionMismatch {
152                expected: d,
153                got: x.len(),
154            });
155        }
156        let h = linear_relu(&self.w1, &self.b1, x, d, self.config.n_hidden);
157        Ok(linear(
158            &self.w2,
159            &self.b2,
160            &h,
161            self.config.n_hidden,
162            self.config.n_classes,
163        ))
164    }
165
166    /// Compute the scalar energy `E(x) = -logsumexp(f(x))`.
167    ///
168    /// Uses the numerically stable max-shift trick for `logsumexp`.
169    ///
170    /// # Errors
171    /// Propagates errors from [`Self::logits`].
172    pub fn energy(&self, x: &[f32]) -> SslResult<f32> {
173        let logits = self.logits(x)?;
174        Ok(-logsumexp(&logits))
175    }
176
177    /// Cross-entropy classification loss for a single sample.
178    ///
179    /// `CE(x, y) = -log softmax(f(x))[y] = logsumexp(f(x)) - f(x)[y]`
180    ///
181    /// # Errors
182    /// - Propagates errors from [`Self::logits`].
183    /// - [`SslError::InvalidParameter`] when `y >= n_classes`.
184    pub fn classify_loss(&self, x: &[f32], y: usize) -> SslResult<f32> {
185        if y >= self.config.n_classes {
186            return Err(SslError::InvalidParameter {
187                name: "y".into(),
188                reason: "class index must be < n_classes".into(),
189            });
190        }
191        let logits = self.logits(x)?;
192        let lse = logsumexp(&logits);
193        Ok(lse - logits[y])
194    }
195
196    /// Approximate `∇_x E(x)` via central finite differences.
197    ///
198    /// `grad[i] = (E(x + ε·eᵢ) − E(x − ε·eᵢ)) / (2ε)`
199    ///
200    /// # Arguments
201    /// * `x`   — input `[d_input]`.
202    /// * `eps` — finite-difference step size (e.g. `1e-3`).
203    ///
204    /// # Errors
205    /// Propagates errors from [`Self::energy`].
206    pub fn energy_grad(&self, x: &[f32], eps: f32) -> SslResult<Vec<f32>> {
207        let d = self.config.d_input;
208        if x.len() != d {
209            return Err(SslError::DimensionMismatch {
210                expected: d,
211                got: x.len(),
212            });
213        }
214        let two_eps = 2.0 * eps;
215        let mut grad = vec![0.0_f32; d];
216        let mut x_pos = x.to_vec();
217        let mut x_neg = x.to_vec();
218        for i in 0..d {
219            x_pos[i] = x[i] + eps;
220            x_neg[i] = x[i] - eps;
221            let e_pos = self.energy(&x_pos)?;
222            let e_neg = self.energy(&x_neg)?;
223            grad[i] = (e_pos - e_neg) / two_eps;
224            x_pos[i] = x[i];
225            x_neg[i] = x[i];
226        }
227        Ok(grad)
228    }
229
230    /// Run SGLD for `config.sgld_steps` steps starting from `x_init`.
231    ///
232    /// ```text
233    /// x_{t+1} = x_t − (step_size / 2) · ∇_x E(x_t) + sgld_noise · N(0, 1)
234    /// ```
235    ///
236    /// Finite-difference gradient approximation is used with `eps = 1e-3`.
237    ///
238    /// # Errors
239    /// Propagates errors from [`Self::energy_grad`].
240    pub fn sgld_step(&self, x_init: &[f32], rng: &mut LcgRng) -> SslResult<Vec<f32>> {
241        let d = self.config.d_input;
242        if x_init.len() != d {
243            return Err(SslError::DimensionMismatch {
244                expected: d,
245                got: x_init.len(),
246            });
247        }
248        let half_step = self.config.sgld_step_size * 0.5;
249        let noise_scale = self.config.sgld_noise;
250        let fd_eps = 1e-3_f32;
251
252        let mut x = x_init.to_vec();
253        for _ in 0..self.config.sgld_steps {
254            let grad = self.energy_grad(&x, fd_eps)?;
255            let mut noise = vec![0.0_f32; d];
256            rng.fill_normal(&mut noise);
257            for i in 0..d {
258                x[i] -= half_step * grad[i];
259                x[i] += noise_scale * noise[i];
260            }
261        }
262        Ok(x)
263    }
264
265    /// Compute the contrastive divergence loss `E(x_mcmc) − E(x_data)`.
266    ///
267    /// 1. Randomly selects an entry from the replay buffer as the MCMC seed.
268    /// 2. Runs [`Self::sgld_step`] to advance the chain.
269    /// 3. Replaces the selected buffer entry with the updated MCMC sample.
270    /// 4. Returns `E(x_mcmc_updated) − E(x_data)`.
271    ///
272    /// # Errors
273    /// Propagates errors from [`Self::sgld_step`] and [`Self::energy`].
274    pub fn cd_loss(&mut self, x_data: &[f32], rng: &mut LcgRng) -> SslResult<f32> {
275        let buf_len = self.replay_buffer.len();
276        // Pick a random replay buffer index.
277        let idx = rng.next_usize(buf_len);
278        let x_mcmc_init = self.replay_buffer[idx].clone();
279        // Run SGLD.
280        let x_mcmc = self.sgld_step(&x_mcmc_init, rng)?;
281        // Update the replay buffer slot.
282        self.replay_buffer[idx] = x_mcmc.clone();
283        // Compute energies.
284        let e_mcmc = self.energy(&x_mcmc)?;
285        let e_data = self.energy(x_data)?;
286        Ok(e_mcmc - e_data)
287    }
288
289    /// Return the input dimension.
290    #[inline]
291    #[must_use]
292    pub fn d_input(&self) -> usize {
293        self.config.d_input
294    }
295
296    /// Return the number of output classes.
297    #[inline]
298    #[must_use]
299    pub fn n_classes(&self) -> usize {
300        self.config.n_classes
301    }
302}
303
304// ─── Internal helpers ────────────────────────────────────────────────────────
305
306/// Kaiming (He) normal weight init: `scale = sqrt(2 / fan_in)`.
307fn kaiming_init(out_dim: usize, in_dim: usize, rng: &mut LcgRng) -> Vec<f32> {
308    let scale = (2.0_f32 / in_dim as f32).sqrt();
309    let mut w = vec![0.0_f32; out_dim * in_dim];
310    rng.fill_normal(&mut w);
311    for v in w.iter_mut() {
312        *v *= scale;
313    }
314    w
315}
316
317/// Row-major matrix-vector multiply: `out[i] = b[i] + Σ_j w[i·in_dim + j] * x[j]`.
318fn linear(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
319    let mut out = vec![0.0_f32; out_dim];
320    for i in 0..out_dim {
321        let mut acc = b[i];
322        let row = i * in_dim;
323        for j in 0..in_dim {
324            acc += w[row + j] * x[j];
325        }
326        out[i] = acc;
327    }
328    out
329}
330
331/// `linear` followed by element-wise ReLU.
332fn linear_relu(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
333    let mut out = linear(w, b, x, in_dim, out_dim);
334    for v in out.iter_mut() {
335        *v = v.max(0.0);
336    }
337    out
338}
339
340/// Numerically stable `logsumexp` using the max-shift trick.
341fn logsumexp(v: &[f32]) -> f32 {
342    if v.is_empty() {
343        return 0.0;
344    }
345    let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
346    if max.is_infinite() {
347        return max;
348    }
349    let sum_exp: f32 = v.iter().map(|&x| (x - max).exp()).sum();
350    max + sum_exp.ln()
351}
352
353// ─── Tests ────────────────────────────────────────────────────────────────────
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::handle::LcgRng;
359
360    fn make_jem(seed: u64) -> Jem {
361        let mut rng = LcgRng::new(seed);
362        Jem::new(JemConfig::default(), &mut rng).expect("value should be present")
363    }
364
365    fn random_vec(n: usize, seed: u64) -> Vec<f32> {
366        let mut rng = LcgRng::new(seed);
367        let mut v = vec![0.0_f32; n];
368        rng.fill_normal(&mut v);
369        v
370    }
371
372    #[test]
373    fn logits_shape() {
374        let j = make_jem(1);
375        let x = random_vec(4, 2);
376        let logits = j.logits(&x).expect("logits should succeed");
377        assert_eq!(
378            logits.len(),
379            j.n_classes(),
380            "logits len must equal n_classes"
381        );
382    }
383
384    #[test]
385    fn energy_finite() {
386        let j = make_jem(3);
387        let x = random_vec(4, 4);
388        let e = j.energy(&x).expect("energy should succeed");
389        assert!(e.is_finite(), "energy must be finite, got {e}");
390    }
391
392    #[test]
393    fn classify_loss_finite() {
394        let j = make_jem(5);
395        let x = random_vec(4, 6);
396        let ce = j
397            .classify_loss(&x, 0)
398            .expect("classify_loss should succeed");
399        assert!(ce.is_finite(), "classify_loss must be finite, got {ce}");
400    }
401
402    #[test]
403    fn classify_loss_nonneg() {
404        let j = make_jem(7);
405        let x = random_vec(4, 8);
406        let ce = j
407            .classify_loss(&x, 1)
408            .expect("classify_loss should succeed");
409        assert!(ce >= 0.0, "cross-entropy must be >= 0, got {ce}");
410    }
411
412    #[test]
413    fn cd_loss_finite() {
414        let mut rng = LcgRng::new(9);
415        let mut j = Jem::new(JemConfig::default(), &mut rng).expect("value should be present");
416        let x = random_vec(4, 10);
417        let cd = j.cd_loss(&x, &mut rng).expect("cd_loss should succeed");
418        assert!(cd.is_finite(), "cd_loss must be finite, got {cd}");
419    }
420
421    #[test]
422    fn sgld_moves_from_init() {
423        let mut rng = LcgRng::new(11);
424        let j = Jem::new(JemConfig::default(), &mut rng).expect("value should be present");
425        let x_init = random_vec(4, 12);
426        let x_out = j
427            .sgld_step(&x_init, &mut rng)
428            .expect("sgld_step should succeed");
429        let diff: f32 = x_init
430            .iter()
431            .zip(x_out.iter())
432            .map(|(a, b)| (a - b).abs())
433            .sum();
434        assert!(diff > 1e-8, "SGLD must move from init, diff={diff}");
435    }
436
437    #[test]
438    fn energy_grad_finite() {
439        let j = make_jem(13);
440        let x = random_vec(4, 14);
441        let g = j.energy_grad(&x, 1e-3).expect("energy_grad should succeed");
442        assert_eq!(g.len(), 4, "gradient must have len == d_input");
443        assert!(
444            g.iter().all(|v| v.is_finite()),
445            "gradient must be all-finite"
446        );
447    }
448
449    #[test]
450    fn d_input_0_error() {
451        let mut rng = LcgRng::new(15);
452        let result = Jem::new(
453            JemConfig {
454                d_input: 0,
455                ..JemConfig::default()
456            },
457            &mut rng,
458        );
459        assert!(result.is_err(), "d_input=0 must return Err");
460    }
461
462    #[test]
463    fn n_classes_0_error() {
464        let mut rng = LcgRng::new(16);
465        let result = Jem::new(
466            JemConfig {
467                n_classes: 0,
468                ..JemConfig::default()
469            },
470            &mut rng,
471        );
472        assert!(result.is_err(), "n_classes=0 must return Err");
473    }
474
475    #[test]
476    fn n_hidden_0_error() {
477        let mut rng = LcgRng::new(17);
478        let result = Jem::new(
479            JemConfig {
480                n_hidden: 0,
481                ..JemConfig::default()
482            },
483            &mut rng,
484        );
485        assert!(result.is_err(), "n_hidden=0 must return Err");
486    }
487
488    #[test]
489    fn classify_loss_invalid_class_error() {
490        let j = make_jem(18);
491        let x = random_vec(4, 19);
492        // n_classes = 2, so y=2 is out of range
493        let r = j.classify_loss(&x, 2);
494        assert!(r.is_err(), "y >= n_classes must return Err");
495    }
496
497    #[test]
498    fn d_input_n_classes_accessors() {
499        let j = make_jem(20);
500        assert_eq!(j.d_input(), 4);
501        assert_eq!(j.n_classes(), 2);
502    }
503}