use faer::Col;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PredictionType {
#[default]
Response,
Link,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IntervalType {
Confidence,
#[default]
Prediction,
}
#[derive(Debug, Clone)]
pub struct PredictionResult {
pub fit: Col<f64>,
pub lower: Col<f64>,
pub upper: Col<f64>,
pub se: Col<f64>,
}
impl PredictionResult {
pub fn point_only(fit: Col<f64>) -> Self {
let n = fit.nrows();
Self {
fit,
lower: Col::zeros(n),
upper: Col::zeros(n),
se: Col::zeros(n),
}
}
pub fn with_intervals(fit: Col<f64>, lower: Col<f64>, upper: Col<f64>, se: Col<f64>) -> Self {
Self {
fit,
lower,
upper,
se,
}
}
pub fn len(&self) -> usize {
self.fit.nrows()
}
pub fn is_empty(&self) -> bool {
self.fit.nrows() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_point_only() {
let fit = Col::from_fn(5, |i| i as f64);
let result = PredictionResult::point_only(fit);
assert_eq!(result.len(), 5);
assert!(!result.is_empty());
}
#[test]
fn test_with_intervals() {
let fit = Col::from_fn(3, |i| i as f64);
let lower = Col::from_fn(3, |i| (i as f64) - 1.0);
let upper = Col::from_fn(3, |i| (i as f64) + 1.0);
let se = Col::from_fn(3, |_| 0.5);
let result = PredictionResult::with_intervals(fit, lower, upper, se);
assert_eq!(result.len(), 3);
}
#[test]
fn test_empty_result() {
let fit = Col::<f64>::zeros(0);
let result = PredictionResult::point_only(fit);
assert!(result.is_empty());
assert_eq!(result.len(), 0);
}
#[test]
fn test_interval_type_default() {
let interval = IntervalType::default();
assert_eq!(interval, IntervalType::Prediction);
}
}