pub trait Family {
type Observation<'obs>;
type Eta;
type Theta;
type NllGradientEta;
fn theta(&self, eta: Self::Eta) -> Self::Theta;
fn nll<'obs>(&self, observation: Self::Observation<'obs>, theta: Self::Theta) -> f64;
fn nll_eta<'obs>(&self, observation: Self::Observation<'obs>, eta: Self::Eta) -> f64 {
self.nll(observation, self.theta(eta))
}
fn nll_and_gradient_eta(
&self,
observation: Self::Observation<'_>,
eta: Self::Eta,
) -> (f64, Self::NllGradientEta);
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DenseInformation<const K: usize> {
values: [[f64; K]; K],
}
impl<const K: usize> DenseInformation<K> {
#[must_use]
pub const fn new(values: [[f64; K]; K]) -> Self {
Self { values }
}
#[must_use]
pub fn diagonal(diagonal: [f64; K]) -> Self {
let mut values = [[0.0; K]; K];
let mut index = 0;
while index < K {
values[index][index] = diagonal[index];
index += 1;
}
Self { values }
}
#[must_use]
pub fn get(&self, row: usize, col: usize) -> f64 {
self.values[row][col]
}
#[must_use]
pub const fn as_array(&self) -> &[[f64; K]; K] {
&self.values
}
}
pub trait HasDiagonalFisherInfo: Family {
fn nll_gradient_and_diagonal_fisher_eta(
&self,
observation: Self::Observation<'_>,
eta: Self::Eta,
) -> (f64, Self::NllGradientEta, Self::NllGradientEta);
}
pub trait HasExpectedInformation<const K: usize>: Family
where
Self::Eta: ParameterParts<K>,
Self::NllGradientEta: ParameterParts<K>,
{
fn nll_gradient_and_expected_information_eta(
&self,
observation: Self::Observation<'_>,
eta: Self::Eta,
) -> (f64, Self::NllGradientEta, DenseInformation<K>);
}
pub trait ParameterParts<const K: usize>: Sized {
fn from_array(values: [f64; K]) -> Self;
fn part(&self, index: usize) -> f64;
}
impl ParameterParts<1> for f64 {
#[inline(always)]
fn from_array(values: [f64; 1]) -> Self {
values[0]
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => *self,
_ => unreachable!("one-parameter parts only have index 0"),
}
}
}
impl ParameterParts<2> for (f64, f64) {
#[inline(always)]
fn from_array(values: [f64; 2]) -> Self {
(values[0], values[1])
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => self.0,
1 => self.1,
_ => unreachable!("two-parameter parts only have indices 0 and 1"),
}
}
}
impl ParameterParts<3> for (f64, f64, f64) {
#[inline(always)]
fn from_array(values: [f64; 3]) -> Self {
(values[0], values[1], values[2])
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => self.0,
1 => self.1,
2 => self.2,
_ => unreachable!("three-parameter parts only have indices 0, 1 and 2"),
}
}
}
impl ParameterParts<4> for (f64, f64, f64, f64) {
#[inline(always)]
fn from_array(values: [f64; 4]) -> Self {
(values[0], values[1], values[2], values[3])
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => self.0,
1 => self.1,
2 => self.2,
3 => self.3,
_ => unreachable!("four-parameter parts only have indices 0, 1, 2 and 3"),
}
}
}
impl ParameterParts<5> for (f64, f64, f64, f64, f64) {
#[inline(always)]
fn from_array(values: [f64; 5]) -> Self {
(values[0], values[1], values[2], values[3], values[4])
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => self.0,
1 => self.1,
2 => self.2,
3 => self.3,
4 => self.4,
_ => unreachable!("five-parameter parts only have indices 0 through 4"),
}
}
}
impl ParameterParts<6> for (f64, f64, f64, f64, f64, f64) {
#[inline(always)]
fn from_array(values: [f64; 6]) -> Self {
(
values[0], values[1], values[2], values[3], values[4], values[5],
)
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => self.0,
1 => self.1,
2 => self.2,
3 => self.3,
4 => self.4,
5 => self.5,
_ => unreachable!("six-parameter parts only have indices 0 through 5"),
}
}
}
impl ParameterParts<7> for (f64, f64, f64, f64, f64, f64, f64) {
#[inline(always)]
fn from_array(values: [f64; 7]) -> Self {
(
values[0], values[1], values[2], values[3], values[4], values[5], values[6],
)
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => self.0,
1 => self.1,
2 => self.2,
3 => self.3,
4 => self.4,
5 => self.5,
6 => self.6,
_ => unreachable!("seven-parameter parts only have indices 0 through 6"),
}
}
}
impl ParameterParts<8> for (f64, f64, f64, f64, f64, f64, f64, f64) {
#[inline(always)]
fn from_array(values: [f64; 8]) -> Self {
(
values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7],
)
}
#[inline(always)]
fn part(&self, index: usize) -> f64 {
match index {
0 => self.0,
1 => self.1,
2 => self.2,
3 => self.3,
4 => self.4,
5 => self.5,
6 => self.6,
7 => self.7,
_ => unreachable!("eight-parameter parts only have indices 0 through 7"),
}
}
}
pub trait ParameterizedFamily<const K: usize>: Family
where
Self::Eta: ParameterParts<K>,
Self::NllGradientEta: ParameterParts<K>,
{
type Params;
type Links;
}
pub trait HasCdf: Family {
fn cdf(&self, y: f64, theta: Self::Theta) -> f64;
}
pub trait HasQuantile: Family {
fn quantile(&self, p: f64, theta: Self::Theta) -> f64;
}
pub trait HasCrps: Family {
fn crps<'obs>(&self, observation: Self::Observation<'obs>, theta: Self::Theta) -> f64;
}
pub trait CanSimulate<Rng>: Family {
fn sample(&self, rng: &mut Rng, theta: Self::Theta) -> f64;
}
pub trait HasDeviance: Family {
fn deviance<'obs>(&self, observation: Self::Observation<'obs>, theta: Self::Theta) -> f64;
}
pub trait HasInitialEta: Family {
fn initial_eta<'obs>(&self, observation: Self::Observation<'obs>) -> Self::Eta;
}