use std::ops::Range;
use crate::ParameterName;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParameterSlice {
pub name: &'static str,
pub range: Range<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParameterLayout {
slices: Vec<ParameterSlice>,
}
impl ParameterLayout {
#[must_use]
pub fn new(slices: Vec<ParameterSlice>) -> Self {
Self { slices }
}
#[must_use]
pub fn len(&self) -> usize {
self.slices.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.slices.is_empty()
}
#[must_use]
pub fn ncoefficients(&self) -> usize {
self.slices
.iter()
.map(|slice| slice.range.end)
.max()
.unwrap_or(0)
}
#[must_use]
pub fn slices(&self) -> &[ParameterSlice] {
&self.slices
}
pub fn visit_slices(&self, mut visit: impl FnMut(usize, &'static str, Range<usize>)) {
for (index, slice) in self.slices.iter().enumerate() {
visit(index, slice.name, slice.range.clone());
}
}
#[must_use]
pub fn slice(&self, name: &str) -> Option<Range<usize>> {
self.slices
.iter()
.find(|slice| slice.name == name)
.map(|slice| slice.range.clone())
}
#[must_use]
pub fn slice_of<P>(&self) -> Option<Range<usize>>
where
P: ParameterName,
{
self.slice(P::NAME)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ParameterCoefficients {
pub name: &'static str,
pub coefficients: Vec<f64>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct UnpackedTheta {
pub blocks: Vec<ParameterCoefficients>,
}
impl UnpackedTheta {
#[must_use]
pub fn block(&self, name: &str) -> Option<&ParameterCoefficients> {
self.blocks.iter().find(|block| block.name == name)
}
#[must_use]
pub fn block_of<P>(&self) -> Option<&ParameterCoefficients>
where
P: ParameterName,
{
self.block(P::NAME)
}
#[must_use]
pub fn coefficients(&self, name: &str) -> Option<&[f64]> {
self.block(name).map(|block| block.coefficients.as_slice())
}
#[must_use]
pub fn coefficients_of<P>(&self) -> Option<&[f64]>
where
P: ParameterName,
{
self.coefficients(P::NAME)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TrainingDiagnostics {
pub objective: f64,
pub train_nll: f64,
pub penalty: f64,
pub gradient_norm: f64,
pub nonfinite_gradient_count: usize,
}