use axonml_autograd::Variable;
use axonml_nn::{AdaptiveAvgPool2d, BatchNorm2d, Conv2d, GRUCell, Linear, Module, Parameter};
use axonml_tensor::Tensor;
use super::LivenessResult;
struct BlazeBlock {
dw_conv: Conv2d,
dw_bn: BatchNorm2d,
pw_conv: Conv2d,
pw_bn: BatchNorm2d,
project: Option<(Conv2d, BatchNorm2d)>,
}
impl BlazeBlock {
fn new(in_ch: usize, out_ch: usize, stride: usize) -> Self {
let dw_conv =
Conv2d::with_groups(in_ch, in_ch, (3, 3), (stride, stride), (1, 1), true, in_ch);
let dw_bn = BatchNorm2d::new(in_ch);
let pw_conv = Conv2d::with_options(in_ch, out_ch, (1, 1), (1, 1), (0, 0), true);
let pw_bn = BatchNorm2d::new(out_ch);
let project = if in_ch != out_ch || stride != 1 {
Some((
Conv2d::with_options(in_ch, out_ch, (1, 1), (stride, stride), (0, 0), true),
BatchNorm2d::new(out_ch),
))
} else {
None
};
Self {
dw_conv,
dw_bn,
pw_conv,
pw_bn,
project,
}
}
fn forward(&self, x: &Variable) -> Variable {
let identity = if let Some((ref proj_conv, ref proj_bn)) = self.project {
proj_bn.forward(&proj_conv.forward(x))
} else {
x.clone()
};
let out = self.dw_bn.forward(&self.dw_conv.forward(x)).relu();
let out = self.pw_bn.forward(&self.pw_conv.forward(&out));
out.add_var(&identity).relu()
}
fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.dw_conv.parameters());
p.extend(self.dw_bn.parameters());
p.extend(self.pw_conv.parameters());
p.extend(self.pw_bn.parameters());
if let Some((ref c, ref b)) = self.project {
p.extend(c.parameters());
p.extend(b.parameters());
}
p
}
}
pub struct MnemosyneIdentity {
stem_conv: Conv2d,
stem_bn: BatchNorm2d,
block1: BlazeBlock,
block2: BlazeBlock,
block3: BlazeBlock,
pool: AdaptiveAvgPool2d,
face_proj: Linear,
quality_gate: Linear,
gru: GRUCell,
convergence_head: Linear,
hidden_dim: usize,
encoding_dim: usize,
}
impl Default for MnemosyneIdentity {
fn default() -> Self {
Self::new()
}
}
impl MnemosyneIdentity {
pub fn new() -> Self {
Self::with_dims(96, 64)
}
pub fn with_dims(encoding_dim: usize, hidden_dim: usize) -> Self {
let stem_conv = Conv2d::with_options(3, 16, (3, 3), (2, 2), (1, 1), true);
let stem_bn = BatchNorm2d::new(16);
let block1 = BlazeBlock::new(16, 24, 2);
let block2 = BlazeBlock::new(24, 32, 2);
let block3 = BlazeBlock::new(32, 48, 2);
let pool = AdaptiveAvgPool2d::new((1, 1));
let face_proj = Linear::new(48, encoding_dim);
let quality_gate = Linear::new(encoding_dim, 1);
let gru = GRUCell::new(encoding_dim, hidden_dim);
let convergence_head = Linear::new(hidden_dim, 2);
Self {
stem_conv,
stem_bn,
block1,
block2,
block3,
pool,
face_proj,
quality_gate,
gru,
convergence_head,
hidden_dim,
encoding_dim,
}
}
pub fn encode_face(&self, face: &Variable) -> Variable {
let x = self.stem_bn.forward(&self.stem_conv.forward(face)).relu();
let x = self.block1.forward(&x);
let x = self.block2.forward(&x);
let x = self.block3.forward(&x);
let x = self.pool.forward(&x);
let shape = x.shape();
let batch = shape[0];
let channels = shape[1];
let flat = x.reshape(&[batch, channels]);
self.face_proj.forward(&flat).relu()
}
pub fn compute_quality(&self, encoding: &Variable) -> Variable {
self.quality_gate.forward(encoding).sigmoid()
}
pub fn crystallize_step(
&self,
face: &Variable,
hidden: Option<&Variable>,
) -> (Variable, Variable, Variable, Variable) {
let encoding = self.encode_face(face);
let quality = self.compute_quality(&encoding);
let batch = encoding.shape()[0];
let quality_expanded = quality.expand(&[batch, self.encoding_dim]);
let gated_input = encoding.mul_var(&quality_expanded);
let h = match hidden {
Some(h) => h.clone(),
None => Variable::new(Tensor::zeros(&[batch, self.hidden_dim]), false),
};
let new_hidden = self.gru.forward_step(&gated_input, &h);
let conv_out = self.convergence_head.forward(&new_hidden);
let velocity = conv_out.narrow(1, 0, 1).sigmoid();
let log_variance = conv_out.narrow(1, 1, 1);
(new_hidden, velocity, log_variance, quality)
}
pub fn extract_identity(&self, hidden: &Variable) -> Vec<f32> {
let data = hidden.data().to_vec();
let dim = self.hidden_dim;
let norm: f32 = data[..dim].iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-8 {
return vec![0.0; dim];
}
data[..dim].iter().map(|x| x / norm).collect()
}
pub fn normalize_identity(&self, hidden: &Variable) -> Variable {
let h_data = hidden.data().to_vec();
let norm_val: f32 = h_data.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
hidden.mul_scalar(1.0 / norm_val)
}
pub fn match_identities(
embedding_a: &[f32],
embedding_b: &[f32],
logvar_a: f32,
logvar_b: f32,
) -> f32 {
assert_eq!(embedding_a.len(), embedding_b.len());
let dim = embedding_a.len();
let var_a = logvar_a.exp();
let var_b = logvar_b.exp();
let precision = 1.0 / (var_a + var_b + 1e-8);
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..dim {
let wa = embedding_a[i] * precision;
let wb = embedding_b[i] * precision;
dot += wa * wb;
norm_a += wa * wa;
norm_b += wb * wb;
}
let denom = (norm_a.sqrt() * norm_b.sqrt()).max(1e-8);
dot / denom
}
pub fn convergence_delta(hidden_prev: &[f32], hidden_curr: &[f32]) -> f32 {
assert_eq!(hidden_prev.len(), hidden_curr.len());
let dim = hidden_prev.len() as f32;
let sq_dist: f32 = hidden_prev
.iter()
.zip(hidden_curr.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
(sq_dist / dim).sqrt()
}
pub fn assess_liveness(&self, face_sequence: &[Variable]) -> LivenessResult {
use super::BiometricModality;
if face_sequence.len() < 3 {
return LivenessResult::unknown();
}
let mut hidden: Option<Variable> = None;
let mut hidden_states: Vec<Vec<f32>> = Vec::new();
for frame in face_sequence {
let (h, _velocity, _logvar, _quality) = self.crystallize_step(frame, hidden.as_ref());
hidden_states.push(h.data().to_vec());
hidden = Some(h);
}
let mut deltas: Vec<Vec<f32>> = Vec::new();
for i in 1..hidden_states.len() {
let delta: Vec<f32> = hidden_states[i]
.iter()
.zip(hidden_states[i - 1].iter())
.map(|(a, b)| a - b)
.collect();
deltas.push(delta);
}
if deltas.is_empty() {
return LivenessResult::unknown();
}
let delta_magnitudes: Vec<f32> = deltas
.iter()
.map(|d| {
let sq_sum: f32 = d.iter().map(|x| x * x).sum();
sq_sum.sqrt()
})
.collect();
let mean_mag: f32 = delta_magnitudes.iter().sum::<f32>() / delta_magnitudes.len() as f32;
let temporal_variance: f32 = if delta_magnitudes.len() > 1 {
delta_magnitudes
.iter()
.map(|m| (m - mean_mag) * (m - mean_mag))
.sum::<f32>()
/ (delta_magnitudes.len() - 1) as f32
} else {
0.0
};
let mut autocorrelations: Vec<f32> = Vec::new();
for i in 1..deltas.len() {
let dot: f32 = deltas[i]
.iter()
.zip(deltas[i - 1].iter())
.map(|(a, b)| a * b)
.sum();
let norm_a: f32 = deltas[i].iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = deltas[i - 1].iter().map(|x| x * x).sum::<f32>().sqrt();
let denom = (norm_a * norm_b).max(1e-8);
autocorrelations.push(dot / denom);
}
let trajectory_smoothness: f32 = if autocorrelations.is_empty() {
0.0
} else {
autocorrelations.iter().sum::<f32>() / autocorrelations.len() as f32
};
let variance_signal = 1.0 / (1.0 + (-50.0 * (temporal_variance - 0.001)).exp());
let smoothness_signal = 1.0 - (trajectory_smoothness.max(0.0));
let liveness_score = (0.6 * variance_signal + 0.4 * smoothness_signal).clamp(0.0, 1.0);
let is_live = liveness_score > 0.5;
LivenessResult {
liveness_score,
is_live,
temporal_variance,
trajectory_smoothness,
modality_liveness: vec![(BiometricModality::Face, liveness_score)],
}
}
pub fn detect_drift(&self, current_hidden: &Variable, original_embedding: &[f32]) -> f32 {
let current_embedding = self.extract_identity(current_hidden);
assert_eq!(
current_embedding.len(),
original_embedding.len(),
"Embedding dimensions must match: current={}, original={}",
current_embedding.len(),
original_embedding.len()
);
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..current_embedding.len() {
dot += current_embedding[i] * original_embedding[i];
norm_a += current_embedding[i] * current_embedding[i];
norm_b += original_embedding[i] * original_embedding[i];
}
let denom = (norm_a.sqrt() * norm_b.sqrt()).max(1e-8);
let cosine_sim = dot / denom;
(1.0 - cosine_sim).clamp(0.0, 2.0)
}
pub fn crystallize_sequence(&self, faces: &[Variable]) -> (Variable, Vec<f32>, f32) {
assert!(!faces.is_empty(), "Face sequence must not be empty");
let mut hidden: Option<Variable> = None;
let mut per_frame_qualities: Vec<f32> = Vec::with_capacity(faces.len());
let mut final_velocity: f32 = 1.0;
for frame in faces {
let (h, velocity, _logvar, quality) = self.crystallize_step(frame, hidden.as_ref());
let q_val = quality.data().to_vec()[0];
per_frame_qualities.push(q_val);
final_velocity = velocity.data().to_vec()[0];
hidden = Some(h);
}
(hidden.unwrap(), per_frame_qualities, final_velocity)
}
pub fn assess_quality(&self, face: &Variable) -> f32 {
let encoding = self.encode_face(face);
let enc_data = encoding.data().to_vec();
let enc_magnitude: f32 = enc_data.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_factor = 1.0 / (1.0 + (-0.1 * (enc_magnitude - 1.0)).exp());
let quality = self.compute_quality(&encoding);
let gate_score = quality.data().to_vec()[0];
(gate_score * 0.7 + magnitude_factor * 0.3).clamp(0.0, 1.0)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.stem_conv.parameters());
p.extend(self.stem_bn.parameters());
p.extend(self.block1.parameters());
p.extend(self.block2.parameters());
p.extend(self.block3.parameters());
p.extend(self.pool.parameters());
p.extend(self.face_proj.parameters());
p.extend(self.quality_gate.parameters());
p.extend(self.gru.parameters());
p.extend(self.convergence_head.parameters());
p
}
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
pub fn encoding_dim(&self) -> usize {
self.encoding_dim
}
}
impl Module for MnemosyneIdentity {
fn forward(&self, input: &Variable) -> Variable {
self.encode_face(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.parameters()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_face(batch: usize, fill: f32) -> Variable {
let n = batch * 3 * 64 * 64;
Variable::new(
Tensor::from_vec(vec![fill; n], &[batch, 3, 64, 64]).unwrap(),
false,
)
}
fn make_face_grad(batch: usize, fill: f32) -> Variable {
let n = batch * 3 * 64 * 64;
Variable::new(
Tensor::from_vec(vec![fill; n], &[batch, 3, 64, 64]).unwrap(),
true,
)
}
fn make_varied_face(batch: usize, seed: u32) -> Variable {
let n = batch * 3 * 64 * 64;
let data: Vec<f32> = (0..n)
.map(|i| {
let v = ((i as u32).wrapping_mul(2654435761).wrapping_add(seed)) as f32
/ u32::MAX as f32;
v * 2.0 - 1.0 })
.collect();
Variable::new(Tensor::from_vec(data, &[batch, 3, 64, 64]).unwrap(), false)
}
#[test]
fn test_mnemosyne_creation() {
let model = MnemosyneIdentity::new();
assert_eq!(model.hidden_dim(), 64);
assert_eq!(model.encoding_dim(), 96);
}
#[test]
fn test_mnemosyne_param_count() {
let model = MnemosyneIdentity::new();
let total: usize = model
.parameters()
.iter()
.map(|p| p.variable().data().to_vec().len())
.sum();
assert!(total < 150_000, "Params {} exceeds 150K budget", total);
assert!(total > 10_000, "Params {} seems too low", total);
}
#[test]
fn test_mnemosyne_forward_shape() {
let model = MnemosyneIdentity::new();
let input = make_face(1, 0.5);
let output = model.forward(&input);
assert_eq!(output.shape(), &[1, 96]);
let data = output.data().to_vec();
let nonzero = data.iter().filter(|&&v| v.abs() > 1e-6).count();
assert!(nonzero > 0, "All outputs are zero — dead network");
}
#[test]
fn test_mnemosyne_crystallize_step() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.5);
let (hidden, velocity, logvar, quality) = model.crystallize_step(&face, None);
assert_eq!(hidden.shape(), &[1, 64]);
assert_eq!(velocity.shape(), &[1, 1]);
assert_eq!(logvar.shape(), &[1, 1]);
assert_eq!(quality.shape(), &[1, 1]);
let vel_val = velocity.data().to_vec()[0];
assert!(
(0.0..=1.0).contains(&vel_val),
"Velocity {} not in [0,1]",
vel_val
);
let qual_val = quality.data().to_vec()[0];
assert!(
(0.0..=1.0).contains(&qual_val),
"Quality {} not in [0,1]",
qual_val
);
}
#[test]
fn test_mnemosyne_multi_step_crystallization() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.3);
let mut hidden = None;
let mut prev_hidden_data: Option<Vec<f32>> = None;
let mut convergence_deltas = Vec::new();
for _ in 0..5 {
let (h, _velocity, _logvar, _quality) = model.crystallize_step(&face, hidden.as_ref());
let h_data = h.data().to_vec();
if let Some(ref prev) = prev_hidden_data {
let delta = MnemosyneIdentity::convergence_delta(prev, &h_data);
convergence_deltas.push(delta);
}
prev_hidden_data = Some(h_data);
hidden = Some(h);
}
assert_eq!(convergence_deltas.len(), 4);
for (i, d) in convergence_deltas.iter().enumerate() {
assert!(d.is_finite(), "Delta {} at step {} is not finite", d, i);
}
}
#[test]
fn test_mnemosyne_identity_matching() {
let a = vec![0.5, 0.3, 0.8, 0.1];
let b = vec![0.5, 0.3, 0.8, 0.1];
let score = MnemosyneIdentity::match_identities(&a, &b, -1.0, -1.0);
assert!(score > 0.99, "Self-match score {} too low", score);
let c = vec![-0.5, -0.3, -0.8, -0.1];
let score2 = MnemosyneIdentity::match_identities(&a, &c, -1.0, -1.0);
assert!(
score2 < 0.0,
"Opposite embedding score {} should be negative",
score2
);
let score_uncertain = MnemosyneIdentity::match_identities(&a, &b, 2.0, 2.0);
assert!(
score_uncertain > 0.9,
"Same embedding should still match: {}",
score_uncertain
);
}
#[test]
fn test_mnemosyne_normalize_identity() {
let model = MnemosyneIdentity::new();
let hidden = Variable::new(
Tensor::from_vec(vec![3.0, 4.0, 0.0, 0.0], &[1, 4]).unwrap(),
false,
);
let normalized = model.normalize_identity(&hidden);
let data = normalized.data().to_vec();
let norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01, "Not unit norm: {}", norm);
}
#[test]
fn test_mnemosyne_quality_gate_range() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.8);
let encoding = model.encode_face(&face);
let quality = model.compute_quality(&encoding);
let q = quality.data().to_vec()[0];
assert!((0.0..=1.0).contains(&q), "Quality {} not in [0,1]", q);
}
#[test]
fn test_mnemosyne_forward_backward() {
let model = MnemosyneIdentity::new();
let face = make_face_grad(1, 0.5);
let output = model.forward(&face);
assert_eq!(output.shape(), &[1, 96]);
let loss = output.sum();
let loss_val = loss.data().to_vec()[0];
assert!(loss_val.is_finite(), "Loss should be finite: {}", loss_val);
}
#[test]
fn test_liveness_real_face_varied_input() {
let model = MnemosyneIdentity::new();
let sequence: Vec<Variable> = (0..8).map(|i| make_varied_face(1, i * 12345 + 7)).collect();
let result = model.assess_liveness(&sequence);
assert!(
result.liveness_score >= 0.0 && result.liveness_score <= 1.0,
"Liveness score {} out of range",
result.liveness_score
);
assert!(
result.temporal_variance.is_finite(),
"Temporal variance should be finite"
);
assert!(
result.trajectory_smoothness.is_finite(),
"Trajectory smoothness should be finite"
);
assert!(
!result.modality_liveness.is_empty(),
"Should have modality liveness entries"
);
}
#[test]
fn test_liveness_spoofed_constant_input() {
let model = MnemosyneIdentity::new();
let constant_face = make_face(1, 0.5);
let sequence: Vec<Variable> = (0..8).map(|_| constant_face.clone()).collect();
let result = model.assess_liveness(&sequence);
assert!(
result.temporal_variance.is_finite(),
"Temporal variance should be finite"
);
assert!(
result.liveness_score >= 0.0 && result.liveness_score <= 1.0,
"Liveness score {} out of range",
result.liveness_score
);
}
#[test]
fn test_liveness_varied_vs_constant_variance() {
let model = MnemosyneIdentity::new();
let varied_seq: Vec<Variable> = (0..8)
.map(|i| make_varied_face(1, (i as u32).wrapping_mul(999_999_937)))
.collect();
let varied_result = model.assess_liveness(&varied_seq);
let const_face = make_face(1, 0.5);
let const_seq: Vec<Variable> = (0..8).map(|_| const_face.clone()).collect();
let const_result = model.assess_liveness(&const_seq);
assert!(varied_result.temporal_variance.is_finite());
assert!(const_result.temporal_variance.is_finite());
assert!(varied_result.temporal_variance >= 0.0);
assert!(const_result.temporal_variance >= 0.0);
assert!(
const_result.temporal_variance < 0.5,
"Constant input should have low temporal variance ({}), got {}",
0.5,
const_result.temporal_variance
);
}
#[test]
fn test_liveness_too_few_frames() {
let model = MnemosyneIdentity::new();
let seq: Vec<Variable> = (0..2).map(|i| make_varied_face(1, i)).collect();
let result = model.assess_liveness(&seq);
assert_eq!(
result.liveness_score, 0.5,
"Too few frames should return unknown"
);
assert!(!result.is_live, "Too few frames should not be judged live");
}
#[test]
fn test_liveness_minimum_frames() {
let model = MnemosyneIdentity::new();
let seq: Vec<Variable> = (0..3).map(|i| make_varied_face(1, i * 5555)).collect();
let result = model.assess_liveness(&seq);
assert!(result.temporal_variance.is_finite());
assert!(result.trajectory_smoothness.is_finite());
assert!(result.liveness_score >= 0.0 && result.liveness_score <= 1.0);
}
#[test]
fn test_liveness_smoothness_range() {
let model = MnemosyneIdentity::new();
let seq: Vec<Variable> = (0..6).map(|i| make_varied_face(1, i * 77777)).collect();
let result = model.assess_liveness(&seq);
assert!(
result.trajectory_smoothness >= -1.0 && result.trajectory_smoothness <= 1.0,
"Trajectory smoothness {} out of [-1, 1]",
result.trajectory_smoothness
);
}
#[test]
fn test_drift_same_face_low_drift() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.5);
let (hidden, _v, _l, _q) = model.crystallize_step(&face, None);
let embedding = model.extract_identity(&hidden);
let drift = model.detect_drift(&hidden, &embedding);
assert!(
drift < 0.01,
"Same-face drift should be near zero, got {}",
drift
);
}
#[test]
fn test_drift_different_face_high_drift() {
let model = MnemosyneIdentity::new();
let face_a = make_face(1, 0.1);
let face_b = make_varied_face(1, 42);
let (hidden_a, _v, _l, _q) = model.crystallize_step(&face_a, None);
let embedding_a = model.extract_identity(&hidden_a);
let mut hidden_b = None;
for _ in 0..5 {
let (h, _v, _l, _q) = model.crystallize_step(&face_b, hidden_b.as_ref());
hidden_b = Some(h);
}
let drift = model.detect_drift(hidden_b.as_ref().unwrap(), &embedding_a);
assert!(drift.is_finite(), "Drift should be finite, got {}", drift);
assert!(drift >= 0.0, "Drift should be non-negative, got {}", drift);
}
#[test]
fn test_drift_range() {
let model = MnemosyneIdentity::new();
let face = make_varied_face(1, 123);
let (hidden, _v, _l, _q) = model.crystallize_step(&face, None);
let orthogonal: Vec<f32> = (0..model.hidden_dim())
.map(|i| if i == 0 { 1.0 } else { 0.0 })
.collect();
let drift = model.detect_drift(&hidden, &orthogonal);
assert!(
(0.0..=2.0).contains(&drift),
"Drift {} should be in [0, 2]",
drift
);
}
#[test]
fn test_crystallize_sequence_basic() {
let model = MnemosyneIdentity::new();
let faces: Vec<Variable> = (0..5).map(|i| make_varied_face(1, i * 11111)).collect();
let (final_hidden, qualities, final_velocity) = model.crystallize_sequence(&faces);
assert_eq!(final_hidden.shape(), &[1, 64]);
assert_eq!(qualities.len(), 5);
for (i, q) in qualities.iter().enumerate() {
assert!(
*q >= 0.0 && *q <= 1.0,
"Quality {} at frame {} out of [0,1]",
q,
i
);
}
assert!(
(0.0..=1.0).contains(&final_velocity),
"Final velocity {} out of [0,1]",
final_velocity
);
}
#[test]
fn test_crystallize_sequence_single_frame() {
let model = MnemosyneIdentity::new();
let faces = vec![make_face(1, 0.5)];
let (hidden, qualities, velocity) = model.crystallize_sequence(&faces);
assert_eq!(hidden.shape(), &[1, 64]);
assert_eq!(qualities.len(), 1);
assert!((0.0..=1.0).contains(&velocity));
}
#[test]
fn test_crystallize_sequence_convergence_over_time() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.4);
let faces: Vec<Variable> = (0..10).map(|_| face.clone()).collect();
let mut prev_hidden_data: Option<Vec<f32>> = None;
let mut hidden: Option<Variable> = None;
let mut deltas = Vec::new();
for frame in &faces {
let (h, _v, _l, _q) = model.crystallize_step(frame, hidden.as_ref());
let h_data = h.data().to_vec();
if let Some(ref prev) = prev_hidden_data {
deltas.push(MnemosyneIdentity::convergence_delta(prev, &h_data));
}
prev_hidden_data = Some(h_data);
hidden = Some(h);
}
for d in &deltas {
assert!(d.is_finite() && *d >= 0.0);
}
}
#[test]
#[should_panic(expected = "Face sequence must not be empty")]
fn test_crystallize_sequence_empty_panics() {
let model = MnemosyneIdentity::new();
let _result = model.crystallize_sequence(&[]);
}
#[test]
fn test_assess_quality_valid_input() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.5);
let quality = model.assess_quality(&face);
assert!(
(0.0..=1.0).contains(&quality),
"Quality {} out of [0,1]",
quality
);
}
#[test]
fn test_assess_quality_nonzero() {
let model = MnemosyneIdentity::new();
let face = make_varied_face(1, 42);
let quality = model.assess_quality(&face);
assert!(quality.is_finite(), "Quality should be finite");
}
#[test]
fn test_assess_quality_zero_input() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.0);
let quality = model.assess_quality(&face);
assert!(
(0.0..=1.0).contains(&quality),
"Quality {} out of range for zero input",
quality
);
}
#[test]
fn test_assess_quality_range_across_inputs() {
let model = MnemosyneIdentity::new();
for fill in [0.0, 0.1, 0.5, 0.9, 1.0] {
let face = make_face(1, fill);
let q = model.assess_quality(&face);
assert!(
(0.0..=1.0).contains(&q),
"Quality {} out of [0,1] for fill={}",
q,
fill
);
}
}
#[test]
fn test_forward_batch() {
let model = MnemosyneIdentity::new();
let batch_face = make_face(4, 0.5);
let output = model.forward(&batch_face);
assert_eq!(output.shape(), &[4, 96]);
}
#[test]
fn test_crystallize_step_batch() {
let model = MnemosyneIdentity::new();
let batch_face = make_face(3, 0.5);
let (hidden, velocity, logvar, quality) = model.crystallize_step(&batch_face, None);
assert_eq!(hidden.shape(), &[3, 64]);
assert_eq!(velocity.shape(), &[3, 1]);
assert_eq!(logvar.shape(), &[3, 1]);
assert_eq!(quality.shape(), &[3, 1]);
}
#[test]
fn test_single_frame_liveness() {
let model = MnemosyneIdentity::new();
let seq = vec![make_face(1, 0.5)];
let result = model.assess_liveness(&seq);
assert_eq!(result.liveness_score, 0.5);
assert!(!result.is_live);
}
#[test]
fn test_zero_input_forward() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.0);
let output = model.forward(&face);
assert_eq!(output.shape(), &[1, 96]);
let data = output.data().to_vec();
for v in &data {
assert!(v.is_finite(), "Output should be finite for zero input");
}
}
#[test]
fn test_zero_input_crystallize() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.0);
let (hidden, velocity, logvar, quality) = model.crystallize_step(&face, None);
for v in hidden.data().to_vec() {
assert!(v.is_finite());
}
assert!(velocity.data().to_vec()[0].is_finite());
assert!(logvar.data().to_vec()[0].is_finite());
assert!(quality.data().to_vec()[0].is_finite());
}
#[test]
fn test_large_input_stability() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 100.0);
let output = model.forward(&face);
let data = output.data().to_vec();
for v in &data {
assert!(
v.is_finite(),
"Output should be finite for large input, got {}",
v
);
}
}
#[test]
fn test_small_input_stability() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 1e-7);
let output = model.forward(&face);
let data = output.data().to_vec();
for v in &data {
assert!(
v.is_finite(),
"Output should be finite for small input, got {}",
v
);
}
}
#[test]
fn test_negative_input_stability() {
let model = MnemosyneIdentity::new();
let face = make_face(1, -1.0);
let output = model.forward(&face);
let data = output.data().to_vec();
for v in &data {
assert!(
v.is_finite(),
"Output should be finite for negative input, got {}",
v
);
}
}
#[test]
fn test_extract_identity_l2_norm() {
let model = MnemosyneIdentity::new();
let face = make_face(1, 0.5);
let (hidden, _v, _l, _q) = model.crystallize_step(&face, None);
let embedding = model.extract_identity(&hidden);
assert_eq!(embedding.len(), 64);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.01,
"Identity embedding should be L2-normalized, got norm={}",
norm
);
}
#[test]
fn test_extract_identity_finite_values() {
let model = MnemosyneIdentity::new();
let face = make_varied_face(1, 999);
let (hidden, _v, _l, _q) = model.crystallize_step(&face, None);
let embedding = model.extract_identity(&hidden);
for (i, v) in embedding.iter().enumerate() {
assert!(v.is_finite(), "Embedding dim {} is not finite: {}", i, v);
}
}
#[test]
fn test_extract_identity_zero_hidden() {
let model = MnemosyneIdentity::new();
let hidden = Variable::new(Tensor::zeros(&[1, model.hidden_dim()]), false);
let embedding = model.extract_identity(&hidden);
assert_eq!(embedding.len(), model.hidden_dim());
for v in &embedding {
assert_eq!(*v, 0.0);
}
}
#[test]
fn test_normalize_identity_preserves_direction() {
let model = MnemosyneIdentity::new();
let hidden = Variable::new(
Tensor::from_vec(vec![3.0, 4.0, 0.0, 0.0], &[1, 4]).unwrap(),
false,
);
let normalized = model.normalize_identity(&hidden);
let data = normalized.data().to_vec();
if data[1].abs() > 1e-8 {
let ratio = data[0] / data[1];
assert!(
(ratio - 0.75).abs() < 0.01,
"Direction not preserved: ratio={}",
ratio
);
}
}
#[test]
fn test_custom_dims() {
let model = MnemosyneIdentity::with_dims(128, 32);
assert_eq!(model.encoding_dim(), 128);
assert_eq!(model.hidden_dim(), 32);
let face = make_face(1, 0.5);
let output = model.forward(&face);
assert_eq!(output.shape(), &[1, 128]);
let (hidden, _v, _l, _q) = model.crystallize_step(&face, None);
assert_eq!(hidden.shape(), &[1, 32]);
let embedding = model.extract_identity(&hidden);
assert_eq!(embedding.len(), 32);
}
#[test]
fn test_liveness_and_crystallize_sequence_together() {
let model = MnemosyneIdentity::new();
let faces: Vec<Variable> = (0..6).map(|i| make_varied_face(1, i * 31337)).collect();
let liveness = model.assess_liveness(&faces);
let (final_hidden, qualities, final_vel) = model.crystallize_sequence(&faces);
assert!(liveness.liveness_score >= 0.0 && liveness.liveness_score <= 1.0);
assert_eq!(final_hidden.shape(), &[1, 64]);
assert_eq!(qualities.len(), 6);
assert!((0.0..=1.0).contains(&final_vel));
}
#[test]
fn test_drift_after_crystallize_sequence() {
let model = MnemosyneIdentity::new();
let faces: Vec<Variable> = (0..5).map(|_| make_face(1, 0.5)).collect();
let (hidden, _qualities, _vel) = model.crystallize_sequence(&faces);
let embedding = model.extract_identity(&hidden);
let drift = model.detect_drift(&hidden, &embedding);
assert!(
drift < 0.01,
"Self-drift should be near zero, got {}",
drift
);
}
#[test]
fn test_convergence_delta_identical() {
let a = vec![1.0, 2.0, 3.0];
let delta = MnemosyneIdentity::convergence_delta(&a, &a);
assert!(delta < 1e-6, "Identical states should have zero delta");
}
#[test]
fn test_convergence_delta_known() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
let delta = MnemosyneIdentity::convergence_delta(&a, &b);
let expected = (25.0f32 / 3.0).sqrt();
assert!(
(delta - expected).abs() < 0.001,
"Expected {}, got {}",
expected,
delta
);
}
#[test]
fn test_mnemosyne_training_e2e() {
use axonml_optim::{Adam, Optimizer};
let model = MnemosyneIdentity::new();
let params = model.parameters();
println!("Mnemosyne params: {}", params.len());
assert!(!params.is_empty(), "Model must have parameters");
let mut optimizer = Adam::new(params, 0.001);
let base_face = Tensor::randn(&[1, 3, 64, 64]);
let mut losses = Vec::new();
for step in 0..20 {
let anchor_face = Variable::new(
base_face
.add(&Tensor::randn(&[1, 3, 64, 64]).mul_scalar(0.1))
.unwrap(),
false,
);
let positive_face = Variable::new(
base_face
.add(&Tensor::randn(&[1, 3, 64, 64]).mul_scalar(0.1))
.unwrap(),
false,
);
let negative_face = Variable::new(Tensor::randn(&[1, 3, 64, 64]), false);
let (hidden_a, _vel_a, _, _) = model.crystallize_step(&anchor_face, None);
let (hidden_p, _vel_p, _, _) = model.crystallize_step(&positive_face, None);
let (hidden_n, _vel_n, _, _) = model.crystallize_step(&negative_face, None);
let emb_a = l2_normalize_var(&hidden_a);
let emb_p = l2_normalize_var(&hidden_p);
let emb_n = l2_normalize_var(&hidden_n);
let dot_pos = emb_a.mul_var(&emb_p).sum();
let dot_neg = emb_a.mul_var(&emb_n).sum();
let dist_pos = dot_pos.mul_scalar(-1.0).add_scalar(1.0);
let dist_neg = dot_neg.mul_scalar(-1.0).add_scalar(1.0);
let margin = 0.3;
let loss = dist_pos.sub_var(&dist_neg).add_scalar(margin).relu();
let loss_val = loss.data().to_vec()[0];
losses.push(loss_val);
if step == 0 {
println!("Step 0: loss = {}", loss_val);
assert!(
loss_val.is_finite(),
"Initial loss must be finite, got {}",
loss_val
);
}
loss.backward();
if step == 0 {
let params_after = model.parameters();
let mut has_grad = 0;
let mut zero_grad = 0;
let mut no_grad = 0;
for p in ¶ms_after {
let name = p.name().to_string();
if let Some(g) = p.variable().grad() {
let grad_norm: f32 = g.to_vec().iter().map(|x| x * x).sum::<f32>().sqrt();
if grad_norm > 1e-10 {
has_grad += 1;
println!(
" HAS GRAD: {} shape={:?} grad_norm={:.6}",
name,
p.variable().shape(),
grad_norm
);
} else {
zero_grad += 1;
println!(" ZERO GRAD: {} shape={:?}", name, p.variable().shape());
}
} else {
no_grad += 1;
println!(" NO GRAD: {} shape={:?}", name, p.variable().shape());
}
}
println!(
"Params with nonzero grad: {}, zero grad: {}, no grad: {}",
has_grad, zero_grad, no_grad
);
assert!(
has_grad > 0,
"At least some parameters must have non-zero gradients"
);
}
optimizer.step();
optimizer.zero_grad();
}
let first_5_avg: f32 = losses[..5].iter().sum::<f32>() / 5.0;
let last_5_avg: f32 = losses[15..].iter().sum::<f32>() / 5.0;
println!(
"First 5 avg loss: {:.4}, Last 5 avg loss: {:.4}",
first_5_avg, last_5_avg
);
println!("All losses: {:?}", losses);
for (i, l) in losses.iter().enumerate() {
assert!(l.is_finite(), "Loss became non-finite at step {}: {}", i, l);
}
assert!(
last_5_avg <= first_5_avg + 0.5,
"Loss should not increase significantly: first_5={:.4} last_5={:.4}",
first_5_avg,
last_5_avg
);
}
fn l2_normalize_var(x: &Variable) -> Variable {
let sq = x.mul_var(x); let sum_sq = sq.sum(); let norm_val = sum_sq.data().to_vec()[0].sqrt().max(1e-8);
x.mul_scalar(1.0 / norm_val)
}
}