web-page-classifier 0.1.0

Fast web page type classification using XGBoost with compact binary model
Documentation
//! Fast web page type classification.
//!
//! Classifies web pages into 7 types using a compact XGBoost model:
//! Article, Forum, Product, Collection, Listing, Documentation, Service.
//!
//! # Quick Start
//!
//! ```
//! use web_page_classifier::{PageType, classify_url};
//!
//! let page_type = classify_url("https://docs.example.com/api/reference");
//! assert_eq!(page_type, PageType::Documentation);
//! ```
//!
//! # ML Classification
//!
//! For higher accuracy, extract numeric features from the HTML DOM and pass
//! them along with title/description text:
//!
//! ```
//! use web_page_classifier::{classify_ml, N_NUMERIC_FEATURES};
//!
//! let features = vec![0.0f64; N_NUMERIC_FEATURES]; // your extracted features
//! let (page_type, confidence) = classify_ml(&features, "Example Article Title");
//! ```

mod model;
pub mod url_heuristics;

use std::sync::OnceLock;

pub use url_heuristics::classify_url;

/// Number of numeric features expected by the ML model.
pub const N_NUMERIC_FEATURES: usize = 89;

/// Web page type classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum PageType {
    Article,
    Collection,
    Documentation,
    Forum,
    Listing,
    Product,
    Service,
}

impl PageType {
    /// Parse from string (case-insensitive). Returns None for unknown types.
    #[must_use]
    pub fn parse(s: &str) -> Option<Self> {
        match s.to_ascii_lowercase().as_str() {
            "article" => Some(Self::Article),
            "collection" | "category" => Some(Self::Collection),
            "documentation" | "docs" => Some(Self::Documentation),
            "forum" => Some(Self::Forum),
            "listing" => Some(Self::Listing),
            "product" => Some(Self::Product),
            "service" => Some(Self::Service),
            _ => None,
        }
    }

    /// Return the type name as a lowercase string.
    #[must_use]
    pub fn as_str(&self) -> &'static str {
        match self {
            Self::Article => "article",
            Self::Collection => "collection",
            Self::Documentation => "documentation",
            Self::Forum => "forum",
            Self::Listing => "listing",
            Self::Product => "product",
            Self::Service => "service",
        }
    }
}

impl std::fmt::Display for PageType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(self.as_str())
    }
}

impl std::str::FromStr for PageType {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Self::parse(s).ok_or_else(|| format!("unknown page type: {s}"))
    }
}

/// Number of features expected by the quality predictor.
pub const N_QUALITY_FEATURES: usize = 27;

/// Embedded classifier model binary.
static MODEL_BYTES: &[u8] = include_bytes!("xgboost_v2.bin");

/// Embedded quality predictor model binary.
static QUALITY_MODEL_BYTES: &[u8] = include_bytes!("quality_model_v1.bin");

/// Lazily-initialized classifier model.
static MODEL: OnceLock<model::Model> = OnceLock::new();

/// Lazily-initialized quality model.
static QUALITY_MODEL: OnceLock<model::QualityModel> = OnceLock::new();

fn get_model() -> &'static model::Model {
    MODEL.get_or_init(|| {
        model::Model::from_bytes(MODEL_BYTES)
            .expect("embedded classifier model is valid")
    })
}

fn get_quality_model() -> &'static model::QualityModel {
    QUALITY_MODEL.get_or_init(|| {
        model::QualityModel::from_bytes(QUALITY_MODEL_BYTES)
            .expect("embedded quality model is valid")
    })
}

/// Classify a web page using the ML model.
///
/// # Arguments
/// * `numeric_features` - Raw (unscaled) numeric features. Must have length
///   [`N_NUMERIC_FEATURES`]. The model handles scaling internally.
/// * `title_meta` - Concatenated title + description text for TF-IDF features.
///
/// # Returns
/// `(PageType, confidence)` where confidence is in `[0.0, 1.0]`.
///
/// # Panics
/// Panics if `numeric_features.len() != N_NUMERIC_FEATURES`.
#[must_use]
pub fn classify_ml(numeric_features: &[f64], title_meta: &str) -> (PageType, f64) {
    assert_eq!(
        numeric_features.len(),
        N_NUMERIC_FEATURES,
        "Expected {} numeric features, got {}",
        N_NUMERIC_FEATURES,
        numeric_features.len()
    );

    let m = get_model();

    // Scale numeric features
    let scaled = m.scale_features(numeric_features);

    // Compute TF-IDF features
    let tfidf = m.compute_tfidf(title_meta);

    // Combine into full feature vector
    let mut all_features = Vec::with_capacity(scaled.len() + tfidf.len());
    all_features.extend_from_slice(&scaled);
    all_features.extend_from_slice(&tfidf);

    // Run forest prediction
    let (class_idx, confidence) = m.predict(&all_features);

    let page_type = m.class_labels.get(class_idx)
        .and_then(|s| PageType::parse(s))
        .unwrap_or(PageType::Article);

    (page_type, confidence)
}

