1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/*!
This crate defines the structure of `.tangram` files using the `buffalo` crate.
*/

pub use self::{
	binary_classifier::*, features::*, grid::*, model_train_options::*, multiclass_classifier::*,
	regressor::*, stats::*,
};
use anyhow::{bail, Result};
use fnv::FnvHashMap;
use num::ToPrimitive;
use std::{convert::TryInto, io::prelude::*, path::Path};

mod binary_classifier;
mod features;
mod grid;
mod model_train_options;
mod multiclass_classifier;
mod regressor;
mod stats;

/// A .tangram file is prefixed with this magic number followed by a 4-byte little endian revision number.
const MAGIC_NUMBER: &[u8] = b"tangram\0";
/// This is the revision number that this version of tangram_model writes.
const CURRENT_REVISION: u32 = 0;
/// This is the oldest revision number that this version of tangram_model can read.
const MIN_SUPPORTED_REVISION: u32 = 0;

pub fn from_bytes(bytes: &[u8]) -> Result<ModelReader> {
	// Verify the magic number.
	let magic_number = &bytes[0..MAGIC_NUMBER.len()];
	if magic_number != MAGIC_NUMBER {
		bail!("This model did not start with the tangram magic number. Are you sure it is a .tangram file?");
	}
	let bytes = &bytes[MAGIC_NUMBER.len()..];
	let revision = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
	#[allow(clippy::absurd_extreme_comparisons)]
	if revision > CURRENT_REVISION {
		bail!("This model has a revision number of {}, which is greater than the revision number of {} used by this version of tangram. Your model is from the future! Please update to the latest version of tangram to use it.", revision, CURRENT_REVISION);
	}
	#[allow(clippy::absurd_extreme_comparisons)]
	if revision < MIN_SUPPORTED_REVISION {
		bail!("This model has a revision number of {}, which is lower than the minumum supported revision number of {} for this version of tangram. Please downgrade to an earlier version of tangram to use it.", revision, MIN_SUPPORTED_REVISION);
	}
	let bytes = &bytes[4..];
	let model = buffalo::read::<ModelReader>(bytes);
	Ok(model)
}

pub fn to_path(path: &Path, bytes: &[u8]) -> Result<()> {
	// Create the file.
	let mut file = std::fs::File::create(path)?;
	// Write the magic number.
	file.write_all(MAGIC_NUMBER)?;
	// Write the revision number.
	file.write_all(&CURRENT_REVISION.to_le_bytes())?;
	// Write the bytes.
	file.write_all(bytes)?;
	Ok(())
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct Model {
	#[buffalo(id = 0, required)]
	pub id: String,
	#[buffalo(id = 1, required)]
	pub version: String,
	#[buffalo(id = 2, required)]
	pub date: String,
	#[buffalo(id = 3, required)]
	pub inner: ModelInner,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "static", value_size = 8)]
#[allow(clippy::large_enum_variant)]
pub enum ModelInner {
	#[buffalo(id = 0)]
	Regressor(Regressor),
	#[buffalo(id = 1)]
	BinaryClassifier(BinaryClassifier),
	#[buffalo(id = 2)]
	MulticlassClassifier(MulticlassClassifier),
}

impl<'a> ColumnStatsReader<'a> {
	pub fn column_name(&self) -> &str {
		match &self {
			ColumnStatsReader::UnknownColumn(c) => c.read().column_name(),
			ColumnStatsReader::NumberColumn(c) => c.read().column_name(),
			ColumnStatsReader::EnumColumn(c) => c.read().column_name(),
			ColumnStatsReader::TextColumn(c) => c.read().column_name(),
		}
	}
}

impl<'a> std::fmt::Display for NGramReader<'a> {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		match self {
			NGramReader::Unigram(token) => {
				let token = token.read();
				write!(f, "{}", token)
			}
			NGramReader::Bigram(token) => {
				let token = token.read();
				write!(f, "{} {}", token.0, token.1)
			}
		}
	}
}

impl<'a> From<TokenizerReader<'a>> for tangram_text::Tokenizer {
	fn from(value: TokenizerReader<'a>) -> Self {
		tangram_text::Tokenizer {
			lowercase: value.lowercase(),
			alphanumeric: value.alphanumeric(),
		}
	}
}

impl<'a> From<NGramReader<'a>> for tangram_text::NGramRef<'a> {
	fn from(value: NGramReader<'a>) -> Self {
		match value {
			NGramReader::Unigram(token) => {
				let token = token.read();
				tangram_text::NGramRef::Unigram((*token).into())
			}
			NGramReader::Bigram(bigram) => {
				let bigram = bigram.read();
				tangram_text::NGramRef::Bigram(bigram.0.into(), bigram.1.into())
			}
		}
	}
}

impl<'a> From<WordEmbeddingModelReader<'a>> for tangram_text::WordEmbeddingModel {
	fn from(value: WordEmbeddingModelReader<'a>) -> Self {
		let size = value.size().to_usize().unwrap();
		let words = value
			.words()
			.iter()
			.map(|(word, index)| (word.to_owned(), index.to_usize().unwrap()))
			.collect::<FnvHashMap<_, _>>();
		let values = value.values().iter().collect();
		tangram_text::WordEmbeddingModel {
			size,
			words,
			values,
		}
	}
}