1use crate::dense_gen_matrix::DenseGenMatrix;
27use crate::dense_vector::{DenseVector, DenseVectorSpace};
28use crate::matrix::{Matrix, MatrixCache};
29use crate::vector::Vector;
30use pounce_common::tagged::{Tag, TaggedObject};
31use pounce_common::types::{Index, Number};
32use std::any::Any;
33use std::rc::Rc;
34
35#[derive(Debug)]
36pub struct MultiVectorMatrixSpace {
37 n_rows: Index,
38 n_cols: Index,
39 col_space: Rc<DenseVectorSpace>,
40}
41
42impl MultiVectorMatrixSpace {
43 pub fn new(n_cols: Index, col_space: Rc<DenseVectorSpace>) -> Rc<Self> {
44 Rc::new(Self {
45 n_rows: col_space.dim(),
46 n_cols,
47 col_space,
48 })
49 }
50
51 pub fn n_rows(&self) -> Index {
52 self.n_rows
53 }
54 pub fn n_cols(&self) -> Index {
55 self.n_cols
56 }
57 pub fn col_vector_space(&self) -> &Rc<DenseVectorSpace> {
58 &self.col_space
59 }
60
61 pub fn make_new_multi_vector(self: &Rc<Self>) -> MultiVectorMatrix {
62 MultiVectorMatrix::new(Rc::clone(self))
63 }
64}
65
66#[derive(Debug)]
67pub struct MultiVectorMatrix {
68 space: Rc<MultiVectorMatrixSpace>,
69 cache: MatrixCache,
70 cols: Vec<Option<Rc<dyn Vector>>>,
74}
75
76impl MultiVectorMatrix {
77 pub fn new(space: Rc<MultiVectorMatrixSpace>) -> Self {
78 let n = space.n_cols.max(0) as usize;
79 Self {
80 space,
81 cache: MatrixCache::new(),
82 cols: (0..n).map(|_| None).collect(),
83 }
84 }
85
86 pub fn space(&self) -> &Rc<MultiVectorMatrixSpace> {
87 &self.space
88 }
89
90 pub fn col_vector_space(&self) -> &Rc<DenseVectorSpace> {
91 self.space.col_vector_space()
92 }
93
94 pub fn set_vector(&mut self, i: Index, vec: Rc<dyn Vector>) {
99 let idx = i as usize;
100 debug_assert!(idx < self.cols.len());
101 debug_assert_eq!(vec.dim(), self.space.n_rows);
102 self.cols[idx] = Some(vec);
103 self.cache.bump();
104 }
105
106 pub fn get_vector(&self, i: Index) -> &Rc<dyn Vector> {
107 let idx = i as usize;
108 debug_assert!(idx < self.cols.len());
109 self.cols[idx]
110 .as_ref()
111 .expect("MultiVectorMatrix column is unset")
112 }
113
114 pub fn fill_with_new_vectors(&mut self) {
118 for slot in self.cols.iter_mut() {
119 let v = self.space.col_space.make_new_dense();
120 *slot = Some(Rc::new(v) as Rc<dyn Vector>);
121 }
122 self.cache.bump();
123 }
124
125 fn col(&self, i: usize) -> &dyn Vector {
126 self.cols[i]
127 .as_ref()
128 .expect("MultiVectorMatrix column is unset")
129 .as_ref()
130 }
131
132 fn col_mut(&mut self, i: usize) -> &mut dyn Vector {
138 let slot = self.cols[i]
139 .as_mut()
140 .expect("MultiVectorMatrix column is unset");
141 let inner: &mut dyn Vector = Rc::get_mut(slot)
142 .expect("MultiVectorMatrix column is shared; cannot mutate (clone first)");
143 inner
144 }
145
146 pub fn lr_mult_vector(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
148 debug_assert_eq!(self.space.n_rows, x.dim());
149 debug_assert_eq!(self.space.n_rows, y.dim());
150 if beta != 0.0 {
151 y.scal(beta);
152 } else {
153 y.set(0.0);
154 }
155 for i in 0..self.cols.len() {
156 let ci = self.col(i);
157 let coef = alpha * ci.dot(x);
158 y.add_one_vector(coef, ci, 1.0);
159 }
160 }
161
162 pub fn scale_columns(&mut self, scal: &DenseVector) {
166 debug_assert_eq!(scal.dim(), self.space.n_cols);
167 let nc = self.cols.len();
168 if scal.is_homogeneous() {
169 let s = scal.scalar();
170 for i in 0..nc {
171 self.col_mut(i).scal(s);
172 }
173 } else {
174 let vals = scal.values().to_vec();
175 for i in 0..nc {
176 self.col_mut(i).scal(vals[i]);
177 }
178 }
179 self.cache.bump();
180 }
181
182 pub fn scale_rows(&mut self, scal: &dyn Vector) {
185 debug_assert_eq!(scal.dim(), self.space.n_rows);
186 let nc = self.cols.len();
187 for i in 0..nc {
188 self.col_mut(i).element_wise_multiply(scal);
189 }
190 self.cache.bump();
191 }
192
193 pub fn add_one_multi_vector_matrix(&mut self, a: Number, mv1: &MultiVectorMatrix, c: Number) {
196 debug_assert_eq!(self.space.n_rows, mv1.space.n_rows);
197 debug_assert_eq!(self.space.n_cols, mv1.space.n_cols);
198 if c == 0.0 {
199 self.fill_with_new_vectors();
200 }
201 let nc = self.cols.len();
202 for i in 0..nc {
203 let src = Rc::clone(&mv1.cols[i].as_ref().expect("source column unset").clone());
206 self.col_mut(i).add_one_vector(a, src.as_ref(), c);
207 }
208 self.cache.bump();
209 }
210
211 pub fn add_right_mult_matrix(
215 &mut self,
216 a: Number,
217 u: &MultiVectorMatrix,
218 c_mat: &DenseGenMatrix,
219 b: Number,
220 ) {
221 debug_assert_eq!(self.space.n_rows, u.space.n_rows);
222 debug_assert_eq!(u.space.n_cols, c_mat.n_rows());
223 debug_assert_eq!(self.space.n_cols, c_mat.n_cols());
224
225 if b == 0.0 {
226 self.fill_with_new_vectors();
227 }
228
229 let c_n_rows = c_mat.n_rows() as usize;
230 let c_values = c_mat.values().to_vec();
231 let temp_space = DenseVectorSpace::new(c_mat.n_rows());
232 let mut tmp_dv = temp_space.make_new_dense();
233 let nc = self.cols.len();
234 for i in 0..nc {
235 let base = i * c_n_rows;
237 let col_slice: Vec<Number> = c_values[base..base + c_n_rows].to_vec();
238 tmp_dv.set_values(&col_slice);
239 u.mult_vector(a, &tmp_dv, b, self.col_mut(i));
246 }
247 self.cache.bump();
248 }
249}
250
251impl TaggedObject for MultiVectorMatrix {
252 fn get_tag(&self) -> Tag {
253 self.cache.tag()
254 }
255}
256
257impl Matrix for MultiVectorMatrix {
258 fn n_rows(&self) -> Index {
259 self.space.n_rows
260 }
261 fn n_cols(&self) -> Index {
262 self.space.n_cols
263 }
264 fn cache(&self) -> &MatrixCache {
265 &self.cache
266 }
267 fn as_any(&self) -> &dyn Any {
268 self
269 }
270 fn as_any_mut(&mut self) -> &mut dyn Any {
271 self
272 }
273 fn as_tagged(&self) -> &dyn TaggedObject {
274 self
275 }
276 fn as_dyn_matrix(&self) -> &dyn Matrix {
277 self
278 }
279
280 fn mult_vector_impl(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
285 debug_assert_eq!(self.space.n_cols, x.dim());
286 debug_assert_eq!(self.space.n_rows, y.dim());
287
288 if beta != 0.0 {
289 y.scal(beta);
290 } else {
291 y.set(0.0);
292 }
293
294 let dx = x
295 .as_any()
296 .downcast_ref::<DenseVector>()
297 .expect("MultiVectorMatrix expects DenseVector input");
298
299 if dx.is_homogeneous() {
300 let val = dx.scalar();
301 for i in 0..self.cols.len() {
302 y.add_one_vector(alpha * val, self.col(i), 1.0);
303 }
304 } else {
305 let values = dx.values();
306 for i in 0..self.cols.len() {
307 y.add_one_vector(alpha * values[i], self.col(i), 1.0);
308 }
309 }
310 }
311
312 fn trans_mult_vector_impl(
316 &self,
317 alpha: Number,
318 x: &dyn Vector,
319 beta: Number,
320 y: &mut dyn Vector,
321 ) {
322 debug_assert_eq!(self.space.n_cols, y.dim());
323 debug_assert_eq!(self.space.n_rows, x.dim());
324
325 let nc = self.cols.len();
328 let mut dots = Vec::with_capacity(nc);
329 for i in 0..nc {
330 dots.push(self.col(i).dot(x));
331 }
332
333 let dy = y
334 .as_any_mut()
335 .downcast_mut::<DenseVector>()
336 .expect("MultiVectorMatrix expects DenseVector output");
337 let yvals = dy.values_mut();
339 if beta != 0.0 {
340 for i in 0..nc {
341 yvals[i] = alpha * dots[i] + beta * yvals[i];
342 }
343 } else {
344 for i in 0..nc {
345 yvals[i] = alpha * dots[i];
346 }
347 }
348 }
349
350 fn has_valid_numbers_impl(&self) -> bool {
351 for slot in &self.cols {
352 match slot {
353 Some(v) => {
354 if !v.has_valid_numbers() {
355 return false;
356 }
357 }
358 None => return false,
359 }
360 }
361 true
362 }
363
364 fn compute_row_amax_impl(&self, _rows_norms: &mut dyn Vector, _init: bool) {
365 unimplemented!("MultiVectorMatrix::compute_row_amax — upstream throws UNIMPLEMENTED");
366 }
367
368 fn compute_col_amax_impl(&self, _cols_norms: &mut dyn Vector, _init: bool) {
369 unimplemented!("MultiVectorMatrix::compute_col_amax — upstream throws UNIMPLEMENTED");
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::dense_gen_matrix::DenseGenMatrixSpace;
377
378 fn dvec(values: &[Number]) -> Rc<DenseVector> {
379 let space = DenseVectorSpace::new(values.len() as Index);
380 let mut v = space.make_new_dense();
381 v.set_values(values);
382 Rc::new(v)
383 }
384
385 fn dvec_box(values: &[Number]) -> Box<DenseVector> {
386 let space = DenseVectorSpace::new(values.len() as Index);
387 let mut v = space.make_new_dense();
388 v.set_values(values);
389 Box::new(v)
390 }
391
392 fn build_mv(cols: &[&[Number]]) -> MultiVectorMatrix {
393 let n_rows = cols[0].len() as Index;
394 let n_cols = cols.len() as Index;
395 let cs = DenseVectorSpace::new(n_rows);
396 let space = MultiVectorMatrixSpace::new(n_cols, cs);
397 let mut mv = space.make_new_multi_vector();
398 for (i, c) in cols.iter().enumerate() {
399 mv.set_vector(i as Index, dvec(c) as Rc<dyn Vector>);
400 }
401 mv
402 }
403
404 #[test]
405 fn dimensions_match_space() {
406 let mv = build_mv(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
407 assert_eq!(mv.n_rows(), 3);
408 assert_eq!(mv.n_cols(), 2);
409 }
410
411 #[test]
412 fn mult_vector_combines_columns() {
413 let mv = build_mv(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
416 let x = dvec_box(&[10.0, 100.0]);
417 let mut y = dvec_box(&[0.0, 0.0, 0.0]);
418 mv.mult_vector(1.0, x.as_dyn_vector(), 0.0, y.as_mut());
419 assert_eq!(y.expanded_values(), vec![410.0, 520.0, 630.0]);
420 }
421
422 #[test]
423 fn mult_vector_alpha_beta_reduction_order() {
424 let mv = build_mv(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
427 let x = dvec_box(&[10.0, 100.0]);
428 let mut y = dvec_box(&[10.0, 20.0, 30.0]);
429 mv.mult_vector(2.0, x.as_dyn_vector(), 0.5, y.as_mut());
430 assert_eq!(y.expanded_values(), vec![825.0, 1050.0, 1275.0]);
431 }
432
433 #[test]
434 fn trans_mult_vector_dot_products() {
435 let mv = build_mv(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
438 let x = dvec_box(&[1.0, 1.0, 1.0]);
439 let mut y = dvec_box(&[0.0, 0.0]);
440 mv.trans_mult_vector(1.0, x.as_dyn_vector(), 0.0, y.as_mut());
441 assert_eq!(y.expanded_values(), vec![6.0, 15.0]);
442 }
443
444 #[test]
445 fn trans_mult_vector_with_beta() {
446 let mv = build_mv(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
447 let x = dvec_box(&[1.0, 1.0, 1.0]);
448 let mut y = dvec_box(&[100.0, 200.0]);
449 mv.trans_mult_vector(2.0, x.as_dyn_vector(), 0.5, y.as_mut());
451 assert_eq!(y.expanded_values(), vec![62.0, 130.0]);
452 }
453
454 #[test]
455 fn lr_mult_vector_yields_v_v_t_x() {
456 let mv = build_mv(&[&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]]);
459 let x = dvec_box(&[3.0, 5.0, 7.0]);
460 let mut y = dvec_box(&[0.0, 0.0, 0.0]);
461 mv.lr_mult_vector(1.0, x.as_dyn_vector(), 0.0, y.as_mut());
462 assert_eq!(y.expanded_values(), vec![3.0, 5.0, 0.0]);
463 }
464
465 #[test]
466 fn lr_mult_vector_alpha_beta() {
467 let mv = build_mv(&[&[1.0, 0.0], &[0.0, 2.0]]);
468 let x = dvec_box(&[10.0, 10.0]);
470 let mut y = dvec_box(&[1.0, 1.0]);
471 mv.lr_mult_vector(2.0, x.as_dyn_vector(), 3.0, y.as_mut());
473 assert_eq!(y.expanded_values(), vec![23.0, 83.0]);
474 }
475
476 #[test]
477 fn fill_with_new_vectors_initializes_columns() {
478 let cs = DenseVectorSpace::new(3);
479 let space = MultiVectorMatrixSpace::new(2, cs);
480 let mut mv = space.make_new_multi_vector();
481 mv.fill_with_new_vectors();
482 assert!(mv.cols[0].is_some());
484 assert!(mv.cols[1].is_some());
485 }
486
487 #[test]
488 fn scale_columns_per_index() {
489 let mv0 = build_mv(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
490 let cs = DenseVectorSpace::new(3);
492 let space = MultiVectorMatrixSpace::new(2, cs);
493 let mut mv = space.make_new_multi_vector();
494 mv.fill_with_new_vectors();
495 mv.col_mut(0).copy(mv0.get_vector(0).as_ref());
497 mv.col_mut(1).copy(mv0.get_vector(1).as_ref());
498
499 let scal = {
500 let s = DenseVectorSpace::new(2);
501 let mut v = s.make_new_dense();
502 v.set_values(&[10.0, 100.0]);
503 v
504 };
505 mv.scale_columns(&scal);
506 let mut probe = dvec_box(&[1.0, 0.0]); let mut y = dvec_box(&[0.0, 0.0, 0.0]);
509 mv.mult_vector(1.0, probe.as_dyn_vector(), 0.0, y.as_mut());
510 assert_eq!(y.expanded_values(), vec![10.0, 20.0, 30.0]);
511 probe.set_values(&[0.0, 1.0]);
512 mv.mult_vector(1.0, probe.as_dyn_vector(), 0.0, y.as_mut());
513 assert_eq!(y.expanded_values(), vec![400.0, 500.0, 600.0]);
514 }
515
516 #[test]
517 fn scale_rows_multiplies_each_column() {
518 let cs = DenseVectorSpace::new(3);
519 let space = MultiVectorMatrixSpace::new(2, cs);
520 let mut mv = space.make_new_multi_vector();
521 mv.fill_with_new_vectors();
522 let v0 = dvec(&[1.0, 2.0, 3.0]);
523 let v1 = dvec(&[4.0, 5.0, 6.0]);
524 mv.col_mut(0).copy(v0.as_ref());
525 mv.col_mut(1).copy(v1.as_ref());
526
527 let scal = dvec(&[10.0, 1.0, 1.0]);
528 mv.scale_rows(scal.as_ref());
529 let mut x = dvec_box(&[1.0, 0.0]);
531 let mut y = dvec_box(&[0.0, 0.0, 0.0]);
532 mv.mult_vector(1.0, x.as_dyn_vector(), 0.0, y.as_mut());
533 assert_eq!(y.expanded_values(), vec![10.0, 2.0, 3.0]);
534 x.set_values(&[0.0, 1.0]);
535 mv.mult_vector(1.0, x.as_dyn_vector(), 0.0, y.as_mut());
536 assert_eq!(y.expanded_values(), vec![40.0, 5.0, 6.0]);
537 }
538
539 #[test]
540 fn add_right_mult_matrix_v_eq_u_times_c() {
541 let cs = DenseVectorSpace::new(2);
545 let u_space = MultiVectorMatrixSpace::new(2, Rc::clone(&cs));
546 let mut u = u_space.make_new_multi_vector();
547 u.set_vector(0, dvec(&[1.0, 0.0]) as Rc<dyn Vector>);
548 u.set_vector(1, dvec(&[0.0, 1.0]) as Rc<dyn Vector>);
549
550 let c_space = DenseGenMatrixSpace::new(2, 2);
551 let mut c_mat = c_space.make_new_dense_gen();
552 c_mat.values_mut().copy_from_slice(&[2.0, 4.0, 3.0, 5.0]);
554
555 let v_space = MultiVectorMatrixSpace::new(2, cs);
556 let mut v = v_space.make_new_multi_vector();
557 v.add_right_mult_matrix(1.0, &u, &c_mat, 0.0);
559
560 let probe = dvec_box(&[1.0, 0.0]);
562 let mut y = dvec_box(&[0.0, 0.0]);
563 v.mult_vector(1.0, probe.as_dyn_vector(), 0.0, y.as_mut());
564 assert_eq!(y.expanded_values(), vec![2.0, 4.0]);
565
566 let probe1 = dvec_box(&[0.0, 1.0]);
567 v.mult_vector(1.0, probe1.as_dyn_vector(), 0.0, y.as_mut());
568 assert_eq!(y.expanded_values(), vec![3.0, 5.0]);
569 }
570
571 #[test]
572 fn has_valid_numbers_detects_nan_in_column() {
573 let cs = DenseVectorSpace::new(2);
574 let space = MultiVectorMatrixSpace::new(2, cs);
575 let mut mv = space.make_new_multi_vector();
576 mv.set_vector(0, dvec(&[1.0, 2.0]) as Rc<dyn Vector>);
577 mv.set_vector(1, dvec(&[f64::NAN, 0.0]) as Rc<dyn Vector>);
578 assert!(!mv.has_valid_numbers());
579 }
580}