1use pounce_common::cached::Cache;
12use pounce_common::tagged::{Tag, TaggedCell, TaggedObject};
13use pounce_common::types::{Index, Number};
14use std::any::Any;
15use std::cell::{Cell, RefCell};
16use std::fmt::Debug;
17
18#[derive(Debug)]
22pub struct VectorCache {
23 tag: TaggedCell,
24 nrm2: Cell<Option<(Tag, Number)>>,
25 asum: Cell<Option<(Tag, Number)>>,
26 amax: Cell<Option<(Tag, Number)>>,
27 max: Cell<Option<(Tag, Number)>>,
28 min: Cell<Option<(Tag, Number)>>,
29 sum: Cell<Option<(Tag, Number)>>,
30 sum_logs: Cell<Option<(Tag, Number)>>,
31 valid: Cell<Option<(Tag, bool)>>,
32 dot: RefCell<Cache<Number>>,
33}
34
35impl Default for VectorCache {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl VectorCache {
42 pub fn new() -> Self {
44 Self {
45 tag: TaggedCell::new(),
46 nrm2: Cell::new(None),
47 asum: Cell::new(None),
48 amax: Cell::new(None),
49 max: Cell::new(None),
50 min: Cell::new(None),
51 sum: Cell::new(None),
52 sum_logs: Cell::new(None),
53 valid: Cell::new(None),
54 dot: RefCell::new(Cache::new(10)),
55 }
56 }
57
58 pub fn tag(&self) -> Tag {
59 self.tag.tag()
60 }
61
62 pub fn bump(&self) {
64 self.tag.bump();
65 }
66}
67
68pub trait Vector: TaggedObject + Debug + 'static {
70 fn dim(&self) -> Index;
71 fn cache(&self) -> &VectorCache;
72
73 fn make_new(&self) -> Box<dyn Vector>;
76
77 fn as_any(&self) -> &dyn Any;
78 fn as_any_mut(&mut self) -> &mut dyn Any;
79 fn as_tagged(&self) -> &dyn TaggedObject;
80 fn as_dyn_vector(&self) -> &dyn Vector;
81
82 fn copy_impl(&mut self, x: &dyn Vector);
85 fn scal_impl(&mut self, alpha: Number);
86 fn axpy_impl(&mut self, alpha: Number, x: &dyn Vector);
87 fn dot_impl(&self, x: &dyn Vector) -> Number;
88 fn nrm2_impl(&self) -> Number;
89 fn asum_impl(&self) -> Number;
90 fn amax_impl(&self) -> Number;
91 fn set_impl(&mut self, alpha: Number);
92 fn element_wise_divide_impl(&mut self, x: &dyn Vector);
93 fn element_wise_multiply_impl(&mut self, x: &dyn Vector);
94 fn element_wise_select_impl(&mut self, x: &dyn Vector);
95 fn element_wise_max_impl(&mut self, x: &dyn Vector);
96 fn element_wise_min_impl(&mut self, x: &dyn Vector);
97 fn element_wise_reciprocal_impl(&mut self);
98 fn element_wise_abs_impl(&mut self);
99 fn element_wise_sqrt_impl(&mut self);
100 fn element_wise_sgn_impl(&mut self);
101 fn add_scalar_impl(&mut self, scalar: Number);
102 fn max_impl(&self) -> Number;
103 fn min_impl(&self) -> Number;
104 fn sum_impl(&self) -> Number;
105 fn sum_logs_impl(&self) -> Number;
106 fn frac_to_bound_impl(&self, delta: &dyn Vector, tau: Number) -> Number;
107 fn add_vector_quotient_impl(&mut self, a: Number, z: &dyn Vector, s: &dyn Vector, c: Number);
108
109 fn add_two_vectors_impl(
115 &mut self,
116 a: Number,
117 v1: &dyn Vector,
118 b: Number,
119 v2: &dyn Vector,
120 c: Number,
121 ) {
122 if c == 0.0 {
123 self.set_impl(0.0);
124 } else if c != 1.0 {
125 self.scal_impl(c);
126 }
127 if a != 0.0 {
128 self.axpy_impl(a, v1);
129 }
130 if b != 0.0 {
131 self.axpy_impl(b, v2);
132 }
133 }
134
135 fn has_valid_numbers_impl(&self) -> bool {
138 self.asum_impl().is_finite()
139 }
140
141 fn copy(&mut self, x: &dyn Vector) {
144 self.copy_impl(x);
145 self.cache().bump();
146 }
147
148 fn make_new_copy(&self) -> Box<dyn Vector> {
149 let mut c = self.make_new();
150 c.copy(self.as_dyn_vector());
151 c
152 }
153
154 fn scal(&mut self, alpha: Number) {
155 self.scal_impl(alpha);
156 self.cache().bump();
157 }
158
159 fn axpy(&mut self, alpha: Number, x: &dyn Vector) {
160 self.axpy_impl(alpha, x);
161 self.cache().bump();
162 }
163
164 fn dot(&self, x: &dyn Vector) -> Number {
165 if std::ptr::eq(self.cache() as *const _, x.cache() as *const _) {
170 let n = self.nrm2();
171 return n * n;
172 }
173 let mut dot_cache = self.cache().dot.borrow_mut();
174 if let Some(v) = dot_cache.get(&[self.as_tagged(), x.as_tagged()], &[]) {
175 return v;
176 }
177 let v = self.dot_impl(x);
178 dot_cache.add(v, &[self.as_tagged(), x.as_tagged()], &[]);
179 v
180 }
181
182 fn nrm2(&self) -> Number {
183 let cur = self.cache().tag();
184 if let Some((t, v)) = self.cache().nrm2.get() {
185 if t == cur {
186 return v;
187 }
188 }
189 let v = self.nrm2_impl();
190 self.cache().nrm2.set(Some((cur, v)));
191 v
192 }
193
194 fn asum(&self) -> Number {
195 let cur = self.cache().tag();
196 if let Some((t, v)) = self.cache().asum.get() {
197 if t == cur {
198 return v;
199 }
200 }
201 let v = self.asum_impl();
202 self.cache().asum.set(Some((cur, v)));
203 v
204 }
205
206 fn amax(&self) -> Number {
207 let cur = self.cache().tag();
208 if let Some((t, v)) = self.cache().amax.get() {
209 if t == cur {
210 return v;
211 }
212 }
213 let v = self.amax_impl();
214 self.cache().amax.set(Some((cur, v)));
215 v
216 }
217
218 fn set(&mut self, alpha: Number) {
219 self.set_impl(alpha);
220 self.cache().bump();
221 }
222
223 fn element_wise_divide(&mut self, x: &dyn Vector) {
224 self.element_wise_divide_impl(x);
225 self.cache().bump();
226 }
227 fn element_wise_multiply(&mut self, x: &dyn Vector) {
228 self.element_wise_multiply_impl(x);
229 self.cache().bump();
230 }
231 fn element_wise_select(&mut self, x: &dyn Vector) {
232 self.element_wise_select_impl(x);
233 self.cache().bump();
234 }
235 fn element_wise_max(&mut self, x: &dyn Vector) {
236 self.element_wise_max_impl(x);
237 self.cache().bump();
238 }
239 fn element_wise_min(&mut self, x: &dyn Vector) {
240 self.element_wise_min_impl(x);
241 self.cache().bump();
242 }
243 fn element_wise_reciprocal(&mut self) {
244 self.element_wise_reciprocal_impl();
245 self.cache().bump();
246 }
247 fn element_wise_abs(&mut self) {
248 self.element_wise_abs_impl();
249 self.cache().bump();
250 }
251 fn element_wise_sqrt(&mut self) {
252 self.element_wise_sqrt_impl();
253 self.cache().bump();
254 }
255 fn element_wise_sgn(&mut self) {
256 self.element_wise_sgn_impl();
257 self.cache().bump();
258 }
259 fn add_scalar(&mut self, scalar: Number) {
260 self.add_scalar_impl(scalar);
261 self.cache().bump();
262 }
263
264 fn max(&self) -> Number {
265 let cur = self.cache().tag();
266 if let Some((t, v)) = self.cache().max.get() {
267 if t == cur {
268 return v;
269 }
270 }
271 let v = self.max_impl();
272 self.cache().max.set(Some((cur, v)));
273 v
274 }
275
276 fn min(&self) -> Number {
277 let cur = self.cache().tag();
278 if let Some((t, v)) = self.cache().min.get() {
279 if t == cur {
280 return v;
281 }
282 }
283 let v = self.min_impl();
284 self.cache().min.set(Some((cur, v)));
285 v
286 }
287
288 fn sum(&self) -> Number {
289 let cur = self.cache().tag();
290 if let Some((t, v)) = self.cache().sum.get() {
291 if t == cur {
292 return v;
293 }
294 }
295 let v = self.sum_impl();
296 self.cache().sum.set(Some((cur, v)));
297 v
298 }
299
300 fn sum_logs(&self) -> Number {
301 let cur = self.cache().tag();
302 if let Some((t, v)) = self.cache().sum_logs.get() {
303 if t == cur {
304 return v;
305 }
306 }
307 let v = self.sum_logs_impl();
308 self.cache().sum_logs.set(Some((cur, v)));
309 v
310 }
311
312 fn add_one_vector(&mut self, a: Number, v1: &dyn Vector, c: Number) {
313 self.add_two_vectors(a, v1, 0.0, v1, c);
315 }
316
317 fn add_two_vectors(
318 &mut self,
319 a: Number,
320 v1: &dyn Vector,
321 b: Number,
322 v2: &dyn Vector,
323 c: Number,
324 ) {
325 self.add_two_vectors_impl(a, v1, b, v2, c);
326 self.cache().bump();
327 }
328
329 fn frac_to_bound(&self, delta: &dyn Vector, tau: Number) -> Number {
332 self.frac_to_bound_impl(delta, tau)
333 }
334
335 fn add_vector_quotient(&mut self, a: Number, z: &dyn Vector, s: &dyn Vector, c: Number) {
336 self.add_vector_quotient_impl(a, z, s, c);
337 self.cache().bump();
338 }
339
340 fn has_valid_numbers(&self) -> bool {
341 let cur = self.cache().tag();
342 if let Some((t, v)) = self.cache().valid.get() {
343 if t == cur {
344 return v;
345 }
346 }
347 let v = self.has_valid_numbers_impl();
348 self.cache().valid.set(Some((cur, v)));
349 v
350 }
351}