strafe-plot 0.1.1

Statistical plotting for Rust statistics based on R
// "Whatever you do, work at it with all your heart, as working for the Lord,
// not for human masters, since you know that you will receive an inheritance
// from the Lord as a reward. It is the Lord Christ you are serving."
// (Col 3:23-24)

use plotters::{
    coord::Shift,
    drawing::DrawingArea,
    prelude::{DrawingBackend, RED},
    style::full_palette::{GREY_300, GREY_400, GREY_900},
};
use r2rs_stats::{
    funcs::{lowess, ppoints},
    traits::StatArray,
};
use strafe_distribution::{Distribution, NormalBuilder};
use strafe_trait::Model;
use strafe_type::{reexports::ToPrimitive, FloatConstraint};

use crate::{
    plot::Plot,
    plot_options::PlotOptions,
    plottable::{
        confidence_interval::ConfidenceInterval, horizontal_line::HorizontalLine, line::Line,
        points::Points, vertical_line::VerticalLine,
    },
};

pub trait ModelPlot {
    fn plot_fit<B: DrawingBackend>(&mut self, root: &DrawingArea<B, Shift>, options: &PlotOptions)
    where
        B::ErrorType: 'static;
    fn plot_residual_fit<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static;
    fn plot_quantile_quantile<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static;
    fn plot_scale_location<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static;
    fn plot_residual_leverage<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static;
}

const NORMAL_BAR_SD: f64 = 1.96;

impl<M: Model> ModelPlot for M {
    fn plot_fit<B: DrawingBackend>(&mut self, root: &DrawingArea<B, Shift>, options: &PlotOptions)
    where
        B::ErrorType: 'static,
    {
        let x_mm = self.get_x();
        let y_mm = self.get_y();

        let x = x_mm.matrix().as_slice().to_vec();
        let y = y_mm.matrix().as_slice().to_vec();

        let ci = self.predictions().unwrap().confidence_interval().matrix();
        let ci_lower = ci.column(0).into_iter().cloned().collect::<Vec<_>>();
        let ci_upper = ci.column(1).into_iter().cloned().collect::<Vec<_>>();

        let estimates = self
            .predictions()
            .unwrap()
            .estimate()
            .matrix()
            .into_iter()
            .cloned()
            .collect::<Vec<_>>();

        Plot::new()
            .with_options(options.clone())
            .with_plottable(Points {
                x: x.clone(),
                y: y.clone(),
                ..Default::default()
            })
            .with_plottable(Line {
                x: x.clone(),
                y: estimates,
                ..Default::default()
            })
            .with_plottable(ConfidenceInterval {
                x,
                lower_y: ci_lower,
                upper_y: ci_upper,
                ..Default::default()
            })
            .plot(root)
            .unwrap();
    }

    fn plot_residual_fit<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static,
    {
        let x_mm = self.get_x();
        let x_m = x_mm.matrix();
        let x = x_m.as_slice().to_vec();

        let standardized_residuals_mm = self.standardized_residuals().unwrap();
        let standardized_residuals_m = standardized_residuals_mm.matrix();
        let standardized_residuals = standardized_residuals_m.as_slice().to_vec();

        let (lowess_x, lowess_y) = lowess(&x, &standardized_residuals);

        Plot::new()
            .with_options(PlotOptions {
                x_axis_label: "Fitted Values".to_string(),
                y_axis_label: "Standardized Residuals".to_string(),
                ..options.clone()
            })
            .with_plottable(HorizontalLine {
                y: 0.0,
                color: GREY_300,
                ..Default::default()
            })
            .with_plottable(HorizontalLine {
                y: NORMAL_BAR_SD,
                color: GREY_300,
                ..Default::default()
            })
            .with_plottable(HorizontalLine {
                y: -NORMAL_BAR_SD,
                color: GREY_300,
                ..Default::default()
            })
            .with_plottable(HorizontalLine {
                y: 2.0 * NORMAL_BAR_SD,
                color: GREY_300,
                ..Default::default()
            })
            .with_plottable(HorizontalLine {
                y: -2.0 * NORMAL_BAR_SD,
                color: GREY_300,
                ..Default::default()
            })
            .with_plottable(HorizontalLine {
                y: 3.0 * NORMAL_BAR_SD,
                color: GREY_300,
                ..Default::default()
            })
            .with_plottable(HorizontalLine {
                y: -3.0 * NORMAL_BAR_SD,
                color: GREY_300,
                ..Default::default()
            })
            .with_plottable(Line {
                x: lowess_x,
                y: lowess_y,
                color: RED,
                ..Default::default()
            })
            .with_plottable(Points {
                x,
                y: standardized_residuals,
                ..Default::default()
            })
            .plot(root)
            .unwrap();
    }

