use std::io;
use crate::attribute::Attribute;
use crate::context::{Context, Flag, Reset, ViterbiState};
use crate::dataset::{self, Instance, Item};
use crate::model::Model;
#[derive(Debug, Clone)]
pub struct Tagger<'a> {
model: &'a Model<'a>,
context: Context,
num_labels: u32,
}
impl<'a> Tagger<'a> {
pub(crate) fn new(model: &'a Model<'a>) -> io::Result<Self> {
let num_labels = model.num_labels();
let mut context = Context::new(Flag::VITERBI | Flag::MARGINALS, num_labels, 0);
context.reset(Reset::TRANS);
let mut tagger = Self {
model,
context,
num_labels,
};
tagger.transition_score()?;
tagger.context.exp_transition();
Ok(tagger)
}
pub fn tag<T: AsRef<[Attribute]>>(&self, xseq: &[T]) -> io::Result<Vec<&str>> {
if xseq.is_empty() {
return Ok(Vec::new());
}
let mut instance = Instance::with_capacity(xseq.len());
for item in xseq {
let item: Item = item
.as_ref()
.iter()
.filter_map(|x| {
self.model
.to_attr_id(&x.name)
.map(|id| dataset::Attribute::new(id, x.value))
})
.collect();
instance.push(item, 0);
}
let mut vstate = ViterbiState::new(self.num_labels, instance.num_items);
self.state_score(&instance, &mut vstate)?;
let (label_ids, _score) = self.context.viterbi(&mut vstate);
let mut labels = Vec::with_capacity(label_ids.len());
for id in label_ids {
let label = self.model.to_label(id).unwrap();
labels.push(label);
}
Ok(labels)
}
fn transition_score(&mut self) -> io::Result<()> {
let l = self.num_labels as usize;
for i in 0..l {
let trans = &mut self.context.trans[l * i..];
let edge = self.model.label_ref(i as u32)?;
for r in 0..edge.num_features {
let fid = edge.get(r as usize)?;
let feature = self.model.feature(fid)?;
let j = feature.target as usize;
trans[j] = feature.weight;
self.context.trans_t[l * j + i] = feature.weight;
}
}
Ok(())
}
fn state_score(&self, instance: &Instance, vstate: &mut ViterbiState) -> io::Result<()> {
let l = self.num_labels as usize;
for t in 0..instance.num_items as usize {
let item = &instance.items[t];
let state_slice = &mut vstate.state[l * t..];
for attr in item {
let id = attr.id;
let attr_ref = self.model.attr_ref(id)?;
let value = attr.value;
for r in 0..attr_ref.num_features as usize {
let fid = attr_ref.get(r)?;
let feature = self.model.feature(fid)?;
state_slice[feature.target as usize] += feature.weight * value;
}
}
}
Ok(())
}
}