orkhon 0.2.3

Machine Learning Inference Framework and Server Runtime
Documentation
use std::collections::HashMap;
use std::hash::Hash;
use crate::errors::*;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use tract_core::itertools::Itertools;

pub struct OrdinalEncoder<I>
where
    I: Eq + Hash + Clone + Debug + Ord
{
    seen: HashMap<I, usize>,
    categories: Vec<I>,
    encoded: Vec<usize>
}

impl<I> OrdinalEncoder<I>
where
    I: Eq + Hash + Clone + Debug + Ord
{
    pub fn new() -> Self {
        Self { seen: HashMap::default(), categories: Vec::default(), encoded: Vec::default() }
    }

    pub fn categories(&self) -> Vec<I> {
        self.categories.clone()
    }

    pub fn fit(mut self, x: Vec<I>) -> Self {
        self.seen.clear();
        self.encoded.clear();
        self.categories.clear();
        let mut x = x;
        x.sort();
        x.into_iter().for_each(|e| {
            match self.seen.entry(e.to_owned()) {
                Entry::Vacant(_) => {
                    let entry_key = self.encoded.len();
                    self.encoded.push(entry_key);
                    self.seen.insert(e, entry_key);
                }
                _ => {}
            }
        });
        let seen = self.seen.iter().sorted_by(|a, b| a.1.cmp(&b.1)).map(|e| e.0.clone()).collect::<Vec<I>>();
        self.categories = seen;

        self
    }

    pub fn transform(&self, x: Vec<I>) -> Result<Vec<usize>> {
        x.iter().map(|r| {
            self.seen.get(r).map_or_else(|| {
                Err(OrkhonError::Preprocessing(format!("Found unknown category '{:?}' during transform", r)))
            }, |i|
                Ok(i.to_owned())
            )
        }).collect()
    }

    pub fn inverse_transform(&self, y: Vec<usize>) -> Result<Vec<I>> {
        y.iter().map(|r| {
            self.encoded.iter().position(|e| e == r).map_or_else(|| {
                Err(OrkhonError::Preprocessing(format!("Found unknown label '{:?}' during transform", r)))
            }, |i|
               Ok(self.categories[i].to_owned())
            )
        }).collect()
    }
}

#[cfg(test)]
mod tests_preprocessing {
    use super::*;
    use serde::*;
    use serde_json::Value;

    #[derive(Deserialize)]
    struct EncoderBlock {
        data: Vec<i32>,
    }

    #[test]
    fn test_ordinal_encoder() {
        let data = vec!["Nonbinary", "Male", "Female", "Male", "Nonbinary"];
        let enc = OrdinalEncoder::new();

        let enc = enc.fit(data);
        assert_eq!(enc.categories(), vec!["Female", "Male", "Nonbinary"]);

        assert_eq!(enc.transform(vec!["Nonbinary", "Nonbinary", "Female"]).unwrap(), vec![2, 2, 0]);
        assert_eq!(enc.transform(vec!["Male", "Nonbinary", "Female"]).unwrap(), vec![1, 2, 0]);
    }

    #[test]
    fn test_ordinal_encoder_pi() {
        let data = vec!["Create Opp", "Log Activity: T", "Create L", "Log Activity: C", "Log Activity: E", "Assign L"];
        let enc = OrdinalEncoder::new();

        let enc = enc.fit(data.clone());

        assert_eq!(enc.transform(data).unwrap(), vec![2, 5, 1, 3, 4, 0]);
    }

    #[test]
    fn test_ordinal_encoder_long_i32() {
        let content = std::fs::read_to_string("testdata/ordinal_encoder_data.json").unwrap();
        let v: EncoderBlock = serde_json::from_str(&*content).unwrap();

        let data = v.data;
        let enc = OrdinalEncoder::<i32>::new();

        let enc = enc.fit(data.clone());
        assert_eq!(enc.categories(), vec![1, 2, 5, 7, 35, 36]);

        let tcontent = std::fs::read_to_string("testdata/ordinal_encoder_transformed.json").unwrap();
        let v: EncoderBlock = serde_json::from_str(&*tcontent).unwrap();
        assert_eq!(enc.transform(data).unwrap(), v.data.iter().map(|e| *e as usize).collect::<Vec<usize>>());
    }


    #[test]
    #[should_panic = "Preprocessing(\"Found unknown category \\\'\\\"CategoryNA\\\"\\\' during transform\""]
    fn test_ordinal_encoder_unknown_category() {
        let data = vec!["Nonbinary", "Male", "Female", "Male", "Nonbinary"];
        let enc = OrdinalEncoder::new();

        let enc = enc.fit(data);
        assert_eq!(enc.categories(), vec!["Female", "Male", "Nonbinary"]);

        enc.transform(vec!["CategoryNA", "Nonbinary", "Female"]).unwrap();
    }

    #[test]
    #[should_panic = "Preprocessing(\"Found unknown label \\\'4\\\' during transform\")"]
    fn test_ordinal_encoder_unknown_label() {
        let data = vec!["Nonbinary", "Male", "Female", "Male", "Nonbinary"];
        let enc = OrdinalEncoder::new();

        let enc = enc.fit(data);
        assert_eq!(enc.categories(), vec!["Female", "Male", "Nonbinary"]);

        enc.inverse_transform(vec![4, 5, 6]).unwrap();
    }
}