use std::collections::{HashMap, HashSet};
use bitvec::prelude::*;
use serde::{de::DeserializeOwned, Serialize};
use crate::{
encode::{Permute, SampleEncoder},
filter::{BuildFilter, Filter, PackedLUTFilter, PackedLUTFilterBuilder},
model::Discriminator,
sample::{Label, Sample},
};
#[derive(Clone, Debug)]
pub struct BinaryWisard<L: Label> {
base: WisardBase<L, PackedLUTFilter>,
seed: [u8; 32],
}
impl<L: Label> BinaryWisard<L> {
pub fn new(
input_size: usize,
addr_size: usize,
labels: HashSet<L>,
) -> Self {
Self::with_seed(input_size, addr_size, labels, rand::random())
}
pub fn with_seed(
input_size: usize,
addr_size: usize,
labels: HashSet<L>,
seed: [u8; 32],
) -> Self {
let builder = PackedLUTFilterBuilder::new(addr_size, 1, 0);
let base = WisardBase::from_filter_builder(
input_size, addr_size, labels, &builder,
);
Self { base, seed }
}
pub fn seed(&self) -> [u8; 32] {
self.seed
}
pub fn fit(&mut self, sample: &Sample<L>) {
let encoder = <Permute>::with_seed(self.seed);
let sample = encoder.encode(sample.clone());
self.base.fit(&sample)
}
pub fn scores(&self, sample: &Sample<L>) -> Vec<(usize, L)> {
let encoder = <Permute>::with_seed(self.seed);
let sample = encoder.encode(sample.clone());
self.base.scores(&sample)
}
pub fn predict(&self, sample: &Sample<L>) -> L {
let encoder = <Permute>::with_seed(self.seed);
let sample = encoder.encode(sample.clone());
self.base.predict(&sample)
}
}
#[derive(Clone, Debug)]
pub struct WisardBase<L, F>
where
L: Label,
F: Filter,
{
disc: HashMap<L, Discriminator<F>>,
}
impl<L, F> WisardBase<L, F>
where
L: Label,
F: Filter,
{
pub fn from_filter_builder<B>(
input_size: usize,
addr_size: usize,
labels: HashSet<L>,
builder: &B,
) -> Self
where
B: BuildFilter<Filter = F>,
{
Self {
disc: labels
.into_iter()
.map(|label| {
(
label,
Discriminator::from_filter_builder(
input_size, addr_size, builder,
),
)
})
.collect(),
}
}
pub fn fit<T, O>(&mut self, sample: &Sample<L, T, O>)
where
T: BitStore + Clone + DeserializeOwned,
T::Mem: Serialize,
O: BitOrder + Clone,
{
self.disc.get_mut(sample.label()).unwrap().fit(sample)
}
pub fn scores<T, O>(&self, sample: &Sample<L, T, O>) -> Vec<(usize, L)>
where
T: BitStore + Clone + DeserializeOwned,
T::Mem: Serialize,
O: BitOrder + Clone,
{
self.disc
.keys()
.map(|label| (self.disc[label].score(sample), *label))
.collect()
}
pub fn predict<T, O>(&self, sample: &Sample<L, T, O>) -> L
where
T: BitStore + Clone + DeserializeOwned,
T::Mem: Serialize,
O: BitOrder + Clone,
{
self.scores(sample)
.into_iter()
.max_by(|a, b| a.0.cmp(&b.0))
.unwrap()
.1
}
}
#[cfg(test)]
mod tests {
use bitvec::prelude::*;
use serde::Deserialize;
use crate::sample::Sample;
use super::*;
#[test]
fn binary_wisard_hot_cold() {
#[derive(
Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize,
)]
enum Label {
Cold,
Hot,
}
let input_size = 8;
let addr_size = 2;
let labels = HashSet::from_iter([Label::Cold, Label::Hot].into_iter());
let mut model = BinaryWisard::new(input_size, addr_size, labels);
let samples = vec![
(bitvec![1, 1, 1, 0, 0, 0, 0, 0], Label::Cold),
(bitvec![1, 1, 1, 1, 0, 0, 0, 0], Label::Cold),
(bitvec![0, 0, 0, 0, 1, 1, 1, 1], Label::Hot),
(bitvec![0, 0, 0, 0, 0, 1, 1, 1], Label::Hot),
];
let encoded_samples = samples
.into_iter()
.map(|(v, l)| Sample::from_raw_parts(v, addr_size, l))
.collect::<Vec<_>>();
for sample in encoded_samples.iter() {
model.fit(sample);
}
for sample in encoded_samples.iter() {
let pred = model.predict(sample);
assert_eq!(&pred, sample.label());
}
}
}