/// Predict extraction quality (estimated F1 score) from post-extraction features.
///
/// Returns a value in `[0.0, 1.0]` estimating how well the extraction captured
/// the page's main content. Low scores (< 0.80) indicate the extraction may be
/// poor and should be routed to an LLM fallback.
///
/// # Arguments
/// * `features` - Raw (unscaled) quality features. Must have length
///   [`N_QUALITY_FEATURES`]. Features include content statistics, page type
///   indicators, and HTML-level signals.
///
/// # Feature order (27 features)
/// 0: heuristic_conf, 1: content_len, 2: word_count, 3: vocab_ratio,
/// 4: avg_word_len, 5: sentence_count, 6: avg_sentence_len,
/// 7: sentence_uniqueness, 8: paragraph_count, 9: avg_paragraph_len,
/// 10: link_count_in_content, 11: link_density, 12: boilerplate_keywords,
/// 13-19: is_article..is_service (one-hot page type),
/// 20: length_ratio, 21: html_size, 22: extraction_ratio,
/// 23: og_overlap, 24: script_count, 25: has_jsonld, 26: top_bigram_freq
///
/// # Panics
/// Panics if `features.len() != N_QUALITY_FEATURES`.
#[must_use]
pub fn predict_quality(features: &[f64]) -> f64 {
    assert_eq!(
        features.len(),
        N_QUALITY_FEATURES,
        "Expected {} quality features, got {}",
        N_QUALITY_FEATURES,
        features.len()
    );

    let m = get_quality_model();
    let scaled = m.scale_features(features);
    let predicted = m.predict(&scaled);
    predicted.clamp(0.0, 1.0)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_model_loads() {
        let m = get_model();
        assert_eq!(m.n_classes, 7);
        assert!(!m.trees.is_empty());
    }

    #[test]
    fn test_classify_ml_returns_valid_type() {
        let features = vec![0.0f64; N_NUMERIC_FEATURES];
        let (page_type, confidence) = classify_ml(&features, "Example blog post about technology");
        assert!(confidence >= 0.0 && confidence <= 1.0);
        // With all-zero features and "blog" in text, should lean toward article
        assert_eq!(page_type.as_str().is_empty(), false);
    }

    #[test]
    fn test_classify_url_basic() {
        assert_eq!(classify_url("https://forum.example.com/thread/123"), PageType::Forum);
        assert_eq!(classify_url("https://docs.example.com/api"), PageType::Documentation);
        assert_eq!(classify_url("https://example.com/products/widget"), PageType::Product);
    }

    #[test]
    fn test_page_type_display() {
        assert_eq!(PageType::Article.to_string(), "article");
        assert_eq!(PageType::Forum.as_str(), "forum");
    }

    #[test]
    fn test_page_type_parse() {
        assert_eq!(PageType::parse("article"), Some(PageType::Article));
        assert_eq!(PageType::parse("FORUM"), Some(PageType::Forum));
        assert_eq!(PageType::parse("category"), Some(PageType::Collection));
        assert_eq!(PageType::parse("unknown"), None);
    }

    #[test]
    fn test_quality_model_loads() {
        // Just verify it parses without panicking
        let _ = get_quality_model();
    }

    #[test]
    fn test_predict_quality() {
        let features = vec![0.0f64; N_QUALITY_FEATURES];
        let quality = predict_quality(&features);
        assert!(quality >= 0.0 && quality <= 1.0);
    }

    #[test]
    fn test_predict_quality_good_extraction() {
        // Simulate a good article extraction
        let mut features = vec![0.0f64; N_QUALITY_FEATURES];
        features[0] = 0.95;  // heuristic_conf
        features[1] = 8000.0; // content_len
        features[2] = 1200.0; // word_count
        features[3] = 0.55;  // vocab_ratio
        features[5] = 40.0;  // sentence_count
        features[8] = 15.0;  // paragraph_count
        features[13] = 1.0;  // is_article
        features[20] = 0.8;  // length_ratio
        let quality = predict_quality(&features);
        // With partial features, score may not be high, but should be valid
        assert!(quality >= 0.0 && quality <= 1.0, "Score should be in [0,1], got {quality}");
    }

    #[test]
    fn test_page_type_from_str_trait() {
        assert_eq!("article".parse::<PageType>(), Ok(PageType::Article));
        assert!("unknown".parse::<PageType>().is_err());
    }
}