use crate::constraint::ViolationComputable;
use crate::error::LogicResult;
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct MPCConfig {
pub prediction_horizon: usize,
pub control_horizon: usize,
pub state_dim: usize,
pub control_dim: usize,
pub max_iterations: usize,
pub tolerance: f32,
pub warm_start: bool,
pub terminal_weight: f32,
}
impl Default for MPCConfig {
fn default() -> Self {
Self {
prediction_horizon: 10,
control_horizon: 10,
state_dim: 1,
control_dim: 1,
max_iterations: 100,
tolerance: 1e-4,
warm_start: true,
terminal_weight: 1.0,
}
}
}
pub trait MPCCost: Send + Sync {
fn stage_cost(&self, state: &Array1<f32>, control: &Array1<f32>, time_step: usize) -> f32;
fn terminal_cost(&self, state: &Array1<f32>) -> f32;
fn stage_cost_grad_state(
&self,
state: &Array1<f32>,
control: &Array1<f32>,
time_step: usize,
) -> Array1<f32>;
fn stage_cost_grad_control(
&self,
state: &Array1<f32>,
control: &Array1<f32>,
time_step: usize,
) -> Array1<f32>;
}
pub trait DynamicsModel: Send + Sync {
fn step(&self, state: &Array1<f32>, control: &Array1<f32>) -> Array1<f32>;
fn jacobian_state(&self, state: &Array1<f32>, control: &Array1<f32>) -> Array2<f32>;
fn jacobian_control(&self, state: &Array1<f32>, control: &Array1<f32>) -> Array2<f32>;
}
pub struct QuadraticCost {
pub x_ref: Vec<Array1<f32>>,
pub u_ref: Array1<f32>,
pub q_weights: Array1<f32>,
pub r_weights: Array1<f32>,
pub q_terminal: Array1<f32>,
}
impl QuadraticCost {
pub fn new(
x_ref: Vec<Array1<f32>>,
u_ref: Array1<f32>,
q_weights: Array1<f32>,
r_weights: Array1<f32>,
q_terminal: Array1<f32>,
) -> Self {
Self {
x_ref,
u_ref,
q_weights,
r_weights,
q_terminal,
}
}
}
impl MPCCost for QuadraticCost {
fn stage_cost(&self, state: &Array1<f32>, control: &Array1<f32>, time_step: usize) -> f32 {
let x_ref = if time_step < self.x_ref.len() {
&self.x_ref[time_step]
} else {
self.x_ref.last().expect("x_ref must be non-empty")
};
let state_error = state - x_ref;
let control_error = control - &self.u_ref;
let state_cost: f32 = state_error
.iter()
.zip(self.q_weights.iter())
.map(|(e, q)| e * e * q)
.sum();
let control_cost: f32 = control_error
.iter()
.zip(self.r_weights.iter())
.map(|(e, r)| e * e * r)
.sum();
state_cost + control_cost
}
fn terminal_cost(&self, state: &Array1<f32>) -> f32 {
let x_ref = self.x_ref.last().unwrap();
let error = state - x_ref;
error
.iter()
.zip(self.q_terminal.iter())
.map(|(e, q)| e * e * q)
.sum()
}
fn stage_cost_grad_state(
&self,
state: &Array1<f32>,
_control: &Array1<f32>,
time_step: usize,
) -> Array1<f32> {
let x_ref = if time_step < self.x_ref.len() {
&self.x_ref[time_step]
} else {
self.x_ref.last().expect("x_ref must be non-empty")
};
let error = state - x_ref;
&error * &(&self.q_weights * 2.0)
}
fn stage_cost_grad_control(
&self,
_state: &Array1<f32>,
control: &Array1<f32>,
_time_step: usize,
) -> Array1<f32> {
let error = control - &self.u_ref;
&error * &(&self.r_weights * 2.0)
}
}
pub struct LinearDynamics {
pub a_matrix: Array2<f32>,
pub b_matrix: Array2<f32>,
}
impl LinearDynamics {
pub fn new(a_matrix: Array2<f32>, b_matrix: Array2<f32>) -> Self {
Self { a_matrix, b_matrix }
}
}
impl DynamicsModel for LinearDynamics {
fn step(&self, state: &Array1<f32>, control: &Array1<f32>) -> Array1<f32> {
let ax = self.a_matrix.dot(state);
let bu = self.b_matrix.dot(control);
&ax + &bu
}
fn jacobian_state(&self, _state: &Array1<f32>, _control: &Array1<f32>) -> Array2<f32> {
self.a_matrix.clone()
}
fn jacobian_control(&self, _state: &Array1<f32>, _control: &Array1<f32>) -> Array2<f32> {
self.b_matrix.clone()
}
}
pub struct MPCController<D: DynamicsModel, C: MPCCost> {
config: MPCConfig,
dynamics: D,
cost: C,
state_constraints: Vec<Box<dyn ViolationComputable + Send + Sync>>,
control_constraints: Vec<Box<dyn ViolationComputable + Send + Sync>>,
previous_controls: Option<VecDeque<Array1<f32>>>,
}
impl<D: DynamicsModel, C: MPCCost> MPCController<D, C> {
pub fn new(config: MPCConfig, dynamics: D, cost: C) -> Self {
Self {
config,
dynamics,
cost,
state_constraints: Vec::new(),
control_constraints: Vec::new(),
previous_controls: None,
}
}
pub fn add_state_constraint(&mut self, constraint: Box<dyn ViolationComputable + Send + Sync>) {
self.state_constraints.push(constraint);
}
pub fn add_control_constraint(
&mut self,
constraint: Box<dyn ViolationComputable + Send + Sync>,
) {
self.control_constraints.push(constraint);
}
pub fn solve(&mut self, current_state: &Array1<f32>) -> LogicResult<MPCSolution> {
let horizon = self.config.control_horizon;
let mut controls = if self.config.warm_start {
if let Some(prev_controls) = &self.previous_controls {
let mut prev = prev_controls.clone();
if !prev.is_empty() {
prev.pop_front();
prev.push_back(Array1::zeros(self.config.control_dim));
}
prev.into_iter().collect::<Vec<_>>()
} else {
vec![Array1::zeros(self.config.control_dim); horizon]
}
} else {
vec![Array1::zeros(self.config.control_dim); horizon]
};
let step_size = 0.01;
let mut best_cost = f32::INFINITY;
for iteration in 0..self.config.max_iterations {
let states = self.simulate_trajectory(current_state, &controls);
let cost = self.compute_total_cost(&states, &controls);
if cost < best_cost {
best_cost = cost;
}
if iteration > 0 && (best_cost - cost).abs() < self.config.tolerance {
break;
}
for t in 0..horizon {
let grad = self.compute_control_gradient(&states, &controls, t);
let new_control = &controls[t] - &(&grad * step_size);
controls[t] = self.project_control(&new_control);
}
}
if self.config.warm_start {
self.previous_controls = Some(controls.iter().cloned().collect());
}
let final_states = self.simulate_trajectory(current_state, &controls);
let final_cost = self.compute_total_cost(&final_states, &controls);
Ok(MPCSolution {
controls,
predicted_states: final_states,
total_cost: final_cost,
horizon,
})
}
fn simulate_trajectory(
&self,
initial_state: &Array1<f32>,
controls: &[Array1<f32>],
) -> Vec<Array1<f32>> {
let mut states = vec![initial_state.clone()];
for control in controls.iter() {
let next_state = self.dynamics.step(states.last().unwrap(), control);
states.push(next_state);
}
states
}
fn compute_total_cost(&self, states: &[Array1<f32>], controls: &[Array1<f32>]) -> f32 {
let mut cost = 0.0;
for (t, control) in controls.iter().enumerate() {
cost += self.cost.stage_cost(&states[t], control, t);
cost += self.constraint_violation_cost(&states[t], control);
}
cost += self.cost.terminal_cost(states.last().unwrap()) * self.config.terminal_weight;
cost
}
fn constraint_violation_cost(&self, state: &Array1<f32>, control: &Array1<f32>) -> f32 {
let mut violation = 0.0;
let state_slice: Vec<f32> = state.iter().copied().collect();
for constraint in &self.state_constraints {
violation += constraint.violation(&state_slice) * 100.0; }
let control_slice: Vec<f32> = control.iter().copied().collect();
for constraint in &self.control_constraints {
violation += constraint.violation(&control_slice) * 100.0;
}
violation
}
fn compute_control_gradient(
&self,
states: &[Array1<f32>],
controls: &[Array1<f32>],
t: usize,
) -> Array1<f32> {
let direct_grad = self
.cost
.stage_cost_grad_control(&states[t], &controls[t], t);
let mut constraint_grad = Array1::zeros(self.config.control_dim);
let eps = 1e-5;
for i in 0..self.config.control_dim {
let mut control_plus = controls[t].clone();
control_plus[i] += eps;
let cost_plus = self.constraint_violation_cost(&states[t], &control_plus);
let cost_base = self.constraint_violation_cost(&states[t], &controls[t]);
constraint_grad[i] = (cost_plus - cost_base) / eps;
}
&direct_grad + &constraint_grad
}
fn project_control(&self, control: &Array1<f32>) -> Array1<f32> {
let mut projected = control.clone();
for constraint in &self.control_constraints {
let control_slice: Vec<f32> = projected.iter().copied().collect();
if constraint.violation(&control_slice) > 0.0 {
projected = projected.mapv(|x| x.clamp(-10.0, 10.0));
}
}
projected
}
pub fn reset(&mut self) {
self.previous_controls = None;
}
}
#[derive(Debug, Clone)]
pub struct MPCSolution {
pub controls: Vec<Array1<f32>>,
pub predicted_states: Vec<Array1<f32>>,
pub total_cost: f32,
pub horizon: usize,
}
impl MPCSolution {
pub fn first_control(&self) -> &Array1<f32> {
&self.controls[0]
}
pub fn predicted_state(&self, time_step: usize) -> Option<&Array1<f32>> {
self.predicted_states.get(time_step)
}
pub fn is_feasible(&self) -> bool {
self.total_cost < f32::INFINITY
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mpc_config() {
let config = MPCConfig::default();
assert_eq!(config.prediction_horizon, 10);
assert_eq!(config.control_horizon, 10);
assert!(config.warm_start);
}
#[test]
fn test_quadratic_cost() {
let x_ref = vec![Array1::from_vec(vec![1.0])];
let u_ref = Array1::from_vec(vec![0.0]);
let q = Array1::from_vec(vec![1.0]);
let r = Array1::from_vec(vec![0.1]);
let q_term = Array1::from_vec(vec![10.0]);
let cost = QuadraticCost::new(x_ref, u_ref, q, r, q_term);
let state = Array1::from_vec(vec![2.0]);
let control = Array1::from_vec(vec![1.0]);
let stage = cost.stage_cost(&state, &control, 0);
assert!(stage > 0.0);
let terminal = cost.terminal_cost(&state);
assert!(terminal > 0.0); }
#[test]
fn test_linear_dynamics() {
let a = Array2::from_shape_vec((1, 1), vec![0.9]).unwrap();
let b = Array2::from_shape_vec((1, 1), vec![0.1]).unwrap();
let dynamics = LinearDynamics::new(a, b);
let state = Array1::from_vec(vec![1.0]);
let control = Array1::from_vec(vec![2.0]);
let next_state = dynamics.step(&state, &control);
assert!((next_state[0] - 1.1).abs() < 1e-5); }
#[test]
fn test_mpc_controller_creation() {
let config = MPCConfig {
prediction_horizon: 5,
control_horizon: 5,
state_dim: 1,
control_dim: 1,
..Default::default()
};
let a = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
let b = Array2::from_shape_vec((1, 1), vec![0.1]).unwrap();
let dynamics = LinearDynamics::new(a, b);
let x_ref = vec![Array1::from_vec(vec![0.0]); 5];
let u_ref = Array1::from_vec(vec![0.0]);
let q = Array1::from_vec(vec![1.0]);
let r = Array1::from_vec(vec![0.1]);
let q_term = Array1::from_vec(vec![10.0]);
let cost = QuadraticCost::new(x_ref, u_ref, q, r, q_term);
let mut mpc = MPCController::new(config, dynamics, cost);
let initial_state = Array1::from_vec(vec![1.0]);
let solution = mpc.solve(&initial_state).unwrap();
assert_eq!(solution.controls.len(), 5);
assert_eq!(solution.predicted_states.len(), 6); assert!(solution.total_cost < f32::INFINITY);
}
#[test]
fn test_mpc_warm_start() {
let config = MPCConfig {
prediction_horizon: 3,
control_horizon: 3,
state_dim: 1,
control_dim: 1,
warm_start: true,
..Default::default()
};
let a = Array2::from_shape_vec((1, 1), vec![0.95]).unwrap();
let b = Array2::from_shape_vec((1, 1), vec![0.05]).unwrap();
let dynamics = LinearDynamics::new(a, b);
let x_ref = vec![Array1::from_vec(vec![0.0]); 3];
let u_ref = Array1::from_vec(vec![0.0]);
let q = Array1::from_vec(vec![1.0]);
let r = Array1::from_vec(vec![0.01]);
let q_term = Array1::from_vec(vec![5.0]);
let cost = QuadraticCost::new(x_ref, u_ref, q, r, q_term);
let mut mpc = MPCController::new(config, dynamics, cost);
let state1 = Array1::from_vec(vec![1.0]);
let _sol1 = mpc.solve(&state1).unwrap();
assert!(mpc.previous_controls.is_some());
let state2 = Array1::from_vec(vec![0.9]);
let _sol2 = mpc.solve(&state2).unwrap();
}
#[test]
fn test_mpc_solution_methods() {
let controls = vec![
Array1::from_vec(vec![1.0]),
Array1::from_vec(vec![0.5]),
Array1::from_vec(vec![0.2]),
];
let states = vec![
Array1::from_vec(vec![0.0]),
Array1::from_vec(vec![0.1]),
Array1::from_vec(vec![0.15]),
Array1::from_vec(vec![0.17]),
];
let solution = MPCSolution {
controls,
predicted_states: states,
total_cost: 1.5,
horizon: 3,
};
assert_eq!(solution.first_control()[0], 1.0);
assert_eq!(solution.predicted_state(0).unwrap()[0], 0.0);
assert_eq!(solution.predicted_state(2).unwrap()[0], 0.15);
assert!(solution.is_feasible());
}
}