1use crate::vector::Vector;
19use pounce_common::tagged::{Tag, TaggedCell, TaggedObject};
20use pounce_common::types::{Index, Number};
21use std::any::Any;
22use std::cell::Cell;
23use std::fmt::Debug;
24
25#[derive(Debug)]
28pub struct MatrixCache {
29 tag: TaggedCell,
30 valid: Cell<Option<(Tag, bool)>>,
31}
32
33impl Default for MatrixCache {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl MatrixCache {
40 pub fn new() -> Self {
41 Self {
42 tag: TaggedCell::new(),
43 valid: Cell::new(None),
44 }
45 }
46
47 pub fn tag(&self) -> Tag {
48 self.tag.tag()
49 }
50
51 pub fn bump(&self) {
53 self.tag.bump();
54 }
55}
56
57pub trait Matrix: TaggedObject + Debug + 'static {
59 fn n_rows(&self) -> Index;
60 fn n_cols(&self) -> Index;
61 fn cache(&self) -> &MatrixCache;
62
63 fn as_any(&self) -> &dyn Any;
64 fn as_any_mut(&mut self) -> &mut dyn Any;
65 fn as_tagged(&self) -> &dyn TaggedObject;
66 fn as_dyn_matrix(&self) -> &dyn Matrix;
67
68 fn mult_vector_impl(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector);
72
73 fn trans_mult_vector_impl(
75 &self,
76 alpha: Number,
77 x: &dyn Vector,
78 beta: Number,
79 y: &mut dyn Vector,
80 );
81
82 fn compute_row_amax_impl(&self, rows_norms: &mut dyn Vector, init: bool);
85
86 fn compute_col_amax_impl(&self, cols_norms: &mut dyn Vector, init: bool);
89
90 fn has_valid_numbers_impl(&self) -> bool {
95 true
96 }
97
98 fn add_m_sinv_z_impl(&self, alpha: Number, s: &dyn Vector, z: &dyn Vector, x: &mut dyn Vector) {
101 let mut tmp = s.make_new_copy();
102 tmp.set(0.0);
104 tmp.add_vector_quotient(1.0, z, s, 0.0);
105 self.mult_vector(alpha, tmp.as_dyn_vector(), 1.0, x);
106 }
107
108 fn sinv_blrm_zmt_dbr_impl(
111 &self,
112 alpha: Number,
113 s: &dyn Vector,
114 r: &dyn Vector,
115 z: &dyn Vector,
116 d: &dyn Vector,
117 x: &mut dyn Vector,
118 ) {
119 self.trans_mult_vector(alpha, d, 0.0, x);
120 x.element_wise_multiply(z);
121 x.axpy(1.0, r);
122 x.element_wise_divide(s);
123 }
124
125 fn mult_vector(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
128 self.mult_vector_impl(alpha, x, beta, y);
129 }
130
131 fn trans_mult_vector(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
132 self.trans_mult_vector_impl(alpha, x, beta, y);
133 }
134
135 fn compute_row_amax(&self, rows_norms: &mut dyn Vector, init: bool) {
136 if init {
137 rows_norms.set(0.0);
138 }
139 self.compute_row_amax_impl(rows_norms, init);
140 }
141
142 fn compute_col_amax(&self, cols_norms: &mut dyn Vector, init: bool) {
143 if init {
144 cols_norms.set(0.0);
145 }
146 self.compute_col_amax_impl(cols_norms, init);
147 }
148
149 fn add_m_sinv_z(&self, alpha: Number, s: &dyn Vector, z: &dyn Vector, x: &mut dyn Vector) {
150 self.add_m_sinv_z_impl(alpha, s, z, x);
151 }
152
153 fn sinv_blrm_zmt_dbr(
154 &self,
155 alpha: Number,
156 s: &dyn Vector,
157 r: &dyn Vector,
158 z: &dyn Vector,
159 d: &dyn Vector,
160 x: &mut dyn Vector,
161 ) {
162 self.sinv_blrm_zmt_dbr_impl(alpha, s, r, z, d, x);
163 }
164
165 fn has_valid_numbers(&self) -> bool {
166 let cur = self.cache().tag();
167 if let Some((t, v)) = self.cache().valid.get() {
168 if t == cur {
169 return v;
170 }
171 }
172 let v = self.has_valid_numbers_impl();
173 self.cache().valid.set(Some((cur, v)));
174 v
175 }
176}
177
178pub trait SymMatrix: Matrix {
184 fn dim(&self) -> Index {
188 debug_assert_eq!(self.n_rows(), self.n_cols());
189 self.n_rows()
190 }
191}
192
193#[inline]
197pub fn sym_default_trans_mult_vector_impl<M: Matrix + ?Sized>(
198 m: &M,
199 alpha: Number,
200 x: &dyn Vector,
201 beta: Number,
202 y: &mut dyn Vector,
203) {
204 m.mult_vector_impl(alpha, x, beta, y);
205}
206
207#[inline]
211pub fn sym_default_compute_col_amax_impl<M: Matrix + ?Sized>(
212 m: &M,
213 cols_norms: &mut dyn Vector,
214 init: bool,
215) {
216 m.compute_row_amax_impl(cols_norms, init);
217}