use std::collections::HashMap;
pub struct Model {
pub n_classes: usize,
pub trees: Vec<Tree>,
scaler_mean: Vec<f64>,
scaler_scale: Vec<f64>,
tfidf_vocabulary: Vec<String>,
tfidf_idf: Vec<f64>,
pub class_labels: Vec<String>,
}
pub struct Tree {
nodes: Vec<Node>,
}
struct Node {
feature: i32, threshold: f64, left: i32, right: i32, }
#[derive(Debug)]
pub enum ModelError {
Truncated { expected: usize, actual: usize },
InvalidMagic,
InvalidHeader(String),
}
impl std::fmt::Display for ModelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Truncated { expected, actual } => {
write!(f, "model data truncated: need {expected} bytes, have {actual}")
}
Self::InvalidMagic => write!(f, "invalid model magic (expected XGBF)"),
Self::InvalidHeader(msg) => write!(f, "invalid model header: {msg}"),
}
}
}
impl Model {
pub fn from_bytes(data: &[u8]) -> Result<Self, ModelError> {
let mut pos = 0;
check_remaining(data, pos, 28)?;
if &data[pos..pos + 4] != b"XGBF" {
return Err(ModelError::InvalidMagic);
}
pos += 4;
let _version = read_u32(data, &mut pos);
let n_classes = read_u32(data, &mut pos) as usize;
let n_trees = read_u32(data, &mut pos) as usize;
let n_numeric = read_u32(data, &mut pos) as usize;
let n_tfidf = read_u32(data, &mut pos) as usize;
let _n_nodes_total = read_u32(data, &mut pos);
if n_classes == 0 || n_classes > 100 {
return Err(ModelError::InvalidHeader(format!("n_classes={n_classes}")));
}
if n_trees == 0 || n_trees > 100_000 {
return Err(ModelError::InvalidHeader(format!("n_trees={n_trees}")));
}
if n_numeric > 10_000 || n_tfidf > 10_000 {
return Err(ModelError::InvalidHeader(format!("n_numeric={n_numeric}, n_tfidf={n_tfidf}")));
}
check_remaining(data, pos, n_numeric * 8 * 2)?;
let mut scaler_mean = Vec::with_capacity(n_numeric);
for _ in 0..n_numeric {
scaler_mean.push(read_f64(data, &mut pos));
}
let mut scaler_scale = Vec::with_capacity(n_numeric);
for _ in 0..n_numeric {
scaler_scale.push(read_f64(data, &mut pos));
}
check_remaining(data, pos, n_tfidf * 4)?;
let mut vocab_lens = Vec::with_capacity(n_tfidf);
for _ in 0..n_tfidf {
vocab_lens.push(read_u32(data, &mut pos) as usize);
}
let total_vocab_bytes: usize = vocab_lens.iter().sum();
check_remaining(data, pos, total_vocab_bytes)?;
let mut tfidf_vocabulary = Vec::with_capacity(n_tfidf);
for &len in &vocab_lens {
let word = std::str::from_utf8(&data[pos..pos + len])
.unwrap_or("")
.to_string();
pos += len;
tfidf_vocabulary.push(word);
}
check_remaining(data, pos, n_tfidf * 8)?;
let mut tfidf_idf = Vec::with_capacity(n_tfidf);
for _ in 0..n_tfidf {
tfidf_idf.push(read_f64(data, &mut pos));
}
check_remaining(data, pos, n_classes * 4)?;
let mut label_lens = Vec::with_capacity(n_classes);
for _ in 0..n_classes {
label_lens.push(read_u32(data, &mut pos) as usize);
}
let total_label_bytes: usize = label_lens.iter().sum();
check_remaining(data, pos, total_label_bytes)?;
let mut class_labels = Vec::with_capacity(n_classes);
for &len in &label_lens {
let label = std::str::from_utf8(&data[pos..pos + len])
.unwrap_or("")
.to_string();
pos += len;
class_labels.push(label);
}
let mut trees = Vec::with_capacity(n_trees);
for _ in 0..n_trees {
check_remaining(data, pos, 4)?;
let n_nodes = read_u32(data, &mut pos) as usize;
if n_nodes > 1_000_000 {
return Err(ModelError::InvalidHeader(format!("tree with {n_nodes} nodes")));
}
check_remaining(data, pos, n_nodes * 20)?; let mut nodes = Vec::with_capacity(n_nodes);
for _ in 0..n_nodes {
let feature = read_i32(data, &mut pos);
let threshold = read_f64(data, &mut pos);
let left = read_i32(data, &mut pos);
let right = read_i32(data, &mut pos);
nodes.push(Node { feature, threshold, left, right });
}
trees.push(Tree { nodes });
}
Ok(Self {
n_classes,
trees,
scaler_mean,
scaler_scale,
tfidf_vocabulary,
tfidf_idf,
class_labels,
})
}
pub fn scale_features(&self, raw: &[f64]) -> Vec<f64> {
raw.iter()
.zip(self.scaler_mean.iter().zip(self.scaler_scale.iter()))
.map(|(&x, (&mean, &scale))| {
if scale > 0.0 { (x - mean) / scale } else { 0.0 }
})
.collect()
}
pub fn compute_tfidf(&self, text: &str) -> Vec<f64> {
let n_tfidf = self.tfidf_idf.len();
let mut result = vec![0.0f64; n_tfidf];
let text_lower = text.to_ascii_lowercase();
if text_lower.is_empty() {
return result;
}
let words: Vec<&str> = text_lower
.split(|c: char| !c.is_alphanumeric())
.filter(|w| !w.is_empty())
.collect();
if words.is_empty() {
return result;
}
let mut tf: HashMap<usize, u32> = HashMap::new();
for word in &words {
if let Some(idx) = self.tfidf_vocabulary.iter().position(|v| v == word) {
*tf.entry(idx).or_insert(0) += 1;
}
}
for (idx, vocab_word) in self.tfidf_vocabulary.iter().enumerate() {
if vocab_word.contains(' ') && text_lower.contains(vocab_word.as_str()) {
tf.entry(idx).or_insert(1);
}
}
let n_words = words.len() as f64;
for (idx, count) in tf {
let tf_val = (count as f64) / n_words;
result[idx] = tf_val * self.tfidf_idf[idx];
}
result
}
pub fn predict(&self, features: &[f64]) -> (usize, f64) {
let n_classes = self.n_classes;
let mut class_scores = vec![0.0f64; n_classes];
for (i, tree) in self.trees.iter().enumerate() {
let class_idx = i % n_classes;
let score = tree.evaluate(features);
class_scores[class_idx] += score;
}
let max_score = class_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exp_scores: Vec<f64> = class_scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f64 = exp_scores.iter().sum();
let probabilities: Vec<f64> = exp_scores.iter().map(|&e| e / sum_exp).collect();
let (best_idx, &best_prob) = probabilities
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, &0.0));
(best_idx, best_prob)
}
}
impl Tree {
fn evaluate(&self, features: &[f64]) -> f64 {
let max_steps = self.nodes.len();
let mut node_idx: usize = 0;
for _ in 0..max_steps {
if node_idx >= self.nodes.len() {
return 0.0; }
let node = &self.nodes[node_idx];
if node.feature < 0 {
return node.threshold; }
let feature_val = features.get(node.feature as usize).copied().unwrap_or(0.0);
if feature_val < node.threshold {
let next = node.left;
if next < 0 || next as usize >= self.nodes.len() {
return 0.0;
}
node_idx = next as usize;
} else {
let next = node.right;
if next < 0 || next as usize >= self.nodes.len() {
return 0.0;
}
node_idx = next as usize;
}
}
0.0 }
}
pub struct QualityModel {
trees: Vec<Tree>,
scaler_mean: Vec<f64>,
scaler_scale: Vec<f64>,
}
impl QualityModel {
pub fn from_bytes(data: &[u8]) -> Result<Self, ModelError> {
let mut pos = 0;
check_remaining(data, pos, 20)?;
if &data[pos..pos + 4] != b"XGBQ" {
return Err(ModelError::InvalidMagic);
}
pos += 4;
let _version = read_u32(data, &mut pos);
let n_trees = read_u32(data, &mut pos) as usize;
let n_features = read_u32(data, &mut pos) as usize;
let _n_nodes_total = read_u32(data, &mut pos);
if n_trees == 0 || n_trees > 10_000 || n_features > 10_000 {
return Err(ModelError::InvalidHeader(format!("n_trees={n_trees}, n_features={n_features}")));
}
check_remaining(data, pos, n_features * 8 * 2)?;
let mut scaler_mean = Vec::with_capacity(n_features);
for _ in 0..n_features {
scaler_mean.push(read_f64(data, &mut pos));
}
let mut scaler_scale = Vec::with_capacity(n_features);
for _ in 0..n_features {
scaler_scale.push(read_f64(data, &mut pos));
}
for _ in 0..n_features {
check_remaining(data, pos, 4)?;
let name_len = read_u32(data, &mut pos) as usize;
check_remaining(data, pos, name_len)?;
pos += name_len;
}
let mut trees = Vec::with_capacity(n_trees);
for _ in 0..n_trees {
check_remaining(data, pos, 4)?;
let n_nodes = read_u32(data, &mut pos) as usize;
if n_nodes > 1_000_000 {
return Err(ModelError::InvalidHeader(format!("tree with {n_nodes} nodes")));
}
check_remaining(data, pos, n_nodes * 20)?;
let mut nodes = Vec::with_capacity(n_nodes);
for _ in 0..n_nodes {
let feature = read_i32(data, &mut pos);
let threshold = read_f64(data, &mut pos);
let left = read_i32(data, &mut pos);
let right = read_i32(data, &mut pos);
nodes.push(Node { feature, threshold, left, right });
}
trees.push(Tree { nodes });
}
Ok(Self { trees, scaler_mean, scaler_scale })
}
pub fn scale_features(&self, raw: &[f64]) -> Vec<f64> {
raw.iter()
.zip(self.scaler_mean.iter().zip(self.scaler_scale.iter()))
.map(|(&x, (&mean, &scale))| {
if scale > 0.0 { (x - mean) / scale } else { 0.0 }
})
.collect()
}
pub fn predict(&self, features: &[f64]) -> f64 {
let mut score = 0.5;
for tree in &self.trees {
score += tree.evaluate(features);
}
score
}
}
fn check_remaining(data: &[u8], pos: usize, needed: usize) -> Result<(), ModelError> {
if pos + needed > data.len() {
Err(ModelError::Truncated {
expected: pos + needed,
actual: data.len(),
})
} else {
Ok(())
}
}
fn read_u32(data: &[u8], pos: &mut usize) -> u32 {
let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
*pos += 4;
val
}
fn read_i32(data: &[u8], pos: &mut usize) -> i32 {
let val = i32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
*pos += 4;
val
}
fn read_f64(data: &[u8], pos: &mut usize) -> f64 {
let val = f64::from_le_bytes([
data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3],
data[*pos + 4], data[*pos + 5], data[*pos + 6], data[*pos + 7],
]);
*pos += 8;
val
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_parse() {
let model = Model::from_bytes(crate::MODEL_BYTES).expect("model should parse");
assert_eq!(model.n_classes, 7);
assert_eq!(model.trees.len(), 1400);
assert_eq!(model.class_labels.len(), 7);
assert_eq!(model.scaler_mean.len(), 89);
assert_eq!(model.scaler_scale.len(), 89);
assert_eq!(model.tfidf_idf.len(), 100);
assert_eq!(model.tfidf_vocabulary.len(), 100);
}
#[test]
fn test_truncated_data() {
let result = Model::from_bytes(&[0u8; 10]);
assert!(result.is_err());
}
#[test]
fn test_invalid_magic() {
let mut data = vec![0u8; 100];
data[..4].copy_from_slice(b"NOPE");
assert!(matches!(Model::from_bytes(&data), Err(ModelError::InvalidMagic)));
}
#[test]
fn test_scale_features() {
let model = Model::from_bytes(crate::MODEL_BYTES).expect("parse");
let raw = vec![0.0f64; 81];
let scaled = model.scale_features(&raw);
assert_eq!(scaled.len(), 81);
}
#[test]
fn test_tfidf_unigram() {
let model = Model::from_bytes(crate::MODEL_BYTES).expect("parse");
let tfidf = model.compute_tfidf("forum discussion thread");
assert_eq!(tfidf.len(), 100);
if let Some(idx) = model.tfidf_vocabulary.iter().position(|w| w == "forum") {
assert!(tfidf[idx] > 0.0, "forum should have nonzero TF-IDF");
}
}
#[test]
fn test_tfidf_bigram() {
let model = Model::from_bytes(crate::MODEL_BYTES).expect("parse");
let tfidf = model.compute_tfidf("best practices for web development");
if let Some(idx) = model.tfidf_vocabulary.iter().position(|w| w == "best practices") {
assert!(tfidf[idx] > 0.0, "bigram 'best practices' should match");
}
}
#[test]
fn test_predict() {
let model = Model::from_bytes(crate::MODEL_BYTES).expect("parse");
let features = vec![0.0f64; 181];
let (class_idx, confidence) = model.predict(&features);
assert!(class_idx < 7);
assert!(confidence >= 0.0 && confidence <= 1.0);
}
#[test]
fn test_tree_evaluate_bounds() {
let tree = Tree { nodes: vec![] };
assert_eq!(tree.evaluate(&[0.0; 10]), 0.0);
}
}