Skip to main content

gamlss_core/
objective.rs

1use std::ops::Range;
2
3use crate::ModelError;
4
5/// Независимый от оптимизатора оракул над плоским вектором параметров.
6///
7/// Методы принимают `&mut self`, чтобы реализации могли переиспользовать
8/// временные буферы, не раскрывая состояние, специфичное для оптимизатора,
9/// в `gamlss-core`.
10pub trait Objective {
11    /// Recoverable error returned by objective evaluation.
12    type Error;
13
14    /// Dimension of the flat parameter vector accepted by this objective.
15    fn dim(&self) -> usize;
16
17    /// Objective value at `theta`.
18    fn value(&mut self, theta: &[f64]) -> Result<f64, Self::Error>;
19
20    /// Writes the gradient at `theta` into preallocated `grad`.
21    fn gradient(&mut self, theta: &[f64], grad: &mut [f64]) -> Result<(), Self::Error>;
22
23    /// Computes objective value and gradient at `theta`.
24    fn value_gradient(&mut self, theta: &[f64], grad: &mut [f64]) -> Result<f64, Self::Error> {
25        let value = self.value(theta)?;
26        self.gradient(theta, grad)?;
27        Ok(value)
28    }
29}
30
31/// Objective по одному блоку коэффициентов при фиксированных остальных блоках.
32///
33/// Оборачивает полный objective и проецирует вызовы на диапазон одного
34/// параметрического блока, копируя коэффициенты блока в общий `full_beta`
35/// перед вычислением.
36#[derive(Debug)]
37pub struct BlockObjective<'a, O> {
38    /// Полный objective.
39    pub full_objective: &'a mut O,
40    /// Текущий полный beta-вектор.
41    pub full_beta: Vec<f64>,
42    /// Диапазон оптимизируемого блока.
43    pub block: Range<usize>,
44}
45
46impl<'a, O> BlockObjective<'a, O> {
47    /// Создаёт block objective поверх полного objective.
48    pub fn new(full_objective: &'a mut O, full_beta: Vec<f64>, block: Range<usize>) -> Self {
49        Self {
50            full_objective,
51            full_beta,
52            block,
53        }
54    }
55}
56
57impl<O> Objective for BlockObjective<'_, O>
58where
59    O: Objective,
60    O::Error: From<ModelError>,
61{
62    type Error = O::Error;
63
64    fn dim(&self) -> usize {
65        self.block.len()
66    }
67
68    fn value(&mut self, block_beta: &[f64]) -> Result<f64, Self::Error> {
69        validate_block_len("theta", block_beta.len(), self.block.len())?;
70
71        let mut beta = self.full_beta.clone();
72        beta[self.block.clone()].copy_from_slice(block_beta);
73        self.full_objective.value(&beta)
74    }
75
76    fn gradient(&mut self, block_beta: &[f64], grad: &mut [f64]) -> Result<(), Self::Error> {
77        validate_block_len("theta", block_beta.len(), self.block.len())?;
78        validate_block_len("gradient", grad.len(), self.block.len())?;
79
80        let mut beta = self.full_beta.clone();
81        beta[self.block.clone()].copy_from_slice(block_beta);
82
83        let mut full_grad = vec![0.0; self.full_objective.dim()];
84        self.full_objective.gradient(&beta, &mut full_grad)?;
85        grad.copy_from_slice(&full_grad[self.block.clone()]);
86        Ok(())
87    }
88}
89
90fn validate_block_len(
91    name: &'static str,
92    actual: usize,
93    expected: usize,
94) -> Result<(), ModelError> {
95    if actual == expected {
96        Ok(())
97    } else if name == "gradient" {
98        Err(ModelError::GradientLength { expected, actual })
99    } else {
100        Err(ModelError::BetaLength { expected, actual })
101    }
102}