Skip to main content

pounce_linalg/
vector.rs

1//! Vector trait + cache machinery.
2//!
3//! Mirrors `LinAlg/IpVector.{hpp,cpp}`. The trait splits public BLAS-1
4//! / element-wise routines (which manage the change tag and cached
5//! scalars) from `*_impl` methods that subclasses override. Concrete
6//! implementations are responsible for **never** bumping the tag from
7//! inside their own `*_impl` body — only the public wrappers do that,
8//! exactly mirroring upstream's split between `Vector::Foo` and
9//! `Vector::FooImpl`.
10
11use pounce_common::cached::Cache;
12use pounce_common::tagged::{Tag, TaggedCell, TaggedObject};
13use pounce_common::types::{Index, Number};
14use std::any::Any;
15use std::cell::{Cell, RefCell};
16use std::fmt::Debug;
17
18/// Cached scalar reductions + dot cache + change tag, embedded by
19/// every concrete vector type. Mirrors the mutable members of
20/// upstream `Vector` (the `cached_*` fields and `dot_cache_`).
21#[derive(Debug)]
22pub struct VectorCache {
23    tag: TaggedCell,
24    nrm2: Cell<Option<(Tag, Number)>>,
25    asum: Cell<Option<(Tag, Number)>>,
26    amax: Cell<Option<(Tag, Number)>>,
27    max: Cell<Option<(Tag, Number)>>,
28    min: Cell<Option<(Tag, Number)>>,
29    sum: Cell<Option<(Tag, Number)>>,
30    sum_logs: Cell<Option<(Tag, Number)>>,
31    valid: Cell<Option<(Tag, bool)>>,
32    dot: RefCell<Cache<Number>>,
33}
34
35impl Default for VectorCache {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl VectorCache {
42    /// Upstream `Vector` constructs `dot_cache_(10)`.
43    pub fn new() -> Self {
44        Self {
45            tag: TaggedCell::new(),
46            nrm2: Cell::new(None),
47            asum: Cell::new(None),
48            amax: Cell::new(None),
49            max: Cell::new(None),
50            min: Cell::new(None),
51            sum: Cell::new(None),
52            sum_logs: Cell::new(None),
53            valid: Cell::new(None),
54            dot: RefCell::new(Cache::new(10)),
55        }
56    }
57
58    pub fn tag(&self) -> Tag {
59        self.tag.tag()
60    }
61
62    /// Equivalent to `TaggedObject::ObjectChanged()`.
63    pub fn bump(&self) {
64        self.tag.bump();
65    }
66}
67
68/// Vector trait — full Ipopt `Vector` API. Object-safe.
69pub trait Vector: TaggedObject + Debug + 'static {
70    fn dim(&self) -> Index;
71    fn cache(&self) -> &VectorCache;
72
73    /// Create a new uninitialized vector belonging to the same
74    /// `VectorSpace`. Equivalent to `Vector::MakeNew`.
75    fn make_new(&self) -> Box<dyn Vector>;
76
77    fn as_any(&self) -> &dyn Any;
78    fn as_any_mut(&mut self) -> &mut dyn Any;
79    fn as_tagged(&self) -> &dyn TaggedObject;
80    fn as_dyn_vector(&self) -> &dyn Vector;
81
82    // ---- pure-virtual implementations ----
83
84    fn copy_impl(&mut self, x: &dyn Vector);
85    fn scal_impl(&mut self, alpha: Number);
86    fn axpy_impl(&mut self, alpha: Number, x: &dyn Vector);
87    fn dot_impl(&self, x: &dyn Vector) -> Number;
88    fn nrm2_impl(&self) -> Number;
89    fn asum_impl(&self) -> Number;
90    fn amax_impl(&self) -> Number;
91    fn set_impl(&mut self, alpha: Number);
92    fn element_wise_divide_impl(&mut self, x: &dyn Vector);
93    fn element_wise_multiply_impl(&mut self, x: &dyn Vector);
94    fn element_wise_select_impl(&mut self, x: &dyn Vector);
95    fn element_wise_max_impl(&mut self, x: &dyn Vector);
96    fn element_wise_min_impl(&mut self, x: &dyn Vector);
97    fn element_wise_reciprocal_impl(&mut self);
98    fn element_wise_abs_impl(&mut self);
99    fn element_wise_sqrt_impl(&mut self);
100    fn element_wise_sgn_impl(&mut self);
101    fn add_scalar_impl(&mut self, scalar: Number);
102    fn max_impl(&self) -> Number;
103    fn min_impl(&self) -> Number;
104    fn sum_impl(&self) -> Number;
105    fn sum_logs_impl(&self) -> Number;
106    fn frac_to_bound_impl(&self, delta: &dyn Vector, tau: Number) -> Number;
107    fn add_vector_quotient_impl(&mut self, a: Number, z: &dyn Vector, s: &dyn Vector, c: Number);
108
109    // ---- defaultable implementations ----
110
111    /// Default fallback. Concrete types override for efficiency, but
112    /// the result must remain bit-identical to upstream's
113    /// `DenseVector::AddTwoVectorsImpl` decision tree.
114    fn add_two_vectors_impl(
115        &mut self,
116        a: Number,
117        v1: &dyn Vector,
118        b: Number,
119        v2: &dyn Vector,
120        c: Number,
121    ) {
122        if c == 0.0 {
123            self.set_impl(0.0);
124        } else if c != 1.0 {
125            self.scal_impl(c);
126        }
127        if a != 0.0 {
128            self.axpy_impl(a, v1);
129        }
130        if b != 0.0 {
131            self.axpy_impl(b, v2);
132        }
133    }
134
135    /// Default uses `Asum` finiteness — matches upstream
136    /// `Vector::HasValidNumbersImpl`.
137    fn has_valid_numbers_impl(&self) -> bool {
138        self.asum_impl().is_finite()
139    }
140
141    // ---- public API (cache-aware wrappers) ----
142
143    fn copy(&mut self, x: &dyn Vector) {
144        self.copy_impl(x);
145        self.cache().bump();
146    }
147
148    fn make_new_copy(&self) -> Box<dyn Vector> {
149        let mut c = self.make_new();
150        c.copy(self.as_dyn_vector());
151        c
152    }
153
154    fn scal(&mut self, alpha: Number) {
155        self.scal_impl(alpha);
156        self.cache().bump();
157    }
158
159    fn axpy(&mut self, alpha: Number, x: &dyn Vector) {
160        self.axpy_impl(alpha, x);
161        self.cache().bump();
162    }
163
164    fn dot(&self, x: &dyn Vector) -> Number {
165        // Same-vector shortcut — upstream uses `this == &x` with the
166        // explanation that the cache cannot key on a self-dependency.
167        // We compare cache addresses (each Vector owns a unique
168        // VectorCache at a stable address).
169        if std::ptr::eq(self.cache() as *const _, x.cache() as *const _) {
170            let n = self.nrm2();
171            return n * n;
172        }
173        let mut dot_cache = self.cache().dot.borrow_mut();
174        if let Some(v) = dot_cache.get(&[self.as_tagged(), x.as_tagged()], &[]) {
175            return v;
176        }
177        let v = self.dot_impl(x);
178        dot_cache.add(v, &[self.as_tagged(), x.as_tagged()], &[]);
179        v
180    }
181
182    fn nrm2(&self) -> Number {
183        let cur = self.cache().tag();
184        if let Some((t, v)) = self.cache().nrm2.get() {
185            if t == cur {
186                return v;
187            }
188        }
189        let v = self.nrm2_impl();
190        self.cache().nrm2.set(Some((cur, v)));
191        v
192    }
193
194    fn asum(&self) -> Number {
195        let cur = self.cache().tag();
196        if let Some((t, v)) = self.cache().asum.get() {
197            if t == cur {
198                return v;
199            }
200        }
201        let v = self.asum_impl();
202        self.cache().asum.set(Some((cur, v)));
203        v
204    }
205
206    fn amax(&self) -> Number {
207        let cur = self.cache().tag();
208        if let Some((t, v)) = self.cache().amax.get() {
209            if t == cur {
210                return v;
211            }
212        }
213        let v = self.amax_impl();
214        self.cache().amax.set(Some((cur, v)));
215        v
216    }
217
218    fn set(&mut self, alpha: Number) {
219        self.set_impl(alpha);
220        self.cache().bump();
221    }
222
223    fn element_wise_divide(&mut self, x: &dyn Vector) {
224        self.element_wise_divide_impl(x);
225        self.cache().bump();
226    }
227    fn element_wise_multiply(&mut self, x: &dyn Vector) {
228        self.element_wise_multiply_impl(x);
229        self.cache().bump();
230    }
231    fn element_wise_select(&mut self, x: &dyn Vector) {
232        self.element_wise_select_impl(x);
233        self.cache().bump();
234    }
235    fn element_wise_max(&mut self, x: &dyn Vector) {
236        self.element_wise_max_impl(x);
237        self.cache().bump();
238    }
239    fn element_wise_min(&mut self, x: &dyn Vector) {
240        self.element_wise_min_impl(x);
241        self.cache().bump();
242    }
243    fn element_wise_reciprocal(&mut self) {
244        self.element_wise_reciprocal_impl();
245        self.cache().bump();
246    }
247    fn element_wise_abs(&mut self) {
248        self.element_wise_abs_impl();
249        self.cache().bump();
250    }
251    fn element_wise_sqrt(&mut self) {
252        self.element_wise_sqrt_impl();
253        self.cache().bump();
254    }
255    fn element_wise_sgn(&mut self) {
256        self.element_wise_sgn_impl();
257        self.cache().bump();
258    }
259    fn add_scalar(&mut self, scalar: Number) {
260        self.add_scalar_impl(scalar);
261        self.cache().bump();
262    }
263
264    fn max(&self) -> Number {
265        let cur = self.cache().tag();
266        if let Some((t, v)) = self.cache().max.get() {
267            if t == cur {
268                return v;
269            }
270        }
271        let v = self.max_impl();
272        self.cache().max.set(Some((cur, v)));
273        v
274    }
275
276    fn min(&self) -> Number {
277        let cur = self.cache().tag();
278        if let Some((t, v)) = self.cache().min.get() {
279            if t == cur {
280                return v;
281            }
282        }
283        let v = self.min_impl();
284        self.cache().min.set(Some((cur, v)));
285        v
286    }
287
288    fn sum(&self) -> Number {
289        let cur = self.cache().tag();
290        if let Some((t, v)) = self.cache().sum.get() {
291            if t == cur {
292                return v;
293            }
294        }
295        let v = self.sum_impl();
296        self.cache().sum.set(Some((cur, v)));
297        v
298    }
299
300    fn sum_logs(&self) -> Number {
301        let cur = self.cache().tag();
302        if let Some((t, v)) = self.cache().sum_logs.get() {
303            if t == cur {
304                return v;
305            }
306        }
307        let v = self.sum_logs_impl();
308        self.cache().sum_logs.set(Some((cur, v)));
309        v
310    }
311
312    fn add_one_vector(&mut self, a: Number, v1: &dyn Vector, c: Number) {
313        // Upstream: AddTwoVectors(a, v1, 0., v1, c).
314        self.add_two_vectors(a, v1, 0.0, v1, c);
315    }
316
317    fn add_two_vectors(
318        &mut self,
319        a: Number,
320        v1: &dyn Vector,
321        b: Number,
322        v2: &dyn Vector,
323        c: Number,
324    ) {
325        self.add_two_vectors_impl(a, v1, b, v2, c);
326        self.cache().bump();
327    }
328
329    /// No cache (matches upstream comment in `IpVector.hpp:820` —
330    /// caching here interferes with the quality function search).
331    fn frac_to_bound(&self, delta: &dyn Vector, tau: Number) -> Number {
332        self.frac_to_bound_impl(delta, tau)
333    }
334
335    fn add_vector_quotient(&mut self, a: Number, z: &dyn Vector, s: &dyn Vector, c: Number) {
336        self.add_vector_quotient_impl(a, z, s, c);
337        self.cache().bump();
338    }
339
340    fn has_valid_numbers(&self) -> bool {
341        let cur = self.cache().tag();
342        if let Some((t, v)) = self.cache().valid.get() {
343            if t == cur {
344                return v;
345            }
346        }
347        let v = self.has_valid_numbers_impl();
348        self.cache().valid.set(Some((cur, v)));
349        v
350    }
351}