1use std::ops::Range;
2
3use crate::ModelError;
4
5pub trait Objective {
11 type Error;
13
14 fn dim(&self) -> usize;
16
17 fn value(&mut self, theta: &[f64]) -> Result<f64, Self::Error>;
19
20 fn gradient(&mut self, theta: &[f64], grad: &mut [f64]) -> Result<(), Self::Error>;
22
23 fn value_gradient(&mut self, theta: &[f64], grad: &mut [f64]) -> Result<f64, Self::Error> {
25 let value = self.value(theta)?;
26 self.gradient(theta, grad)?;
27 Ok(value)
28 }
29}
30
31#[derive(Debug)]
37pub struct BlockObjective<'a, O> {
38 pub full_objective: &'a mut O,
40 pub full_beta: Vec<f64>,
42 pub block: Range<usize>,
44}
45
46impl<'a, O> BlockObjective<'a, O> {
47 pub fn new(full_objective: &'a mut O, full_beta: Vec<f64>, block: Range<usize>) -> Self {
49 Self {
50 full_objective,
51 full_beta,
52 block,
53 }
54 }
55}
56
57impl<O> Objective for BlockObjective<'_, O>
58where
59 O: Objective,
60 O::Error: From<ModelError>,
61{
62 type Error = O::Error;
63
64 fn dim(&self) -> usize {
65 self.block.len()
66 }
67
68 fn value(&mut self, block_beta: &[f64]) -> Result<f64, Self::Error> {
69 validate_block_len("theta", block_beta.len(), self.block.len())?;
70
71 let mut beta = self.full_beta.clone();
72 beta[self.block.clone()].copy_from_slice(block_beta);
73 self.full_objective.value(&beta)
74 }
75
76 fn gradient(&mut self, block_beta: &[f64], grad: &mut [f64]) -> Result<(), Self::Error> {
77 validate_block_len("theta", block_beta.len(), self.block.len())?;
78 validate_block_len("gradient", grad.len(), self.block.len())?;
79
80 let mut beta = self.full_beta.clone();
81 beta[self.block.clone()].copy_from_slice(block_beta);
82
83 let mut full_grad = vec![0.0; self.full_objective.dim()];
84 self.full_objective.gradient(&beta, &mut full_grad)?;
85 grad.copy_from_slice(&full_grad[self.block.clone()]);
86 Ok(())
87 }
88}
89
90fn validate_block_len(
91 name: &'static str,
92 actual: usize,
93 expected: usize,
94) -> Result<(), ModelError> {
95 if actual == expected {
96 Ok(())
97 } else if name == "gradient" {
98 Err(ModelError::GradientLength { expected, actual })
99 } else {
100 Err(ModelError::BetaLength { expected, actual })
101 }
102}