use std::ops::Range;
use crate::ModelError;
pub trait Objective {
type Error;
fn dim(&self) -> usize;
fn value(&mut self, theta: &[f64]) -> Result<f64, Self::Error>;
fn value_gradient(&mut self, theta: &[f64], grad: &mut [f64]) -> Result<f64, Self::Error>;
fn gradient(&mut self, theta: &[f64], grad: &mut [f64]) -> Result<(), Self::Error> {
self.value_gradient(theta, grad).map(|_| ())
}
}
#[derive(Debug)]
pub struct BlockObjective<'a, O> {
pub full_objective: &'a mut O,
pub working_beta: Vec<f64>,
pub full_grad: Vec<f64>,
pub block: Range<usize>,
}
impl<'a, O> BlockObjective<'a, O>
where
O: Objective,
{
pub fn new(full_objective: &'a mut O, full_beta: Vec<f64>, block: Range<usize>) -> Self {
let full_grad = vec![0.0; full_objective.dim()];
debug_assert!(block.end <= full_beta.len());
debug_assert!(block.end <= full_grad.len());
debug_assert_eq!(full_beta.len(), full_grad.len());
Self {
full_objective,
working_beta: full_beta,
full_grad,
block,
}
}
fn update_block_beta(&mut self, block_beta: &[f64]) {
self.working_beta[self.block.start..self.block.end].copy_from_slice(block_beta);
}
}
impl<O> Objective for BlockObjective<'_, O>
where
O: Objective,
O::Error: From<ModelError>,
{
type Error = O::Error;
fn dim(&self) -> usize {
self.block.len()
}
fn value(&mut self, block_beta: &[f64]) -> Result<f64, Self::Error> {
validate_block_len("theta", block_beta.len(), self.block.len())?;
self.update_block_beta(block_beta);
self.full_objective.value(&self.working_beta)
}
fn gradient(&mut self, block_beta: &[f64], grad: &mut [f64]) -> Result<(), Self::Error> {
self.value_gradient(block_beta, grad).map(|_| ())
}
fn value_gradient(&mut self, block_beta: &[f64], grad: &mut [f64]) -> Result<f64, Self::Error> {
validate_block_len("theta", block_beta.len(), self.block.len())?;
validate_block_len("gradient", grad.len(), self.block.len())?;
self.update_block_beta(block_beta);
let value = self
.full_objective
.value_gradient(&self.working_beta, &mut self.full_grad)?;
grad.copy_from_slice(&self.full_grad[self.block.start..self.block.end]);
Ok(value)
}
}
fn validate_block_len(
name: &'static str,
actual: usize,
expected: usize,
) -> Result<(), ModelError> {
if actual == expected {
Ok(())
} else if name == "gradient" {
Err(ModelError::GradientLength { expected, actual })
} else {
Err(ModelError::BetaLength { expected, actual })
}
}
#[cfg(test)]
mod tests {
use super::{BlockObjective, Objective};
use crate::ModelError;
#[derive(Debug)]
struct QuadraticObjective {
dim: usize,
}
impl Objective for QuadraticObjective {
type Error = ModelError;
fn dim(&self) -> usize {
self.dim
}
fn value(&mut self, theta: &[f64]) -> Result<f64, Self::Error> {
Ok(0.5 * theta.iter().map(|value| value * value).sum::<f64>())
}
fn value_gradient(&mut self, theta: &[f64], grad: &mut [f64]) -> Result<f64, Self::Error> {
grad.copy_from_slice(theta);
self.value(theta)
}
}
#[test]
fn block_objective_reuses_working_buffers_on_repeated_calls() {
let mut full = QuadraticObjective { dim: 3 };
let mut objective = BlockObjective::new(&mut full, vec![1.0, 2.0, 3.0], 1..3);
let beta_capacity = objective.working_beta.capacity();
let grad_capacity = objective.full_grad.capacity();
let mut grad = vec![0.0; objective.dim()];
assert_eq!(objective.dim(), 2);
assert_eq!(objective.value(&[4.0, 5.0]).unwrap(), 21.0);
assert_eq!(
objective.value_gradient(&[6.0, 7.0], &mut grad).unwrap(),
43.0
);
assert_eq!(grad, vec![6.0, 7.0]);
assert_eq!(objective.working_beta, vec![1.0, 6.0, 7.0]);
assert_eq!(objective.working_beta.capacity(), beta_capacity);
assert_eq!(objective.full_grad.capacity(), grad_capacity);
assert_eq!(objective.value(&[8.0, 9.0]).unwrap(), 73.0);
assert_eq!(objective.working_beta.capacity(), beta_capacity);
assert_eq!(objective.full_grad.capacity(), grad_capacity);
}
}