1use std::{marker::PhantomData, ops::Range};
2
3use crate::{DesignMatrix, LinearPredictorBlock, ModelError, PredictorBlock};
4
5pub trait ParameterName {
7 const NAME: &'static str;
9}
10
11#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
19pub struct ParameterBlocks;
20
21pub trait AssignParameterOffsets: Sized {
23 #[must_use]
25 fn assign_offsets(self, start: usize) -> Self;
26}
27
28impl ParameterBlocks {
29 #[allow(clippy::new_ret_no_self)]
31 #[must_use]
32 #[inline]
33 pub fn new<Blocks>(blocks: Blocks) -> Blocks
34 where
35 Blocks: AssignParameterOffsets,
36 {
37 Self::with_start(0, blocks)
38 }
39
40 #[must_use]
42 #[inline]
43 pub fn with_start<Blocks>(start: usize, blocks: Blocks) -> Blocks
44 where
45 Blocks: AssignParameterOffsets,
46 {
47 blocks.assign_offsets(start)
48 }
49}
50
51#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
53pub struct Mu;
54
55#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
57pub struct Sigma;
58
59#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
61pub struct Nu;
62
63#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
65pub struct Tau;
66
67#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
69pub struct Rate;
70
71#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub struct Shape;
74
75#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
77pub struct Scale;
78
79#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
84pub struct Precision;
85
86impl ParameterName for Mu {
87 const NAME: &'static str = "mu";
88}
89
90impl ParameterName for Sigma {
91 const NAME: &'static str = "sigma";
92}
93
94impl ParameterName for Nu {
95 const NAME: &'static str = "nu";
96}
97
98impl ParameterName for Tau {
99 const NAME: &'static str = "tau";
100}
101
102impl ParameterName for Rate {
103 const NAME: &'static str = "rate";
104}
105
106impl ParameterName for Shape {
107 const NAME: &'static str = "shape";
108}
109
110impl ParameterName for Scale {
111 const NAME: &'static str = "scale";
112}
113
114impl ParameterName for Precision {
115 const NAME: &'static str = "precision";
116}
117
118#[derive(Debug, Clone, PartialEq)]
124pub struct ParameterBlock<P, L, X, Penalty> {
125 pub x: X,
127 pub penalty: Penalty,
129 pub offset: usize,
131 pub len: usize,
133 marker: PhantomData<(P, L)>,
134}
135
136impl<P, L, X, Penalty> ParameterBlock<P, L, X, Penalty>
137where
138 X: PredictorBlock,
139{
140 #[must_use]
142 #[inline]
143 pub fn new(x: X, penalty: Penalty, offset: usize) -> Self {
144 let len = x.nparams();
145 Self::from_len(x, penalty, offset, len)
146 }
147
148 #[must_use]
153 #[inline]
154 pub fn from_predictor(x: X, penalty: Penalty, offset: usize) -> Self {
155 Self::new(x, penalty, offset)
156 }
157}
158
159impl<P, L, X, Penalty> ParameterBlock<P, L, LinearPredictorBlock<X>, Penalty>
160where
161 X: DesignMatrix,
162{
163 #[must_use]
165 #[inline]
166 pub fn linear(x: X, penalty: Penalty, offset: usize) -> Self {
167 Self::new(LinearPredictorBlock::new(x), penalty, offset)
168 }
169}
170
171impl<P, L, X, Penalty> ParameterBlock<P, L, X, Penalty> {
172 #[inline]
173 fn from_len(x: X, penalty: Penalty, offset: usize, len: usize) -> Self {
174 Self {
175 x,
176 penalty,
177 offset,
178 len,
179 marker: PhantomData,
180 }
181 }
182
183 #[must_use]
185 #[inline]
186 pub fn with_offset(mut self, offset: usize) -> Self {
187 self.offset = offset;
188 self
189 }
190
191 #[must_use]
198 #[inline]
199 pub fn range(&self) -> Range<usize> {
200 self.offset..self.end()
201 }
202
203 #[must_use]
210 #[inline]
211 pub fn end(&self) -> usize {
212 self.offset
213 .checked_add(self.len)
214 .expect("parameter block range end must fit in usize")
215 }
216
217 #[must_use]
219 #[inline]
220 pub fn len(&self) -> usize {
221 self.len
222 }
223
224 #[must_use]
226 #[inline]
227 pub fn is_empty(&self) -> bool {
228 self.len == 0
229 }
230}
231
232impl<P, L, X, Penalty> ParameterBlock<P, L, X, Penalty>
233where
234 P: ParameterName,
235{
236 #[inline]
243 pub fn try_range(&self) -> Result<Range<usize>, ModelError> {
244 let end = self
245 .offset
246 .checked_add(self.len)
247 .ok_or(ModelError::BlockRangeOverflow {
248 parameter: P::NAME,
249 offset: self.offset,
250 len: self.len,
251 })?;
252 Ok(self.offset..end)
253 }
254}
255
256macro_rules! impl_assign_offsets {
257 (
258 types = ($($block:ident),+);
259 vars = ($($var:ident),+)
260 ) => {
261 impl<$($block,)+> AssignParameterOffsets for ($($block,)+)
262 where
263 $($block: OffsetAssignable,)+
264 {
265 #[inline]
266 fn assign_offsets(self, start: usize) -> Self {
267 let ($($var,)+) = self;
268 let mut offset = start;
269 $(
270 let $var = $var.with_assigned_offset(offset);
271 offset = offset.saturating_add($var.assigned_len());
272 )+
273 let _ = offset;
274 ($($var,)+)
275 }
276 }
277 };
278}
279
280trait OffsetAssignable: Sized {
281 fn with_assigned_offset(self, offset: usize) -> Self;
282 fn assigned_len(&self) -> usize;
283}
284
285impl<P, L, X, Penalty> OffsetAssignable for ParameterBlock<P, L, X, Penalty> {
286 fn with_assigned_offset(self, offset: usize) -> Self {
287 self.with_offset(offset)
288 }
289
290 fn assigned_len(&self) -> usize {
291 self.len()
292 }
293}
294
295impl_assign_offsets!(types = (B1); vars = (b1));
296impl_assign_offsets!(types = (B1, B2); vars = (b1, b2));
297impl_assign_offsets!(types = (B1, B2, B3); vars = (b1, b2, b3));
298impl_assign_offsets!(types = (B1, B2, B3, B4); vars = (b1, b2, b3, b4));
299impl_assign_offsets!(types = (B1, B2, B3, B4, B5); vars = (b1, b2, b3, b4, b5));
300impl_assign_offsets!(types = (B1, B2, B3, B4, B5, B6); vars = (b1, b2, b3, b4, b5, b6));
301impl_assign_offsets!(
302 types = (B1, B2, B3, B4, B5, B6, B7);
303 vars = (b1, b2, b3, b4, b5, b6, b7)
304);
305impl_assign_offsets!(
306 types = (B1, B2, B3, B4, B5, B6, B7, B8);
307 vars = (b1, b2, b3, b4, b5, b6, b7, b8)
308);
309
310#[cfg(test)]
311mod tests {
312 use crate::{DenseDesign, Identity, LinearPredictorBlock, NoPenalty};
313
314 use super::{
315 Mu, Nu, ParameterBlock, ParameterBlocks, Precision, Rate, Scale, Shape, Sigma, Tau,
316 };
317
318 #[test]
319 fn parameter_blocks_assign_offsets_for_one_block() {
320 let mu = ParameterBlock::<Mu, Identity, _, _>::linear(
321 DenseDesign::from_rows(&[[1.0, 2.0]]),
322 NoPenalty,
323 99,
324 );
325
326 let (mu,) = ParameterBlocks::new((mu,));
327
328 assert_eq!(mu.range(), 0..2);
329 }
330
331 #[test]
332 fn parameter_blocks_assign_offsets_for_two_blocks() {
333 let mu = ParameterBlock::<Mu, Identity, _, _>::linear(
334 DenseDesign::from_rows(&[[1.0, 2.0]]),
335 NoPenalty,
336 99,
337 );
338 let sigma = ParameterBlock::<Sigma, Identity, _, _>::linear(
339 DenseDesign::from_rows(&[[1.0, 2.0, 3.0]]),
340 NoPenalty,
341 99,
342 );
343
344 let (mu, sigma) = ParameterBlocks::new((mu, sigma));
345
346 assert_eq!(mu.range(), 0..2);
347 assert_eq!(sigma.range(), 2..5);
348 }
349
350 #[test]
351 fn parameter_blocks_assign_offsets_for_eight_blocks_with_start() {
352 let blocks = (
353 intercept_block::<Mu>(),
354 intercept_block::<Sigma>(),
355 intercept_block::<Nu>(),
356 intercept_block::<Tau>(),
357 intercept_block::<Shape>(),
358 intercept_block::<Scale>(),
359 intercept_block::<Rate>(),
360 intercept_block::<Precision>(),
361 );
362
363 let (b1, b2, b3, b4, b5, b6, b7, b8) = ParameterBlocks::with_start(10, blocks);
364
365 assert_eq!(b1.range(), 10..11);
366 assert_eq!(b2.range(), 11..12);
367 assert_eq!(b3.range(), 12..13);
368 assert_eq!(b4.range(), 13..14);
369 assert_eq!(b5.range(), 14..15);
370 assert_eq!(b6.range(), 15..16);
371 assert_eq!(b7.range(), 16..17);
372 assert_eq!(b8.range(), 17..18);
373 }
374
375 #[test]
376 fn parameter_block_try_range_reports_overflow() {
377 let block = ParameterBlock::<Mu, Identity, _, _>::linear(
378 DenseDesign::from_rows(&[[1.0, 2.0]]),
379 NoPenalty,
380 usize::MAX,
381 );
382
383 assert_eq!(
384 block.try_range().unwrap_err(),
385 crate::ModelError::BlockRangeOverflow {
386 parameter: "mu",
387 offset: usize::MAX,
388 len: 2,
389 }
390 );
391 }
392
393 fn intercept_block<P>()
394 -> ParameterBlock<P, Identity, LinearPredictorBlock<DenseDesign>, NoPenalty> {
395 ParameterBlock::linear(DenseDesign::intercept(1), NoPenalty, 99)
396 }
397}