xgboost-rust 0.1.0

Rust bindings for XGBoost, a gradient boosting library for machine learning. Downloads XGBoost binaries at build time for cross-platform compatibility.
#![cfg(feature = "polars")]

use polars::prelude::*;
use std::time::Instant;
use xgboost_rust::{Booster, BoosterPolarsExt};

/// Create test data as raw Vec<f32>
fn create_raw_data(rows: usize, cols: usize) -> Vec<f32> {
    (0..rows * cols).map(|i| (i % 100) as f32 / 10.0).collect()
}

/// Create test data as Polars DataFrame
fn create_polars_dataframe(rows: usize, cols: usize) -> DataFrame {
    let mut columns = Vec::new();

    for col_idx in 0..cols {
        let data: Vec<f32> = (0..rows)
            .map(|i| ((i * cols + col_idx) % 100) as f32 / 10.0)
            .collect();
        let series = Series::new(&format!("feature_{}", col_idx), data);
        columns.push(series);
    }

    DataFrame::new(columns).expect("Failed to create DataFrame")
}

#[test]
#[ignore] // Requires model file
fn test_performance_comparison() {
    let model_path = "demo_model.json";
    if !std::path::Path::new(model_path).exists() {
        println!("Skipping performance test - no model file found");
        return;
    }

    let booster = Booster::load(model_path).expect("Failed to load model");

    let test_sizes = vec![
        (100, 10, "100 rows"),
        (1000, 10, "1k rows"),
        (10000, 10, "10k rows"),
    ];

    println!("\n╔════════════════════════════════════════════════════════════╗");
    println!("║          XGBoost Polars Performance Comparison             ║");
    println!("╠════════════════════════════════════════════════════════════╣");

    for (rows, cols, name) in test_sizes {
        println!("║ Test: {} x {} features ({})", rows, cols, name);
        println!("╟────────────────────────────────────────────────────────────╢");

        // Warm up
        let raw_data = create_raw_data(rows, cols);
        let _ = booster.predict(&raw_data, rows, cols, 0, false);

        // Test 1: Raw array prediction
        let iterations = 10;
        let start = Instant::now();
        for _ in 0..iterations {
            let _ = booster
                .predict(&raw_data, rows, cols, 0, false)
                .expect("Prediction failed");
        }
        let raw_duration = start.elapsed();
        let raw_avg = raw_duration.as_micros() / iterations;

        println!("║   Raw Array:      {:>8} μs/iter", raw_avg);

        // Test 2: Polars DataFrame prediction
        let df = create_polars_dataframe(rows, cols);
        let start = Instant::now();
        for _ in 0..iterations {
            let _ = booster
                .predict_dataframe(&df, 0, false)
                .expect("Prediction failed");
        }
        let polars_duration = start.elapsed();
        let polars_avg = polars_duration.as_micros() / iterations;

        println!("║   Polars DataFrame: {:>8} μs/iter", polars_avg);

        // Calculate overhead
        let overhead_pct = if raw_avg > 0 {
            ((polars_avg as f64 - raw_avg as f64) / raw_avg as f64 * 100.0)
        } else {
            0.0
        };

        println!("║   Overhead:       {:>8.2}%", overhead_pct);
        println!("╟────────────────────────────────────────────────────────────╢");
    }

    println!("╚════════════════════════════════════════════════════════════╝\n");
}

#[test]
#[ignore] // Requires model file
fn test_threading_performance() {
    let model_path = "demo_model.json";
    if !std::path::Path::new(model_path).exists() {
        println!("Skipping threading test - no model file found");
        return;
    }

    let booster = Booster::load(model_path).expect("Failed to load model");
    let rows = 10000;
    let cols = 10;

    println!("\n╔════════════════════════════════════════════════════════════╗");
    println!("║          Threading Impact on Polars Predictions            ║");
    println!("╠════════════════════════════════════════════════════════════╣");

    let df = create_polars_dataframe(rows, cols);

    // Test with different thread counts
    // Note: This is mainly to observe if Polars causes thread contention
    let iterations = 5;

    println!("║ Dataset: {} rows x {} features", rows, cols);
    println!("╟────────────────────────────────────────────────────────────╢");

    let start = Instant::now();
    for _ in 0..iterations {
        let _ = booster
            .predict_dataframe(&df, 0, false)
            .expect("Prediction failed");
    }
    let duration = start.elapsed();
    let avg = duration.as_millis() / iterations;

    println!("║   Average time: {} ms/iter", avg);
    println!(
        "║   Total time:   {} ms for {} iterations",
        duration.as_millis(),
        iterations
    );
    println!("╚════════════════════════════════════════════════════════════╝\n");

    println!("Note: To test single-threaded mode, set:");
    println!("  export POLARS_MAX_THREADS=1");
    println!("  or use Polars config: POLARS_MAX_THREADS environment variable");
}