use std::hash::Hash;
use std::marker::PhantomData;
use numberer::Numberer;
use serde_derive::{Deserialize, Serialize};
use udgraph::graph::Sentence;
use crate::categorical::{ImmutableNumberer, MutableNumberer, Number};
use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
pub type ImmutableCategoricalEncoder<E, V> = CategoricalEncoder<E, V, ImmutableNumberer<V>>;
pub type MutableCategoricalEncoder<E, V> = CategoricalEncoder<E, V, MutableNumberer<V>>;
#[derive(Deserialize, Serialize)]
pub struct CategoricalEncoder<E, V, M>
where
V: Clone + Eq + Hash,
M: Number<V>,
{
inner: E,
numberer: M,
#[serde(skip)]
_phantom: PhantomData<V>,
}
impl<E, V, M> CategoricalEncoder<E, V, M>
where
V: Clone + Eq + Hash,
M: Number<V>,
{
pub fn new(encoder: E, numberer: Numberer<V>) -> Self {
CategoricalEncoder {
inner: encoder,
numberer: M::new(numberer),
_phantom: PhantomData,
}
}
}
impl<D, M> CategoricalEncoder<D, D::Encoding, M>
where
D: SentenceDecoder,
D::Encoding: Clone + Eq + Hash + ToOwned,
M: Number<D::Encoding>,
{
pub fn decode_without_inner<S>(&self, labels: &[S]) -> Vec<Vec<EncodingProb<D::Encoding>>>
where
S: AsRef<[EncodingProb<usize>]>,
{
labels
.iter()
.map(|encoding_probs| {
encoding_probs
.as_ref()
.iter()
.map(|encoding_prob| {
EncodingProb::new(
self.numberer
.value(*encoding_prob.encoding())
.expect("Unknown label"),
encoding_prob.prob(),
)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
}
impl<E, V, M> CategoricalEncoder<E, V, M>
where
V: Clone + Eq + Hash,
M: Number<V>,
{
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> usize {
self.numberer.len()
}
}
impl<E, M> SentenceEncoder for CategoricalEncoder<E, E::Encoding, M>
where
E: SentenceEncoder,
E::Encoding: Clone + Eq + Hash,
M: Number<E::Encoding>,
{
type Encoding = usize;
type Error = E::Error;
fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
let encoding = self.inner.encode(sentence)?;
let categorical_encoding = encoding
.into_iter()
.map(|e| self.numberer.number(e).unwrap_or(0))
.collect();
Ok(categorical_encoding)
}
}
impl<D, M> SentenceDecoder for CategoricalEncoder<D, D::Encoding, M>
where
D: SentenceDecoder,
D::Encoding: Clone + Eq + Hash,
M: Number<D::Encoding>,
{
type Encoding = usize;
type Error = D::Error;
fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
where
S: AsRef<[EncodingProb<Self::Encoding>]>,
{
let categorial_encoding = self.decode_without_inner(labels);
self.inner.decode(&categorial_encoding, sentence)
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use conllu::io::Reader;
use numberer::Numberer;
use super::{EncodingProb, MutableCategoricalEncoder, SentenceDecoder, SentenceEncoder};
use crate::layer::Layer;
use crate::layer::LayerEncoder;
static NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";
fn test_encoding<P, E, C>(path: P, encoder_decoder: E)
where
P: AsRef<Path>,
E: SentenceEncoder<Encoding = C> + SentenceDecoder<Encoding = C>,
C: 'static + Clone,
{
let f = File::open(path).unwrap();
let reader = Reader::new(BufReader::new(f));
for sentence in reader {
let sentence = sentence.unwrap();
let encodings = encoder_decoder
.encode(&sentence)
.unwrap()
.into_iter()
.map(|e| [EncodingProb::new(e, 1.)])
.collect::<Vec<_>>();
let mut test_sentence = sentence.clone();
encoder_decoder
.decode(&encodings, &mut test_sentence)
.unwrap();
assert_eq!(sentence, test_sentence);
}
}
#[test]
fn categorical_encoder() {
let numberer = Numberer::new(1);
let encoder = LayerEncoder::new(Layer::XPos);
let categorical_encoder = MutableCategoricalEncoder::new(encoder, numberer);
assert_eq!(categorical_encoder.len(), 1);
test_encoding(NON_PROJECTIVE_DATA, categorical_encoder);
}
}