Skip to main content

pounce_linalg/
compound_vector.rs

1//! Block vector — port of `LinAlg/IpCompoundVector.{hpp,cpp}`.
2//!
3//! Stacks zero-or-more component vectors into a single virtual vector.
4//! Every operation dispatches block-by-block in the order the
5//! components were registered, matching upstream's iteration order
6//! exactly so that summed reductions (`nrm2`, `dot`, `sum`) preserve
7//! bit-equivalence under the same component layout.
8//!
9//! Component construction uses a `Vec` of factory closures rather
10//! than a `VectorSpace` trait — see the docstring on
11//! [`CompoundVectorSpace::set_comp`] for usage.
12
13use crate::vector::{Vector, VectorCache};
14use pounce_common::tagged::{Tag, TaggedObject};
15use pounce_common::types::{Index, Number};
16use std::any::Any;
17use std::cell::RefCell;
18use std::rc::Rc;
19
20type CompFactory = Box<dyn Fn() -> Box<dyn Vector>>;
21
22/// Vector space describing the block layout. Constructed by
23/// [`CompoundVectorSpace::new`], populated with [`set_comp`], then
24/// passed to [`CompoundVector::new`] to create new compound vectors.
25pub struct CompoundVectorSpace {
26    total_dim: Index,
27    n_comp_spaces: Index,
28    /// Component dimensions; `Index::MIN` until set.
29    comp_dims: RefCell<Vec<Index>>,
30    factories: RefCell<Vec<Option<CompFactory>>>,
31}
32
33impl std::fmt::Debug for CompoundVectorSpace {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("CompoundVectorSpace")
36            .field("total_dim", &self.total_dim)
37            .field("n_comp_spaces", &self.n_comp_spaces)
38            .field("comp_dims", &self.comp_dims.borrow())
39            .finish()
40    }
41}
42
43impl CompoundVectorSpace {
44    pub fn new(n_comp_spaces: Index, total_dim: Index) -> Rc<Self> {
45        let mut factories: Vec<Option<CompFactory>> = Vec::with_capacity(n_comp_spaces as usize);
46        for _ in 0..n_comp_spaces {
47            factories.push(None);
48        }
49        Rc::new(Self {
50            total_dim,
51            n_comp_spaces,
52            comp_dims: RefCell::new(vec![0; n_comp_spaces as usize]),
53            factories: RefCell::new(factories),
54        })
55    }
56
57    pub fn dim(&self) -> Index {
58        self.total_dim
59    }
60
61    pub fn n_comp_spaces(&self) -> Index {
62        self.n_comp_spaces
63    }
64
65    pub fn comp_dim(&self, icomp: Index) -> Index {
66        self.comp_dims.borrow()[icomp as usize]
67    }
68
69    /// Register the factory that builds a fresh component at slot
70    /// `icomp`. The factory closure must capture (typically by `Rc`)
71    /// any subspace it needs. Mirrors upstream
72    /// `CompoundVectorSpace::SetCompSpace`.
73    pub fn set_comp<F>(&self, icomp: Index, dim: Index, factory: F)
74    where
75        F: Fn() -> Box<dyn Vector> + 'static,
76    {
77        assert!(icomp < self.n_comp_spaces);
78        self.comp_dims.borrow_mut()[icomp as usize] = dim;
79        self.factories.borrow_mut()[icomp as usize] = Some(Box::new(factory));
80    }
81}
82
83/// Compound (block) vector. Owns its components.
84#[derive(Debug)]
85pub struct CompoundVector {
86    space: Rc<CompoundVectorSpace>,
87    cache: VectorCache,
88    comps: Vec<Box<dyn Vector>>,
89}
90
91impl CompoundVector {
92    /// Construct, calling each registered factory once. Equivalent to
93    /// upstream `CompoundVector(owner_space, /*create_new=*/true)`.
94    pub fn new(space: Rc<CompoundVectorSpace>) -> Self {
95        let n = space.n_comp_spaces() as usize;
96        let mut comps: Vec<Box<dyn Vector>> = Vec::with_capacity(n);
97        let factories = space.factories.borrow();
98        let mut dim_check: Index = 0;
99        for f in factories.iter() {
100            let factory = match f.as_ref() {
101                Some(fac) => fac,
102                None => panic!("CompoundVectorSpace component not set — call set_comp on every component before constructing a CompoundVector"),
103            };
104            let v = factory();
105            dim_check += v.dim();
106            comps.push(v);
107        }
108        debug_assert_eq!(dim_check, space.total_dim);
109        drop(factories);
110        Self {
111            space,
112            cache: VectorCache::new(),
113            comps,
114        }
115    }
116
117    pub fn n_comps(&self) -> Index {
118        self.comps.len() as Index
119    }
120
121    pub fn comp(&self, i: Index) -> &dyn Vector {
122        self.comps[i as usize].as_ref()
123    }
124
125    /// Mutable access to a component. Marks the compound as changed,
126    /// matching upstream's `GetCompNonConst` which calls
127    /// `ObjectChanged()` because the caller is about to mutate.
128    pub fn comp_mut(&mut self, i: Index) -> &mut dyn Vector {
129        self.cache.bump();
130        self.comps[i as usize].as_mut()
131    }
132
133    pub fn space(&self) -> &Rc<CompoundVectorSpace> {
134        &self.space
135    }
136}
137
138fn downcast_compound(x: &dyn Vector) -> &CompoundVector {
139    match x.as_any().downcast_ref::<CompoundVector>() {
140        Some(v) => v,
141        None => panic!("Vector argument is not a CompoundVector"),
142    }
143}
144
145impl TaggedObject for CompoundVector {
146    fn get_tag(&self) -> Tag {
147        self.cache.tag()
148    }
149}
150
151impl Vector for CompoundVector {
152    fn dim(&self) -> Index {
153        self.space.total_dim
154    }
155
156    fn cache(&self) -> &VectorCache {
157        &self.cache
158    }
159
160    fn make_new(&self) -> Box<dyn Vector> {
161        Box::new(CompoundVector::new(Rc::clone(&self.space)))
162    }
163
164    fn as_any(&self) -> &dyn Any {
165        self
166    }
167
168    fn as_any_mut(&mut self) -> &mut dyn Any {
169        self
170    }
171
172    fn as_tagged(&self) -> &dyn TaggedObject {
173        self
174    }
175
176    fn as_dyn_vector(&self) -> &dyn Vector {
177        self
178    }
179
180    fn copy_impl(&mut self, x: &dyn Vector) {
181        let cx = downcast_compound(x);
182        debug_assert_eq!(self.n_comps(), cx.n_comps());
183        for i in 0..self.comps.len() {
184            self.comps[i].copy(cx.comps[i].as_ref());
185        }
186    }
187
188    fn scal_impl(&mut self, alpha: Number) {
189        for c in &mut self.comps {
190            c.scal(alpha);
191        }
192    }
193
194    fn axpy_impl(&mut self, alpha: Number, x: &dyn Vector) {
195        let cx = downcast_compound(x);
196        debug_assert_eq!(self.n_comps(), cx.n_comps());
197        for i in 0..self.comps.len() {
198            self.comps[i].axpy(alpha, cx.comps[i].as_ref());
199        }
200    }
201
202    fn dot_impl(&self, x: &dyn Vector) -> Number {
203        let cx = downcast_compound(x);
204        debug_assert_eq!(self.n_comps(), cx.n_comps());
205        let mut s = 0.0;
206        for i in 0..self.comps.len() {
207            s += self.comps[i].dot(cx.comps[i].as_ref());
208        }
209        s
210    }
211
212    fn nrm2_impl(&self) -> Number {
213        let mut sum_sq = 0.0;
214        for c in &self.comps {
215            let n = c.nrm2();
216            sum_sq += n * n;
217        }
218        sum_sq.sqrt()
219    }
220
221    fn asum_impl(&self) -> Number {
222        let mut s = 0.0;
223        for c in &self.comps {
224            s += c.asum();
225        }
226        s
227    }
228
229    fn amax_impl(&self) -> Number {
230        let mut m: Number = 0.0;
231        for c in &self.comps {
232            let v = c.amax();
233            if v > m {
234                m = v;
235            }
236        }
237        m
238    }
239
240    fn set_impl(&mut self, value: Number) {
241        for c in &mut self.comps {
242            c.set(value);
243        }
244    }
245
246    fn element_wise_divide_impl(&mut self, x: &dyn Vector) {
247        let cx = downcast_compound(x);
248        for i in 0..self.comps.len() {
249            self.comps[i].element_wise_divide(cx.comps[i].as_ref());
250        }
251    }
252    fn element_wise_multiply_impl(&mut self, x: &dyn Vector) {
253        let cx = downcast_compound(x);
254        for i in 0..self.comps.len() {
255            self.comps[i].element_wise_multiply(cx.comps[i].as_ref());
256        }
257    }
258    fn element_wise_select_impl(&mut self, x: &dyn Vector) {
259        let cx = downcast_compound(x);
260        for i in 0..self.comps.len() {
261            self.comps[i].element_wise_select(cx.comps[i].as_ref());
262        }
263    }
264    fn element_wise_max_impl(&mut self, x: &dyn Vector) {
265        let cx = downcast_compound(x);
266        for i in 0..self.comps.len() {
267            self.comps[i].element_wise_max(cx.comps[i].as_ref());
268        }
269    }
270    fn element_wise_min_impl(&mut self, x: &dyn Vector) {
271        let cx = downcast_compound(x);
272        for i in 0..self.comps.len() {
273            self.comps[i].element_wise_min(cx.comps[i].as_ref());
274        }
275    }
276    fn element_wise_reciprocal_impl(&mut self) {
277        for c in &mut self.comps {
278            c.element_wise_reciprocal();
279        }
280    }
281    fn element_wise_abs_impl(&mut self) {
282        for c in &mut self.comps {
283            c.element_wise_abs();
284        }
285    }
286    fn element_wise_sqrt_impl(&mut self) {
287        for c in &mut self.comps {
288            c.element_wise_sqrt();
289        }
290    }
291    fn element_wise_sgn_impl(&mut self) {
292        for c in &mut self.comps {
293            c.element_wise_sgn();
294        }
295    }
296    fn add_scalar_impl(&mut self, scalar: Number) {
297        for c in &mut self.comps {
298            c.add_scalar(scalar);
299        }
300    }
301
302    fn max_impl(&self) -> Number {
303        debug_assert!(!self.comps.is_empty() && self.dim() > 0);
304        let mut m = -Number::MAX;
305        for c in &self.comps {
306            if c.dim() != 0 {
307                let v = c.max();
308                if v > m {
309                    m = v;
310                }
311            }
312        }
313        m
314    }
315
316    fn min_impl(&self) -> Number {
317        debug_assert!(!self.comps.is_empty() && self.dim() > 0);
318        let mut m = Number::MAX;
319        for c in &self.comps {
320            if c.dim() != 0 {
321                let v = c.min();
322                if v < m {
323                    m = v;
324                }
325            }
326        }
327        m
328    }
329
330    fn sum_impl(&self) -> Number {
331        let mut s = 0.0;
332        for c in &self.comps {
333            s += c.sum();
334        }
335        s
336    }
337
338    fn sum_logs_impl(&self) -> Number {
339        let mut s = 0.0;
340        for c in &self.comps {
341            s += c.sum_logs();
342        }
343        s
344    }
345
346    fn add_two_vectors_impl(
347        &mut self,
348        a: Number,
349        v1: &dyn Vector,
350        b: Number,
351        v2: &dyn Vector,
352        c: Number,
353    ) {
354        let cv1 = downcast_compound(v1);
355        let cv2 = downcast_compound(v2);
356        debug_assert_eq!(self.n_comps(), cv1.n_comps());
357        debug_assert_eq!(self.n_comps(), cv2.n_comps());
358        for i in 0..self.comps.len() {
359            self.comps[i].add_two_vectors(a, cv1.comps[i].as_ref(), b, cv2.comps[i].as_ref(), c);
360        }
361    }
362
363    fn frac_to_bound_impl(&self, delta: &dyn Vector, tau: Number) -> Number {
364        let cd = downcast_compound(delta);
365        debug_assert_eq!(self.n_comps(), cd.n_comps());
366        let mut alpha: Number = 1.0;
367        for i in 0..self.comps.len() {
368            let a = self.comps[i].frac_to_bound(cd.comps[i].as_ref(), tau);
369            if a < alpha {
370                alpha = a;
371            }
372        }
373        alpha
374    }
375
376    fn add_vector_quotient_impl(&mut self, a: Number, z: &dyn Vector, s: &dyn Vector, c: Number) {
377        let cz = downcast_compound(z);
378        let cs = downcast_compound(s);
379        for i in 0..self.comps.len() {
380            self.comps[i].add_vector_quotient(a, cz.comps[i].as_ref(), cs.comps[i].as_ref(), c);
381        }
382    }
383
384    fn has_valid_numbers_impl(&self) -> bool {
385        for c in &self.comps {
386            if !c.has_valid_numbers() {
387                return false;
388            }
389        }
390        true
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::dense_vector::{DenseVector, DenseVectorSpace};
398
399    fn make_2block_space(d1: Index, d2: Index) -> Rc<CompoundVectorSpace> {
400        let space = CompoundVectorSpace::new(2, d1 + d2);
401        let s1 = DenseVectorSpace::new(d1);
402        let s2 = DenseVectorSpace::new(d2);
403        space.set_comp(0, d1, {
404            let s = Rc::clone(&s1);
405            move || Box::new(DenseVector::new(Rc::clone(&s)))
406        });
407        space.set_comp(1, d2, {
408            let s = Rc::clone(&s2);
409            move || Box::new(DenseVector::new(Rc::clone(&s)))
410        });
411        space
412    }
413
414    fn fill_dense(v: &mut dyn Vector, vals: &[Number]) {
415        let dv = v
416            .as_any_mut()
417            .downcast_mut::<DenseVector>()
418            .expect("DenseVector");
419        dv.set_values(vals);
420    }
421
422    #[test]
423    fn nrm2_combines_blocks() {
424        let space = make_2block_space(2, 3);
425        let mut v = CompoundVector::new(space);
426        fill_dense(v.comp_mut(0), &[3.0, 4.0]); // nrm2 = 5
427        fill_dense(v.comp_mut(1), &[0.0, 0.0, 12.0]); // nrm2 = 12
428                                                      // sqrt(25 + 144) = 13
429        assert!((v.nrm2() - 13.0).abs() < 1e-15);
430    }
431
432    #[test]
433    fn dot_routes_to_blocks() {
434        let space = make_2block_space(2, 2);
435        let mut x = CompoundVector::new(Rc::clone(&space));
436        fill_dense(x.comp_mut(0), &[1.0, 2.0]);
437        fill_dense(x.comp_mut(1), &[3.0, 4.0]);
438        let mut y = CompoundVector::new(Rc::clone(&space));
439        fill_dense(y.comp_mut(0), &[10.0, 20.0]);
440        fill_dense(y.comp_mut(1), &[100.0, 1000.0]);
441        // 1*10 + 2*20 + 3*100 + 4*1000 = 10 + 40 + 300 + 4000 = 4350
442        assert_eq!(x.dot(&y), 4350.0);
443    }
444
445    #[test]
446    fn axpy_propagates_to_blocks() {
447        let space = make_2block_space(2, 1);
448        let mut x = CompoundVector::new(Rc::clone(&space));
449        fill_dense(x.comp_mut(0), &[1.0, 1.0]);
450        fill_dense(x.comp_mut(1), &[1.0]);
451        let mut y = CompoundVector::new(Rc::clone(&space));
452        fill_dense(y.comp_mut(0), &[10.0, 20.0]);
453        fill_dense(y.comp_mut(1), &[30.0]);
454        y.axpy(2.0, &x);
455        let dy0 = y.comp(0).as_any().downcast_ref::<DenseVector>().unwrap();
456        let dy1 = y.comp(1).as_any().downcast_ref::<DenseVector>().unwrap();
457        assert_eq!(dy0.values(), &[12.0, 22.0]);
458        assert_eq!(dy1.values(), &[32.0]);
459    }
460
461    #[test]
462    fn asum_sums_block_asums() {
463        let space = make_2block_space(2, 2);
464        let mut x = CompoundVector::new(space);
465        fill_dense(x.comp_mut(0), &[-1.0, 2.0]); // asum = 3
466        fill_dense(x.comp_mut(1), &[3.0, -4.0]); // asum = 7
467        assert_eq!(x.asum(), 10.0);
468    }
469
470    #[test]
471    fn amax_takes_max_across_blocks() {
472        let space = make_2block_space(2, 3);
473        let mut x = CompoundVector::new(space);
474        fill_dense(x.comp_mut(0), &[1.0, -2.0]);
475        fill_dense(x.comp_mut(1), &[0.5, -10.0, 3.0]);
476        assert_eq!(x.amax(), 10.0);
477    }
478
479    #[test]
480    fn frac_to_bound_takes_min_across_blocks() {
481        let space = make_2block_space(2, 1);
482        let mut x = CompoundVector::new(Rc::clone(&space));
483        fill_dense(x.comp_mut(0), &[1.0, 2.0]);
484        fill_dense(x.comp_mut(1), &[3.0]);
485        let mut delta = CompoundVector::new(space);
486        fill_dense(delta.comp_mut(0), &[-2.0, 0.0]); // alpha = tau/2 * 1 = 0.5
487        fill_dense(delta.comp_mut(1), &[-1.5]); // alpha = tau/1.5 * 3 = 2*tau
488                                                // Min(0.5, 2*tau) for tau=1 → 0.5
489        let alpha = x.frac_to_bound(&delta, 1.0);
490        assert!((alpha - 0.5).abs() < 1e-15);
491    }
492
493    #[test]
494    fn make_new_creates_uninitialized_compound() {
495        let space = make_2block_space(2, 1);
496        let mut x = CompoundVector::new(Rc::clone(&space));
497        fill_dense(x.comp_mut(0), &[1.0, 2.0]);
498        fill_dense(x.comp_mut(1), &[3.0]);
499        let y = x.make_new();
500        let cy = y.as_any().downcast_ref::<CompoundVector>().unwrap();
501        assert_eq!(cy.n_comps(), 2);
502        assert_eq!(cy.comp(0).dim(), 2);
503        assert_eq!(cy.comp(1).dim(), 1);
504    }
505}