use crate::algorithms::online::{FractionalStep, Step};
use crate::config::{Config, FractionalConfig};
use crate::model::{ModelOutputFailure, ModelOutputSuccess};
use crate::numerics::convex_optimization::{
find_minimizer, find_minimizer_of_hitting_cost, WrappedObjective,
};
use crate::problem::{FractionalSmoothedConvexOptimization, Online, Problem};
use crate::result::{Failure, Result};
use crate::schedule::FractionalSchedule;
use crate::utils::assert;
use pyo3::prelude::*;
#[pyclass]
#[derive(Clone)]
pub struct Options {
pub m: f64,
pub alpha: f64,
pub beta: f64,
}
impl Default for Options {
fn default() -> Self {
unimplemented!()
}
}
#[pymethods]
impl Options {
#[new]
fn constructor(m: f64, alpha: f64, beta: f64) -> Self {
Options { m, alpha, beta }
}
}
#[derive(Clone)]
struct RegularizationFunctionObjectiveData<'a, C, D> {
o: Online<FractionalSmoothedConvexOptimization<'a, C, D>>,
t: i32,
lambda_1: f64,
lambda_2: f64,
prev_x: FractionalConfig,
v: FractionalConfig,
}
pub fn robd<C, D>(
o: Online<FractionalSmoothedConvexOptimization<C, D>>,
t: i32,
xs: &FractionalSchedule,
_: (),
Options { m, alpha, beta }: Options,
) -> Result<FractionalStep<()>>
where
C: ModelOutputSuccess,
D: ModelOutputFailure,
{
assert(o.w == 0, Failure::UnsupportedPredictionWindow(o.w))?;
let (lambda_1, lambda_2) = build_parameters(m, alpha, beta);
let prev_x = xs.now_with_default(Config::repeat(0., o.p.d));
let v = Config::new(
find_minimizer_of_hitting_cost(
t,
o.p.hitting_cost.clone(),
o.p.bounds.clone(),
)
.0,
);
let bounds = o.p.bounds.clone();
let regularization_function = WrappedObjective::new(
RegularizationFunctionObjectiveData {
o,
t,
lambda_1,
lambda_2,
prev_x,
v,
},
|x_, data| {
let x = Config::new(x_.to_vec());
data.o.p.hit_cost(data.t, x.clone()).cost
+ data.lambda_1
* (data.o.p.switching_cost)(x.clone() - data.prev_x.clone())
.raw()
+ data.lambda_2
* (data.o.p.switching_cost)(x - data.v.clone()).raw()
},
);
let x = Config::new(find_minimizer(regularization_function, bounds).0);
Ok(Step(x, None))
}
fn build_parameters(m: f64, alpha: f64, beta: f64) -> (f64, f64) {
let f_lambda_2 = |lambda_1| {
(lambda_1 * m / 2.
* (1. + (1. + 4. * beta.powi(2) / (alpha * m)).sqrt())
- m)
/ beta
};
let mut lambda_2 = 0.;
let mut lambda_1 =
2. / (1. + (1. + 4. * beta.powi(2) / (alpha * m)).sqrt());
if (f_lambda_2(lambda_1) - lambda_2).abs() < f64::EPSILON {
return (lambda_1, lambda_2);
}
lambda_1 = 1.;
lambda_2 = f_lambda_2(lambda_1);
(lambda_1, lambda_2)
}