use std::collections::BTreeMap;
use std::io;
use super::dictionary::Dictionary;
use crate::dataset::Instance;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FeatureType {
State = 0,
Transition = 1,
}
#[derive(Debug, Clone)]
pub struct Feature {
pub ftype: FeatureType,
pub src: u32,
pub dst: u32,
pub weight: f64,
}
#[derive(Debug, Clone, Default)]
pub struct FeatureRefs {
pub fids: Vec<u32>,
}
pub struct FeatureGenerator {
pub features: Vec<Feature>,
pub attr_refs: Vec<FeatureRefs>,
pub label_refs: Vec<FeatureRefs>,
}
impl FeatureGenerator {
pub fn generate(
instances: &[Instance],
attrs: &Dictionary,
labels: &Dictionary,
min_freq: f64,
) -> io::Result<Self> {
let num_labels = labels.len();
let num_attrs = attrs.len();
let mut state_counts: BTreeMap<(u32, u32, u32), f64> = BTreeMap::new();
let mut trans_counts: BTreeMap<(u32, u32, u32), f64> = BTreeMap::new();
for inst in instances {
let seq_len = inst.num_items as usize;
let inst_weight = inst.weight;
for t in 0..seq_len {
let label = inst.labels[t];
for attr in &inst.items[t] {
let key = (FeatureType::State as u32, attr.id, label);
*state_counts.entry(key).or_insert(0.0) += attr.value * inst_weight;
}
}
for t in 1..seq_len {
let prev_label = inst.labels[t - 1];
let label = inst.labels[t];
let key = (FeatureType::Transition as u32, prev_label, label);
*trans_counts.entry(key).or_insert(0.0) += inst_weight;
}
}
let mut features = Vec::new();
let mut attr_refs = vec![FeatureRefs::default(); num_attrs];
let mut label_refs = vec![FeatureRefs::default(); num_labels];
for ((_, aid, lid), freq) in state_counts {
if freq >= min_freq {
let fid = u32::try_from(features.len())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "too many features"))?;
features.push(Feature {
ftype: FeatureType::State,
src: aid,
dst: lid,
weight: 0.0,
});
attr_refs[aid as usize].fids.push(fid);
}
}
for ((_, prev_lid, lid), freq) in trans_counts {
if freq >= min_freq {
let fid = u32::try_from(features.len())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "too many features"))?;
features.push(Feature {
ftype: FeatureType::Transition,
src: prev_lid,
dst: lid,
weight: 0.0,
});
label_refs[prev_lid as usize].fids.push(fid);
}
}
Ok(Self {
features,
attr_refs,
label_refs,
})
}
pub fn num_features(&self) -> usize {
self.features.len()
}
pub fn set_weights(&mut self, weights: &[f64]) {
assert_eq!(
weights.len(),
self.features.len(),
"weights length ({}) must equal number of features ({})",
weights.len(),
self.features.len()
);
for (i, feature) in self.features.iter_mut().enumerate() {
feature.weight = weights[i];
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::Attribute;
#[test]
fn test_feature_generation() {
let mut attrs = Dictionary::new();
let mut labels = Dictionary::new();
let walk_id = attrs.get_or_insert("walk");
let shop_id = attrs.get_or_insert("shop");
let sunny_id = labels.get_or_insert("sunny");
let rainy_id = labels.get_or_insert("rainy");
let mut inst = Instance::with_capacity(3);
inst.push(vec![Attribute::new(walk_id, 1.0)], sunny_id);
inst.push(vec![Attribute::new(shop_id, 1.0)], sunny_id);
inst.push(vec![Attribute::new(walk_id, 1.0)], rainy_id);
let instances = vec![inst];
let fgen = FeatureGenerator::generate(&instances, &attrs, &labels, 0.0).unwrap();
assert!(fgen.num_features() > 0);
let has_state = fgen.features.iter().any(|f| f.ftype == FeatureType::State);
let has_trans = fgen
.features
.iter()
.any(|f| f.ftype == FeatureType::Transition);
assert!(has_state);
assert!(has_trans);
}
}