Skip to main content

arael_sym/
linalg.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::ops;
4use super::{AsVarName, E, constant};
5
6// ============================================================
7// SymVec
8// ============================================================
9
10/// Symbolic vector of expressions.
11///
12/// Supports element-wise arithmetic, dot product, differentiation,
13/// evaluation, and code generation. Indexing is zero-based.
14#[derive(Debug, Clone, PartialEq)]
15pub struct SymVec(pub Vec<E>);
16
17impl SymVec {
18    /// Create a symbolic vector from a list of expressions. Accepts
19    /// any iterable of `Into<E>` so you can pass bare numeric
20    /// literals: `SymVec::new([1.0, 2.0, 3.0])`.
21    pub fn new<I>(elems: I) -> Self
22    where
23        I: IntoIterator,
24        I::Item: Into<E>,
25    {
26        SymVec(elems.into_iter().map(Into::into).collect())
27    }
28
29    /// Return the number of elements.
30    pub fn len(&self) -> usize {
31        self.0.len()
32    }
33
34    /// Return true if the vector has no elements.
35    pub fn is_empty(&self) -> bool {
36        self.0.is_empty()
37    }
38
39    /// Get a reference to the element at index `i`.
40    pub fn get(&self, i: usize) -> &E {
41        &self.0[i]
42    }
43
44    /// Compute the dot product with another symbolic vector.
45    pub fn dot(&self, other: &SymVec) -> E {
46        assert_eq!(self.len(), other.len(), "dot product: length mismatch");
47        let mut terms: Vec<E> = Vec::with_capacity(self.len());
48        for i in 0..self.len() {
49            terms.push(self.0[i].clone() * other.0[i].clone());
50        }
51        terms.into_iter().reduce(|a, b| a + b).unwrap_or_else(|| constant(0.0))
52    }
53
54    /// Differentiate each element with respect to a variable.
55    pub fn diff(&self, var: impl AsVarName) -> SymVec {
56        let v = var.var_name();
57        SymVec(self.0.iter().map(|e| e.diff(v)).collect())
58    }
59
60    /// Evaluate each element numerically given variable bindings.
61    pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<Vec<f64>, String> {
62        self.0.iter().map(|e| e.eval(vars)).collect()
63    }
64
65    /// Simplify each element.
66    pub fn simplify(&self) -> SymVec {
67        SymVec(self.0.iter().map(|e| e.simplify()).collect())
68    }
69
70    /// Expand each element (distribute products over sums).
71    pub fn expand(&self) -> SymVec {
72        SymVec(self.0.iter().map(|e| e.expand()).collect())
73    }
74
75    /// Substitute a variable in each element.
76    pub fn subs(&self, var: impl AsVarName, replacement: &E) -> SymVec {
77        let name = var.var_name();
78        SymVec(self.0.iter().map(|e| e.subs(name, replacement)).collect())
79    }
80
81    /// Format the vector as a LaTeX column vector (pmatrix).
82    pub fn to_latex(&self) -> String {
83        let mut buf = String::from("\\begin{pmatrix} ");
84        for (i, e) in self.0.iter().enumerate() {
85            if i > 0 { buf.push_str(" \\\\ "); }
86            buf.push_str(&e.to_latex());
87        }
88        buf.push_str(" \\end{pmatrix}");
89        buf
90    }
91
92    /// Generate Rust source code for the vector as a literal array.
93    pub fn to_rust(&self, ft: &str) -> String {
94        let mut buf = String::from("[");
95        for (i, e) in self.0.iter().enumerate() {
96            if i > 0 { buf.push_str(", "); }
97            buf.push_str(&e.to_rust(ft));
98        }
99        buf.push(']');
100        buf
101    }
102}
103
104impl ops::Index<usize> for SymVec {
105    type Output = E;
106    fn index(&self, i: usize) -> &E {
107        &self.0[i]
108    }
109}
110
111impl ops::Add for SymVec {
112    type Output = SymVec;
113    fn add(self, rhs: SymVec) -> SymVec {
114        assert_eq!(self.len(), rhs.len(), "SymVec add: length mismatch");
115        SymVec(
116            self.0.into_iter().zip(rhs.0)
117                .map(|(a, b)| a + b)
118                .collect()
119        )
120    }
121}
122
123impl ops::Mul<E> for SymVec {
124    type Output = SymVec;
125    fn mul(self, rhs: E) -> SymVec {
126        SymVec(self.0.into_iter().map(|e| e * rhs.clone()).collect())
127    }
128}
129
130impl ops::Mul<SymVec> for E {
131    type Output = SymVec;
132    fn mul(self, rhs: SymVec) -> SymVec {
133        SymVec(rhs.0.into_iter().map(|e| self.clone() * e).collect())
134    }
135}
136
137impl fmt::Display for SymVec {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        write!(f, "[")?;
140        for (i, e) in self.0.iter().enumerate() {
141            if i > 0 { write!(f, ", ")?; }
142            fmt::Display::fmt(e, f)?;
143        }
144        write!(f, "]")
145    }
146}
147
148// ============================================================
149// SymMat
150// ============================================================
151
152/// Symbolic matrix of expressions.
153///
154/// Stored in row-major order. Supports transpose, matrix multiplication,
155/// matrix-vector product, element-wise differentiation, and code generation.
156#[derive(Debug, Clone, PartialEq)]
157pub struct SymMat {
158    /// Number of rows.
159    pub rows: usize,
160    /// Number of columns.
161    pub cols: usize,
162    /// Row-major element data.
163    pub data: Vec<E>,
164}
165
166impl SymMat {
167    /// Create a matrix from dimensions and row-major data. Accepts
168    /// any iterable of `Into<E>` so you can pass bare numeric literals:
169    /// `SymMat::new(2, 2, [1.0, 2.0, 3.0, 4.0])`.
170    pub fn new<I>(rows: usize, cols: usize, data: I) -> Self
171    where
172        I: IntoIterator,
173        I::Item: Into<E>,
174    {
175        let data: Vec<E> = data.into_iter().map(Into::into).collect();
176        assert_eq!(data.len(), rows * cols, "SymMat::new: data size mismatch");
177        SymMat { rows, cols, data }
178    }
179
180    /// Create a zero matrix of the given dimensions.
181    pub fn zeros(rows: usize, cols: usize) -> Self {
182        SymMat {
183            rows,
184            cols,
185            data: vec![constant(0.0); rows * cols],
186        }
187    }
188
189    /// Create an n-by-n identity matrix.
190    pub fn identity(n: usize) -> Self {
191        let mut data = vec![constant(0.0); n * n];
192        for i in 0..n {
193            data[i * n + i] = constant(1.0);
194        }
195        SymMat { rows: n, cols: n, data }
196    }
197
198    /// Get a reference to the element at row `i`, column `j`.
199    pub fn get(&self, i: usize, j: usize) -> &E {
200        &self.data[i * self.cols + j]
201    }
202
203    /// Set the element at row `i`, column `j`.
204    pub fn set(&mut self, i: usize, j: usize, val: E) {
205        self.data[i * self.cols + j] = val;
206    }
207
208    /// Return the transpose of this matrix.
209    pub fn transpose(&self) -> SymMat {
210        let mut data = Vec::with_capacity(self.rows * self.cols);
211        for j in 0..self.cols {
212            for i in 0..self.rows {
213                data.push(self.get(i, j).clone());
214            }
215        }
216        SymMat { rows: self.cols, cols: self.rows, data }
217    }
218
219    /// Differentiate every element with respect to a variable.
220    pub fn diff(&self, var: impl AsVarName) -> SymMat {
221        let v = var.var_name();
222        SymMat {
223            rows: self.rows,
224            cols: self.cols,
225            data: self.data.iter().map(|e| e.diff(v)).collect(),
226        }
227    }
228
229    /// Evaluate every element numerically, returning a nested `Vec<Vec<f64>>`.
230    pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<Vec<Vec<f64>>, String> {
231        let mut result = Vec::with_capacity(self.rows);
232        for i in 0..self.rows {
233            let mut row = Vec::with_capacity(self.cols);
234            for j in 0..self.cols {
235                row.push(self.get(i, j).eval(vars)?);
236            }
237            result.push(row);
238        }
239        Ok(result)
240    }
241
242    /// Simplify every element.
243    pub fn simplify(&self) -> SymMat {
244        SymMat {
245            rows: self.rows,
246            cols: self.cols,
247            data: self.data.iter().map(|e| e.simplify()).collect(),
248        }
249    }
250
251    /// Expand every element (distribute products over sums).
252    pub fn expand(&self) -> SymMat {
253        SymMat {
254            rows: self.rows,
255            cols: self.cols,
256            data: self.data.iter().map(|e| e.expand()).collect(),
257        }
258    }
259
260    /// Substitute a variable in every element.
261    pub fn subs(&self, var: impl AsVarName, replacement: &E) -> SymMat {
262        let name = var.var_name();
263        SymMat {
264            rows: self.rows,
265            cols: self.cols,
266            data: self.data.iter().map(|e| e.subs(name, replacement)).collect(),
267        }
268    }
269
270    /// Format the matrix as a LaTeX pmatrix.
271    pub fn to_latex(&self) -> String {
272        let mut buf = String::from("\\begin{pmatrix} ");
273        for i in 0..self.rows {
274            if i > 0 { buf.push_str(" \\\\ "); }
275            for j in 0..self.cols {
276                if j > 0 { buf.push_str(" & "); }
277                buf.push_str(&self.get(i, j).to_latex());
278            }
279        }
280        buf.push_str(" \\end{pmatrix}");
281        buf
282    }
283
284    /// Generate Rust source code for the matrix as a nested array literal.
285    pub fn to_rust(&self, ft: &str) -> String {
286        let mut buf = String::from("[");
287        for i in 0..self.rows {
288            if i > 0 { buf.push_str(", "); }
289            buf.push('[');
290            for j in 0..self.cols {
291                if j > 0 { buf.push_str(", "); }
292                buf.push_str(&self.get(i, j).to_rust(ft));
293            }
294            buf.push(']');
295        }
296        buf.push(']');
297        buf
298    }
299}
300
301// SymMat + SymMat
302impl ops::Add for SymMat {
303    type Output = SymMat;
304    fn add(self, rhs: SymMat) -> SymMat {
305        assert_eq!((self.rows, self.cols), (rhs.rows, rhs.cols), "SymMat add: dimension mismatch");
306        SymMat {
307            rows: self.rows,
308            cols: self.cols,
309            data: self.data.into_iter().zip(rhs.data)
310                .map(|(a, b)| a + b)
311                .collect(),
312        }
313    }
314}
315
316// SymMat * SymMat
317impl ops::Mul for SymMat {
318    type Output = SymMat;
319    fn mul(self, rhs: SymMat) -> SymMat {
320        assert_eq!(self.cols, rhs.rows, "SymMat mul: dimension mismatch");
321        let mut data = Vec::with_capacity(self.rows * rhs.cols);
322        for i in 0..self.rows {
323            for j in 0..rhs.cols {
324                let mut sum: Option<E> = None;
325                for k in 0..self.cols {
326                    let prod = self.get(i, k).clone() * rhs.get(k, j).clone();
327                    sum = Some(match sum {
328                        Some(acc) => acc + prod,
329                        None => prod,
330                    });
331                }
332                data.push(sum.unwrap_or_else(|| constant(0.0)));
333            }
334        }
335        SymMat { rows: self.rows, cols: rhs.cols, data }
336    }
337}
338
339// SymMat * SymVec
340impl ops::Mul<SymVec> for SymMat {
341    type Output = SymVec;
342    fn mul(self, rhs: SymVec) -> SymVec {
343        assert_eq!(self.cols, rhs.len(), "SymMat * SymVec: dimension mismatch");
344        let mut result = Vec::with_capacity(self.rows);
345        for i in 0..self.rows {
346            let mut sum: Option<E> = None;
347            for j in 0..self.cols {
348                let prod = self.get(i, j).clone() * rhs[j].clone();
349                sum = Some(match sum {
350                    Some(acc) => acc + prod,
351                    None => prod,
352                });
353            }
354            result.push(sum.unwrap_or_else(|| constant(0.0)));
355        }
356        SymVec(result)
357    }
358}
359
360// SymMat * E (scalar)
361impl ops::Mul<E> for SymMat {
362    type Output = SymMat;
363    fn mul(self, rhs: E) -> SymMat {
364        SymMat {
365            rows: self.rows,
366            cols: self.cols,
367            data: self.data.into_iter().map(|e| e * rhs.clone()).collect(),
368        }
369    }
370}
371
372// E * SymMat (scalar)
373impl ops::Mul<SymMat> for E {
374    type Output = SymMat;
375    fn mul(self, rhs: SymMat) -> SymMat {
376        SymMat {
377            rows: rhs.rows,
378            cols: rhs.cols,
379            data: rhs.data.into_iter().map(|e| self.clone() * e).collect(),
380        }
381    }
382}
383
384impl fmt::Display for SymMat {
385    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386        write!(f, "[")?;
387        for i in 0..self.rows {
388            if i > 0 { write!(f, "; ")?; }
389            for j in 0..self.cols {
390                if j > 0 { write!(f, ", ")?; }
391                fmt::Display::fmt(self.get(i, j), f)?;
392            }
393        }
394        write!(f, "]")
395    }
396}
397
398// ============================================================
399// Jacobian
400// ============================================================
401
402/// Compute the Jacobian matrix: partial derivatives of each expression with
403/// respect to each variable.
404///
405/// Returns a [`SymMat`] with `exprs.len()` rows and `vars.len()` columns,
406/// where element (i, j) is `d(exprs[i]) / d(vars[j])`.
407pub fn jacobian(exprs: &[E], vars: &[&str]) -> SymMat {
408    let rows = exprs.len();
409    let cols = vars.len();
410    let mut data = Vec::with_capacity(rows * cols);
411    for expr in exprs {
412        for var in vars {
413            data.push(expr.diff(var));
414        }
415    }
416    SymMat { rows, cols, data }
417}