use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use num_traits::AsPrimitive;
use rayon::iter::{ParallelBridge, ParallelIterator};
use serde::{Deserialize, Serialize};
use std::{fs::File, iter::zip, marker::PhantomData, ops::Add, path::Path};
use crate::{
Error, Sample,
util::{argmax_by, argsort_by, max_per_row},
};
pub fn dpa<T, M, F>(
traces: ArrayView2<T>,
metadata: ArrayView1<M>,
guess_range: usize,
selection_function: F,
batch_size: usize,
) -> Dpa
where
T: Sample + Copy + Sync,
<T as Sample>::Container: Send,
M: Clone + Send + Sync,
F: Fn(M, usize) -> bool + Send + Sync + Copy,
{
assert!(batch_size > 0);
zip(
traces.axis_chunks_iter(Axis(0), batch_size),
metadata.axis_chunks_iter(Axis(0), batch_size),
)
.par_bridge()
.fold(
|| DpaProcessor::new(traces.shape()[1], guess_range),
|mut dpa, (trace_batch, metadata_batch)| {
dpa.batch_update(trace_batch, metadata_batch, &selection_function);
dpa
},
)
.reduce_with(|a, b| a + b)
.unwrap()
.finalize()
}
#[derive(Debug)]
pub struct Dpa {
differential_curves: Array2<f32>,
}
impl Dpa {
pub fn rank(&self) -> Array1<usize> {
let rank = argsort_by(&self.max_differential_curves().to_vec()[..], f32::total_cmp);
Array1::from_vec(rank)
}
pub fn differential_curves(&self) -> ArrayView2<'_, f32> {
self.differential_curves.view()
}
pub fn best_guess(&self) -> usize {
argmax_by(self.max_differential_curves().view(), f32::total_cmp)
}
pub fn max_differential_curves(&self) -> Array1<f32> {
max_per_row(self.differential_curves.view())
}
}
#[derive(Serialize, Deserialize)]
pub struct DpaProcessor<T, M>
where
T: Sample,
{
num_samples: usize,
guess_range: usize,
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
sum_0: Array2<<T as Sample>::Container>,
#[serde(bound(serialize = "<T as Sample>::Container: Serialize"))]
#[serde(bound(deserialize = "<T as Sample>::Container: Deserialize<'de>"))]
sum_1: Array2<<T as Sample>::Container>,
count_0: Array1<usize>,
count_1: Array1<usize>,
num_traces: usize,
_metadata: PhantomData<M>,
}
impl<T, M> DpaProcessor<T, M>
where
T: Sample + Copy,
M: Clone,
{
pub fn new(num_samples: usize, guess_range: usize) -> Self {
Self {
num_samples,
guess_range,
sum_0: Array2::zeros((guess_range, num_samples)),
sum_1: Array2::zeros((guess_range, num_samples)),
count_0: Array1::zeros(guess_range),
count_1: Array1::zeros(guess_range),
num_traces: 0,
_metadata: PhantomData,
}
}
pub fn update<F>(&mut self, trace: ArrayView1<T>, metadata: M, selection_function: F)
where
F: Fn(M, usize) -> bool,
{
debug_assert_eq!(trace.shape()[0], self.num_samples);
for guess in 0..self.guess_range {
if selection_function(metadata.clone(), guess) {
for i in 0..self.num_samples {
self.sum_1[[guess, i]] += trace[i].into();
}
self.count_1[guess] += 1;
} else {
for i in 0..self.num_samples {
self.sum_0[[guess, i]] += trace[i].into();
}
self.count_0[guess] += 1;
}
}
self.num_traces += 1;
}
pub fn batch_update<F>(
&mut self,
trace_batch: ArrayView2<T>,
metadata_batch: ArrayView1<M>,
selection_function: &F,
) where
F: Fn(M, usize) -> bool,
{
debug_assert_eq!(trace_batch.shape()[0], metadata_batch.shape()[0]);
for (trace, metadata) in zip(trace_batch.rows(), metadata_batch.iter()) {
self.update(trace, metadata.clone(), selection_function);
}
}
pub fn finalize(&self) -> Dpa {
let mut differential_curves = Array2::zeros((self.guess_range, self.num_samples));
for guess in 0..self.guess_range {
for i in 0..self.num_samples {
let mean_0 = self.sum_0[[guess, i]].as_() / self.count_0[guess] as f32;
let mean_1 = self.sum_1[[guess, i]].as_() / self.count_1[guess] as f32;
differential_curves[[guess, i]] = f32::abs(mean_0 - mean_1);
}
}
Dpa {
differential_curves,
}
}
fn is_compatible_with(&self, other: &Self) -> bool {
self.num_samples == other.num_samples && self.guess_range == other.guess_range
}
}
impl<T, M> DpaProcessor<T, M>
where
T: Sample,
<T as Sample>::Container: Serialize,
{
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
let file = File::create(path)?;
serde_json::to_writer(file, self)?;
Ok(())
}
}
impl<T, M> DpaProcessor<T, M>
where
T: Sample,
<T as Sample>::Container: for<'de> Deserialize<'de>,
{
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let file = File::open(path)?;
let p: DpaProcessor<T, M> = serde_json::from_reader(file)?;
Ok(p)
}
}
impl<T, M> Add for DpaProcessor<T, M>
where
T: Sample + Copy,
M: Clone,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
debug_assert!(self.is_compatible_with(&rhs));
Self {
num_samples: self.num_samples,
guess_range: self.guess_range,
sum_0: self.sum_0 + rhs.sum_0,
sum_1: self.sum_1 + rhs.sum_1,
count_0: self.count_0 + rhs.count_0,
count_1: self.count_1 + rhs.count_1,
num_traces: self.num_traces + rhs.num_traces,
_metadata: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::{DpaProcessor, dpa};
use ndarray::{Array1, ArrayView1, array};
use serde::Deserialize;
#[test]
fn test_dpa_helper() {
let traces = array![
[77usize, 137, 51, 91],
[72, 61, 91, 83],
[39, 49, 52, 23],
[26, 114, 63, 45],
[30, 8, 97, 91],
[13, 68, 7, 45],
[17, 181, 60, 34],
[43, 88, 76, 78],
[0, 36, 35, 0],
[93, 191, 49, 26],
];
let plaintexts = array![[1], [3], [1], [2], [3], [2], [2], [1], [3], [1]];
let selection_function =
|plaintext: ArrayView1<u8>, guess| (plaintext[0] as usize ^ guess) & 1 == 1;
let mut processor = DpaProcessor::new(traces.shape()[1], 256);
for i in 0..traces.shape()[0] {
processor.update(
traces.row(i).map(|&x| x as f32).view(),
plaintexts.row(i),
selection_function,
);
}
assert_eq!(
processor.finalize().differential_curves(),
dpa(
traces.view().map(|&x| x as f32).view(),
plaintexts
.rows()
.into_iter()
.collect::<Array1<ArrayView1<u8>>>()
.view(),
256,
selection_function,
2
)
.differential_curves()
);
}
#[test]
fn test_serialize_deserialize_processor() {
let traces = array![
[77usize, 137, 51, 91],
[72, 61, 91, 83],
[39, 49, 52, 23],
[26, 114, 63, 45],
[30, 8, 97, 91],
[13, 68, 7, 45],
[17, 181, 60, 34],
[43, 88, 76, 78],
[0, 36, 35, 0],
[93, 191, 49, 26],
];
let plaintexts = array![[1], [3], [1], [2], [3], [2], [2], [1], [3], [1]];
let selection_function =
|plaintext: ArrayView1<u8>, guess| (plaintext[0] as usize ^ guess) & 1 == 1;
let mut processor = DpaProcessor::new(traces.shape()[1], 256);
for i in 0..traces.shape()[0] {
processor.update(
traces.row(i).map(|&x| x as f32).view(),
plaintexts.row(i),
selection_function,
);
}
let serialized = serde_json::to_string(&processor).unwrap();
let mut deserializer = serde_json::Deserializer::from_str(serialized.as_str());
let restored_processor =
DpaProcessor::<f32, ArrayView1<u8>>::deserialize(&mut deserializer).unwrap();
assert_eq!(processor.num_samples, restored_processor.num_samples);
assert_eq!(processor.guess_range, restored_processor.guess_range);
assert_eq!(processor.sum_0, restored_processor.sum_0);
assert_eq!(processor.sum_1, restored_processor.sum_1);
assert_eq!(processor.count_0, restored_processor.count_0);
assert_eq!(processor.count_1, restored_processor.count_1);
assert_eq!(processor.num_traces, restored_processor.num_traces);
}
}