1use std::{marker::PhantomData, ops::Range};
2
3use crate::{DesignMatrix, LinearPredictorBlock, ModelError, PredictorBlock};
4
5pub trait ParameterName {
7 const NAME: &'static str;
9}
10
11#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
19pub struct ParameterBlocks;
20
21pub trait AssignParameterOffsets: Sized {
23 #[must_use]
25 fn assign_offsets(self, start: usize) -> Self;
26}
27
28impl ParameterBlocks {
29 #[allow(clippy::new_ret_no_self)]
31 #[must_use]
32 pub fn new<Blocks>(blocks: Blocks) -> Blocks
33 where
34 Blocks: AssignParameterOffsets,
35 {
36 Self::with_start(0, blocks)
37 }
38
39 #[must_use]
41 pub fn with_start<Blocks>(start: usize, blocks: Blocks) -> Blocks
42 where
43 Blocks: AssignParameterOffsets,
44 {
45 blocks.assign_offsets(start)
46 }
47}
48
49#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
51pub struct Mu;
52
53#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
55pub struct Sigma;
56
57#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
59pub struct Nu;
60
61#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
63pub struct Tau;
64
65#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
67pub struct Rate;
68
69#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
71pub struct Shape;
72
73#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
75pub struct Scale;
76
77#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
82pub struct Precision;
83
84impl ParameterName for Mu {
85 const NAME: &'static str = "mu";
86}
87
88impl ParameterName for Sigma {
89 const NAME: &'static str = "sigma";
90}
91
92impl ParameterName for Nu {
93 const NAME: &'static str = "nu";
94}
95
96impl ParameterName for Tau {
97 const NAME: &'static str = "tau";
98}
99
100impl ParameterName for Rate {
101 const NAME: &'static str = "rate";
102}
103
104impl ParameterName for Shape {
105 const NAME: &'static str = "shape";
106}
107
108impl ParameterName for Scale {
109 const NAME: &'static str = "scale";
110}
111
112impl ParameterName for Precision {
113 const NAME: &'static str = "precision";
114}
115
116#[derive(Debug, Clone, PartialEq)]
122pub struct ParameterBlock<P, L, X, Penalty> {
123 pub x: X,
125 pub penalty: Penalty,
127 pub offset: usize,
129 pub len: usize,
131 marker: PhantomData<(P, L)>,
132}
133
134impl<P, L, X, Penalty> ParameterBlock<P, L, X, Penalty>
135where
136 X: PredictorBlock,
137{
138 #[must_use]
140 pub fn new(x: X, penalty: Penalty, offset: usize) -> Self {
141 let len = x.nparams();
142 Self::from_len(x, penalty, offset, len)
143 }
144
145 #[must_use]
150 pub fn from_predictor(x: X, penalty: Penalty, offset: usize) -> Self {
151 Self::new(x, penalty, offset)
152 }
153}
154
155impl<P, L, X, Penalty> ParameterBlock<P, L, LinearPredictorBlock<X>, Penalty>
156where
157 X: DesignMatrix,
158{
159 #[must_use]
161 pub fn linear(x: X, penalty: Penalty, offset: usize) -> Self {
162 Self::new(LinearPredictorBlock::new(x), penalty, offset)
163 }
164}
165
166impl<P, L, X, Penalty> ParameterBlock<P, L, X, Penalty> {
167 fn from_len(x: X, penalty: Penalty, offset: usize, len: usize) -> Self {
168 Self {
169 x,
170 penalty,
171 offset,
172 len,
173 marker: PhantomData,
174 }
175 }
176
177 #[must_use]
179 pub fn with_offset(mut self, offset: usize) -> Self {
180 self.offset = offset;
181 self
182 }
183
184 #[must_use]
191 pub fn range(&self) -> Range<usize> {
192 self.offset..self.end()
193 }
194
195 #[must_use]
202 pub fn end(&self) -> usize {
203 self.offset
204 .checked_add(self.len)
205 .expect("parameter block range end must fit in usize")
206 }
207
208 #[must_use]
210 pub fn len(&self) -> usize {
211 self.len
212 }
213
214 #[must_use]
216 pub fn is_empty(&self) -> bool {
217 self.len == 0
218 }
219}
220
221impl<P, L, X, Penalty> ParameterBlock<P, L, X, Penalty>
222where
223 P: ParameterName,
224{
225 pub fn try_range(&self) -> Result<Range<usize>, ModelError> {
232 let end = self
233 .offset
234 .checked_add(self.len)
235 .ok_or(ModelError::BlockRangeOverflow {
236 parameter: P::NAME,
237 offset: self.offset,
238 len: self.len,
239 })?;
240 Ok(self.offset..end)
241 }
242}
243
244macro_rules! impl_assign_offsets {
245 (
246 types = ($($block:ident),+);
247 vars = ($($var:ident),+)
248 ) => {
249 impl<$($block,)+> AssignParameterOffsets for ($($block,)+)
250 where
251 $($block: OffsetAssignable,)+
252 {
253 fn assign_offsets(self, start: usize) -> Self {
254 let ($($var,)+) = self;
255 let mut offset = start;
256 $(
257 let $var = $var.with_assigned_offset(offset);
258 offset = offset.saturating_add($var.assigned_len());
259 )+
260 let _ = offset;
261 ($($var,)+)
262 }
263 }
264 };
265}
266
267trait OffsetAssignable: Sized {
268 fn with_assigned_offset(self, offset: usize) -> Self;
269 fn assigned_len(&self) -> usize;
270}
271
272impl<P, L, X, Penalty> OffsetAssignable for ParameterBlock<P, L, X, Penalty> {
273 fn with_assigned_offset(self, offset: usize) -> Self {
274 self.with_offset(offset)
275 }
276
277 fn assigned_len(&self) -> usize {
278 self.len()
279 }
280}
281
282impl_assign_offsets!(types = (B1); vars = (b1));
283impl_assign_offsets!(types = (B1, B2); vars = (b1, b2));
284impl_assign_offsets!(types = (B1, B2, B3); vars = (b1, b2, b3));
285impl_assign_offsets!(types = (B1, B2, B3, B4); vars = (b1, b2, b3, b4));
286impl_assign_offsets!(types = (B1, B2, B3, B4, B5); vars = (b1, b2, b3, b4, b5));
287impl_assign_offsets!(types = (B1, B2, B3, B4, B5, B6); vars = (b1, b2, b3, b4, b5, b6));
288impl_assign_offsets!(
289 types = (B1, B2, B3, B4, B5, B6, B7);
290 vars = (b1, b2, b3, b4, b5, b6, b7)
291);
292impl_assign_offsets!(
293 types = (B1, B2, B3, B4, B5, B6, B7, B8);
294 vars = (b1, b2, b3, b4, b5, b6, b7, b8)
295);
296
297#[cfg(test)]
298mod tests {
299 use crate::{DenseDesign, Identity, LinearPredictorBlock, NoPenalty};
300
301 use super::{
302 Mu, Nu, ParameterBlock, ParameterBlocks, Precision, Rate, Scale, Shape, Sigma, Tau,
303 };
304
305 #[test]
306 fn parameter_blocks_assign_offsets_for_one_block() {
307 let mu = ParameterBlock::<Mu, Identity, _, _>::linear(
308 DenseDesign::from_rows(&[[1.0, 2.0]]),
309 NoPenalty,
310 99,
311 );
312
313 let (mu,) = ParameterBlocks::new((mu,));
314
315 assert_eq!(mu.range(), 0..2);
316 }
317
318 #[test]
319 fn parameter_blocks_assign_offsets_for_two_blocks() {
320 let mu = ParameterBlock::<Mu, Identity, _, _>::linear(
321 DenseDesign::from_rows(&[[1.0, 2.0]]),
322 NoPenalty,
323 99,
324 );
325 let sigma = ParameterBlock::<Sigma, Identity, _, _>::linear(
326 DenseDesign::from_rows(&[[1.0, 2.0, 3.0]]),
327 NoPenalty,
328 99,
329 );
330
331 let (mu, sigma) = ParameterBlocks::new((mu, sigma));
332
333 assert_eq!(mu.range(), 0..2);
334 assert_eq!(sigma.range(), 2..5);
335 }
336
337 #[test]
338 fn parameter_blocks_assign_offsets_for_eight_blocks_with_start() {
339 let blocks = (
340 intercept_block::<Mu>(),
341 intercept_block::<Sigma>(),
342 intercept_block::<Nu>(),
343 intercept_block::<Tau>(),
344 intercept_block::<Shape>(),
345 intercept_block::<Scale>(),
346 intercept_block::<Rate>(),
347 intercept_block::<Precision>(),
348 );
349
350 let (b1, b2, b3, b4, b5, b6, b7, b8) = ParameterBlocks::with_start(10, blocks);
351
352 assert_eq!(b1.range(), 10..11);
353 assert_eq!(b2.range(), 11..12);
354 assert_eq!(b3.range(), 12..13);
355 assert_eq!(b4.range(), 13..14);
356 assert_eq!(b5.range(), 14..15);
357 assert_eq!(b6.range(), 15..16);
358 assert_eq!(b7.range(), 16..17);
359 assert_eq!(b8.range(), 17..18);
360 }
361
362 #[test]
363 fn parameter_block_try_range_reports_overflow() {
364 let block = ParameterBlock::<Mu, Identity, _, _>::linear(
365 DenseDesign::from_rows(&[[1.0, 2.0]]),
366 NoPenalty,
367 usize::MAX,
368 );
369
370 assert_eq!(
371 block.try_range().unwrap_err(),
372 crate::ModelError::BlockRangeOverflow {
373 parameter: "mu",
374 offset: usize::MAX,
375 len: 2,
376 }
377 );
378 }
379
380 fn intercept_block<P>()
381 -> ParameterBlock<P, Identity, LinearPredictorBlock<DenseDesign>, NoPenalty> {
382 ParameterBlock::linear(DenseDesign::intercept(1), NoPenalty, 99)
383 }
384}