mod model;
pub mod url_heuristics;
use std::sync::OnceLock;
pub use url_heuristics::classify_url;
pub const N_NUMERIC_FEATURES: usize = 89;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum PageType {
Article,
Collection,
Documentation,
Forum,
Listing,
Product,
Service,
}
impl PageType {
#[must_use]
pub fn parse(s: &str) -> Option<Self> {
match s.to_ascii_lowercase().as_str() {
"article" => Some(Self::Article),
"collection" | "category" => Some(Self::Collection),
"documentation" | "docs" => Some(Self::Documentation),
"forum" => Some(Self::Forum),
"listing" => Some(Self::Listing),
"product" => Some(Self::Product),
"service" => Some(Self::Service),
_ => None,
}
}
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Article => "article",
Self::Collection => "collection",
Self::Documentation => "documentation",
Self::Forum => "forum",
Self::Listing => "listing",
Self::Product => "product",
Self::Service => "service",
}
}
}
impl std::fmt::Display for PageType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for PageType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::parse(s).ok_or_else(|| format!("unknown page type: {s}"))
}
}
pub const N_QUALITY_FEATURES: usize = 27;
static MODEL_BYTES: &[u8] = include_bytes!("xgboost_v2.bin");
static QUALITY_MODEL_BYTES: &[u8] = include_bytes!("quality_model_v1.bin");
static MODEL: OnceLock<model::Model> = OnceLock::new();
static QUALITY_MODEL: OnceLock<model::QualityModel> = OnceLock::new();
fn get_model() -> &'static model::Model {
MODEL.get_or_init(|| {
model::Model::from_bytes(MODEL_BYTES)
.expect("embedded classifier model is valid")
})
}
fn get_quality_model() -> &'static model::QualityModel {
QUALITY_MODEL.get_or_init(|| {
model::QualityModel::from_bytes(QUALITY_MODEL_BYTES)
.expect("embedded quality model is valid")
})
}
#[must_use]
pub fn classify_ml(numeric_features: &[f64], title_meta: &str) -> (PageType, f64) {
assert_eq!(
numeric_features.len(),
N_NUMERIC_FEATURES,
"Expected {} numeric features, got {}",
N_NUMERIC_FEATURES,
numeric_features.len()
);
let m = get_model();
let scaled = m.scale_features(numeric_features);
let tfidf = m.compute_tfidf(title_meta);
let mut all_features = Vec::with_capacity(scaled.len() + tfidf.len());
all_features.extend_from_slice(&scaled);
all_features.extend_from_slice(&tfidf);
let (class_idx, confidence) = m.predict(&all_features);
let page_type = m.class_labels.get(class_idx)
.and_then(|s| PageType::parse(s))
.unwrap_or(PageType::Article);
(page_type, confidence)
}
#[must_use]
pub fn predict_quality(features: &[f64]) -> f64 {
assert_eq!(
features.len(),
N_QUALITY_FEATURES,
"Expected {} quality features, got {}",
N_QUALITY_FEATURES,
features.len()
);
let m = get_quality_model();
let scaled = m.scale_features(features);
let predicted = m.predict(&scaled);
predicted.clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_loads() {
let m = get_model();
assert_eq!(m.n_classes, 7);
assert!(!m.trees.is_empty());
}
#[test]
fn test_classify_ml_returns_valid_type() {
let features = vec![0.0f64; N_NUMERIC_FEATURES];
let (page_type, confidence) = classify_ml(&features, "Example blog post about technology");
assert!(confidence >= 0.0 && confidence <= 1.0);
assert_eq!(page_type.as_str().is_empty(), false);
}
#[test]
fn test_classify_url_basic() {
assert_eq!(classify_url("https://forum.example.com/thread/123"), PageType::Forum);
assert_eq!(classify_url("https://docs.example.com/api"), PageType::Documentation);
assert_eq!(classify_url("https://example.com/products/widget"), PageType::Product);
}
#[test]
fn test_page_type_display() {
assert_eq!(PageType::Article.to_string(), "article");
assert_eq!(PageType::Forum.as_str(), "forum");
}
#[test]
fn test_page_type_parse() {
assert_eq!(PageType::parse("article"), Some(PageType::Article));
assert_eq!(PageType::parse("FORUM"), Some(PageType::Forum));
assert_eq!(PageType::parse("category"), Some(PageType::Collection));
assert_eq!(PageType::parse("unknown"), None);
}
#[test]
fn test_quality_model_loads() {
let _ = get_quality_model();
}
#[test]
fn test_predict_quality() {
let features = vec![0.0f64; N_QUALITY_FEATURES];
let quality = predict_quality(&features);
assert!(quality >= 0.0 && quality <= 1.0);
}
#[test]
fn test_predict_quality_good_extraction() {
let mut features = vec![0.0f64; N_QUALITY_FEATURES];
features[0] = 0.95; features[1] = 8000.0; features[2] = 1200.0; features[3] = 0.55; features[5] = 40.0; features[8] = 15.0; features[13] = 1.0; features[20] = 0.8; let quality = predict_quality(&features);
assert!(quality >= 0.0 && quality <= 1.0, "Score should be in [0,1], got {quality}");
}
#[test]
fn test_page_type_from_str_trait() {
assert_eq!("article".parse::<PageType>(), Ok(PageType::Article));
assert!("unknown".parse::<PageType>().is_err());
}
}