1use crate::vector::{Vector, VectorCache};
14use pounce_common::tagged::{Tag, TaggedObject};
15use pounce_common::types::{Index, Number};
16use std::any::Any;
17use std::cell::RefCell;
18use std::rc::Rc;
19
20type CompFactory = Box<dyn Fn() -> Box<dyn Vector>>;
21
22pub struct CompoundVectorSpace {
26 total_dim: Index,
27 n_comp_spaces: Index,
28 comp_dims: RefCell<Vec<Index>>,
30 factories: RefCell<Vec<Option<CompFactory>>>,
31}
32
33impl std::fmt::Debug for CompoundVectorSpace {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("CompoundVectorSpace")
36 .field("total_dim", &self.total_dim)
37 .field("n_comp_spaces", &self.n_comp_spaces)
38 .field("comp_dims", &self.comp_dims.borrow())
39 .finish()
40 }
41}
42
43impl CompoundVectorSpace {
44 pub fn new(n_comp_spaces: Index, total_dim: Index) -> Rc<Self> {
45 let mut factories: Vec<Option<CompFactory>> = Vec::with_capacity(n_comp_spaces as usize);
46 for _ in 0..n_comp_spaces {
47 factories.push(None);
48 }
49 Rc::new(Self {
50 total_dim,
51 n_comp_spaces,
52 comp_dims: RefCell::new(vec![0; n_comp_spaces as usize]),
53 factories: RefCell::new(factories),
54 })
55 }
56
57 pub fn dim(&self) -> Index {
58 self.total_dim
59 }
60
61 pub fn n_comp_spaces(&self) -> Index {
62 self.n_comp_spaces
63 }
64
65 pub fn comp_dim(&self, icomp: Index) -> Index {
66 self.comp_dims.borrow()[icomp as usize]
67 }
68
69 pub fn set_comp<F>(&self, icomp: Index, dim: Index, factory: F)
74 where
75 F: Fn() -> Box<dyn Vector> + 'static,
76 {
77 assert!(icomp < self.n_comp_spaces);
78 self.comp_dims.borrow_mut()[icomp as usize] = dim;
79 self.factories.borrow_mut()[icomp as usize] = Some(Box::new(factory));
80 }
81}
82
83#[derive(Debug)]
85pub struct CompoundVector {
86 space: Rc<CompoundVectorSpace>,
87 cache: VectorCache,
88 comps: Vec<Box<dyn Vector>>,
89}
90
91impl CompoundVector {
92 pub fn new(space: Rc<CompoundVectorSpace>) -> Self {
95 let n = space.n_comp_spaces() as usize;
96 let mut comps: Vec<Box<dyn Vector>> = Vec::with_capacity(n);
97 let factories = space.factories.borrow();
98 let mut dim_check: Index = 0;
99 for f in factories.iter() {
100 let factory = match f.as_ref() {
101 Some(fac) => fac,
102 None => panic!("CompoundVectorSpace component not set — call set_comp on every component before constructing a CompoundVector"),
103 };
104 let v = factory();
105 dim_check += v.dim();
106 comps.push(v);
107 }
108 debug_assert_eq!(dim_check, space.total_dim);
109 drop(factories);
110 Self {
111 space,
112 cache: VectorCache::new(),
113 comps,
114 }
115 }
116
117 pub fn n_comps(&self) -> Index {
118 self.comps.len() as Index
119 }
120
121 pub fn comp(&self, i: Index) -> &dyn Vector {
122 self.comps[i as usize].as_ref()
123 }
124
125 pub fn comp_mut(&mut self, i: Index) -> &mut dyn Vector {
129 self.cache.bump();
130 self.comps[i as usize].as_mut()
131 }
132
133 pub fn space(&self) -> &Rc<CompoundVectorSpace> {
134 &self.space
135 }
136}
137
138fn downcast_compound(x: &dyn Vector) -> &CompoundVector {
139 match x.as_any().downcast_ref::<CompoundVector>() {
140 Some(v) => v,
141 None => panic!("Vector argument is not a CompoundVector"),
142 }
143}
144
145impl TaggedObject for CompoundVector {
146 fn get_tag(&self) -> Tag {
147 self.cache.tag()
148 }
149}
150
151impl Vector for CompoundVector {
152 fn dim(&self) -> Index {
153 self.space.total_dim
154 }
155
156 fn cache(&self) -> &VectorCache {
157 &self.cache
158 }
159
160 fn make_new(&self) -> Box<dyn Vector> {
161 Box::new(CompoundVector::new(Rc::clone(&self.space)))
162 }
163
164 fn as_any(&self) -> &dyn Any {
165 self
166 }
167
168 fn as_any_mut(&mut self) -> &mut dyn Any {
169 self
170 }
171
172 fn as_tagged(&self) -> &dyn TaggedObject {
173 self
174 }
175
176 fn as_dyn_vector(&self) -> &dyn Vector {
177 self
178 }
179
180 fn copy_impl(&mut self, x: &dyn Vector) {
181 let cx = downcast_compound(x);
182 debug_assert_eq!(self.n_comps(), cx.n_comps());
183 for i in 0..self.comps.len() {
184 self.comps[i].copy(cx.comps[i].as_ref());
185 }
186 }
187
188 fn scal_impl(&mut self, alpha: Number) {
189 for c in &mut self.comps {
190 c.scal(alpha);
191 }
192 }
193
194 fn axpy_impl(&mut self, alpha: Number, x: &dyn Vector) {
195 let cx = downcast_compound(x);
196 debug_assert_eq!(self.n_comps(), cx.n_comps());
197 for i in 0..self.comps.len() {
198 self.comps[i].axpy(alpha, cx.comps[i].as_ref());
199 }
200 }
201
202 fn dot_impl(&self, x: &dyn Vector) -> Number {
203 let cx = downcast_compound(x);
204 debug_assert_eq!(self.n_comps(), cx.n_comps());
205 let mut s = 0.0;
206 for i in 0..self.comps.len() {
207 s += self.comps[i].dot(cx.comps[i].as_ref());
208 }
209 s
210 }
211
212 fn nrm2_impl(&self) -> Number {
213 let mut sum_sq = 0.0;
214 for c in &self.comps {
215 let n = c.nrm2();
216 sum_sq += n * n;
217 }
218 sum_sq.sqrt()
219 }
220
221 fn asum_impl(&self) -> Number {
222 let mut s = 0.0;
223 for c in &self.comps {
224 s += c.asum();
225 }
226 s
227 }
228
229 fn amax_impl(&self) -> Number {
230 let mut m: Number = 0.0;
231 for c in &self.comps {
232 let v = c.amax();
233 if v > m {
234 m = v;
235 }
236 }
237 m
238 }
239
240 fn set_impl(&mut self, value: Number) {
241 for c in &mut self.comps {
242 c.set(value);
243 }
244 }
245
246 fn element_wise_divide_impl(&mut self, x: &dyn Vector) {
247 let cx = downcast_compound(x);
248 for i in 0..self.comps.len() {
249 self.comps[i].element_wise_divide(cx.comps[i].as_ref());
250 }
251 }
252 fn element_wise_multiply_impl(&mut self, x: &dyn Vector) {
253 let cx = downcast_compound(x);
254 for i in 0..self.comps.len() {
255 self.comps[i].element_wise_multiply(cx.comps[i].as_ref());
256 }
257 }
258 fn element_wise_select_impl(&mut self, x: &dyn Vector) {
259 let cx = downcast_compound(x);
260 for i in 0..self.comps.len() {
261 self.comps[i].element_wise_select(cx.comps[i].as_ref());
262 }
263 }
264 fn element_wise_max_impl(&mut self, x: &dyn Vector) {
265 let cx = downcast_compound(x);
266 for i in 0..self.comps.len() {
267 self.comps[i].element_wise_max(cx.comps[i].as_ref());
268 }
269 }
270 fn element_wise_min_impl(&mut self, x: &dyn Vector) {
271 let cx = downcast_compound(x);
272 for i in 0..self.comps.len() {
273 self.comps[i].element_wise_min(cx.comps[i].as_ref());
274 }
275 }
276 fn element_wise_reciprocal_impl(&mut self) {
277 for c in &mut self.comps {
278 c.element_wise_reciprocal();
279 }
280 }
281 fn element_wise_abs_impl(&mut self) {
282 for c in &mut self.comps {
283 c.element_wise_abs();
284 }
285 }
286 fn element_wise_sqrt_impl(&mut self) {
287 for c in &mut self.comps {
288 c.element_wise_sqrt();
289 }
290 }
291 fn element_wise_sgn_impl(&mut self) {
292 for c in &mut self.comps {
293 c.element_wise_sgn();
294 }
295 }
296 fn add_scalar_impl(&mut self, scalar: Number) {
297 for c in &mut self.comps {
298 c.add_scalar(scalar);
299 }
300 }
301
302 fn max_impl(&self) -> Number {
303 debug_assert!(!self.comps.is_empty() && self.dim() > 0);
304 let mut m = -Number::MAX;
305 for c in &self.comps {
306 if c.dim() != 0 {
307 let v = c.max();
308 if v > m {
309 m = v;
310 }
311 }
312 }
313 m
314 }
315
316 fn min_impl(&self) -> Number {
317 debug_assert!(!self.comps.is_empty() && self.dim() > 0);
318 let mut m = Number::MAX;
319 for c in &self.comps {
320 if c.dim() != 0 {
321 let v = c.min();
322 if v < m {
323 m = v;
324 }
325 }
326 }
327 m
328 }
329
330 fn sum_impl(&self) -> Number {
331 let mut s = 0.0;
332 for c in &self.comps {
333 s += c.sum();
334 }
335 s
336 }
337
338 fn sum_logs_impl(&self) -> Number {
339 let mut s = 0.0;
340 for c in &self.comps {
341 s += c.sum_logs();
342 }
343 s
344 }
345
346 fn add_two_vectors_impl(
347 &mut self,
348 a: Number,
349 v1: &dyn Vector,
350 b: Number,
351 v2: &dyn Vector,
352 c: Number,
353 ) {
354 let cv1 = downcast_compound(v1);
355 let cv2 = downcast_compound(v2);
356 debug_assert_eq!(self.n_comps(), cv1.n_comps());
357 debug_assert_eq!(self.n_comps(), cv2.n_comps());
358 for i in 0..self.comps.len() {
359 self.comps[i].add_two_vectors(a, cv1.comps[i].as_ref(), b, cv2.comps[i].as_ref(), c);
360 }
361 }
362
363 fn frac_to_bound_impl(&self, delta: &dyn Vector, tau: Number) -> Number {
364 let cd = downcast_compound(delta);
365 debug_assert_eq!(self.n_comps(), cd.n_comps());
366 let mut alpha: Number = 1.0;
367 for i in 0..self.comps.len() {
368 let a = self.comps[i].frac_to_bound(cd.comps[i].as_ref(), tau);
369 if a < alpha {
370 alpha = a;
371 }
372 }
373 alpha
374 }
375
376 fn add_vector_quotient_impl(&mut self, a: Number, z: &dyn Vector, s: &dyn Vector, c: Number) {
377 let cz = downcast_compound(z);
378 let cs = downcast_compound(s);
379 for i in 0..self.comps.len() {
380 self.comps[i].add_vector_quotient(a, cz.comps[i].as_ref(), cs.comps[i].as_ref(), c);
381 }
382 }
383
384 fn has_valid_numbers_impl(&self) -> bool {
385 for c in &self.comps {
386 if !c.has_valid_numbers() {
387 return false;
388 }
389 }
390 true
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::dense_vector::{DenseVector, DenseVectorSpace};
398
399 fn make_2block_space(d1: Index, d2: Index) -> Rc<CompoundVectorSpace> {
400 let space = CompoundVectorSpace::new(2, d1 + d2);
401 let s1 = DenseVectorSpace::new(d1);
402 let s2 = DenseVectorSpace::new(d2);
403 space.set_comp(0, d1, {
404 let s = Rc::clone(&s1);
405 move || Box::new(DenseVector::new(Rc::clone(&s)))
406 });
407 space.set_comp(1, d2, {
408 let s = Rc::clone(&s2);
409 move || Box::new(DenseVector::new(Rc::clone(&s)))
410 });
411 space
412 }
413
414 fn fill_dense(v: &mut dyn Vector, vals: &[Number]) {
415 let dv = v
416 .as_any_mut()
417 .downcast_mut::<DenseVector>()
418 .expect("DenseVector");
419 dv.set_values(vals);
420 }
421
422 #[test]
423 fn nrm2_combines_blocks() {
424 let space = make_2block_space(2, 3);
425 let mut v = CompoundVector::new(space);
426 fill_dense(v.comp_mut(0), &[3.0, 4.0]); fill_dense(v.comp_mut(1), &[0.0, 0.0, 12.0]); assert!((v.nrm2() - 13.0).abs() < 1e-15);
430 }
431
432 #[test]
433 fn dot_routes_to_blocks() {
434 let space = make_2block_space(2, 2);
435 let mut x = CompoundVector::new(Rc::clone(&space));
436 fill_dense(x.comp_mut(0), &[1.0, 2.0]);
437 fill_dense(x.comp_mut(1), &[3.0, 4.0]);
438 let mut y = CompoundVector::new(Rc::clone(&space));
439 fill_dense(y.comp_mut(0), &[10.0, 20.0]);
440 fill_dense(y.comp_mut(1), &[100.0, 1000.0]);
441 assert_eq!(x.dot(&y), 4350.0);
443 }
444
445 #[test]
446 fn axpy_propagates_to_blocks() {
447 let space = make_2block_space(2, 1);
448 let mut x = CompoundVector::new(Rc::clone(&space));
449 fill_dense(x.comp_mut(0), &[1.0, 1.0]);
450 fill_dense(x.comp_mut(1), &[1.0]);
451 let mut y = CompoundVector::new(Rc::clone(&space));
452 fill_dense(y.comp_mut(0), &[10.0, 20.0]);
453 fill_dense(y.comp_mut(1), &[30.0]);
454 y.axpy(2.0, &x);
455 let dy0 = y.comp(0).as_any().downcast_ref::<DenseVector>().unwrap();
456 let dy1 = y.comp(1).as_any().downcast_ref::<DenseVector>().unwrap();
457 assert_eq!(dy0.values(), &[12.0, 22.0]);
458 assert_eq!(dy1.values(), &[32.0]);
459 }
460
461 #[test]
462 fn asum_sums_block_asums() {
463 let space = make_2block_space(2, 2);
464 let mut x = CompoundVector::new(space);
465 fill_dense(x.comp_mut(0), &[-1.0, 2.0]); fill_dense(x.comp_mut(1), &[3.0, -4.0]); assert_eq!(x.asum(), 10.0);
468 }
469
470 #[test]
471 fn amax_takes_max_across_blocks() {
472 let space = make_2block_space(2, 3);
473 let mut x = CompoundVector::new(space);
474 fill_dense(x.comp_mut(0), &[1.0, -2.0]);
475 fill_dense(x.comp_mut(1), &[0.5, -10.0, 3.0]);
476 assert_eq!(x.amax(), 10.0);
477 }
478
479 #[test]
480 fn frac_to_bound_takes_min_across_blocks() {
481 let space = make_2block_space(2, 1);
482 let mut x = CompoundVector::new(Rc::clone(&space));
483 fill_dense(x.comp_mut(0), &[1.0, 2.0]);
484 fill_dense(x.comp_mut(1), &[3.0]);
485 let mut delta = CompoundVector::new(space);
486 fill_dense(delta.comp_mut(0), &[-2.0, 0.0]); fill_dense(delta.comp_mut(1), &[-1.5]); let alpha = x.frac_to_bound(&delta, 1.0);
490 assert!((alpha - 0.5).abs() < 1e-15);
491 }
492
493 #[test]
494 fn make_new_creates_uninitialized_compound() {
495 let space = make_2block_space(2, 1);
496 let mut x = CompoundVector::new(Rc::clone(&space));
497 fill_dense(x.comp_mut(0), &[1.0, 2.0]);
498 fill_dense(x.comp_mut(1), &[3.0]);
499 let y = x.make_new();
500 let cy = y.as_any().downcast_ref::<CompoundVector>().unwrap();
501 assert_eq!(cy.n_comps(), 2);
502 assert_eq!(cy.comp(0).dim(), 2);
503 assert_eq!(cy.comp(1).dim(), 1);
504 }
505}