1use ndarray::{Array1, Array2, Axis};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6
7#[derive(Clone, Debug)]
17pub enum PenaltyMatrix {
18 Dense(Array2<f64>),
19 KroneckerFactored {
20 left: Array2<f64>,
21 right: Array2<f64>,
22 },
23 Blockwise {
27 local: Array2<f64>,
28 col_range: std::ops::Range<usize>,
29 total_dim: usize,
30 },
31 Labeled {
34 label: String,
35 inner: Box<PenaltyMatrix>,
36 },
37 Fixed {
41 log_lambda: f64,
42 inner: Box<PenaltyMatrix>,
43 },
44}
45
46impl PenaltyMatrix {
47 pub fn dim(&self) -> usize {
49 match self {
50 Self::Dense(m) => m.nrows(),
51 Self::KroneckerFactored { left, right } => left.nrows() * right.nrows(),
52 Self::Blockwise { total_dim, .. } => *total_dim,
53 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dim(),
54 }
55 }
56
57 pub fn shape(&self) -> (usize, usize) {
59 let d = self.dim();
60 (d, d)
61 }
62
63 pub fn to_dense(&self) -> Array2<f64> {
65 match self {
66 Self::Dense(m) => m.clone(),
67 Self::KroneckerFactored { left, right } => kronecker_product(left, right),
68 Self::Blockwise {
69 local,
70 col_range,
71 total_dim,
72 } => {
73 let mut g = Array2::zeros((*total_dim, *total_dim));
74 g.slice_mut(ndarray::s![
75 col_range.start..col_range.end,
76 col_range.start..col_range.end
77 ])
78 .assign(local);
79 g
80 }
81 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.to_dense(),
82 }
83 }
84
85 pub fn as_dense_cow(&self) -> std::borrow::Cow<'_, Array2<f64>> {
87 match self {
88 Self::Dense(m) => std::borrow::Cow::Borrowed(m),
89 Self::KroneckerFactored { .. }
90 | Self::Blockwise { .. }
91 | Self::Labeled { .. }
92 | Self::Fixed { .. } => std::borrow::Cow::Owned(self.to_dense()),
93 }
94 }
95
96 pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
98 match self {
99 Self::Dense(m) => Some(m),
100 Self::Fixed { inner, .. } => inner.as_dense_ref(),
101 Self::KroneckerFactored { .. } | Self::Blockwise { .. } | Self::Labeled { .. } => None,
102 }
103 }
104
105 pub fn with_precision_label(self, label: impl Into<String>) -> Self {
106 Self::Labeled {
107 label: label.into(),
108 inner: Box::new(self),
109 }
110 }
111
112 pub fn precision_label(&self) -> Option<&str> {
113 match self {
114 Self::Labeled { label, .. } => Some(label.as_str()),
115 Self::Fixed { .. } => None,
116 _ => None,
117 }
118 }
119
120 pub fn with_fixed_log_lambda(self, log_lambda: f64) -> Self {
121 Self::Fixed {
122 log_lambda,
123 inner: Box::new(self),
124 }
125 }
126
127 pub fn fixed_log_lambda(&self) -> Option<f64> {
128 match self {
129 Self::Fixed { log_lambda, .. } => Some(*log_lambda),
130 Self::Labeled { inner, .. } => inner.fixed_log_lambda(),
131 _ => None,
132 }
133 }
134
135 pub fn dot(&self, v: &Array1<f64>) -> Array1<f64> {
139 match self {
140 Self::Dense(m) => m.dot(v),
141 Self::KroneckerFactored { left, right } => {
142 let p_left = left.nrows();
143 let p_right = right.nrows();
144 let v_mat =
146 ndarray::ArrayView2::from_shape((p_left, p_right), v.as_slice().unwrap())
147 .unwrap();
148 let avbt = left.dot(&v_mat).dot(&right.t());
149 let standard = avbt.as_standard_layout();
150 Array1::from_iter(standard.iter().copied())
151 }
152 Self::Blockwise {
153 local,
154 col_range,
155 total_dim,
156 } => {
157 let mut out = Array1::zeros(*total_dim);
158 let v_block = v.slice(ndarray::s![col_range.clone()]);
159 let result_block = local.dot(&v_block);
160 out.slice_mut(ndarray::s![col_range.clone()])
161 .assign(&result_block);
162 out
163 }
164 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.dot(v),
165 }
166 }
167
168 pub fn add_scaled_to(&self, lambda: f64, target: &mut Array2<f64>) {
170 match self {
171 Self::Dense(m) => {
172 target.scaled_add(lambda, m);
173 }
174 Self::KroneckerFactored { left, right } => {
175 let p_left = left.nrows();
176 let p_right = right.nrows();
177 for i1 in 0..p_left {
178 for j1 in 0..p_left {
179 let a_ij = left[[i1, j1]];
180 if a_ij == 0.0 {
181 continue;
182 }
183 let scaled_a = lambda * a_ij;
184 for i2 in 0..p_right {
185 let row = i1 * p_right + i2;
186 for j2 in 0..p_right {
187 let col = j1 * p_right + j2;
188 target[[row, col]] += scaled_a * right[[i2, j2]];
189 }
190 }
191 }
192 }
193 }
194 Self::Blockwise {
195 local, col_range, ..
196 } => {
197 target
198 .slice_mut(ndarray::s![col_range.clone(), col_range.clone()])
199 .scaled_add(lambda, local);
200 }
201 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
202 inner.add_scaled_to(lambda, target)
203 }
204 }
205 }
206
207 pub fn add_scaled_diag_to(&self, lambda: f64, target: &mut Array1<f64>) {
209 match self {
210 Self::Dense(m) => {
211 let p = m.nrows().min(target.len());
212 for j in 0..p {
213 target[j] += lambda * m[[j, j]];
214 }
215 }
216 Self::KroneckerFactored { left, right } => {
217 let p_left = left.nrows();
218 let p_right = right.nrows();
219 assert_eq!(target.len(), p_left * p_right);
220 for i_left in 0..p_left {
221 let left_diag = left[[i_left, i_left]];
222 if left_diag == 0.0 {
223 continue;
224 }
225 let scaled_left = lambda * left_diag;
226 for i_right in 0..p_right {
227 target[i_left * p_right + i_right] +=
228 scaled_left * right[[i_right, i_right]];
229 }
230 }
231 }
232 Self::Blockwise {
233 local, col_range, ..
234 } => {
235 let width = local.nrows().min(col_range.len());
236 for local_idx in 0..width {
237 target[col_range.start + local_idx] += lambda * local[[local_idx, local_idx]];
238 }
239 }
240 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => {
241 inner.add_scaled_diag_to(lambda, target)
242 }
243 }
244 }
245
246 pub fn quadratic_form(&self, beta: &Array1<f64>) -> f64 {
248 match self {
249 Self::Dense(m) => beta.dot(&m.dot(beta)),
250 Self::KroneckerFactored { .. } => {
251 let sv = self.dot(beta);
252 beta.dot(&sv)
253 }
254 Self::Blockwise {
255 local, col_range, ..
256 } => {
257 let beta_block = beta.slice(ndarray::s![col_range.clone()]);
258 let sv = local.dot(&beta_block);
259 beta_block.dot(&sv)
260 }
261 Self::Labeled { inner, .. } | Self::Fixed { inner, .. } => inner.quadratic_form(beta),
262 }
263 }
264
265 pub fn nrows(&self) -> usize {
267 self.dim()
268 }
269
270 pub fn ncols(&self) -> usize {
271 self.dim()
272 }
273}
274
275impl From<Array2<f64>> for PenaltyMatrix {
276 fn from(m: Array2<f64>) -> Self {
277 Self::Dense(m)
278 }
279}
280
281fn kronecker_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
285 let (arows, a_cols) = a.dim();
286 let (brows, b_cols) = b.dim();
287 if arows == 0 || a_cols == 0 || brows == 0 || b_cols == 0 {
288 return Array2::zeros((arows * brows, a_cols * b_cols));
289 }
290 let mut result = Array2::zeros((arows * brows, a_cols * b_cols));
291
292 result
293 .axis_chunks_iter_mut(Axis(0), brows)
294 .into_par_iter()
295 .enumerate()
296 .for_each(|(i, mut row_block)| {
297 let arow = a.row(i);
298 let col_chunks = row_block.axis_chunks_iter_mut(Axis(1), b_cols);
299 for (j, mut block) in col_chunks.into_iter().enumerate() {
300 let aval = arow[j];
301 if aval == 0.0 {
302 continue;
303 }
304 for (dest, &src) in block.iter_mut().zip(b.iter()) {
305 *dest = aval * src;
306 }
307 }
308 });
309
310 result
311}