Skip to main content

pounce_linalg/
matrix.rs

1//! Matrix trait + cache machinery.
2//!
3//! Mirrors `LinAlg/IpMatrix.{hpp,cpp}` and `LinAlg/IpSymMatrix.hpp`.
4//! Like [`crate::vector::Vector`], the trait splits public BLAS-2 /
5//! reduction routines (which manage the change tag and `valid` cache)
6//! from `*_impl` methods that subclasses override.
7//!
8//! `SymMatrix` is a refinement of `Matrix` — concrete symmetric types
9//! must impl both. Upstream gives `SymMatrix` a default
10//! `TransMultVectorImpl` that calls `MultVector`; we provide the same
11//! pattern via [`SymMatrix`] supplying default impls for the derived
12//! quantities.
13//!
14//! Print is deferred to the iteration-output phase; the public method
15//! is intentionally absent here so we don't drag the Journalist into
16//! every concrete matrix in Phase 2.
17
18use crate::vector::Vector;
19use pounce_common::tagged::{Tag, TaggedCell, TaggedObject};
20use pounce_common::types::{Index, Number};
21use std::any::Any;
22use std::cell::Cell;
23use std::fmt::Debug;
24
25/// Cached `valid_numbers` bit + change tag, embedded by every concrete
26/// matrix type. Mirrors `Matrix::valid_cache_tag_` / `cached_valid_`.
27#[derive(Debug)]
28pub struct MatrixCache {
29    tag: TaggedCell,
30    valid: Cell<Option<(Tag, bool)>>,
31}
32
33impl Default for MatrixCache {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl MatrixCache {
40    pub fn new() -> Self {
41        Self {
42            tag: TaggedCell::new(),
43            valid: Cell::new(None),
44        }
45    }
46
47    pub fn tag(&self) -> Tag {
48        self.tag.tag()
49    }
50
51    /// Equivalent to `TaggedObject::ObjectChanged()`.
52    pub fn bump(&self) {
53        self.tag.bump();
54    }
55}
56
57/// Matrix trait — full Ipopt `Matrix` API minus printing. Object-safe.
58pub trait Matrix: TaggedObject + Debug + 'static {
59    fn n_rows(&self) -> Index;
60    fn n_cols(&self) -> Index;
61    fn cache(&self) -> &MatrixCache;
62
63    fn as_any(&self) -> &dyn Any;
64    fn as_any_mut(&mut self) -> &mut dyn Any;
65    fn as_tagged(&self) -> &dyn TaggedObject;
66    fn as_dyn_matrix(&self) -> &dyn Matrix;
67
68    // ---- pure-virtual implementations ----
69
70    /// `y ← α · M · x + β · y`.
71    fn mult_vector_impl(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector);
72
73    /// `y ← α · Mᵀ · x + β · y`.
74    fn trans_mult_vector_impl(
75        &self,
76        alpha: Number,
77        x: &dyn Vector,
78        beta: Number,
79        y: &mut dyn Vector,
80    );
81
82    /// `rows_norms[i] ← max(rows_norms[i], maxⱼ |M[i,j]|)`. Caller has
83    /// already zeroed `rows_norms` if `init`.
84    fn compute_row_amax_impl(&self, rows_norms: &mut dyn Vector, init: bool);
85
86    /// `cols_norms[j] ← max(cols_norms[j], maxᵢ |M[i,j]|)`. Caller has
87    /// already zeroed `cols_norms` if `init`.
88    fn compute_col_amax_impl(&self, cols_norms: &mut dyn Vector, init: bool);
89
90    // ---- defaultable implementations ----
91
92    /// Default returns true. Concrete matrices override when they hold
93    /// floating-point storage that may go NaN/Inf.
94    fn has_valid_numbers_impl(&self) -> bool {
95        true
96    }
97
98    /// `X = X + α · M · S⁻¹ · Z`. Default: build `tmp = Z./S`, then
99    /// `MultVector(α, tmp, 1, X)`. Override for ExpansionMatrix etc.
100    fn add_m_sinv_z_impl(&self, alpha: Number, s: &dyn Vector, z: &dyn Vector, x: &mut dyn Vector) {
101        let mut tmp = s.make_new_copy();
102        // tmp ← (1)*Z/S + (0)*tmp  ≡  tmp = Z/S
103        tmp.set(0.0);
104        tmp.add_vector_quotient(1.0, z, s, 0.0);
105        self.mult_vector(alpha, tmp.as_dyn_vector(), 1.0, x);
106    }
107
108    /// `X = S⁻¹ · (R + α · Z · Mᵀ · D)`. Default per upstream
109    /// `Matrix::SinvBlrmZMTdBrImpl`.
110    fn sinv_blrm_zmt_dbr_impl(
111        &self,
112        alpha: Number,
113        s: &dyn Vector,
114        r: &dyn Vector,
115        z: &dyn Vector,
116        d: &dyn Vector,
117        x: &mut dyn Vector,
118    ) {
119        self.trans_mult_vector(alpha, d, 0.0, x);
120        x.element_wise_multiply(z);
121        x.axpy(1.0, r);
122        x.element_wise_divide(s);
123    }
124
125    // ---- public API (cache-aware wrappers) ----
126
127    fn mult_vector(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
128        self.mult_vector_impl(alpha, x, beta, y);
129    }
130
131    fn trans_mult_vector(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
132        self.trans_mult_vector_impl(alpha, x, beta, y);
133    }
134
135    fn compute_row_amax(&self, rows_norms: &mut dyn Vector, init: bool) {
136        if init {
137            rows_norms.set(0.0);
138        }
139        self.compute_row_amax_impl(rows_norms, init);
140    }
141
142    fn compute_col_amax(&self, cols_norms: &mut dyn Vector, init: bool) {
143        if init {
144            cols_norms.set(0.0);
145        }
146        self.compute_col_amax_impl(cols_norms, init);
147    }
148
149    fn add_m_sinv_z(&self, alpha: Number, s: &dyn Vector, z: &dyn Vector, x: &mut dyn Vector) {
150        self.add_m_sinv_z_impl(alpha, s, z, x);
151    }
152
153    fn sinv_blrm_zmt_dbr(
154        &self,
155        alpha: Number,
156        s: &dyn Vector,
157        r: &dyn Vector,
158        z: &dyn Vector,
159        d: &dyn Vector,
160        x: &mut dyn Vector,
161    ) {
162        self.sinv_blrm_zmt_dbr_impl(alpha, s, r, z, d, x);
163    }
164
165    fn has_valid_numbers(&self) -> bool {
166        let cur = self.cache().tag();
167        if let Some((t, v)) = self.cache().valid.get() {
168            if t == cur {
169                return v;
170            }
171        }
172        let v = self.has_valid_numbers_impl();
173        self.cache().valid.set(Some((cur, v)));
174        v
175    }
176}
177
178/// Symmetric refinement of [`Matrix`]. A concrete symmetric matrix
179/// implements both [`Matrix`] and [`SymMatrix`]; convenience helpers
180/// for the symmetric-derived overrides are provided as free functions
181/// `sym_default_*` which concrete impls can call from their own
182/// `Matrix::*_impl` bodies.
183pub trait SymMatrix: Matrix {
184    /// `dim()` is identical to `n_rows()` and `n_cols()` by symmetry,
185    /// but mirroring upstream's `SymMatrix::Dim()` accessor is helpful
186    /// for clarity at call sites.
187    fn dim(&self) -> Index {
188        debug_assert_eq!(self.n_rows(), self.n_cols());
189        self.n_rows()
190    }
191}
192
193/// Helper that concrete symmetric matrices may forward to from
194/// [`Matrix::trans_mult_vector_impl`] — exactly what upstream
195/// `SymMatrix::TransMultVectorImpl` does.
196#[inline]
197pub fn sym_default_trans_mult_vector_impl<M: Matrix + ?Sized>(
198    m: &M,
199    alpha: Number,
200    x: &dyn Vector,
201    beta: Number,
202    y: &mut dyn Vector,
203) {
204    m.mult_vector_impl(alpha, x, beta, y);
205}
206
207/// Helper that concrete symmetric matrices may forward to from
208/// [`Matrix::compute_col_amax_impl`] — exactly what upstream
209/// `SymMatrix::ComputeColAMaxImpl` does (row==col by symmetry).
210#[inline]
211pub fn sym_default_compute_col_amax_impl<M: Matrix + ?Sized>(
212    m: &M,
213    cols_norms: &mut dyn Vector,
214    init: bool,
215) {
216    m.compute_row_amax_impl(cols_norms, init);
217}