use std::ops::{Mul, MulAssign};
use num_traits::One;
use crate::data::linear_algebra::traits::{SparseComparator, SparseElement};
use crate::data::linear_algebra::vector::{SparseVector, Vector};
use crate::data::linear_program::elements::RangedConstraintRelation;
use crate::data::linear_program::general_form::GeneralForm;
mod rational;
pub trait Scalable<T> {
#[must_use = "Use the scaling to transform the solution."]
fn scale(&mut self) -> Scaling<T>;
fn scale_back(&mut self, scale_info: Scaling<T>);
}
const WARNING_MESSAGE: &str = "WARNING: Not scaling. Does your number type fulfill the necessary \
bound constraints?";
impl<T: One + SparseElement<T> + SparseComparator + Clone> Scalable<T> for GeneralForm<T> {
default fn scale(&mut self) -> Scaling<T> {
println!("{}", WARNING_MESSAGE);
Scaling {
cost_factor: T::one(),
constraint_row_factors: vec![T::one(); self.nr_active_constraints()],
constraint_column_factors: vec![T::one(); self.nr_active_variables()],
}
}
default fn scale_back(&mut self, _scale_info: Scaling<T>) {
println!("{}", WARNING_MESSAGE);
}
}
#[derive(Eq, PartialEq, Debug)]
pub struct Scaling<R> {
cost_factor: R,
constraint_row_factors: Vec<R>,
constraint_column_factors: Vec<R>,
}
impl<T> Scaling<T> {
pub fn scale_back<S>(&self, vector: &mut SparseVector<S, S>)
where
for<'r> S: MulAssign<&'r T>,
S: SparseElement<S> + SparseComparator,
{
debug_assert_eq!(vector.len(), self.constraint_column_factors.len());
for (j, value) in vector.iter_mut() {
*value *= &self.constraint_column_factors[*j];
}
}
}
fn scale<T, F, G>(
general_form: &mut GeneralForm<T>, scaling: &Scaling<T>,
op: F, inverse_op: G,
)
where
T: SparseElement<T> + SparseComparator,
for<'r> &'r T: Mul<&'r T, Output=T>,
F: Fn(&mut T, &T),
G: Fn(&mut T, &T),
{
let Scaling {
cost_factor,
constraint_row_factors,
constraint_column_factors,
} = scaling;
for (j, column) in general_form.constraints.data.iter_mut().enumerate() {
let column_factor = &constraint_column_factors[j];
let variable = &mut general_form.variables[j];
op(&mut variable.cost, &(cost_factor * column_factor));
if let Some(bound) = &mut variable.lower_bound {
inverse_op(bound, column_factor)
}
if let Some(bound) = &mut variable.upper_bound {
inverse_op(bound, column_factor)
}
for (i, value) in column {
let row_factor = &constraint_row_factors[*i];
op(value, &(row_factor * column_factor));
}
}
for (i, value) in general_form.b.iter_mut().enumerate() {
let row_factor = &constraint_row_factors[i];
op(value, row_factor);
if let RangedConstraintRelation::Range(range) = &mut general_form.constraint_types[i] {
op(range, row_factor);
}
}
}