use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName, OMatrix, OVector};
use crate::setup;
use crate::solver;
use crate::types::SolverStats;
pub struct ControlAllocator<const NU: usize, const NV: usize, const NC: usize>
where
Const<NC>: DimName,
Const<NU>: DimName,
Const<NV>: DimName,
DefaultAllocator: Allocator<Const<NC>, Const<NU>> + Allocator<Const<NU>> + Allocator<Const<NV>>,
{
a: OMatrix<f32, Const<NC>, Const<NU>>,
wv: OVector<f32, Const<NV>>,
wu_norm: OVector<f32, Const<NU>>,
gamma: f32,
us: OVector<f32, Const<NU>>,
ws: [i8; NU],
}
impl<const NU: usize, const NV: usize, const NC: usize> ControlAllocator<NU, NV, NC>
where
Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
Const<NU>: DimName,
Const<NV>: DimName,
DefaultAllocator: Allocator<Const<NC>, Const<NU>>
+ Allocator<Const<NC>, Const<NC>>
+ Allocator<Const<NU>, Const<NU>>
+ Allocator<Const<NC>>
+ Allocator<Const<NU>>
+ Allocator<Const<NV>>,
{
pub fn new(
g: &OMatrix<f32, Const<NV>, Const<NU>>,
wv: &OVector<f32, Const<NV>>,
mut wu: OVector<f32, Const<NU>>,
theta: f32,
cond_bound: f32,
) -> Self {
const { assert!(NC == NU + NV, "ControlAllocator requires NC == NU + NV") };
let (a, gamma) = setup::setup_a::<NU, NV, NC>(g, wv, &mut wu, theta, cond_bound);
Self {
a,
wv: wv.clone_owned(),
wu_norm: wu,
gamma,
us: OVector::zeros(),
ws: [0i8; NU],
}
}
pub fn solve(
&mut self,
v: &OVector<f32, Const<NV>>,
ud: &OVector<f32, Const<NU>>,
umin: &OVector<f32, Const<NU>>,
umax: &OVector<f32, Const<NU>>,
imax: usize,
) -> SolverStats {
let b = setup::setup_b::<NU, NV, NC>(v, ud, &self.wv, &self.wu_norm, self.gamma);
solver::solve::<NU, NV, NC>(&self.a, &b, umin, umax, &mut self.us, &mut self.ws, imax)
}
pub fn solution(&self) -> &OVector<f32, Const<NU>> {
&self.us
}
pub fn gamma(&self) -> f32 {
self.gamma
}
pub fn set_warmstart(&mut self, us: &OVector<f32, Const<NU>>) {
self.us = us.clone_owned();
self.ws = [0i8; NU];
}
pub fn reset_warmstart(&mut self) {
self.us = OVector::zeros();
self.ws = [0i8; NU];
}
}