1use std::collections::HashMap;
2use std::fmt;
3use std::ops;
4use super::{AsVarName, E, constant};
5
6#[derive(Debug, Clone, PartialEq)]
15pub struct SymVec(pub Vec<E>);
16
17impl SymVec {
18 pub fn new<I>(elems: I) -> Self
22 where
23 I: IntoIterator,
24 I::Item: Into<E>,
25 {
26 SymVec(elems.into_iter().map(Into::into).collect())
27 }
28
29 pub fn len(&self) -> usize {
31 self.0.len()
32 }
33
34 pub fn is_empty(&self) -> bool {
36 self.0.is_empty()
37 }
38
39 pub fn get(&self, i: usize) -> &E {
41 &self.0[i]
42 }
43
44 pub fn dot(&self, other: &SymVec) -> E {
46 assert_eq!(self.len(), other.len(), "dot product: length mismatch");
47 let mut terms: Vec<E> = Vec::with_capacity(self.len());
48 for i in 0..self.len() {
49 terms.push(self.0[i].clone() * other.0[i].clone());
50 }
51 terms.into_iter().reduce(|a, b| a + b).unwrap_or_else(|| constant(0.0))
52 }
53
54 pub fn diff(&self, var: impl AsVarName) -> SymVec {
56 let v = var.var_name();
57 SymVec(self.0.iter().map(|e| e.diff(v)).collect())
58 }
59
60 pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<Vec<f64>, String> {
62 self.0.iter().map(|e| e.eval(vars)).collect()
63 }
64
65 pub fn simplify(&self) -> SymVec {
67 SymVec(self.0.iter().map(|e| e.simplify()).collect())
68 }
69
70 pub fn expand(&self) -> SymVec {
72 SymVec(self.0.iter().map(|e| e.expand()).collect())
73 }
74
75 pub fn subs(&self, var: impl AsVarName, replacement: &E) -> SymVec {
77 let name = var.var_name();
78 SymVec(self.0.iter().map(|e| e.subs(name, replacement)).collect())
79 }
80
81 pub fn to_latex(&self) -> String {
83 let mut buf = String::from("\\begin{pmatrix} ");
84 for (i, e) in self.0.iter().enumerate() {
85 if i > 0 { buf.push_str(" \\\\ "); }
86 buf.push_str(&e.to_latex());
87 }
88 buf.push_str(" \\end{pmatrix}");
89 buf
90 }
91
92 pub fn to_rust(&self, ft: &str) -> String {
94 let mut buf = String::from("[");
95 for (i, e) in self.0.iter().enumerate() {
96 if i > 0 { buf.push_str(", "); }
97 buf.push_str(&e.to_rust(ft));
98 }
99 buf.push(']');
100 buf
101 }
102}
103
104impl ops::Index<usize> for SymVec {
105 type Output = E;
106 fn index(&self, i: usize) -> &E {
107 &self.0[i]
108 }
109}
110
111impl ops::Add for SymVec {
112 type Output = SymVec;
113 fn add(self, rhs: SymVec) -> SymVec {
114 assert_eq!(self.len(), rhs.len(), "SymVec add: length mismatch");
115 SymVec(
116 self.0.into_iter().zip(rhs.0)
117 .map(|(a, b)| a + b)
118 .collect()
119 )
120 }
121}
122
123impl ops::Mul<E> for SymVec {
124 type Output = SymVec;
125 fn mul(self, rhs: E) -> SymVec {
126 SymVec(self.0.into_iter().map(|e| e * rhs.clone()).collect())
127 }
128}
129
130impl ops::Mul<SymVec> for E {
131 type Output = SymVec;
132 fn mul(self, rhs: SymVec) -> SymVec {
133 SymVec(rhs.0.into_iter().map(|e| self.clone() * e).collect())
134 }
135}
136
137impl fmt::Display for SymVec {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 write!(f, "[")?;
140 for (i, e) in self.0.iter().enumerate() {
141 if i > 0 { write!(f, ", ")?; }
142 fmt::Display::fmt(e, f)?;
143 }
144 write!(f, "]")
145 }
146}
147
148#[derive(Debug, Clone, PartialEq)]
157pub struct SymMat {
158 pub rows: usize,
160 pub cols: usize,
162 pub data: Vec<E>,
164}
165
166impl SymMat {
167 pub fn new<I>(rows: usize, cols: usize, data: I) -> Self
171 where
172 I: IntoIterator,
173 I::Item: Into<E>,
174 {
175 let data: Vec<E> = data.into_iter().map(Into::into).collect();
176 assert_eq!(data.len(), rows * cols, "SymMat::new: data size mismatch");
177 SymMat { rows, cols, data }
178 }
179
180 pub fn zeros(rows: usize, cols: usize) -> Self {
182 SymMat {
183 rows,
184 cols,
185 data: vec![constant(0.0); rows * cols],
186 }
187 }
188
189 pub fn identity(n: usize) -> Self {
191 let mut data = vec![constant(0.0); n * n];
192 for i in 0..n {
193 data[i * n + i] = constant(1.0);
194 }
195 SymMat { rows: n, cols: n, data }
196 }
197
198 pub fn get(&self, i: usize, j: usize) -> &E {
200 &self.data[i * self.cols + j]
201 }
202
203 pub fn set(&mut self, i: usize, j: usize, val: E) {
205 self.data[i * self.cols + j] = val;
206 }
207
208 pub fn transpose(&self) -> SymMat {
210 let mut data = Vec::with_capacity(self.rows * self.cols);
211 for j in 0..self.cols {
212 for i in 0..self.rows {
213 data.push(self.get(i, j).clone());
214 }
215 }
216 SymMat { rows: self.cols, cols: self.rows, data }
217 }
218
219 pub fn diff(&self, var: impl AsVarName) -> SymMat {
221 let v = var.var_name();
222 SymMat {
223 rows: self.rows,
224 cols: self.cols,
225 data: self.data.iter().map(|e| e.diff(v)).collect(),
226 }
227 }
228
229 pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<Vec<Vec<f64>>, String> {
231 let mut result = Vec::with_capacity(self.rows);
232 for i in 0..self.rows {
233 let mut row = Vec::with_capacity(self.cols);
234 for j in 0..self.cols {
235 row.push(self.get(i, j).eval(vars)?);
236 }
237 result.push(row);
238 }
239 Ok(result)
240 }
241
242 pub fn simplify(&self) -> SymMat {
244 SymMat {
245 rows: self.rows,
246 cols: self.cols,
247 data: self.data.iter().map(|e| e.simplify()).collect(),
248 }
249 }
250
251 pub fn expand(&self) -> SymMat {
253 SymMat {
254 rows: self.rows,
255 cols: self.cols,
256 data: self.data.iter().map(|e| e.expand()).collect(),
257 }
258 }
259
260 pub fn subs(&self, var: impl AsVarName, replacement: &E) -> SymMat {
262 let name = var.var_name();
263 SymMat {
264 rows: self.rows,
265 cols: self.cols,
266 data: self.data.iter().map(|e| e.subs(name, replacement)).collect(),
267 }
268 }
269
270 pub fn to_latex(&self) -> String {
272 let mut buf = String::from("\\begin{pmatrix} ");
273 for i in 0..self.rows {
274 if i > 0 { buf.push_str(" \\\\ "); }
275 for j in 0..self.cols {
276 if j > 0 { buf.push_str(" & "); }
277 buf.push_str(&self.get(i, j).to_latex());
278 }
279 }
280 buf.push_str(" \\end{pmatrix}");
281 buf
282 }
283
284 pub fn to_rust(&self, ft: &str) -> String {
286 let mut buf = String::from("[");
287 for i in 0..self.rows {
288 if i > 0 { buf.push_str(", "); }
289 buf.push('[');
290 for j in 0..self.cols {
291 if j > 0 { buf.push_str(", "); }
292 buf.push_str(&self.get(i, j).to_rust(ft));
293 }
294 buf.push(']');
295 }
296 buf.push(']');
297 buf
298 }
299}
300
301impl ops::Add for SymMat {
303 type Output = SymMat;
304 fn add(self, rhs: SymMat) -> SymMat {
305 assert_eq!((self.rows, self.cols), (rhs.rows, rhs.cols), "SymMat add: dimension mismatch");
306 SymMat {
307 rows: self.rows,
308 cols: self.cols,
309 data: self.data.into_iter().zip(rhs.data)
310 .map(|(a, b)| a + b)
311 .collect(),
312 }
313 }
314}
315
316impl ops::Mul for SymMat {
318 type Output = SymMat;
319 fn mul(self, rhs: SymMat) -> SymMat {
320 assert_eq!(self.cols, rhs.rows, "SymMat mul: dimension mismatch");
321 let mut data = Vec::with_capacity(self.rows * rhs.cols);
322 for i in 0..self.rows {
323 for j in 0..rhs.cols {
324 let mut sum: Option<E> = None;
325 for k in 0..self.cols {
326 let prod = self.get(i, k).clone() * rhs.get(k, j).clone();
327 sum = Some(match sum {
328 Some(acc) => acc + prod,
329 None => prod,
330 });
331 }
332 data.push(sum.unwrap_or_else(|| constant(0.0)));
333 }
334 }
335 SymMat { rows: self.rows, cols: rhs.cols, data }
336 }
337}
338
339impl ops::Mul<SymVec> for SymMat {
341 type Output = SymVec;
342 fn mul(self, rhs: SymVec) -> SymVec {
343 assert_eq!(self.cols, rhs.len(), "SymMat * SymVec: dimension mismatch");
344 let mut result = Vec::with_capacity(self.rows);
345 for i in 0..self.rows {
346 let mut sum: Option<E> = None;
347 for j in 0..self.cols {
348 let prod = self.get(i, j).clone() * rhs[j].clone();
349 sum = Some(match sum {
350 Some(acc) => acc + prod,
351 None => prod,
352 });
353 }
354 result.push(sum.unwrap_or_else(|| constant(0.0)));
355 }
356 SymVec(result)
357 }
358}
359
360impl ops::Mul<E> for SymMat {
362 type Output = SymMat;
363 fn mul(self, rhs: E) -> SymMat {
364 SymMat {
365 rows: self.rows,
366 cols: self.cols,
367 data: self.data.into_iter().map(|e| e * rhs.clone()).collect(),
368 }
369 }
370}
371
372impl ops::Mul<SymMat> for E {
374 type Output = SymMat;
375 fn mul(self, rhs: SymMat) -> SymMat {
376 SymMat {
377 rows: rhs.rows,
378 cols: rhs.cols,
379 data: rhs.data.into_iter().map(|e| self.clone() * e).collect(),
380 }
381 }
382}
383
384impl fmt::Display for SymMat {
385 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386 write!(f, "[")?;
387 for i in 0..self.rows {
388 if i > 0 { write!(f, "; ")?; }
389 for j in 0..self.cols {
390 if j > 0 { write!(f, ", ")?; }
391 fmt::Display::fmt(self.get(i, j), f)?;
392 }
393 }
394 write!(f, "]")
395 }
396}
397
398pub fn jacobian(exprs: &[E], vars: &[&str]) -> SymMat {
408 let rows = exprs.len();
409 let cols = vars.len();
410 let mut data = Vec::with_capacity(rows * cols);
411 for expr in exprs {
412 for var in vars {
413 data.push(expr.diff(var));
414 }
415 }
416 SymMat { rows, cols, data }
417}