use crate::core::{ArgminFloat, Error, SendAlias, SyncAlias};
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use std::collections::HashMap;
#[derive(Clone, Debug, Default)]
pub struct Problem<O> {
pub problem: Option<O>,
pub counts: HashMap<&'static str, u64>,
}
impl<O> Problem<O> {
pub fn new(problem: O) -> Self {
Problem {
problem: Some(problem),
counts: HashMap::new(),
}
}
pub fn problem<T, F: FnOnce(&O) -> Result<T, Error>>(
&mut self,
counts_string: &'static str,
func: F,
) -> Result<T, Error> {
let count = self.counts.entry(counts_string).or_insert(0);
*count += 1;
func(self.problem.as_ref().unwrap())
}
pub fn bulk_problem<T, F: FnOnce(&O) -> Result<T, Error>>(
&mut self,
counts_string: &'static str,
num_param_vecs: usize,
func: F,
) -> Result<T, Error> {
let count = self.counts.entry(counts_string).or_insert(0);
*count += num_param_vecs as u64;
func(self.problem.as_ref().unwrap())
}
pub fn take_problem(&mut self) -> Option<O> {
self.problem.take()
}
pub fn consume_problem(&mut self, mut other: Problem<O>) {
self.problem = Some(other.take_problem().unwrap());
self.consume_func_counts(other);
}
pub fn consume_func_counts<O2>(&mut self, other: Problem<O2>) {
for (k, v) in other.counts.iter() {
let count = self.counts.entry(k).or_insert(0);
*count += v
}
}
pub fn reset(&mut self) {
for (_, v) in self.counts.iter_mut() {
*v = 0;
}
}
pub fn get_problem(self) -> Option<O> {
self.problem
}
}
pub trait Operator {
type Param;
type Output;
fn apply(&self, param: &Self::Param) -> Result<Self::Output, Error>;
bulk!(apply, Self::Param, Self::Output);
}
pub trait CostFunction {
type Param;
type Output;
fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error>;
bulk!(cost, Self::Param, Self::Output);
}
pub trait Gradient {
type Param;
type Gradient;
fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, Error>;
bulk!(gradient, Self::Param, Self::Gradient);
}
pub trait Hessian {
type Param;
type Hessian;
fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, Error>;
bulk!(hessian, Self::Param, Self::Hessian);
}
pub trait Jacobian {
type Param;
type Jacobian;
fn jacobian(&self, param: &Self::Param) -> Result<Self::Jacobian, Error>;
bulk!(jacobian, Self::Param, Self::Jacobian);
}
pub trait LinearProgram {
type Param;
type Float: ArgminFloat;
fn c(&self) -> Result<Vec<Self::Float>, Error> {
Err(argmin_error!(
NotImplemented,
"Method `c` of LinearProgram trait not implemented!"
))
}
fn b(&self) -> Result<Vec<Self::Float>, Error> {
Err(argmin_error!(
NotImplemented,
"Method `b` of LinearProgram trait not implemented!"
))
}
#[allow(non_snake_case)]
fn A(&self) -> Result<Vec<Vec<Self::Float>>, Error> {
Err(argmin_error!(
NotImplemented,
"Method `A` of LinearProgram trait not implemented!"
))
}
}
impl<O: Operator> Problem<O> {
pub fn apply(&mut self, param: &O::Param) -> Result<O::Output, Error> {
self.problem("operator_count", |problem| problem.apply(param))
}
pub fn bulk_apply<P>(&mut self, params: &[P]) -> Result<Vec<O::Output>, Error>
where
P: std::borrow::Borrow<O::Param> + SyncAlias,
O::Output: SendAlias,
O: SyncAlias,
{
self.bulk_problem("operator_count", params.len(), |problem| {
problem.bulk_apply(params)
})
}
}
impl<O: CostFunction> Problem<O> {
pub fn cost(&mut self, param: &O::Param) -> Result<O::Output, Error> {
self.problem("cost_count", |problem| problem.cost(param))
}
pub fn bulk_cost<P>(&mut self, params: &[P]) -> Result<Vec<O::Output>, Error>
where
P: std::borrow::Borrow<O::Param> + SyncAlias,
O::Output: SendAlias,
O: SyncAlias,
{
self.bulk_problem("cost_count", params.len(), |problem| {
problem.bulk_cost(params)
})
}
}
impl<O: Gradient> Problem<O> {
pub fn gradient(&mut self, param: &O::Param) -> Result<O::Gradient, Error> {
self.problem("gradient_count", |problem| problem.gradient(param))
}
pub fn bulk_gradient<P>(&mut self, params: &[P]) -> Result<Vec<O::Gradient>, Error>
where
P: std::borrow::Borrow<O::Param> + SyncAlias,
O::Gradient: SendAlias,
O: SyncAlias,
{
self.bulk_problem("gradient_count", params.len(), |problem| {
problem.bulk_gradient(params)
})
}
}
impl<O: Hessian> Problem<O> {
pub fn hessian(&mut self, param: &O::Param) -> Result<O::Hessian, Error> {
self.problem("hessian_count", |problem| problem.hessian(param))
}
pub fn bulk_hessian<P>(&mut self, params: &[P]) -> Result<Vec<O::Hessian>, Error>
where
P: std::borrow::Borrow<O::Param> + SyncAlias,
O::Hessian: SendAlias,
O: SyncAlias,
{
self.bulk_problem("hessian_count", params.len(), |problem| {
problem.bulk_hessian(params)
})
}
}
impl<O: Jacobian> Problem<O> {
pub fn jacobian(&mut self, param: &O::Param) -> Result<O::Jacobian, Error> {
self.problem("jacobian_count", |problem| problem.jacobian(param))
}
pub fn bulk_jacobian<P>(&mut self, params: &[P]) -> Result<Vec<O::Jacobian>, Error>
where
P: std::borrow::Borrow<O::Param> + SyncAlias,
O::Jacobian: SendAlias,
O: SyncAlias,
{
self.bulk_problem("jacobian_count", params.len(), |problem| {
problem.bulk_jacobian(params)
})
}
}
impl<O: LinearProgram> Problem<O> {
pub fn c(&self) -> Result<Vec<O::Float>, Error> {
self.problem.as_ref().unwrap().c()
}
pub fn b(&self) -> Result<Vec<O::Float>, Error> {
self.problem.as_ref().unwrap().b()
}
#[allow(non_snake_case)]
pub fn A(&self) -> Result<Vec<Vec<O::Float>>, Error> {
self.problem.as_ref().unwrap().A()
}
}