use super::TripGenerator;
use super::error::TripGenerationError;
use crate::zone::Zone;
#[derive(Debug, Clone)]
pub struct RegressionCoefficients {
pub intercept: f64,
pub pop_coeff: f64,
pub emp_coeff: f64,
pub hh_coeff: f64,
pub income_coeff: f64,
}
impl Default for RegressionCoefficients {
fn default() -> Self {
RegressionCoefficients {
intercept: 0.0,
pop_coeff: 0.5,
emp_coeff: 0.0,
hh_coeff: 0.0,
income_coeff: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct RegressionGenerator {
pub production_coeffs: RegressionCoefficients,
pub attraction_coeffs: RegressionCoefficients,
}
impl RegressionGenerator {
pub fn new() -> Self {
RegressionGenerator {
production_coeffs: RegressionCoefficients {
intercept: 0.0,
pop_coeff: 0.5,
emp_coeff: 0.1,
hh_coeff: 0.0,
income_coeff: 0.0,
},
attraction_coeffs: RegressionCoefficients {
intercept: 0.0,
pop_coeff: 0.1,
emp_coeff: 0.8,
hh_coeff: 0.0,
income_coeff: 0.0,
},
}
}
pub fn with_coefficients(
production: RegressionCoefficients,
attraction: RegressionCoefficients,
) -> Self {
RegressionGenerator {
production_coeffs: production,
attraction_coeffs: attraction,
}
}
fn compute_value(zone: &Zone, coeffs: &RegressionCoefficients) -> f64 {
let value = coeffs.intercept
+ coeffs.pop_coeff * zone.population
+ coeffs.emp_coeff * zone.employment
+ coeffs.hh_coeff * zone.households
+ coeffs.income_coeff * zone.avg_income;
value.max(0.0)
}
}
impl Default for RegressionGenerator {
fn default() -> Self {
Self::new()
}
}
impl TripGenerator for RegressionGenerator {
fn generate(&self, zones: &[Zone]) -> Result<(Vec<f64>, Vec<f64>), TripGenerationError> {
if zones.is_empty() {
return Err(TripGenerationError::NoZones);
}
let productions: Vec<f64> = zones
.iter()
.map(|z| Self::compute_value(z, &self.production_coeffs))
.collect();
let attractions: Vec<f64> = zones
.iter()
.map(|z| Self::compute_value(z, &self.attraction_coeffs))
.collect();
Ok((productions, attractions))
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
#[test]
fn default_coefficients_single_zone() {
let trip_gen = RegressionGenerator::new();
let zones = vec![
Zone::new(1)
.with_population(10000.0)
.with_employment(5000.0)
.build(),
];
let (prods, attrs) = trip_gen.generate(&zones).unwrap();
assert!((prods[0] - 5500.0).abs() < EPS);
assert!((attrs[0] - 5000.0).abs() < EPS);
}
#[test]
fn multiple_zones() {
let trip_gen = RegressionGenerator::new();
let zones = vec![
Zone::new(1)
.with_population(6000.0)
.with_employment(500.0)
.build(),
Zone::new(2)
.with_population(2000.0)
.with_employment(3000.0)
.build(),
];
let (prods, attrs) = trip_gen.generate(&zones).unwrap();
assert!((prods[0] - 3050.0).abs() < EPS);
assert!((prods[1] - 1300.0).abs() < EPS);
assert!((attrs[0] - 1000.0).abs() < EPS);
assert!((attrs[1] - 2600.0).abs() < EPS);
}
#[test]
fn custom_coefficients() {
let trip_gen = RegressionGenerator::with_coefficients(
RegressionCoefficients {
intercept: 10.0,
pop_coeff: 0.3,
emp_coeff: 0.0,
hh_coeff: 0.0,
income_coeff: 0.0,
},
RegressionCoefficients {
intercept: 0.0,
pop_coeff: 0.0,
emp_coeff: 1.2,
hh_coeff: 0.0,
income_coeff: 0.0,
},
);
let zones = vec![
Zone::new(1)
.with_population(1000.0)
.with_employment(500.0)
.build(),
];
let (prods, attrs) = trip_gen.generate(&zones).unwrap();
assert!((prods[0] - 310.0).abs() < EPS);
assert!((attrs[0] - 600.0).abs() < EPS);
}
#[test]
fn negative_result_clamped_to_zero() {
let trip_gen = RegressionGenerator::with_coefficients(
RegressionCoefficients {
intercept: -1000.0,
pop_coeff: 0.1,
emp_coeff: 0.0,
hh_coeff: 0.0,
income_coeff: 0.0,
},
RegressionCoefficients::default(),
);
let zones = vec![
Zone::new(1).with_population(100.0).build(),
];
let (prods, _) = trip_gen.generate(&zones).unwrap();
assert_eq!(prods[0], 0.0);
}
#[test]
fn empty_zones_returns_error() {
let trip_gen = RegressionGenerator::new();
let result = trip_gen.generate(&[]);
assert!(result.is_err());
}
#[test]
fn households_and_income_coefficients() {
let trip_gen = RegressionGenerator::with_coefficients(
RegressionCoefficients {
intercept: 0.0,
pop_coeff: 0.0,
emp_coeff: 0.0,
hh_coeff: 2.0,
income_coeff: 0.001,
},
RegressionCoefficients::default(),
);
let zones = vec![
Zone::new(1)
.with_households(500.0)
.with_avg_income(50000.0)
.build(),
];
let (prods, _) = trip_gen.generate(&zones).unwrap();
assert!((prods[0] - 1050.0).abs() < EPS);
}
}