pub use self::{
binary_classifier::*, features::*, grid::*, model_train_options::*, multiclass_classifier::*,
regressor::*, stats::*,
};
use anyhow::{bail, Result};
use fnv::FnvHashMap;
use num::ToPrimitive;
use std::{convert::TryInto, io::prelude::*, path::Path};
mod binary_classifier;
mod features;
mod grid;
mod model_train_options;
mod multiclass_classifier;
mod regressor;
mod stats;
const MAGIC_NUMBER: &[u8] = b"tangram\0";
const CURRENT_REVISION: u32 = 0;
const MIN_SUPPORTED_REVISION: u32 = 0;
pub fn from_bytes(bytes: &[u8]) -> Result<ModelReader> {
let magic_number = &bytes[0..MAGIC_NUMBER.len()];
if magic_number != MAGIC_NUMBER {
bail!("This model did not start with the tangram magic number. Are you sure it is a .tangram file?");
}
let bytes = &bytes[MAGIC_NUMBER.len()..];
let revision = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
#[allow(clippy::absurd_extreme_comparisons)]
if revision > CURRENT_REVISION {
bail!("This model has a revision number of {}, which is greater than the revision number of {} used by this version of tangram. Your model is from the future! Please update to the latest version of tangram to use it.", revision, CURRENT_REVISION);
}
#[allow(clippy::absurd_extreme_comparisons)]
if revision < MIN_SUPPORTED_REVISION {
bail!("This model has a revision number of {}, which is lower than the minumum supported revision number of {} for this version of tangram. Please downgrade to an earlier version of tangram to use it.", revision, MIN_SUPPORTED_REVISION);
}
let bytes = &bytes[4..];
let model = buffalo::read::<ModelReader>(bytes);
Ok(model)
}
pub fn to_path(path: &Path, bytes: &[u8]) -> Result<()> {
let mut file = std::fs::File::create(path)?;
file.write_all(MAGIC_NUMBER)?;
file.write_all(&CURRENT_REVISION.to_le_bytes())?;
file.write_all(bytes)?;
Ok(())
}
#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct Model {
#[buffalo(id = 0, required)]
pub id: String,
#[buffalo(id = 1, required)]
pub version: String,
#[buffalo(id = 2, required)]
pub date: String,
#[buffalo(id = 3, required)]
pub inner: ModelInner,
}
#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "static", value_size = 8)]
#[allow(clippy::large_enum_variant)]
pub enum ModelInner {
#[buffalo(id = 0)]
Regressor(Regressor),
#[buffalo(id = 1)]
BinaryClassifier(BinaryClassifier),
#[buffalo(id = 2)]
MulticlassClassifier(MulticlassClassifier),
}
impl<'a> ColumnStatsReader<'a> {
pub fn column_name(&self) -> &str {
match &self {
ColumnStatsReader::UnknownColumn(c) => c.read().column_name(),
ColumnStatsReader::NumberColumn(c) => c.read().column_name(),
ColumnStatsReader::EnumColumn(c) => c.read().column_name(),
ColumnStatsReader::TextColumn(c) => c.read().column_name(),
}
}
}
impl<'a> std::fmt::Display for NGramReader<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NGramReader::Unigram(token) => {
let token = token.read();
write!(f, "{}", token)
}
NGramReader::Bigram(token) => {
let token = token.read();
write!(f, "{} {}", token.0, token.1)
}
}
}
}
impl<'a> From<TokenizerReader<'a>> for tangram_text::Tokenizer {
fn from(value: TokenizerReader<'a>) -> Self {
tangram_text::Tokenizer {
lowercase: value.lowercase(),
alphanumeric: value.alphanumeric(),
}
}
}
impl<'a> From<NGramReader<'a>> for tangram_text::NGramRef<'a> {
fn from(value: NGramReader<'a>) -> Self {
match value {
NGramReader::Unigram(token) => {
let token = token.read();
tangram_text::NGramRef::Unigram((*token).into())
}
NGramReader::Bigram(bigram) => {
let bigram = bigram.read();
tangram_text::NGramRef::Bigram(bigram.0.into(), bigram.1.into())
}
}
}
}
impl<'a> From<WordEmbeddingModelReader<'a>> for tangram_text::WordEmbeddingModel {
fn from(value: WordEmbeddingModelReader<'a>) -> Self {
let size = value.size().to_usize().unwrap();
let words = value
.words()
.iter()
.map(|(word, index)| (word.to_owned(), index.to_usize().unwrap()))
.collect::<FnvHashMap<_, _>>();
let values = value.values().iter().collect();
tangram_text::WordEmbeddingModel {
size,
words,
values,
}
}
}