    fn plot_quantile_quantile<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static,
    {
        let standardized_residuals_mm = self.standardized_residuals().unwrap();
        let standardized_residuals_m = standardized_residuals_mm.matrix();
        let mut standardized_residuals = standardized_residuals_m.as_slice().to_vec();
        standardized_residuals.sort_by(|r1, r2| r1.partial_cmp(r2).unwrap());

        let norm = NormalBuilder::new().build();
        let q = ppoints(standardized_residuals.len(), None)
            .into_iter()
            .map(|p| norm.quantile(p.to_f64().unwrap(), true).unwrap())
            .collect::<Vec<_>>();

        let norm = NormalBuilder::new().build();
        let se = |qi: &f64| {
            let pnorm = norm.probability(qi, true).unwrap();
            let dnorm = norm.density(qi).unwrap();
            let n = standardized_residuals.len() as f64;
            (pnorm * (1.0 - pnorm) / n).sqrt() / dnorm
        };

        let lower_bound = q
            .iter()
            .map(|qi| qi - NORMAL_BAR_SD * se(qi))
            .collect::<Vec<_>>();
        let upper_bound = q
            .iter()
            .map(|qi| qi + NORMAL_BAR_SD * se(qi))
            .collect::<Vec<_>>();

        Plot::new()
            .with_options(PlotOptions {
                x_axis_label: "Theoretical Quantiles".to_string(),
                y_axis_label: "Standardized Residuals".to_string(),
                ..options.clone()
            })
            .with_plottable(Line {
                x: q.clone(),
                y: q.clone(),
                color: GREY_400,
                ..Default::default()
            })
            .with_plottable(ConfidenceInterval {
                x: q.clone(),
                lower_y: lower_bound,
                upper_y: upper_bound,
                color: GREY_900,
                ..Default::default()
            })
            .with_plottable(Points {
                x: q,
                y: standardized_residuals,
                ..Default::default()
            })
            .plot(root)
            .unwrap();
    }

    fn plot_scale_location<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static,
    {
        let x_mm = self.get_x();
        let x_m = x_mm.matrix();
        let x = x_m.as_slice().to_vec();

        let standardized_residuals_mm = self.standardized_residuals().unwrap();
        let standardized_residuals_m = standardized_residuals_mm.matrix();
        let mut standardized_residuals = standardized_residuals_m.as_slice().to_vec();
        standardized_residuals
            .iter_mut()
            .for_each(|f| *f = f.abs().sqrt());

        let (lowess_x, lowess_y) = lowess(&x, &standardized_residuals);

        Plot::new()
            .with_options(PlotOptions {
                x_axis_label: "Fitted Values".to_string(),
                y_axis_label: "SQRT(|Standardized Residuals|)".to_string(),
                ..options.clone()
            })
            .with_plottable(Line {
                x: lowess_x,
                y: lowess_y,
                color: RED,
                ..Default::default()
            })
            .with_plottable(Points {
                x,
                y: standardized_residuals,
                ..Default::default()
            })
            .plot(root)
            .unwrap();
    }

    fn plot_residual_leverage<B: DrawingBackend>(
        &mut self,
        root: &DrawingArea<B, Shift>,
        options: &PlotOptions,
    ) where
        B::ErrorType: 'static,
    {
        let x_mm = self.get_x();
        let mut x_m = x_mm.matrix();
        if self.get_intercept() {
            x_m = x_m.insert_column(0, 1.0);
        }

        let leverage_m = (x_m.clone()
            * (x_m.transpose() * x_m.clone())
                .pseudo_inverse(f64::EPSILON)
                .unwrap()
            * x_m.transpose())
        .diagonal();
        let leverage = leverage_m.as_slice().to_vec();

        let standardized_residuals_mm = self.standardized_residuals().unwrap();
        let standardized_residuals_m = standardized_residuals_mm.matrix();
        let standardized_residuals = standardized_residuals_m.as_slice().to_vec();

        let (lowess_x, lowess_y) = lowess(&leverage, &standardized_residuals);

        let mut plot = Plot::new();
        plot.with_options(PlotOptions {
            x_axis_label: "Leverage".to_string(),
            y_axis_label: "Standardized Residuals".to_string(),
            x_min: Some(0.0),
            ..options.clone()
        })
        .with_plottable(Line {
            x: lowess_x,
            y: lowess_y,
            color: RED,
            ..Default::default()
        })
        .with_plottable(Points {
            x: leverage.clone(),
            y: standardized_residuals,
            ..Default::default()
        })
        .with_plottable(HorizontalLine {
            y: 0.0,
            ..Default::default()
        })
        .with_plottable(VerticalLine {
            x: 0.0,
            ..Default::default()
        });

        let drawing_coords = plot.get_drawing_coords(root);
        let x_right = drawing_coords.x_max + drawing_coords.x_space;
        let cook_levels = [0.5, 1.0];

        let mut hh = (0..100)
            .map(|i| leverage.min() + (i as f64 * ((x_right - leverage.min()) / 2.0)))
            .collect::<Vec<_>>();
        hh.push(x_right);

        for crit in cook_levels {
            let cl_h_pos = hh
                .iter()
                .map(|hh_i| (crit * x_m.ncols() as f64 * (1.0 - hh_i) / hh_i).sqrt())
                .collect::<Vec<_>>();
            let cl_h_neg = cl_h_pos.iter().map(|f| -f).collect::<Vec<_>>();

            plot.with_plottable(Line {
                x: hh.clone(),
                y: cl_h_pos,
                force_fit_all: false,
                color: GREY_400,
                ..Default::default()
            });
            plot.with_plottable(Line {
                x: hh.clone(),
                y: cl_h_neg,
                force_fit_all: false,
                color: GREY_400,
                ..Default::default()
            });
        }

        plot.plot(root).unwrap();
    }
}