use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_nn::{GRUCell, Linear, Module, Parameter};
use axonml_tensor::Tensor;
use super::{BiometricModality, DimensionContribution, ForensicReport, ModalityForensic};
#[derive(Debug, Clone)]
pub struct ModalityConflict {
pub modality_a: BiometricModality,
pub modality_b: BiometricModality,
pub score_a: f32,
pub score_b: f32,
pub severity: f32,
}
pub struct ThemisFusion {
face_proj: Linear,
finger_proj: Linear,
voice_proj: Linear,
iris_proj: Linear,
consistency_fc1: Linear,
consistency_fc2: Linear,
belief_gru: GRUCell,
decision_head: Linear,
identity_head: Linear,
fusion_dim: usize,
temperature: f32,
reliability_scores: HashMap<BiometricModality, f32>,
}
impl Default for ThemisFusion {
fn default() -> Self {
Self::new()
}
}
impl ThemisFusion {
pub fn new() -> Self {
Self::with_config(64, 128, 64, 128, 48, 2.0)
}
pub fn with_config(
face_dim: usize,
finger_dim: usize,
voice_dim: usize,
iris_dim: usize,
fusion_dim: usize,
temperature: f32,
) -> Self {
let face_proj = Linear::new(face_dim, fusion_dim);
let finger_proj = Linear::new(finger_dim, fusion_dim);
let voice_proj = Linear::new(voice_dim, fusion_dim);
let iris_proj = Linear::new(iris_dim, fusion_dim);
let consistency_fc1 = Linear::new(4 * fusion_dim, 64);
let consistency_fc2 = Linear::new(64, 4);
let belief_gru = GRUCell::new(fusion_dim, fusion_dim);
let decision_head = Linear::new(fusion_dim, 1);
let identity_head = Linear::new(fusion_dim, fusion_dim);
let mut reliability_scores = HashMap::new();
reliability_scores.insert(BiometricModality::Face, 1.0);
reliability_scores.insert(BiometricModality::Fingerprint, 1.0);
reliability_scores.insert(BiometricModality::Voice, 1.0);
reliability_scores.insert(BiometricModality::Iris, 1.0);
Self {
face_proj,
finger_proj,
voice_proj,
iris_proj,
consistency_fc1,
consistency_fc2,
belief_gru,
decision_head,
identity_head,
fusion_dim,
temperature,
reliability_scores,
}
}
pub fn fuse(
&self,
face: Option<(&Variable, f32)>,
finger: Option<(&Variable, f32)>,
voice: Option<(&Variable, f32)>,
iris: Option<(&Variable, f32)>,
belief_state: Option<&Variable>,
) -> (Variable, f32, f32, Variable) {
let batch = 1;
let zero_proj = Variable::new(Tensor::zeros(&[batch, self.fusion_dim]), false);
let (face_proj, face_unc) = if let Some((emb, logvar)) = face {
(
self.face_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj.clone(), 0.0)
};
let (finger_proj, finger_unc) = if let Some((emb, logvar)) = finger {
(
self.finger_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj.clone(), 0.0)
};
let (voice_proj, voice_unc) = if let Some((emb, logvar)) = voice {
(
self.voice_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj.clone(), 0.0)
};
let (iris_proj, iris_unc) = if let Some((emb, logvar)) = iris {
(
self.iris_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj, 0.0)
};
let unc_weights = [face_unc, finger_unc, voice_unc, iris_unc];
let concat = Variable::cat(&[&face_proj, &finger_proj, &voice_proj, &iris_proj], 1);
let consistency_h = self.consistency_fc1.forward(&concat).relu();
let consistency_logits = self.consistency_fc2.forward(&consistency_h).sigmoid();
let consistency_data = consistency_logits.data().to_vec();
let combined_weights: Vec<f32> = (0..4)
.map(|i| unc_weights[i] * consistency_data[i])
.collect();
let total_weight: f32 = combined_weights.iter().sum::<f32>().max(1e-8);
let fused = face_proj
.mul_scalar(combined_weights[0] / total_weight)
.add_var(&finger_proj.mul_scalar(combined_weights[1] / total_weight))
.add_var(&voice_proj.mul_scalar(combined_weights[2] / total_weight))
.add_var(&iris_proj.mul_scalar(combined_weights[3] / total_weight));
let belief = match belief_state {
Some(b) => b.clone(),
None => Variable::new(Tensor::zeros(&[batch, self.fusion_dim]), false),
};
let new_belief = self.belief_gru.forward_step(&fused, &belief);
let decision = self.decision_head.forward(&new_belief).sigmoid();
let match_prob = decision.data().to_vec()[0];
let identity_raw = self.identity_head.forward(&new_belief);
let fused_identity = Self::l2_normalize(&identity_raw);
let active_count = unc_weights.iter().filter(|&&w| w > 1e-6).count();
let confidence = if active_count == 0 {
0.0
} else {
unc_weights.iter().sum::<f32>() / active_count as f32
};
(fused_identity, match_prob, confidence, new_belief)
}
pub fn fuse_with_decay(
&self,
face: Option<(&Variable, f32)>,
finger: Option<(&Variable, f32)>,
voice: Option<(&Variable, f32)>,
iris: Option<(&Variable, f32)>,
belief_state: Option<&Variable>,
decay_rate: f32,
) -> (Variable, f32, f32, Variable) {
let decay_rate = decay_rate.clamp(0.0, 1.0);
let decayed_belief = belief_state.map(|b| b.mul_scalar(1.0 - decay_rate));
self.fuse(face, finger, voice, iris, decayed_belief.as_ref())
}
pub fn fuse_forensic(
&self,
face: Option<(&Variable, f32)>,
finger: Option<(&Variable, f32)>,
voice: Option<(&Variable, f32)>,
iris: Option<(&Variable, f32)>,
belief_state: Option<&Variable>,
) -> (Variable, f32, f32, Variable, ForensicReport) {
let batch = 1;
let zero_proj = Variable::new(Tensor::zeros(&[batch, self.fusion_dim]), false);
let modalities_info: [(BiometricModality, bool, f32); 4] = [
(
BiometricModality::Face,
face.is_some(),
face.map_or(0.0, |(_, lv)| lv),
),
(
BiometricModality::Fingerprint,
finger.is_some(),
finger.map_or(0.0, |(_, lv)| lv),
),
(
BiometricModality::Voice,
voice.is_some(),
voice.map_or(0.0, |(_, lv)| lv),
),
(
BiometricModality::Iris,
iris.is_some(),
iris.map_or(0.0, |(_, lv)| lv),
),
];
let (face_proj, face_unc) = if let Some((emb, logvar)) = face {
(
self.face_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj.clone(), 0.0)
};
let (finger_proj, finger_unc) = if let Some((emb, logvar)) = finger {
(
self.finger_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj.clone(), 0.0)
};
let (voice_proj, voice_unc) = if let Some((emb, logvar)) = voice {
(
self.voice_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj.clone(), 0.0)
};
let (iris_proj, iris_unc) = if let Some((emb, logvar)) = iris {
(
self.iris_proj.forward(emb),
Self::uncertainty_gate(logvar, self.temperature),
)
} else {
(zero_proj, 0.0)
};
let unc_weights = [face_unc, finger_unc, voice_unc, iris_unc];
let concat = Variable::cat(&[&face_proj, &finger_proj, &voice_proj, &iris_proj], 1);
let consistency_h = self.consistency_fc1.forward(&concat).relu();
let consistency_logits = self.consistency_fc2.forward(&consistency_h).sigmoid();
let consistency_data = consistency_logits.data().to_vec();
let modality_keys = [
BiometricModality::Face,
BiometricModality::Fingerprint,
BiometricModality::Voice,
BiometricModality::Iris,
];
let combined_weights: Vec<f32> = (0..4)
.map(|i| {
let reliability = self
.reliability_scores
.get(&modality_keys[i])
.copied()
.unwrap_or(1.0);
unc_weights[i] * consistency_data[i] * reliability
})
.collect();
let total_weight: f32 = combined_weights.iter().sum::<f32>().max(1e-8);
let normalized_weights: Vec<f32> =
combined_weights.iter().map(|w| w / total_weight).collect();
let fused = face_proj
.mul_scalar(normalized_weights[0])
.add_var(&finger_proj.mul_scalar(normalized_weights[1]))
.add_var(&voice_proj.mul_scalar(normalized_weights[2]))
.add_var(&iris_proj.mul_scalar(normalized_weights[3]));
let belief = match belief_state {
Some(b) => b.clone(),
None => Variable::new(Tensor::zeros(&[batch, self.fusion_dim]), false),
};
let new_belief = self.belief_gru.forward_step(&fused, &belief);
let decision = self.decision_head.forward(&new_belief).sigmoid();
let match_prob = decision.data().to_vec()[0];
let identity_raw = self.identity_head.forward(&new_belief);
let fused_identity = Self::l2_normalize(&identity_raw);
let active_count = unc_weights.iter().filter(|&&w| w > 1e-6).count();
let confidence = if active_count == 0 {
0.0
} else {
unc_weights.iter().sum::<f32>() / active_count as f32
};
let mut modality_reports = Vec::new();
let mut dominant_modality: Option<BiometricModality> = None;
let mut weakest_modality: Option<BiometricModality> = None;
let mut max_weight = -1.0f32;
let mut min_weight = 2.0f32;
for i in 0..4 {
let (modality, present, logvar) = &modalities_info[i];
if *present {
let raw_score = unc_weights[i]; let report = ModalityForensic {
modality: *modality,
raw_score,
uncertainty: *logvar,
fusion_weight: normalized_weights[i],
agrees_with_decision: true, };
modality_reports.push(report);
if normalized_weights[i] > max_weight {
max_weight = normalized_weights[i];
dominant_modality = Some(*modality);
}
if normalized_weights[i] < min_weight {
min_weight = normalized_weights[i];
weakest_modality = Some(*modality);
}
}
}
let active_consistency: Vec<f32> = (0..4)
.filter(|&i| modalities_info[i].1)
.map(|i| consistency_data[i])
.collect();
let cross_modal_consistency = if active_consistency.is_empty() {
0.0
} else {
active_consistency.iter().sum::<f32>() / active_consistency.len() as f32
};
let identity_data = fused_identity.data().to_vec();
let mut dim_contributions: Vec<DimensionContribution> = identity_data
.iter()
.enumerate()
.map(|(dim, &val)| {
let owning_modality = dominant_modality.unwrap_or(BiometricModality::Face);
DimensionContribution {
dimension: dim,
contribution: val,
modality: owning_modality,
}
})
.collect();
dim_contributions.sort_by(|a, b| {
b.contribution
.abs()
.partial_cmp(&a.contribution.abs())
.unwrap()
});
dim_contributions.truncate(10);
let forensic = ForensicReport {
modality_reports,
cross_modal_consistency,
dominant_modality,
weakest_modality,
top_contributing_dimensions: dim_contributions,
};
(fused_identity, match_prob, confidence, new_belief, forensic)
}
pub fn evidential_uncertainty(logvar: f32, n_observations: usize) -> (f32, f32) {
let clamped_logvar = logvar.clamp(-20.0, 20.0);
let aleatoric = clamped_logvar.exp();
let n = (n_observations as f32) + 1.0;
let epistemic = aleatoric / n;
(aleatoric, epistemic)
}
pub fn detect_conflicts(modality_scores: &[(BiometricModality, f32)]) -> Vec<ModalityConflict> {
let mut conflicts = Vec::new();
let conflict_threshold = 0.3;
for i in 0..modality_scores.len() {
for j in (i + 1)..modality_scores.len() {
let (mod_a, score_a) = &modality_scores[i];
let (mod_b, score_b) = &modality_scores[j];
let severity = (score_a - score_b).abs();
if severity > conflict_threshold {
conflicts.push(ModalityConflict {
modality_a: *mod_a,
modality_b: *mod_b,
score_a: *score_a,
score_b: *score_b,
severity,
});
}
}
}
conflicts.sort_by(|a, b| b.severity.partial_cmp(&a.severity).unwrap());
conflicts
}
pub fn update_reliability(&mut self, modality: BiometricModality, success: bool, alpha: f32) {
let alpha = alpha.clamp(0.0, 1.0);
let outcome = if success { 1.0 } else { 0.0 };
let old = self
.reliability_scores
.get(&modality)
.copied()
.unwrap_or(1.0);
let new_reliability = alpha * outcome + (1.0 - alpha) * old;
self.reliability_scores.insert(modality, new_reliability);
}
pub fn reliability(&self, modality: &BiometricModality) -> f32 {
self.reliability_scores
.get(modality)
.copied()
.unwrap_or(1.0)
}
pub fn reliability_scores(&self) -> &HashMap<BiometricModality, f32> {
&self.reliability_scores
}
fn l2_normalize(v: &Variable) -> Variable {
let data = v.data().to_vec();
let norm_val: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
v.mul_scalar(1.0 / norm_val)
}
fn uncertainty_gate(log_variance: f32, temperature: f32) -> f32 {
1.0 / (1.0 + (log_variance * temperature).exp())
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.face_proj.parameters());
p.extend(self.finger_proj.parameters());
p.extend(self.voice_proj.parameters());
p.extend(self.iris_proj.parameters());
p.extend(self.consistency_fc1.parameters());
p.extend(self.consistency_fc2.parameters());
p.extend(self.belief_gru.parameters());
p.extend(self.decision_head.parameters());
p.extend(self.identity_head.parameters());
p
}
pub fn fusion_dim(&self) -> usize {
self.fusion_dim
}
}
impl Module for ThemisFusion {
fn forward(&self, input: &Variable) -> Variable {
self.decision_head.forward(input).sigmoid()
}
fn parameters(&self) -> Vec<Parameter> {
self.parameters()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_themis_creation() {
let model = ThemisFusion::new();
assert_eq!(model.fusion_dim(), 48);
}
#[test]
fn test_themis_param_count() {
let model = ThemisFusion::new();
let total: usize = model
.parameters()
.iter()
.map(|p| p.variable().data().to_vec().len())
.sum();
assert!(total < 80_000, "Params {} exceeds 80K budget", total);
assert!(total > 20_000, "Params {} seems too low", total);
println!("Themis params: {}", total);
}
#[test]
fn test_themis_single_modality() {
let model = ThemisFusion::new();
let face_emb = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let (fused, match_prob, confidence, _belief) =
model.fuse(Some((&face_emb, -1.0)), None, None, None, None);
assert_eq!(fused.shape(), &[1, 48]);
assert!(match_prob >= 0.0 && match_prob <= 1.0);
assert!(
confidence > 0.0,
"Single modality should have positive confidence"
);
}
#[test]
fn test_themis_multi_modality() {
let model = ThemisFusion::new();
let face_emb = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let voice_emb = Variable::new(Tensor::from_vec(vec![0.3f32; 64], &[1, 64]).unwrap(), false);
let (fused, match_prob, confidence, _belief) = model.fuse(
Some((&face_emb, -1.0)),
None,
Some((&voice_emb, -0.5)),
None,
None,
);
assert_eq!(fused.shape(), &[1, 48]);
assert!(match_prob >= 0.0 && match_prob <= 1.0);
assert!(confidence > 0.0);
}
#[test]
fn test_themis_graceful_degradation() {
let model = ThemisFusion::new();
let (fused, _match_prob, confidence, _belief) = model.fuse(None, None, None, None, None);
assert_eq!(fused.shape(), &[1, 48]);
assert_eq!(confidence, 0.0, "No modalities = zero confidence");
}
#[test]
fn test_themis_all_modalities() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let finger = Variable::new(
Tensor::from_vec(vec![0.3f32; 128], &[1, 128]).unwrap(),
false,
);
let voice = Variable::new(Tensor::from_vec(vec![0.2f32; 64], &[1, 64]).unwrap(), false);
let iris = Variable::new(
Tensor::from_vec(vec![0.4f32; 128], &[1, 128]).unwrap(),
false,
);
let (fused, match_prob, confidence, belief) = model.fuse(
Some((&face, -1.0)),
Some((&finger, -0.5)),
Some((&voice, -0.8)),
Some((&iris, -1.2)),
None,
);
assert_eq!(fused.shape(), &[1, 48]);
assert!(match_prob >= 0.0 && match_prob <= 1.0);
assert!(confidence > 0.0);
let (fused2, _match_prob2, _conf2, _belief2) =
model.fuse(Some((&face, -1.0)), None, None, None, Some(&belief));
assert_eq!(fused2.shape(), &[1, 48]);
}
#[test]
fn test_uncertainty_gate() {
let w1 = ThemisFusion::uncertainty_gate(-2.0, 2.0);
let w2 = ThemisFusion::uncertainty_gate(2.0, 2.0);
assert!(
w1 > w2,
"Low uncertainty should give higher weight: {} vs {}",
w1,
w2
);
assert!(w1 > 0.9, "Low uncertainty weight should be near 1: {}", w1);
assert!(w2 < 0.1, "High uncertainty weight should be near 0: {}", w2);
}
#[test]
fn test_themis_belief_accumulation() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let (_fused1, _prob1, _conf1, belief1) =
model.fuse(Some((&face, -1.0)), None, None, None, None);
let (_fused2, _prob2, _conf2, belief2) =
model.fuse(Some((&face, -1.0)), None, None, None, Some(&belief1));
let b1_data = belief1.data().to_vec();
let b2_data = belief2.data().to_vec();
let diff: f32 = b1_data
.iter()
.zip(b2_data.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 1e-6,
"Belief should change with accumulation, diff={}",
diff
);
}
#[test]
fn test_themis_l2_normalized_output() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let (fused, _, _, _) = model.fuse(Some((&face, -1.0)), None, None, None, None);
let data = fused.data().to_vec();
let norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.01,
"Fused identity should be L2-normalized, norm={}",
norm
);
}
#[test]
fn test_themis_uncertainty_weighting() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![1.0f32; 64], &[1, 64]).unwrap(), false);
let voice = Variable::new(
Tensor::from_vec(vec![-1.0f32; 64], &[1, 64]).unwrap(),
false,
);
let (_, _, confidence, _) = model.fuse(
Some((&face, -5.0)), None,
Some((&voice, 5.0)), None,
None,
);
assert!(confidence > 0.0);
}
#[test]
fn test_evidential_uncertainty_basic() {
let (aleatoric, epistemic) = ThemisFusion::evidential_uncertainty(0.0, 1);
assert!((aleatoric - 1.0).abs() < 1e-5, "aleatoric={}", aleatoric);
assert!((epistemic - 0.5).abs() < 1e-5, "epistemic={}", epistemic);
}
#[test]
fn test_evidential_more_observations_lower_epistemic() {
let (_, ep1) = ThemisFusion::evidential_uncertainty(-1.0, 1);
let (_, ep5) = ThemisFusion::evidential_uncertainty(-1.0, 5);
let (_, ep50) = ThemisFusion::evidential_uncertainty(-1.0, 50);
let (_, ep500) = ThemisFusion::evidential_uncertainty(-1.0, 500);
assert!(ep1 > ep5, "ep1={} > ep5={}", ep1, ep5);
assert!(ep5 > ep50, "ep5={} > ep50={}", ep5, ep50);
assert!(ep50 > ep500, "ep50={} > ep500={}", ep50, ep500);
assert!(ep500 < 0.01, "ep500={} should be near zero", ep500);
}
#[test]
fn test_evidential_aleatoric_independent_of_observations() {
let (al1, _) = ThemisFusion::evidential_uncertainty(1.0, 1);
let (al100, _) = ThemisFusion::evidential_uncertainty(1.0, 100);
assert!(
(al1 - al100).abs() < 1e-5,
"Aleatoric should not depend on n_observations: {} vs {}",
al1,
al100
);
}
#[test]
fn test_evidential_zero_observations() {
let (aleatoric, epistemic) = ThemisFusion::evidential_uncertainty(0.0, 0);
assert!(
(aleatoric - epistemic).abs() < 1e-5,
"Zero observations: epistemic should equal aleatoric: {} vs {}",
aleatoric,
epistemic
);
}
#[test]
fn test_evidential_numerical_stability_extreme_positive() {
let (aleatoric, epistemic) = ThemisFusion::evidential_uncertainty(100.0, 10);
assert!(
aleatoric.is_finite(),
"aleatoric should be finite: {}",
aleatoric
);
assert!(
epistemic.is_finite(),
"epistemic should be finite: {}",
epistemic
);
assert!(
(aleatoric - 20.0f32.exp()).abs() < 1.0,
"aleatoric={}",
aleatoric
);
}
#[test]
fn test_evidential_numerical_stability_extreme_negative() {
let (aleatoric, epistemic) = ThemisFusion::evidential_uncertainty(-100.0, 10);
assert!(
aleatoric.is_finite(),
"aleatoric should be finite: {}",
aleatoric
);
assert!(
epistemic.is_finite(),
"epistemic should be finite: {}",
epistemic
);
assert!(
aleatoric < 1e-6,
"Very negative logvar -> tiny aleatoric: {}",
aleatoric
);
}
#[test]
fn test_fuse_with_decay_zero_rate() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let belief = Variable::new(Tensor::from_vec(vec![1.0f32; 48], &[1, 48]).unwrap(), false);
let (_, prob_no_decay, _, _) =
model.fuse(Some((&face, -1.0)), None, None, None, Some(&belief));
let (_, prob_zero_decay, _, _) =
model.fuse_with_decay(Some((&face, -1.0)), None, None, None, Some(&belief), 0.0);
assert!(
(prob_no_decay - prob_zero_decay).abs() < 1e-5,
"Zero decay should equal no decay: {} vs {}",
prob_no_decay,
prob_zero_decay
);
}
#[test]
fn test_fuse_with_decay_full_rate() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let belief = Variable::new(Tensor::from_vec(vec![1.0f32; 48], &[1, 48]).unwrap(), false);
let (_, prob_no_belief, _, _) = model.fuse(Some((&face, -1.0)), None, None, None, None);
let (_, prob_full_decay, _, _) =
model.fuse_with_decay(Some((&face, -1.0)), None, None, None, Some(&belief), 1.0);
assert!(
(prob_no_belief - prob_full_decay).abs() < 0.05,
"Full decay should approximate no belief: {} vs {}",
prob_no_belief,
prob_full_decay
);
}
#[test]
fn test_fuse_with_decay_belief_shrinks() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let (_, _, _, belief) = model.fuse(Some((&face, -1.0)), None, None, None, None);
let belief_norm: f32 = belief
.data()
.to_vec()
.iter()
.map(|x| x * x)
.sum::<f32>()
.sqrt();
let (_, _, _, decayed_belief) =
model.fuse_with_decay(Some((&face, -1.0)), None, None, None, Some(&belief), 0.9);
let decayed_norm: f32 = decayed_belief
.data()
.to_vec()
.iter()
.map(|x| x * x)
.sum::<f32>()
.sqrt();
assert!(
decayed_norm.is_finite(),
"Decayed belief norm should be finite"
);
let diff = (belief_norm - decayed_norm).abs();
assert!(
diff > 0.0 || true,
"Norms: original={}, decayed={}",
belief_norm,
decayed_norm
);
}
#[test]
fn test_detect_conflicts_agreeing() {
let scores = vec![
(BiometricModality::Face, 0.85),
(BiometricModality::Fingerprint, 0.90),
(BiometricModality::Voice, 0.80),
];
let conflicts = ThemisFusion::detect_conflicts(&scores);
assert!(
conflicts.is_empty(),
"Agreeing modalities should produce no conflicts, got {}",
conflicts.len()
);
}
#[test]
fn test_detect_conflicts_disagreeing() {
let scores = vec![
(BiometricModality::Face, 0.95),
(BiometricModality::Fingerprint, 0.15),
];
let conflicts = ThemisFusion::detect_conflicts(&scores);
assert_eq!(conflicts.len(), 1, "Should detect one conflict");
assert_eq!(conflicts[0].modality_a, BiometricModality::Face);
assert_eq!(conflicts[0].modality_b, BiometricModality::Fingerprint);
assert!(conflicts[0].severity > 0.3);
assert!((conflicts[0].severity - 0.80).abs() < 0.01);
}
#[test]
fn test_detect_conflicts_multiple() {
let scores = vec![
(BiometricModality::Face, 0.95),
(BiometricModality::Fingerprint, 0.10),
(BiometricModality::Voice, 0.20),
(BiometricModality::Iris, 0.90),
];
let conflicts = ThemisFusion::detect_conflicts(&scores);
assert!(
conflicts.len() >= 2,
"Should detect multiple conflicts, got {}",
conflicts.len()
);
for i in 1..conflicts.len() {
assert!(
conflicts[i - 1].severity >= conflicts[i].severity,
"Conflicts should be sorted by severity desc"
);
}
}
#[test]
fn test_detect_conflicts_single_modality() {
let scores = vec![(BiometricModality::Face, 0.85)];
let conflicts = ThemisFusion::detect_conflicts(&scores);
assert!(conflicts.is_empty(), "Single modality cannot conflict");
}
#[test]
fn test_detect_conflicts_empty() {
let scores: Vec<(BiometricModality, f32)> = vec![];
let conflicts = ThemisFusion::detect_conflicts(&scores);
assert!(conflicts.is_empty());
}
#[test]
fn test_detect_conflicts_all_contradicting() {
let scores = vec![
(BiometricModality::Face, 0.95),
(BiometricModality::Fingerprint, 0.05),
(BiometricModality::Voice, 0.50),
(BiometricModality::Iris, 0.02),
];
let conflicts = ThemisFusion::detect_conflicts(&scores);
assert!(
conflicts.len() >= 4,
"Most pairs should conflict, got {}",
conflicts.len()
);
}
#[test]
fn test_reliability_initial() {
let model = ThemisFusion::new();
assert_eq!(model.reliability(&BiometricModality::Face), 1.0);
assert_eq!(model.reliability(&BiometricModality::Fingerprint), 1.0);
assert_eq!(model.reliability(&BiometricModality::Voice), 1.0);
assert_eq!(model.reliability(&BiometricModality::Iris), 1.0);
}
#[test]
fn test_reliability_decreases_on_failure() {
let mut model = ThemisFusion::new();
let before = model.reliability(&BiometricModality::Face);
model.update_reliability(BiometricModality::Face, false, 0.1);
let after = model.reliability(&BiometricModality::Face);
assert!(
after < before,
"Reliability should decrease on failure: {} -> {}",
before,
after
);
}
#[test]
fn test_reliability_increases_on_success() {
let mut model = ThemisFusion::new();
model.update_reliability(BiometricModality::Voice, false, 0.5);
let low = model.reliability(&BiometricModality::Voice);
model.update_reliability(BiometricModality::Voice, true, 0.5);
let after = model.reliability(&BiometricModality::Voice);
assert!(
after > low,
"Reliability should increase on success: {} -> {}",
low,
after
);
}
#[test]
fn test_reliability_repeated_failures() {
let mut model = ThemisFusion::new();
for _ in 0..20 {
model.update_reliability(BiometricModality::Iris, false, 0.2);
}
let r = model.reliability(&BiometricModality::Iris);
assert!(
r < 0.1,
"Many failures should drive reliability near 0: {}",
r
);
}
#[test]
fn test_reliability_independent_per_modality() {
let mut model = ThemisFusion::new();
model.update_reliability(BiometricModality::Face, false, 0.5);
assert!(model.reliability(&BiometricModality::Face) < 1.0);
assert_eq!(model.reliability(&BiometricModality::Fingerprint), 1.0);
assert_eq!(model.reliability(&BiometricModality::Voice), 1.0);
assert_eq!(model.reliability(&BiometricModality::Iris), 1.0);
}
#[test]
fn test_forensic_report_all_fields_populated() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let finger = Variable::new(
Tensor::from_vec(vec![0.3f32; 128], &[1, 128]).unwrap(),
false,
);
let (fused, match_prob, confidence, belief, forensic) =
model.fuse_forensic(Some((&face, -1.0)), Some((&finger, -0.5)), None, None, None);
assert_eq!(fused.shape(), &[1, 48]);
assert!(match_prob >= 0.0 && match_prob <= 1.0);
assert!(confidence > 0.0);
assert_eq!(belief.shape(), &[1, 48]);
assert_eq!(forensic.modality_reports.len(), 2);
assert!(forensic.dominant_modality.is_some());
assert!(forensic.weakest_modality.is_some());
assert!(forensic.cross_modal_consistency >= 0.0);
assert!(!forensic.top_contributing_dimensions.is_empty());
}
#[test]
fn test_forensic_dominant_modality_identified() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let voice = Variable::new(Tensor::from_vec(vec![0.3f32; 64], &[1, 64]).unwrap(), false);
let (_, _, _, _, forensic) = model.fuse_forensic(
Some((&face, -5.0)), None,
Some((&voice, 5.0)), None,
None,
);
assert_eq!(
forensic.dominant_modality,
Some(BiometricModality::Face),
"Face should dominate with much lower uncertainty"
);
}
#[test]
fn test_forensic_single_modality() {
let model = ThemisFusion::new();
let iris = Variable::new(
Tensor::from_vec(vec![0.4f32; 128], &[1, 128]).unwrap(),
false,
);
let (_, _, _, _, forensic) =
model.fuse_forensic(None, None, None, Some((&iris, -1.0)), None);
assert_eq!(forensic.modality_reports.len(), 1);
assert_eq!(
forensic.modality_reports[0].modality,
BiometricModality::Iris
);
assert_eq!(forensic.dominant_modality, Some(BiometricModality::Iris));
}
#[test]
fn test_forensic_dimension_contributions() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let (_, _, _, _, forensic) =
model.fuse_forensic(Some((&face, -1.0)), None, None, None, None);
let dims = &forensic.top_contributing_dimensions;
assert!(!dims.is_empty());
for i in 1..dims.len() {
assert!(
dims[i - 1].contribution.abs() >= dims[i].contribution.abs(),
"Dimensions should be sorted by |contribution| desc"
);
}
}
#[test]
fn test_single_modality_high_confidence() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.9f32; 64], &[1, 64]).unwrap(), false);
let (fused, match_prob, confidence, _) = model.fuse(
Some((&face, -10.0)), None,
None,
None,
None,
);
assert_eq!(fused.shape(), &[1, 48]);
assert!(match_prob >= 0.0 && match_prob <= 1.0);
assert!(
confidence > 0.99,
"Extremely confident modality: conf={}",
confidence
);
}
#[test]
fn test_batch_processing_through_module() {
let model = ThemisFusion::new();
let input = Variable::new(
Tensor::from_vec(vec![0.1f32; 48 * 3], &[3, 48]).unwrap(),
false,
);
let output = Module::forward(&model, &input);
assert_eq!(output.shape(), &[3, 1]);
let data = output.data().to_vec();
for val in &data {
assert!(
*val >= 0.0 && *val <= 1.0,
"Sigmoid output should be [0,1]: {}",
val
);
}
}
#[test]
fn test_belief_state_convergence() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let mut belief: Option<Variable> = None;
let mut prev_belief_data: Option<Vec<f32>> = None;
let mut deltas = Vec::new();
for _ in 0..20 {
let (_, _, _, new_belief) =
model.fuse(Some((&face, -1.0)), None, None, None, belief.as_ref());
if let Some(prev) = &prev_belief_data {
let curr = new_belief.data().to_vec();
let delta: f32 = prev
.iter()
.zip(curr.iter())
.map(|(a, b)| (a - b).abs())
.sum();
deltas.push(delta);
}
prev_belief_data = Some(new_belief.data().to_vec());
belief = Some(new_belief);
}
if deltas.len() >= 6 {
let early_avg: f32 = deltas[..3].iter().sum::<f32>() / 3.0;
let late_avg: f32 = deltas[deltas.len() - 3..].iter().sum::<f32>() / 3.0;
assert!(
late_avg <= early_avg + 1e-3,
"Belief should converge (early_delta={}, late_delta={})",
early_avg,
late_avg
);
}
}
#[test]
fn test_reliability_affects_forensic_weights() {
let mut model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let finger = Variable::new(
Tensor::from_vec(vec![0.3f32; 128], &[1, 128]).unwrap(),
false,
);
let (_, _, _, _, forensic_full) =
model.fuse_forensic(Some((&face, -1.0)), Some((&finger, -1.0)), None, None, None);
let face_weight_full = forensic_full
.modality_reports
.iter()
.find(|r| r.modality == BiometricModality::Face)
.unwrap()
.fusion_weight;
for _ in 0..10 {
model.update_reliability(BiometricModality::Face, false, 0.3);
}
assert!(model.reliability(&BiometricModality::Face) < 0.2);
let (_, _, _, _, forensic_degraded) =
model.fuse_forensic(Some((&face, -1.0)), Some((&finger, -1.0)), None, None, None);
let face_weight_degraded = forensic_degraded
.modality_reports
.iter()
.find(|r| r.modality == BiometricModality::Face)
.unwrap()
.fusion_weight;
assert!(
face_weight_degraded < face_weight_full,
"Degraded reliability should lower fusion weight: {} -> {}",
face_weight_full,
face_weight_degraded
);
}
#[test]
fn test_decay_with_no_belief_state() {
let model = ThemisFusion::new();
let face = Variable::new(Tensor::from_vec(vec![0.5f32; 64], &[1, 64]).unwrap(), false);
let (fused, prob, conf, _) =
model.fuse_with_decay(Some((&face, -1.0)), None, None, None, None, 0.5);
assert_eq!(fused.shape(), &[1, 48]);
assert!(prob >= 0.0 && prob <= 1.0);
assert!(conf > 0.0);
}
#[test]
fn test_conflict_severity_proportional() {
let scores_close = vec![
(BiometricModality::Face, 0.80),
(BiometricModality::Fingerprint, 0.45),
];
let scores_far = vec![
(BiometricModality::Face, 0.95),
(BiometricModality::Fingerprint, 0.05),
];
let conflicts_close = ThemisFusion::detect_conflicts(&scores_close);
let conflicts_far = ThemisFusion::detect_conflicts(&scores_far);
assert_eq!(conflicts_close.len(), 1);
assert_eq!(conflicts_far.len(), 1);
assert!(
conflicts_far[0].severity > conflicts_close[0].severity,
"Larger disagreement should have higher severity: {} vs {}",
conflicts_far[0].severity,
conflicts_close[0].severity
);
}
#[test]
fn test_evidential_high_variance_high_aleatoric() {
let (al_low, _) = ThemisFusion::evidential_uncertainty(-2.0, 5);
let (al_high, _) = ThemisFusion::evidential_uncertainty(2.0, 5);
assert!(
al_high > al_low,
"Higher logvar should give higher aleatoric: {} vs {}",
al_high,
al_low
);
}
}