1use crate::error::{SslError, SslResult};
30use crate::handle::LcgRng;
31
32#[derive(Debug, Clone)]
36pub struct JemConfig {
37 pub d_input: usize,
39 pub n_classes: usize,
41 pub n_hidden: usize,
43 pub sgld_steps: usize,
45 pub sgld_step_size: f32,
47 pub sgld_noise: f32,
49 pub buffer_size: usize,
51}
52
53impl Default for JemConfig {
54 fn default() -> Self {
55 Self {
56 d_input: 4,
57 n_classes: 2,
58 n_hidden: 16,
59 sgld_steps: 20,
60 sgld_step_size: 0.01,
61 sgld_noise: 0.005,
62 buffer_size: 64,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
74pub struct Jem {
75 w1: Vec<f32>,
77 b1: Vec<f32>,
79 w2: Vec<f32>,
81 b2: Vec<f32>,
83 replay_buffer: Vec<Vec<f32>>,
85 config: JemConfig,
87}
88
89impl Jem {
90 pub fn new(config: JemConfig, rng: &mut LcgRng) -> SslResult<Self> {
96 if config.d_input == 0 {
97 return Err(SslError::InvalidParameter {
98 name: "d_input".into(),
99 reason: "must be > 0".into(),
100 });
101 }
102 if config.n_classes == 0 {
103 return Err(SslError::InvalidParameter {
104 name: "n_classes".into(),
105 reason: "must be > 0".into(),
106 });
107 }
108 if config.n_hidden == 0 {
109 return Err(SslError::InvalidParameter {
110 name: "n_hidden".into(),
111 reason: "must be > 0".into(),
112 });
113 }
114
115 let w1 = kaiming_init(config.n_hidden, config.d_input, rng);
116 let b1 = vec![0.0_f32; config.n_hidden];
117 let w2 = kaiming_init(config.n_classes, config.n_hidden, rng);
118 let b2 = vec![0.0_f32; config.n_classes];
119
120 let buf_size = config.buffer_size.max(1);
122 let mut replay_buffer = Vec::with_capacity(buf_size);
123 for _ in 0..buf_size {
124 let mut entry = vec![0.0_f32; config.d_input];
125 rng.fill_normal(&mut entry);
126 for v in entry.iter_mut() {
127 *v *= 0.01;
128 }
129 replay_buffer.push(entry);
130 }
131
132 Ok(Self {
133 w1,
134 b1,
135 w2,
136 b2,
137 replay_buffer,
138 config,
139 })
140 }
141
142 pub fn logits(&self, x: &[f32]) -> SslResult<Vec<f32>> {
149 let d = self.config.d_input;
150 if x.len() != d {
151 return Err(SslError::DimensionMismatch {
152 expected: d,
153 got: x.len(),
154 });
155 }
156 let h = linear_relu(&self.w1, &self.b1, x, d, self.config.n_hidden);
157 Ok(linear(
158 &self.w2,
159 &self.b2,
160 &h,
161 self.config.n_hidden,
162 self.config.n_classes,
163 ))
164 }
165
166 pub fn energy(&self, x: &[f32]) -> SslResult<f32> {
173 let logits = self.logits(x)?;
174 Ok(-logsumexp(&logits))
175 }
176
177 pub fn classify_loss(&self, x: &[f32], y: usize) -> SslResult<f32> {
185 if y >= self.config.n_classes {
186 return Err(SslError::InvalidParameter {
187 name: "y".into(),
188 reason: "class index must be < n_classes".into(),
189 });
190 }
191 let logits = self.logits(x)?;
192 let lse = logsumexp(&logits);
193 Ok(lse - logits[y])
194 }
195
196 pub fn energy_grad(&self, x: &[f32], eps: f32) -> SslResult<Vec<f32>> {
207 let d = self.config.d_input;
208 if x.len() != d {
209 return Err(SslError::DimensionMismatch {
210 expected: d,
211 got: x.len(),
212 });
213 }
214 let two_eps = 2.0 * eps;
215 let mut grad = vec![0.0_f32; d];
216 let mut x_pos = x.to_vec();
217 let mut x_neg = x.to_vec();
218 for i in 0..d {
219 x_pos[i] = x[i] + eps;
220 x_neg[i] = x[i] - eps;
221 let e_pos = self.energy(&x_pos)?;
222 let e_neg = self.energy(&x_neg)?;
223 grad[i] = (e_pos - e_neg) / two_eps;
224 x_pos[i] = x[i];
225 x_neg[i] = x[i];
226 }
227 Ok(grad)
228 }
229
230 pub fn sgld_step(&self, x_init: &[f32], rng: &mut LcgRng) -> SslResult<Vec<f32>> {
241 let d = self.config.d_input;
242 if x_init.len() != d {
243 return Err(SslError::DimensionMismatch {
244 expected: d,
245 got: x_init.len(),
246 });
247 }
248 let half_step = self.config.sgld_step_size * 0.5;
249 let noise_scale = self.config.sgld_noise;
250 let fd_eps = 1e-3_f32;
251
252 let mut x = x_init.to_vec();
253 for _ in 0..self.config.sgld_steps {
254 let grad = self.energy_grad(&x, fd_eps)?;
255 let mut noise = vec![0.0_f32; d];
256 rng.fill_normal(&mut noise);
257 for i in 0..d {
258 x[i] -= half_step * grad[i];
259 x[i] += noise_scale * noise[i];
260 }
261 }
262 Ok(x)
263 }
264
265 pub fn cd_loss(&mut self, x_data: &[f32], rng: &mut LcgRng) -> SslResult<f32> {
275 let buf_len = self.replay_buffer.len();
276 let idx = rng.next_usize(buf_len);
278 let x_mcmc_init = self.replay_buffer[idx].clone();
279 let x_mcmc = self.sgld_step(&x_mcmc_init, rng)?;
281 self.replay_buffer[idx] = x_mcmc.clone();
283 let e_mcmc = self.energy(&x_mcmc)?;
285 let e_data = self.energy(x_data)?;
286 Ok(e_mcmc - e_data)
287 }
288
289 #[inline]
291 #[must_use]
292 pub fn d_input(&self) -> usize {
293 self.config.d_input
294 }
295
296 #[inline]
298 #[must_use]
299 pub fn n_classes(&self) -> usize {
300 self.config.n_classes
301 }
302}
303
304fn kaiming_init(out_dim: usize, in_dim: usize, rng: &mut LcgRng) -> Vec<f32> {
308 let scale = (2.0_f32 / in_dim as f32).sqrt();
309 let mut w = vec![0.0_f32; out_dim * in_dim];
310 rng.fill_normal(&mut w);
311 for v in w.iter_mut() {
312 *v *= scale;
313 }
314 w
315}
316
317fn linear(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
319 let mut out = vec![0.0_f32; out_dim];
320 for i in 0..out_dim {
321 let mut acc = b[i];
322 let row = i * in_dim;
323 for j in 0..in_dim {
324 acc += w[row + j] * x[j];
325 }
326 out[i] = acc;
327 }
328 out
329}
330
331fn linear_relu(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
333 let mut out = linear(w, b, x, in_dim, out_dim);
334 for v in out.iter_mut() {
335 *v = v.max(0.0);
336 }
337 out
338}
339
340fn logsumexp(v: &[f32]) -> f32 {
342 if v.is_empty() {
343 return 0.0;
344 }
345 let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
346 if max.is_infinite() {
347 return max;
348 }
349 let sum_exp: f32 = v.iter().map(|&x| (x - max).exp()).sum();
350 max + sum_exp.ln()
351}
352
353#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::handle::LcgRng;
359
360 fn make_jem(seed: u64) -> Jem {
361 let mut rng = LcgRng::new(seed);
362 Jem::new(JemConfig::default(), &mut rng).expect("value should be present")
363 }
364
365 fn random_vec(n: usize, seed: u64) -> Vec<f32> {
366 let mut rng = LcgRng::new(seed);
367 let mut v = vec![0.0_f32; n];
368 rng.fill_normal(&mut v);
369 v
370 }
371
372 #[test]
373 fn logits_shape() {
374 let j = make_jem(1);
375 let x = random_vec(4, 2);
376 let logits = j.logits(&x).expect("logits should succeed");
377 assert_eq!(
378 logits.len(),
379 j.n_classes(),
380 "logits len must equal n_classes"
381 );
382 }
383
384 #[test]
385 fn energy_finite() {
386 let j = make_jem(3);
387 let x = random_vec(4, 4);
388 let e = j.energy(&x).expect("energy should succeed");
389 assert!(e.is_finite(), "energy must be finite, got {e}");
390 }
391
392 #[test]
393 fn classify_loss_finite() {
394 let j = make_jem(5);
395 let x = random_vec(4, 6);
396 let ce = j
397 .classify_loss(&x, 0)
398 .expect("classify_loss should succeed");
399 assert!(ce.is_finite(), "classify_loss must be finite, got {ce}");
400 }
401
402 #[test]
403 fn classify_loss_nonneg() {
404 let j = make_jem(7);
405 let x = random_vec(4, 8);
406 let ce = j
407 .classify_loss(&x, 1)
408 .expect("classify_loss should succeed");
409 assert!(ce >= 0.0, "cross-entropy must be >= 0, got {ce}");
410 }
411
412 #[test]
413 fn cd_loss_finite() {
414 let mut rng = LcgRng::new(9);
415 let mut j = Jem::new(JemConfig::default(), &mut rng).expect("value should be present");
416 let x = random_vec(4, 10);
417 let cd = j.cd_loss(&x, &mut rng).expect("cd_loss should succeed");
418 assert!(cd.is_finite(), "cd_loss must be finite, got {cd}");
419 }
420
421 #[test]
422 fn sgld_moves_from_init() {
423 let mut rng = LcgRng::new(11);
424 let j = Jem::new(JemConfig::default(), &mut rng).expect("value should be present");
425 let x_init = random_vec(4, 12);
426 let x_out = j
427 .sgld_step(&x_init, &mut rng)
428 .expect("sgld_step should succeed");
429 let diff: f32 = x_init
430 .iter()
431 .zip(x_out.iter())
432 .map(|(a, b)| (a - b).abs())
433 .sum();
434 assert!(diff > 1e-8, "SGLD must move from init, diff={diff}");
435 }
436
437 #[test]
438 fn energy_grad_finite() {
439 let j = make_jem(13);
440 let x = random_vec(4, 14);
441 let g = j.energy_grad(&x, 1e-3).expect("energy_grad should succeed");
442 assert_eq!(g.len(), 4, "gradient must have len == d_input");
443 assert!(
444 g.iter().all(|v| v.is_finite()),
445 "gradient must be all-finite"
446 );
447 }
448
449 #[test]
450 fn d_input_0_error() {
451 let mut rng = LcgRng::new(15);
452 let result = Jem::new(
453 JemConfig {
454 d_input: 0,
455 ..JemConfig::default()
456 },
457 &mut rng,
458 );
459 assert!(result.is_err(), "d_input=0 must return Err");
460 }
461
462 #[test]
463 fn n_classes_0_error() {
464 let mut rng = LcgRng::new(16);
465 let result = Jem::new(
466 JemConfig {
467 n_classes: 0,
468 ..JemConfig::default()
469 },
470 &mut rng,
471 );
472 assert!(result.is_err(), "n_classes=0 must return Err");
473 }
474
475 #[test]
476 fn n_hidden_0_error() {
477 let mut rng = LcgRng::new(17);
478 let result = Jem::new(
479 JemConfig {
480 n_hidden: 0,
481 ..JemConfig::default()
482 },
483 &mut rng,
484 );
485 assert!(result.is_err(), "n_hidden=0 must return Err");
486 }
487
488 #[test]
489 fn classify_loss_invalid_class_error() {
490 let j = make_jem(18);
491 let x = random_vec(4, 19);
492 let r = j.classify_loss(&x, 2);
494 assert!(r.is_err(), "y >= n_classes must return Err");
495 }
496
497 #[test]
498 fn d_input_n_classes_accessors() {
499 let j = make_jem(20);
500 assert_eq!(j.d_input(), 4);
501 assert_eq!(j.n_classes(), 2);
502 }
503}