pub struct SimSiam { /* private fields */ }Expand description
Struct-based SimSiam model that owns its projector and predictor weights.
All weight matrices use Kaiming (He) initialisation with scale = sqrt(2 / fan_in).
Implementations§
Source§impl SimSiam
impl SimSiam
Sourcepub fn new(config: SimSiamConfig, rng: &mut LcgRng) -> SslResult<Self>
pub fn new(config: SimSiamConfig, rng: &mut LcgRng) -> SslResult<Self>
Create a new SimSiam model with Kaiming-initialised weights.
§Errors
SslError::InvalidParameter when any dimension in config is zero.
Sourcepub fn project(&self, z: &[f32]) -> SslResult<Vec<f32>>
pub fn project(&self, z: &[f32]) -> SslResult<Vec<f32>>
Project a single encoder output vector.
Computes z = L2_norm(proj_w2 · ReLU(proj_w1 · x + proj_b1) + proj_b2).
§Arguments
z— encoder output[d_encoder].
§Errors
SslError::DimensionMismatch when z.len() != d_encoder.
Sourcepub fn predict(&self, p: &[f32]) -> SslResult<Vec<f32>>
pub fn predict(&self, p: &[f32]) -> SslResult<Vec<f32>>
Apply the predictor to a projected representation.
Computes p = L2_norm(pred_w2 · ReLU(pred_w1 · proj + pred_b1) + pred_b2).
§Arguments
p— projected representation[d_out].
§Errors
SslError::DimensionMismatch when p.len() != d_out.
Sourcepub fn loss(&self, z1: &[f32], z2: &[f32]) -> SslResult<f32>
pub fn loss(&self, z1: &[f32], z2: &[f32]) -> SslResult<f32>
Compute the symmetric SimSiam loss for two encoder outputs.
Implements L = (D(p1, sg(z2_p)) + D(p2, sg(z1_p))) / 2 where
D(a, b) = -(a · b) for unit-norm vectors and sg denotes stop-gradient
(a no-op in this pure-Rust implementation since there is no autograd engine).
§Arguments
z1— encoder output from view 1[d_encoder].z2— encoder output from view 2[d_encoder].
§Errors
Propagates dimension mismatch errors from Self::project and Self::predict.
Sourcepub fn d_out(&self) -> usize
pub fn d_out(&self) -> usize
Return the output dimension of the projector (= predictor I/O dim).
Sourcepub fn set_identity_predictor(&mut self) -> SslResult<()>
pub fn set_identity_predictor(&mut self) -> SslResult<()>
Overwrite the predictor with an exact direction-preserving identity map.
The predictor MLP L2(W2 · ReLU(W1·p + b1) + b2) is normally a learned,
randomly-initialised non-linear transform, so for a random predictor the
SimSiam loss of two identical views is some arbitrary value in [-1, 1].
SimSiam’s negative-cosine loss only attains its minimum of -1 for
identical views when the predictor preserves the projection’s direction.
This installs such an identity predictor. The ReLU non-linearity is bridged
with the standard positive/negative split: the hidden layer computes
[ReLU(p), ReLU(-p)] and the output layer reconstructs p = p⁺ - p⁻,
reproducing the input exactly. The trailing L2-norm then leaves an
already unit-norm projection unchanged, so predict(project(z)) == project(z)
and loss(z, z) == -1.
This requires the predictor hidden dimension to be exactly twice the output dimension so the two halves can hold the positive and negative parts.
§Errors
SslError::InvalidParameter when d_predictor != 2 * d_out.