use crate::{Error, Sample, processors::MeanVar};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use num_traits::AsPrimitive;
use rayon::iter::{ParallelBridge, ParallelIterator};
use serde::{Deserialize, Serialize};
use std::{fs::File, iter::zip, ops::Add, path::Path};
pub fn snr<T, F>(
traces: ArrayView2<T>,
classes: usize,
get_class: F,
batch_size: usize,
) -> Array1<f32>
where
T: Sample + Copy + Sync,
<T as Sample>::Container: Send,
F: Fn(usize) -> usize + Sync,
{
assert!(batch_size > 0);
traces
.axis_chunks_iter(Axis(0), batch_size)
.enumerate()
.par_bridge()
.fold(
|| SnrProcessor::new(traces.shape()[1], classes),
|mut snr, (batch_idx, trace_batch)| {
for i in 0..trace_batch.shape()[0] {
snr.process(trace_batch.row(i), get_class(batch_idx * batch_size + i));
}
snr
},
)
.reduce_with(|a, b| a + b)
.unwrap()
.snr()
}
#[derive(Serialize, Deserialize)]
pub struct SnrProcessor<T>
where
T: Sample,
{
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
mean_var: MeanVar<T>,
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
classes_sum: Array2<<T as Sample>::Container>,
classes_count: Array1<usize>,
}
impl<T> SnrProcessor<T>
where
T: Sample + Copy,
{
pub fn new(size: usize, num_classes: usize) -> Self {
Self {
mean_var: MeanVar::new(size),
classes_sum: Array2::zeros((num_classes, size)),
classes_count: Array1::zeros(num_classes),
}
}
pub fn process(&mut self, trace: ArrayView1<T>, class: usize) {
debug_assert!(trace.len() == self.size());
debug_assert!(class < self.num_classes());
self.mean_var.process(trace);
for i in 0..self.size() {
self.classes_sum[[class, i]] += trace[i].into();
}
self.classes_count[class] += 1;
}
pub fn snr(&self) -> Array1<f32> {
let mean = self.mean_var.mean();
let mut velx = Array1::zeros(self.size());
for class in 0..self.num_classes() {
let class_count = self.classes_count[class];
if self.classes_count[class] == 0 {
continue;
}
let class_mean = self
.classes_sum
.row(class)
.mapv(|x| x.as_() / class_count as f32);
velx += &((class_mean - &mean).mapv(|d| d * d) * class_count as f32
/ self.mean_var.count() as f32);
}
let var = self.mean_var.var();
velx.clone() / (var - velx)
}
pub fn size(&self) -> usize {
self.classes_sum.shape()[1]
}
pub fn num_classes(&self) -> usize {
self.classes_count.len()
}
fn is_compatible_with(&self, other: &Self) -> bool {
self.size() == other.size() && self.num_classes() == other.num_classes()
}
}
impl<T> SnrProcessor<T>
where
T: Sample,
<T as Sample>::Container: Serialize,
{
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
let file = File::create(path)?;
serde_json::to_writer(file, self)?;
Ok(())
}
}
impl<T> SnrProcessor<T>
where
T: Sample,
<T as Sample>::Container: for<'de> Deserialize<'de>,
{
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let file = File::open(path)?;
let p = serde_json::from_reader(file)?;
Ok(p)
}
}
impl<T> Add for SnrProcessor<T>
where
T: Sample + Copy,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
debug_assert!(self.is_compatible_with(&rhs));
Self {
mean_var: self.mean_var + rhs.mean_var,
classes_sum: self.classes_sum + rhs.classes_sum,
classes_count: self.classes_count + rhs.classes_count,
}
}
}
pub fn nicv<T, F>(
traces: ArrayView2<T>,
classes: usize,
get_class: F,
batch_size: usize,
) -> Array1<f32>
where
T: Sample + Copy + Sync,
<T as Sample>::Container: Send,
F: Fn(usize) -> usize + Sync,
{
assert!(batch_size > 0);
traces
.axis_chunks_iter(Axis(0), batch_size)
.enumerate()
.par_bridge()
.fold(
|| NicvProcessor::new(traces.shape()[1], classes),
|mut nicv, (batch_idx, trace_batch)| {
for i in 0..trace_batch.shape()[0] {
nicv.process(trace_batch.row(i), get_class(batch_idx * batch_size + i));
}
nicv
},
)
.reduce_with(|a, b| a + b)
.unwrap()
.nicv()
}
#[derive(Serialize, Deserialize)]
pub struct NicvProcessor<T>
where
T: Sample,
{
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
mean_var: MeanVar<T>,
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
classes_sum: Array2<<T as Sample>::Container>,
classes_count: Array1<usize>,
}
impl<T> NicvProcessor<T>
where
T: Sample + Copy,
{
pub fn new(size: usize, num_classes: usize) -> Self {
Self {
mean_var: MeanVar::new(size),
classes_sum: Array2::zeros((num_classes, size)),
classes_count: Array1::zeros(num_classes),
}
}
pub fn process(&mut self, trace: ArrayView1<T>, class: usize) {
debug_assert!(trace.len() == self.size());
debug_assert!(class < self.num_classes());
self.mean_var.process(trace);
for i in 0..self.size() {
self.classes_sum[[class, i]] += trace[i].into();
}
self.classes_count[class] += 1;
}
pub fn nicv(&self) -> Array1<f32> {
let mean = self.mean_var.mean();
let mut velx = Array1::zeros(self.size());
for class in 0..self.num_classes() {
let class_count = self.classes_count[class];
if class_count == 0 {
continue;
}
let class_mean = self
.classes_sum
.row(class)
.mapv(|x| x.as_() / class_count as f32);
velx += &((class_mean - &mean).mapv(|d| d * d) * class_count as f32
/ self.mean_var.count() as f32);
}
velx / self.mean_var.var()
}
pub fn size(&self) -> usize {
self.classes_sum.shape()[1]
}
pub fn num_classes(&self) -> usize {
self.classes_count.len()
}
fn is_compatible_with(&self, other: &Self) -> bool {
self.size() == other.size() && self.num_classes() == other.num_classes()
}
}
impl<T> Add for NicvProcessor<T>
where
T: Sample + Copy,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
debug_assert!(self.is_compatible_with(&rhs));
Self {
mean_var: self.mean_var + rhs.mean_var,
classes_sum: self.classes_sum + rhs.classes_sum,
classes_count: self.classes_count + rhs.classes_count,
}
}
}
pub fn ttest<T>(
traces: ArrayView2<T>,
trace_classes: ArrayView1<bool>,
batch_size: usize,
) -> Array1<f32>
where
T: Sample + Copy + Sync,
<T as Sample>::Container: Send,
{
assert_eq!(traces.shape()[0], trace_classes.shape()[0]);
assert!(batch_size > 0);
zip(
traces.axis_chunks_iter(Axis(0), batch_size),
trace_classes.axis_chunks_iter(Axis(0), batch_size),
)
.par_bridge()
.fold(
|| TTestProcessor::new(traces.shape()[1]),
|mut ttest, (trace_batch, trace_classes_batch)| {
for i in 0..trace_batch.shape()[0] {
ttest.process(trace_batch.row(i), trace_classes_batch[i]);
}
ttest
},
)
.reduce_with(|a, b| a + b)
.unwrap()
.ttest()
}
#[derive(Serialize, Deserialize)]
pub struct TTestProcessor<T>
where
T: Sample,
{
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
mean_var_1: MeanVar<T>,
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
mean_var_2: MeanVar<T>,
}
impl<T> TTestProcessor<T>
where
T: Sample + Copy,
{
pub fn new(size: usize) -> Self {
Self {
mean_var_1: MeanVar::new(size),
mean_var_2: MeanVar::new(size),
}
}
pub fn process(&mut self, trace: ArrayView1<T>, class: bool) {
debug_assert!(trace.len() == self.size());
if class {
self.mean_var_2.process(trace);
} else {
self.mean_var_1.process(trace);
}
}
pub fn ttest(&self) -> Array1<f32> {
let q = self.mean_var_1.mean() - self.mean_var_2.mean();
let d = ((self.mean_var_1.var() / self.mean_var_1.count() as f32)
+ (self.mean_var_2.var() / self.mean_var_2.count() as f32))
.mapv(f32::sqrt);
q / d
}
pub fn size(&self) -> usize {
self.mean_var_1.size()
}
fn is_compatible_with(&self, other: &Self) -> bool {
self.size() == other.size()
}
}
impl<T> TTestProcessor<T>
where
T: Sample,
<T as Sample>::Container: Serialize,
{
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
let file = File::create(path)?;
serde_json::to_writer(file, self)?;
Ok(())
}
}
impl<T> TTestProcessor<T>
where
T: Sample,
<T as Sample>::Container: for<'de> Deserialize<'de>,
{
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let file = File::open(path)?;
let p = serde_json::from_reader(file)?;
Ok(p)
}
}
impl<T> Add for TTestProcessor<T>
where
T: Sample + Copy,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
debug_assert!(self.is_compatible_with(&rhs));
Self {
mean_var_1: self.mean_var_1 + rhs.mean_var_1,
mean_var_2: self.mean_var_2 + rhs.mean_var_2,
}
}
}
#[cfg(test)]
mod tests {
use super::{NicvProcessor, SnrProcessor, TTestProcessor, nicv, snr, ttest};
use ndarray::array;
#[test]
fn test_snr_helper() {
let traces = array![
[77, 137, 51, 91],
[72, 61, 91, 83],
[39, 49, 52, 23],
[26, 114, 63, 45],
[30, 8, 97, 91],
[13, 68, 7, 45],
[17, 181, 60, 34],
[43, 88, 76, 78],
[0, 36, 35, 0],
[93, 191, 49, 26],
];
let classes = [1, 3, 1, 2, 3, 2, 2, 1, 3, 1];
let mut processor = SnrProcessor::new(traces.shape()[1], 256);
for (trace, class) in std::iter::zip(traces.rows(), classes.iter()) {
processor.process(trace, *class);
}
assert_eq!(processor.snr(), snr(traces.view(), 256, |i| classes[i], 2));
}
#[test]
fn test_ttest() {
let traces = [
array![77, 137, 51, 91],
array![72, 61, 91, 83],
array![39, 49, 52, 23],
array![26, 114, 63, 45],
array![30, 8, 97, 91],
array![13, 68, 7, 45],
array![17, 181, 60, 34],
array![43, 88, 76, 78],
array![0, 36, 35, 0],
array![93, 191, 49, 26],
];
let mut processor = TTestProcessor::new(4);
for (i, trace) in traces.iter().enumerate() {
processor.process(trace.view(), i % 3 == 0);
}
assert_eq!(
processor.ttest(),
array![-1.0910345, -5.5249214, 0.29385296, 0.23308459]
);
}
#[test]
fn test_ttest_helper() {
let traces = array![
[77, 137, 51, 91],
[72, 61, 91, 83],
[39, 49, 52, 23],
[26, 114, 63, 45],
[30, 8, 97, 91],
[13, 68, 7, 45],
[17, 181, 60, 34],
[43, 88, 76, 78],
[0, 36, 35, 0],
[93, 191, 49, 26],
];
let trace_classes = array![
true, false, false, true, false, false, true, false, false, true
];
let mut processor = TTestProcessor::new(4);
for (i, trace) in traces.rows().into_iter().enumerate() {
processor.process(trace, trace_classes[i]);
}
assert_eq!(
processor.ttest(),
ttest(traces.view(), trace_classes.view(), 2)
);
}
#[test]
fn test_nicv_helper() {
let traces = array![
[77, 137, 51, 91],
[72, 61, 91, 83],
[39, 49, 52, 23],
[26, 114, 63, 45],
[30, 8, 97, 91],
[13, 68, 7, 45],
[17, 181, 60, 34],
[43, 88, 76, 78],
[0, 36, 35, 0],
[93, 191, 49, 26],
];
let classes = [1, 3, 1, 2, 3, 2, 2, 1, 3, 1];
let mut processor = NicvProcessor::new(traces.shape()[1], 256);
for (trace, class) in std::iter::zip(traces.rows(), classes.iter()) {
processor.process(trace, *class);
}
assert_eq!(
processor.nicv(),
nicv(traces.view(), 256, |i| classes[i], 2)
);
}
#[test]
fn test_nicv_bounds() {
let traces = array![
[10f32, 20f32, 30f32, 40f32],
[11f32, 19f32, 29f32, 41f32],
[9f32, 21f32, 31f32, 39f32],
[10.5f32, 20.5f32, 30.5f32, 40.5f32],
];
let classes = [0usize, 1usize, 0usize, 1usize];
let result = nicv(traces.view(), 2, |i| classes[i], 2);
for v in result.iter() {
assert!(*v >= 0.0 - 1e-7 && *v <= 1.0 + 1e-7);
}
}
}