use std::fs::File;
use std::hash::Hash;
use conllx::graph::Sentence;
use failure::{Fallible, ResultExt};
use crate::config::{Config, EncoderType, LabelerType};
use crate::serialization::CborRead;
use sticker::depparse::{RelativePOSEncoder, RelativePositionEncoder};
use sticker::tensorflow::{Tagger, TaggerGraph};
use sticker::Tag;
use sticker::{CategoricalEncoder, LayerEncoder, Numberer, SentVectorizer, SentenceDecoder};
trait TagRef {
fn tag_sentences_ref(&self, sentences: &mut [&mut Sentence]) -> Fallible<()>;
}
impl<T> TagRef for T
where
T: Tag,
{
fn tag_sentences_ref(&self, sentences: &mut [&mut Sentence]) -> Fallible<()> {
self.tag_sentences(sentences)
}
}
pub struct TaggerWrapper {
inner: Box<TagRef + Send + Sync>,
}
impl TaggerWrapper {
pub fn new(config: &Config) -> Fallible<Self> {
let embeddings = config
.embeddings
.load_embeddings()
.with_context(|e| format!("Cannot load embeddings: {}", e))?;
let vectorizer = SentVectorizer::new(embeddings);
let graph_reader = File::open(&config.model.graph).with_context(|e| {
format!(
"Cannot open computation graph '{}' for reading: {}",
&config.model.graph, e
)
})?;
let graph = TaggerGraph::load_graph(graph_reader, &config.model)
.with_context(|e| format!("Cannot load computation graph: {}", e))?;
match config.labeler.labeler_type {
LabelerType::Sequence(ref layer) => {
Self::new_with_decoder(&config, vectorizer, graph, LayerEncoder::new(layer.clone()))
}
LabelerType::Parser(EncoderType::RelativePOS) => {
Self::new_with_decoder(&config, vectorizer, graph, RelativePOSEncoder)
}
LabelerType::Parser(EncoderType::RelativePosition) => {
Self::new_with_decoder(&config, vectorizer, graph, RelativePositionEncoder)
}
}
}
fn new_with_decoder<D>(
config: &Config,
vectorizer: SentVectorizer,
graph: TaggerGraph,
decoder: D,
) -> Fallible<Self>
where
D: 'static + Send + SentenceDecoder + Sync,
D::Encoding: Clone + Eq + Hash + Send + Sync,
Numberer<D::Encoding>: CborRead,
{
let labels = config.labeler.load_labels().with_context(|e| {
format!("Cannot load label file '{}': {}", config.labeler.labels, e)
})?;
let categorical_decoder = CategoricalEncoder::new(decoder, labels);
let tagger = Tagger::load_weights(
graph,
categorical_decoder,
vectorizer,
&config.model.parameters,
)
.with_context(|e| format!("Cannot construct tagger: {}", e))?;
Ok(TaggerWrapper {
inner: Box::new(tagger),
})
}
pub fn tag_sentences(&self, sentences: &mut [&mut Sentence]) -> Fallible<()> {
self.inner.tag_sentences_ref(sentences)
}
}