1use crate::error::{SslError, SslResult};
23use crate::handle::LcgRng;
24
25#[derive(Debug, Clone)]
29pub struct SimSiamConfig {
30 pub d_encoder: usize,
32 pub d_projector: usize,
34 pub d_predictor: usize,
36 pub d_out: usize,
38}
39
40impl Default for SimSiamConfig {
41 fn default() -> Self {
42 Self {
43 d_encoder: 64,
44 d_projector: 128,
45 d_predictor: 64,
46 d_out: 32,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
57pub struct SimSiam {
58 proj_w1: Vec<f32>,
60 proj_b1: Vec<f32>,
62 proj_w2: Vec<f32>,
64 proj_b2: Vec<f32>,
66 pred_w1: Vec<f32>,
68 pred_b1: Vec<f32>,
70 pred_w2: Vec<f32>,
72 pred_b2: Vec<f32>,
74 config: SimSiamConfig,
76}
77
78impl SimSiam {
79 pub fn new(config: SimSiamConfig, rng: &mut LcgRng) -> SslResult<Self> {
84 if config.d_encoder == 0 {
85 return Err(SslError::InvalidParameter {
86 name: "d_encoder".into(),
87 reason: "must be > 0".into(),
88 });
89 }
90 if config.d_projector == 0 {
91 return Err(SslError::InvalidParameter {
92 name: "d_projector".into(),
93 reason: "must be > 0".into(),
94 });
95 }
96 if config.d_predictor == 0 {
97 return Err(SslError::InvalidParameter {
98 name: "d_predictor".into(),
99 reason: "must be > 0".into(),
100 });
101 }
102 if config.d_out == 0 {
103 return Err(SslError::InvalidParameter {
104 name: "d_out".into(),
105 reason: "must be > 0".into(),
106 });
107 }
108
109 let proj_w1 = kaiming_init(config.d_projector, config.d_encoder, rng);
110 let proj_b1 = vec![0.0_f32; config.d_projector];
111 let proj_w2 = kaiming_init(config.d_out, config.d_projector, rng);
112 let proj_b2 = vec![0.0_f32; config.d_out];
113
114 let pred_w1 = kaiming_init(config.d_predictor, config.d_out, rng);
115 let pred_b1 = vec![0.0_f32; config.d_predictor];
116 let pred_w2 = kaiming_init(config.d_out, config.d_predictor, rng);
117 let pred_b2 = vec![0.0_f32; config.d_out];
118
119 Ok(Self {
120 proj_w1,
121 proj_b1,
122 proj_w2,
123 proj_b2,
124 pred_w1,
125 pred_b1,
126 pred_w2,
127 pred_b2,
128 config,
129 })
130 }
131
132 pub fn project(&self, z: &[f32]) -> SslResult<Vec<f32>> {
142 let d = self.config.d_encoder;
143 if z.len() != d {
144 return Err(SslError::DimensionMismatch {
145 expected: d,
146 got: z.len(),
147 });
148 }
149 let hidden = linear_relu(&self.proj_w1, &self.proj_b1, z, d, self.config.d_projector);
150 let out = linear(
151 &self.proj_w2,
152 &self.proj_b2,
153 &hidden,
154 self.config.d_projector,
155 self.config.d_out,
156 );
157 Ok(l2_normalize(out))
158 }
159
160 pub fn predict(&self, p: &[f32]) -> SslResult<Vec<f32>> {
170 let d = self.config.d_out;
171 if p.len() != d {
172 return Err(SslError::DimensionMismatch {
173 expected: d,
174 got: p.len(),
175 });
176 }
177 let hidden = linear_relu(&self.pred_w1, &self.pred_b1, p, d, self.config.d_predictor);
178 let out = linear(
179 &self.pred_w2,
180 &self.pred_b2,
181 &hidden,
182 self.config.d_predictor,
183 self.config.d_out,
184 );
185 Ok(l2_normalize(out))
186 }
187
188 pub fn loss(&self, z1: &[f32], z2: &[f32]) -> SslResult<f32> {
201 let z1_proj = self.project(z1)?;
202 let z2_proj = self.project(z2)?;
203 let p1 = self.predict(&z1_proj)?;
204 let p2 = self.predict(&z2_proj)?;
205
206 let d1 = neg_dot(&p1, &z2_proj);
208 let d2 = neg_dot(&p2, &z1_proj);
209 Ok((d1 + d2) * 0.5)
210 }
211
212 #[inline]
214 #[must_use]
215 pub fn d_out(&self) -> usize {
216 self.config.d_out
217 }
218
219 pub fn set_identity_predictor(&mut self) -> SslResult<()> {
240 let d_out = self.config.d_out;
241 let d_pred = self.config.d_predictor;
242 if d_pred != 2 * d_out {
243 return Err(SslError::InvalidParameter {
244 name: "d_predictor".into(),
245 reason: "identity predictor requires d_predictor == 2 * d_out".into(),
246 });
247 }
248
249 let mut pred_w1 = vec![0.0_f32; d_pred * d_out];
251 for i in 0..d_out {
252 pred_w1[i * d_out + i] = 1.0;
253 pred_w1[(d_out + i) * d_out + i] = -1.0;
254 }
255 let mut pred_w2 = vec![0.0_f32; d_out * d_pred];
257 for i in 0..d_out {
258 pred_w2[i * d_pred + i] = 1.0;
259 pred_w2[i * d_pred + (d_out + i)] = -1.0;
260 }
261
262 self.pred_w1 = pred_w1;
263 self.pred_b1 = vec![0.0_f32; d_pred];
264 self.pred_w2 = pred_w2;
265 self.pred_b2 = vec![0.0_f32; d_out];
266 Ok(())
267 }
268}
269
270fn kaiming_init(out_dim: usize, in_dim: usize, rng: &mut LcgRng) -> Vec<f32> {
276 let scale = (2.0_f32 / in_dim as f32).sqrt();
277 let mut w = vec![0.0_f32; out_dim * in_dim];
278 rng.fill_normal(&mut w);
279 for v in w.iter_mut() {
280 *v *= scale;
281 }
282 w
283}
284
285fn linear(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
289 let mut out = vec![0.0_f32; out_dim];
290 for i in 0..out_dim {
291 let mut acc = b[i];
292 let row_start = i * in_dim;
293 for j in 0..in_dim {
294 acc += w[row_start + j] * x[j];
295 }
296 out[i] = acc;
297 }
298 out
299}
300
301fn linear_relu(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
303 let mut out = linear(w, b, x, in_dim, out_dim);
304 for v in out.iter_mut() {
305 *v = v.max(0.0);
306 }
307 out
308}
309
310fn l2_normalize(mut v: Vec<f32>) -> Vec<f32> {
314 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt().max(1e-12);
315 for x in v.iter_mut() {
316 *x /= norm;
317 }
318 v
319}
320
321fn neg_dot(a: &[f32], b: &[f32]) -> f32 {
323 -a.iter()
324 .zip(b.iter())
325 .map(|(&ai, &bi)| ai * bi)
326 .sum::<f32>()
327}
328
329#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::handle::LcgRng;
335
336 fn make_simsiam(seed: u64) -> SimSiam {
337 let mut rng = LcgRng::new(seed);
338 SimSiam::new(
339 SimSiamConfig {
340 d_encoder: 16,
341 d_projector: 32,
342 d_predictor: 16,
343 d_out: 8,
344 },
345 &mut rng,
346 )
347 .expect("value should be present")
348 }
349
350 fn random_vec(n: usize, seed: u64) -> Vec<f32> {
351 let mut rng = LcgRng::new(seed);
352 let mut v = vec![0.0_f32; n];
353 rng.fill_normal(&mut v);
354 v
355 }
356
357 #[test]
358 fn project_shape() {
359 let ss = make_simsiam(1);
360 let z = random_vec(16, 2);
361 let out = ss.project(&z).expect("project should succeed");
362 assert_eq!(out.len(), 8, "project output must have len == d_out");
363 }
364
365 #[test]
366 fn predict_shape() {
367 let ss = make_simsiam(3);
368 let p = random_vec(8, 4);
369 let out = ss.predict(&p).expect("predict should succeed");
370 assert_eq!(out.len(), 8, "predict output must have len == d_out");
371 }
372
373 #[test]
374 fn loss_finite() {
375 let ss = make_simsiam(5);
376 let z1 = random_vec(16, 6);
377 let z2 = random_vec(16, 7);
378 let l = ss.loss(&z1, &z2).expect("loss should succeed");
379 assert!(l.is_finite(), "loss must be finite, got {l}");
380 }
381
382 #[test]
383 fn loss_in_range() {
384 let ss = make_simsiam(8);
385 let z1 = random_vec(16, 9);
386 let z2 = random_vec(16, 10);
387 let l = ss.loss(&z1, &z2).expect("loss should succeed");
388 assert!(
389 (-1.0 - 1e-5..=1.0 + 1e-5).contains(&l),
390 "loss={l} must be in [-1, 1]"
391 );
392 }
393
394 #[test]
395 fn loss_symmetric() {
396 let ss = make_simsiam(11);
397 let z1 = random_vec(16, 12);
398 let z2 = random_vec(16, 13);
399 let l12 = ss.loss(&z1, &z2).expect("loss should succeed");
400 let l21 = ss.loss(&z2, &z1).expect("loss should succeed");
401 assert!(
402 (l12 - l21).abs() < 1e-5,
403 "loss(z1,z2)={l12} != loss(z2,z1)={l21}"
404 );
405 }
406
407 #[test]
408 fn different_views_different_projections() {
409 let ss = make_simsiam(14);
410 let z1 = random_vec(16, 15);
411 let z2 = random_vec(16, 16);
412 let p1 = ss.project(&z1).expect("project should succeed");
413 let p2 = ss.project(&z2).expect("project should succeed");
414 let diff: f32 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).abs()).sum();
415 assert!(
416 diff > 1e-6,
417 "projections of different inputs must differ, diff={diff}"
418 );
419 }
420
421 #[test]
422 fn identical_views_low_loss() {
423 let mut ss = make_simsiam(17);
431 ss.set_identity_predictor()
432 .expect("config has d_predictor == 2 * d_out");
433 let z = random_vec(16, 18);
434 let l = ss.loss(&z, &z).expect("loss should succeed");
435 assert!(
436 (l - (-1.0)).abs() < 1e-5,
437 "with a direction-preserving predictor, loss for identical views must be -1, got {l}"
438 );
439 }
440
441 #[test]
442 fn identity_predictor_is_direction_preserving() {
443 let mut ss = make_simsiam(27);
446 ss.set_identity_predictor()
447 .expect("config has d_predictor == 2 * d_out");
448 for seed in 0..6_u64 {
449 let z = random_vec(16, seed + 200);
450 let zp = ss.project(&z).expect("project should succeed");
451 let p = ss.predict(&zp).expect("predict should succeed");
452 let max_diff = zp
453 .iter()
454 .zip(p.iter())
455 .map(|(a, b)| (a - b).abs())
456 .fold(0.0_f32, f32::max);
457 assert!(
458 max_diff < 1e-5,
459 "identity predictor must reproduce input, max|p-zp|={max_diff} (seed={seed})"
460 );
461 }
462 }
463
464 #[test]
465 fn set_identity_predictor_requires_double_hidden() {
466 let mut rng = LcgRng::new(28);
468 let mut ss = SimSiam::new(
469 SimSiamConfig {
470 d_encoder: 16,
471 d_projector: 32,
472 d_predictor: 8,
473 d_out: 8,
474 },
475 &mut rng,
476 )
477 .expect("value should be present");
478 assert!(
479 ss.set_identity_predictor().is_err(),
480 "identity predictor with d_predictor != 2*d_out must return Err"
481 );
482 }
483
484 #[test]
485 fn d_out_0_error() {
486 let mut rng = LcgRng::new(19);
487 let result = SimSiam::new(
488 SimSiamConfig {
489 d_encoder: 8,
490 d_projector: 16,
491 d_predictor: 8,
492 d_out: 0,
493 },
494 &mut rng,
495 );
496 assert!(result.is_err(), "d_out=0 must return Err");
497 }
498
499 #[test]
500 fn project_output_normalized() {
501 let ss = make_simsiam(20);
502 let z = random_vec(16, 21);
503 let out = ss.project(&z).expect("project should succeed");
504 let norm: f32 = out.iter().map(|&x| x * x).sum::<f32>().sqrt();
505 assert!(
506 (norm - 1.0).abs() < 1e-5,
507 "project output must be unit-norm, norm={norm}"
508 );
509 }
510
511 #[test]
512 fn loss_stop_grad_invariant() {
513 let ss = make_simsiam(22);
516 for seed in 0..8_u64 {
517 let z1 = random_vec(16, seed * 2 + 100);
518 let z2 = random_vec(16, seed * 2 + 101);
519 let l = ss.loss(&z1, &z2).expect("loss should succeed");
520 assert!(
521 l.is_finite(),
522 "loss must be finite for seed={seed}, got {l}"
523 );
524 }
525 }
526
527 #[test]
528 fn d_encoder_0_error() {
529 let mut rng = LcgRng::new(23);
530 assert!(
531 SimSiam::new(
532 SimSiamConfig {
533 d_encoder: 0,
534 d_projector: 16,
535 d_predictor: 8,
536 d_out: 8
537 },
538 &mut rng
539 )
540 .is_err()
541 );
542 }
543
544 #[test]
545 fn predict_output_normalized() {
546 let ss = make_simsiam(24);
547 let p = random_vec(8, 25);
548 let out = ss.predict(&p).expect("predict should succeed");
549 let norm: f32 = out.iter().map(|&x| x * x).sum::<f32>().sqrt();
550 assert!(
551 (norm - 1.0).abs() < 1e-5,
552 "predict output must be unit-norm, norm={norm}"
553 );
554 }
555
556 #[test]
557 fn d_out_accessor() {
558 let ss = make_simsiam(26);
559 assert_eq!(ss.d_out(), 8);
560 }
561}