csparse21/
lib.rs

1#![crate_name = "csparse21"]
2
3//! Solving large systems of linear equations using sparse matrix methods.
4//!
5//! [![Docs](https://docs.rs/csparse21/badge.svg)](docs.rs/csparse21)
6//!
7//! ```rust
8//! let mut m = csparse21::Matrix::from_entries(vec![
9//!             (0, 0, 1.0),
10//!             (0, 1, 1.0),
11//!             (0, 2, 1.0),
12//!             (1, 1, 2.0),
13//!             (1, 2, 5.0),
14//!             (2, 0, 2.0),
15//!             (2, 1, 5.0),
16//!             (2, 2, -1.0),
17//!         ]);
18//!
19//! let soln = m.solve(vec![6.0, -4.0, 27.0]);
20//! // => vec![5.0, 3.0, -2.0]
21//! ```
22//!
23//! Sparse methods are primarily valuable for systems in which the number of non-zero entries is substantially less than the overall size of the matrix. Such situations are common in physical systems, including electronic circuit simulation. All elements of a sparse matrix are assumed to be zero-valued unless indicated otherwise.
24//!
25//! ## Usage
26//!
27//! CSparse21 exposes two primary data structures:
28//!
29//! * `Matrix` represents an `Complex64`-valued sparse matrix
30//! * `System` represents a system of linear equations of the form `Ax=b`, including a `Matrix` (A) and right-hand-side `Vec` (b).
31//!
32//! Once matrices and systems have been created, their primary public method is `solve`, which returns a (dense) `Vec` solution-vector.
33//!
34
35use num_complex::Complex64;
36use std::cmp::{max, min};
37use std::error::Error;
38use std::fmt;
39use std::ops::{Index, IndexMut};
40use std::usize::MAX;
41
42#[derive(Debug, Copy, Clone, PartialEq, Eq)]
43pub struct Eindex(usize);
44
45// `Entry`s are a type alias for tuples of (row, col, val).
46type Entry = (usize, usize, Complex64);
47
48#[derive(Debug, Copy, Clone)]
49enum Axis {
50    ROWS = 0,
51    COLS,
52}
53
54use Axis::*;
55
56impl Axis {
57    fn other(&self) -> Axis {
58        match self {
59            Axis::ROWS => Axis::COLS,
60            Axis::COLS => Axis::ROWS,
61        }
62    }
63}
64
65struct AxisPair<T> {
66    rows: T,
67    cols: T,
68}
69
70impl<T> Index<Axis> for AxisPair<T> {
71    type Output = T;
72
73    fn index(&self, ax: Axis) -> &Self::Output {
74        match ax {
75            Axis::ROWS => &self.rows,
76            Axis::COLS => &self.cols,
77        }
78    }
79}
80
81impl<T> IndexMut<Axis> for AxisPair<T> {
82    fn index_mut(&mut self, ax: Axis) -> &mut Self::Output {
83        match ax {
84            Axis::ROWS => &mut self.rows,
85            Axis::COLS => &mut self.cols,
86        }
87    }
88}
89
90#[derive(PartialEq, Debug, Copy, Clone)]
91enum MatrixState {
92    CREATED = 0,
93    FACTORING,
94    FACTORED,
95}
96
97#[derive(Debug)]
98#[allow(dead_code)]
99pub struct Element {
100    index: Eindex,
101    row: usize,
102    col: usize,
103    val: Complex64,
104    fillin: bool,
105    orig: (usize, usize, Complex64),
106    next_in_row: Option<Eindex>,
107    next_in_col: Option<Eindex>,
108}
109
110impl PartialEq for Element {
111    fn eq(&self, other: &Self) -> bool {
112        return self.row == other.row && self.col == other.col && self.val == other.val;
113    }
114}
115
116impl Element {
117    fn new(index: Eindex, row: usize, col: usize, val: Complex64, fillin: bool) -> Element {
118        Element {
119            index,
120            row,
121            col,
122            val,
123            fillin,
124            orig: (row, col, val),
125            next_in_row: None,
126            next_in_col: None,
127        }
128    }
129    fn loc(&self, ax: Axis) -> usize {
130        match ax {
131            Axis::ROWS => self.row,
132            Axis::COLS => self.col,
133        }
134    }
135    fn set_loc(&mut self, ax: Axis, to: usize) {
136        match ax {
137            Axis::ROWS => self.row = to,
138            Axis::COLS => self.col = to,
139        }
140    }
141    fn next(&self, ax: Axis) -> Option<Eindex> {
142        match ax {
143            Axis::ROWS => self.next_in_row,
144            Axis::COLS => self.next_in_col,
145        }
146    }
147    fn set_next(&mut self, ax: Axis, e: Option<Eindex>) {
148        match ax {
149            Axis::ROWS => self.next_in_row = e,
150            Axis::COLS => self.next_in_col = e,
151        }
152    }
153}
154
155struct AxisMapping {
156    e2i: Vec<usize>,
157    i2e: Vec<usize>,
158    history: Vec<(usize, usize)>,
159}
160
161impl AxisMapping {
162    fn new(size: usize) -> AxisMapping {
163        AxisMapping {
164            e2i: (0..size).collect(),
165            i2e: (0..size).collect(),
166            history: vec![],
167        }
168    }
169    fn swap_int(&mut self, x: usize, y: usize) {
170        // Swap internal indices x and y
171        let tmp = self.i2e[x];
172        self.i2e[x] = self.i2e[y];
173        self.i2e[y] = tmp;
174        self.e2i[self.i2e[x]] = x;
175        self.e2i[self.i2e[y]] = y;
176        self.history.push((x, y));
177    }
178}
179
180#[allow(dead_code)]
181struct AxisData {
182    ax: Axis,
183    hdrs: Vec<Option<Eindex>>,
184    qtys: Vec<usize>,
185    markowitz: Vec<usize>,
186    mapping: Option<AxisMapping>,
187}
188
189impl AxisData {
190    fn new(ax: Axis) -> AxisData {
191        AxisData {
192            ax,
193            hdrs: vec![],
194            qtys: vec![],
195            markowitz: vec![],
196            mapping: None,
197        }
198    }
199    fn grow(&mut self, to: usize) {
200        if to <= self.hdrs.len() {
201            return;
202        }
203        let by = to - self.hdrs.len();
204        for _ in 0..by {
205            self.hdrs.push(None);
206            self.qtys.push(0);
207            self.markowitz.push(0);
208        }
209    }
210    fn setup_factoring(&mut self) {
211        self.markowitz.copy_from_slice(&self.qtys);
212        self.mapping = Some(AxisMapping::new(self.hdrs.len()));
213    }
214    fn swap(&mut self, x: usize, y: usize) {
215        self.hdrs.swap(x, y);
216        self.qtys.swap(x, y);
217        self.markowitz.swap(x, y);
218        if let Some(m) = &mut self.mapping {
219            m.swap_int(x, y);
220        }
221    }
222}
223
224type SpResult<T> = Result<T, &'static str>;
225
226/// Sparse Matrix
227pub struct Matrix {
228    // Matrix.elements is the owner of all `Element`s.
229    // Everything else gets referenced via `Eindex`es.
230    state: MatrixState,
231    elements: Vec<Element>,
232    axes: AxisPair<AxisData>,
233    diag: Vec<Option<Eindex>>,
234    fillins: Vec<Eindex>,
235}
236
237impl Matrix {
238    /// Create a new, initially empty `Matrix`
239    pub fn new() -> Matrix {
240        Matrix {
241            state: MatrixState::CREATED,
242            axes: AxisPair {
243                rows: AxisData::new(Axis::ROWS),
244                cols: AxisData::new(Axis::COLS),
245            },
246            diag: vec![],
247            elements: vec![],
248            fillins: vec![],
249        }
250    }
251    /// Create a new `Matrix` from a vector of (row, col, val) `entries`.
252    pub fn from_entries(entries: Vec<Entry>) -> Matrix {
253        let mut m = Matrix::new();
254        for e in entries.iter() {
255            m.add_element(e.0, e.1, e.2);
256        }
257        return m;
258    }
259    /// Create an n*n identity `Matrix`
260    ///
261    pub fn identity(n: usize) -> Matrix {
262        let one = Complex64 { re: 1.0, im: 0.0 };
263        let mut m = Matrix::new();
264        for k in 0..n {
265            m.add_element(k, k, one);
266        }
267        return m;
268    }
269    /// Add an element at location `(row, col)` with value `val`.
270    pub fn add_element(&mut self, row: usize, col: usize, val: Complex64) {
271        self._add_element(row, col, val, false);
272    }
273    /// Add elements correspoding to each triplet `(row, col, val)`
274    /// Rows and columns are `usize`, and `vals` are `Complex64`.
275    pub fn add_elements(&mut self, elements: Vec<Entry>) {
276        for e in elements.iter() {
277            self.add_element(e.0, e.1, e.2);
278        }
279    }
280    /// Create a zero-valued element at `(row, col)`,
281    /// or return existing Element index if present
282    pub fn make(&mut self, row: usize, col: usize) -> Eindex {
283        let zero = Complex64 { re: 0.0, im: 0.0 };
284        return match self.get_elem(row, col) {
285            Some(ei) => ei,
286            None => self._add_element(row, col, zero, false),
287        };
288    }
289    /// Reset all Elements to zero value.
290    pub fn reset(&mut self) {
291        let zero = Complex64 { re: 0.0, im: 0.0 };
292        for e in self.elements.iter_mut() {
293            e.val = zero;
294        }
295        self.set_state(MatrixState::CREATED).unwrap();
296    }
297    /// Update `Element` `ei` by `val`
298    pub fn update(&mut self, ei: Eindex, val: Complex64) {
299        self[ei].val += val;
300    }
301    /// Multiply by Vec
302    pub fn vecmul(&self, x: &Vec<Complex64>) -> SpResult<Vec<Complex64>> {
303        if x.len() != self.num_cols() {
304            return Err("Invalid Dimensions");
305        }
306        let zero = Complex64 { re: 0.0, im: 0.0 };
307        let mut y: Vec<Complex64> = vec![zero; self.num_rows()];
308        for row in 0..self.num_rows() {
309            let mut ep = self.hdr(ROWS, row);
310            while let Some(ei) = ep {
311                y[row] += self[ei].val * x[self[ei].col];
312                ep = self[ei].next_in_row;
313            }
314        }
315        return Ok(y);
316    }
317    pub fn res(&self, x: &Vec<Complex64>, rhs: &Vec<Complex64>) -> SpResult<Vec<Complex64>> {
318        println!("X:");
319        println!("{:?}", x);
320
321        let zero = Complex64 { re: 0.0, im: 0.0 };
322        let mut xi: Vec<Complex64> = vec![zero; self.num_cols()];
323        if self.state == MatrixState::FACTORED {
324            // If we have factored, unwind any column-swaps
325            let col_mapping = self.axes[COLS].mapping.as_ref().unwrap();
326            let row_mapping = self.axes[ROWS].mapping.as_ref().unwrap();
327            println!("COL_MAP:");
328            println!("{:?}", col_mapping.e2i);
329            println!("ROW_MAP:");
330            println!("{:?}", row_mapping.e2i);
331            for k in 0..xi.len() {
332                xi[k] = x[col_mapping.e2i[k]];
333            }
334        } else {
335            for k in 0..xi.len() {
336                xi[k] = x[k];
337            }
338        }
339
340        println!("XI:");
341        println!("{:?}", xi);
342
343        let m: Vec<Complex64> = self.vecmul(&xi)?;
344        let zero = Complex64 { re: 0.0, im: 0.0 };
345        let mut res = vec![zero; m.len()];
346
347        if self.state == MatrixState::FACTORED {
348            let row_mapping = self.axes[ROWS].mapping.as_ref().unwrap();
349            for k in 0..xi.len() {
350                res[k] = rhs[row_mapping.e2i[k]] - m[k];
351            }
352        } else {
353            for k in 0..xi.len() {
354                res[k] = rhs[k] - m[k];
355            }
356        }
357        // for k in 0..self.x.len() {
358        //     res[k] = -1.0 * res[k];
359        //     // res[k] += self.rhs[k];
360        // }
361        // println!("RES_BEFORE_RHS:");
362        // println!("{:?}", res);
363        // for k in 0..self.x.len() {
364        //     res[k] += self.rhs[k];
365        // }
366        // println!("RES_WITH_RHS:");
367        // println!("{:?}", res);
368        // // res[0] += self.rhs[0];
369        // // res[1] += self.rhs[2];
370        // // res[2] += self.rhs[1];
371        return Ok(res);
372    }
373    fn insert(&mut self, e: &mut Element) {
374        let mut expanded = false;
375        if e.row + 1 > self.num_rows() {
376            self.axes[Axis::ROWS].grow(e.row + 1);
377            expanded = true;
378        }
379        if e.col + 1 > self.num_cols() {
380            self.axes[Axis::COLS].grow(e.col + 1);
381            expanded = true;
382        }
383        if expanded {
384            let new_diag_len = std::cmp::min(self.num_rows(), self.num_cols());
385            for _ in 0..new_diag_len - self.diag.len() {
386                self.diag.push(None);
387            }
388        }
389        // Insert along each Axis
390        self.insert_axis(Axis::COLS, e);
391        self.insert_axis(Axis::ROWS, e);
392        // Update row & col qtys
393        self.axes[ROWS].qtys[e.row] += 1;
394        self.axes[COLS].qtys[e.col] += 1;
395        if self.state == MatrixState::FACTORING {
396            self.axes[ROWS].markowitz[e.row] += 1;
397            self.axes[COLS].markowitz[e.col] += 1;
398        }
399        // Update our special arrays
400        if e.row == e.col {
401            self.diag[e.row] = Some(e.index);
402        }
403        if e.fillin {
404            self.fillins.push(e.index);
405        }
406    }
407    fn insert_axis(&mut self, ax: Axis, e: &mut Element) {
408        // Insert Element `e` along Axis `ax`
409
410        let head_ptr = self.axes[ax].hdrs[e.loc(ax)];
411        let head_idx = match head_ptr {
412            Some(h) => h,
413            None => {
414                // Adding first element in this row/col
415                return self.set_hdr(ax, e.loc(ax), Some(e.index));
416            }
417        };
418        let off_ax = ax.other();
419        if self[head_idx].loc(off_ax) > e.loc(off_ax) {
420            // `e` is the new first element
421            e.set_next(ax, head_ptr);
422            return self.set_hdr(ax, e.loc(ax), Some(e.index));
423        }
424
425        // `e` comes after at least one Element.  Search for its position.
426        let mut prev = head_idx;
427        while let Some(next) = self[prev].next(ax) {
428            if self[next].loc(off_ax) >= e.loc(off_ax) {
429                break;
430            }
431            prev = next;
432        }
433        // And splice it in-between `prev` and `nxt`
434        e.set_next(ax, self[prev].next(ax));
435        self[prev].set_next(ax, Some(e.index));
436    }
437    fn add_fillin(&mut self, row: usize, col: usize) -> Eindex {
438        let zero = Complex64 { re: 0.0, im: 0.0 };
439        return self._add_element(row, col, zero, true);
440    }
441    fn _add_element(&mut self, row: usize, col: usize, val: Complex64, fillin: bool) -> Eindex {
442        // Element creation & insertion, used by `add_fillin` and the public `add_element`.
443        let index = Eindex(self.elements.len());
444        let mut e = Element::new(index.clone(), row, col, val, fillin);
445        self.insert(&mut e);
446        self.elements.push(e);
447        return index;
448    }
449    /// Returns the Element-index at `(row, col)` if present, or None if not.
450    pub fn get_elem(&self, row: usize, col: usize) -> Option<Eindex> {
451        if row >= self.num_rows() {
452            return None;
453        }
454        if col >= self.num_cols() {
455            return None;
456        }
457
458        if row == col {
459            // On diagonal; easy access
460            return self.diag[row];
461        }
462        // Off-diagonal. Search across `row`.
463        let mut ep = self.hdr(ROWS, row);
464        while let Some(ei) = ep {
465            let e = &self[ei];
466            if e.col == col {
467                return Some(ei);
468            } else if e.col > col {
469                return None;
470            }
471            ep = e.next_in_row;
472        }
473        return None;
474    }
475    /// Returns the Element-value at `(row, col)` if present, or None if not.
476    pub fn get(&self, row: usize, col: usize) -> Option<Complex64> {
477        return match self.get_elem(row, col) {
478            None => None,
479            Some(ei) => Some(self[ei].val),
480        };
481    }
482    /// Make major state transitions
483    fn set_state(&mut self, state: MatrixState) -> Result<(), &'static str> {
484        match state {
485            MatrixState::CREATED => return Ok(()), //Err("Matrix State Error"),
486            MatrixState::FACTORING => {
487                if self.state == MatrixState::FACTORING {
488                    return Ok(());
489                }
490                if self.state == MatrixState::FACTORED {
491                    return Err("Already Factored");
492                }
493
494                self.axes[Axis::ROWS].setup_factoring();
495                self.axes[Axis::COLS].setup_factoring();
496
497                self.state = state;
498                return Ok(());
499            }
500            MatrixState::FACTORED => {
501                if self.state == MatrixState::FACTORING {
502                    self.state = state;
503                    return Ok(());
504                } else {
505                    return Err("Matrix State Error");
506                }
507            }
508        }
509    }
510    fn move_element(&mut self, ax: Axis, idx: Eindex, to: usize) {
511        let loc = self[idx].loc(ax);
512        if loc == to {
513            return;
514        }
515        let off_ax = ax.other();
516        let y = self[idx].loc(off_ax);
517
518        if loc < to {
519            let br = match self.before_loc(off_ax, y, to, Some(idx)) {
520                Some(ei) => ei,
521                None => panic!("ERROR"),
522            };
523            if br != idx {
524                let be = self.prev(off_ax, idx, None);
525                let nxt = self[idx].next(off_ax);
526                match be {
527                    None => self.set_hdr(off_ax, y, nxt),
528                    Some(be) => self[be].set_next(off_ax, nxt),
529                };
530                let brn = self[br].next(off_ax);
531                self[idx].set_next(off_ax, brn);
532                self[br].set_next(off_ax, Some(idx));
533            }
534        } else {
535            let br = self.before_loc(off_ax, y, to, None);
536            let be = self.prev(off_ax, idx, None);
537
538            if br != be {
539                // We (may) need some pointer updates
540                if let Some(ei) = be {
541                    let nxt = self[idx].next(off_ax);
542                    self[ei].set_next(off_ax, nxt);
543                }
544                match br {
545                    None => {
546                        // New first in row/col
547                        let first = self.hdr(off_ax, y);
548                        self[idx].set_next(off_ax, first);
549                        self.axes[off_ax].hdrs[y] = Some(idx);
550                    }
551                    Some(br) => {
552                        if br != idx {
553                            // Splice `idx` in after `br`
554                            let nxt = self[br].next(off_ax);
555                            self[idx].set_next(off_ax, nxt);
556                            self[br].set_next(off_ax, Some(idx));
557                        }
558                    }
559                };
560            }
561        }
562
563        // Update the moved-Element's location
564        self[idx].set_loc(ax, to);
565
566        if loc == y {
567            // If idx was on our diagonal, remove it
568            self.diag[loc] = None;
569        } else if to == y {
570            // Or if it's now on the diagonal, add it
571            self.diag[to] = Some(idx);
572        }
573    }
574    fn exchange_elements(&mut self, ax: Axis, ix: Eindex, iy: Eindex) {
575        // Swap two elements `ax` indices.
576        // Elements must be in the same off-axis vector,
577        // and the first argument `ex` must be the lower-indexed off-axis.
578        // E.g. exchange_elements(Axis.rows, ex, ey) exchanges the rows of ex and ey.
579
580        let off_ax = ax.other();
581        let off_loc = self[ix].loc(off_ax);
582
583        let bx = self.prev(off_ax, ix, None);
584        let by = match self.prev(off_ax, iy, Some(ix)) {
585            Some(e) => e,
586            None => panic!("ERROR!"),
587        };
588
589        let locx = self[ix].loc(ax);
590        let locy = self[iy].loc(ax);
591        self[iy].set_loc(ax, locx);
592        self[ix].set_loc(ax, locy);
593
594        match bx {
595            None => {
596                // If `ex` is the *first* entry in the column, replace it to our header-list
597                self.set_hdr(off_ax, off_loc, Some(iy));
598            }
599            Some(bxe) => {
600                // Otherwise patch ey into bx
601                self[bxe].set_next(off_ax, Some(iy));
602            }
603        }
604
605        if by == ix {
606            // `ex` and `ey` are adjacent
607            let tmp = self[iy].next(off_ax);
608            self[iy].set_next(off_ax, Some(ix));
609            self[ix].set_next(off_ax, tmp);
610        } else {
611            // Elements in-between `ex` and `ey`.  Update the last one.
612            let xnxt = self[ix].next(off_ax);
613            let ynxt = self[iy].next(off_ax);
614            self[iy].set_next(off_ax, xnxt);
615            self[ix].set_next(off_ax, ynxt);
616            self[by].set_next(off_ax, Some(ix));
617        }
618
619        // Update our diagonal array, if necessary
620        if locx == off_loc {
621            self.diag[off_loc] = Some(iy);
622        } else if locy == off_loc {
623            self.diag[off_loc] = Some(ix);
624        }
625    }
626    fn prev(&self, ax: Axis, idx: Eindex, hint: Option<Eindex>) -> Option<Eindex> {
627        // Find the element previous to `idx` along axis `ax`.
628        // If provided, `hint` *must* be before `idx`, or search will fail.
629        let prev: Option<Eindex> = match hint {
630            Some(_) => hint,
631            None => self.hdr(ax, self[idx].loc(ax)),
632        };
633        let mut pi: Eindex = match prev {
634            None => {
635                return None;
636            }
637            Some(pi) if pi == idx => {
638                return None;
639            }
640            Some(pi) => pi,
641        };
642        while let Some(nxt) = self[pi].next(ax) {
643            if nxt == idx {
644                break;
645            }
646            pi = nxt;
647        }
648        return Some(pi);
649    }
650    fn before_loc(
651        &self,
652        ax: Axis,
653        loc: usize,
654        before: usize,
655        hint: Option<Eindex>,
656    ) -> Option<Eindex> {
657        let prev: Option<Eindex> = match hint {
658            Some(_) => hint,
659            None => self.hdr(ax, loc),
660        };
661        let off_ax = ax.other();
662        let mut pi: Eindex = match prev {
663            None => {
664                return None;
665            }
666            Some(pi) if self[pi].loc(off_ax) >= before => {
667                return None;
668            }
669            Some(pi) => pi,
670        };
671        while let Some(nxt) = self[pi].next(ax) {
672            if self[nxt].loc(off_ax) >= before {
673                break;
674            }
675            pi = nxt;
676        }
677        return Some(pi);
678    }
679    fn swap(&mut self, ax: Axis, a: usize, b: usize) {
680        if a == b {
681            return;
682        }
683        let x = min(a, b);
684        let y = max(a, b);
685
686        let hdrs = &self.axes[ax].hdrs;
687        let mut ix = hdrs[x];
688        let mut iy = hdrs[y];
689        let off_ax = ax.other();
690
691        loop {
692            match (ix, iy) {
693                (Some(ex), Some(ey)) => {
694                    let ox = self[ex].loc(off_ax);
695                    let oy = self[ey].loc(off_ax);
696                    if ox < oy {
697                        self.move_element(ax, ex, y);
698                        ix = self[ex].next(ax);
699                    } else if oy < ox {
700                        self.move_element(ax, ey, x);
701                        iy = self[ey].next(ax);
702                    } else {
703                        self.exchange_elements(ax, ex, ey);
704                        ix = self[ex].next(ax);
705                        iy = self[ey].next(ax);
706                    }
707                }
708                (None, Some(ey)) => {
709                    self.move_element(ax, ey, x);
710                    iy = self[ey].next(ax);
711                }
712                (Some(ex), None) => {
713                    self.move_element(ax, ex, y);
714                    ix = self[ex].next(ax);
715                }
716                (None, None) => {
717                    break;
718                }
719            }
720        }
721        // Swap all the relevant pointers & counters
722        self.axes[ax].swap(x, y);
723    }
724    /// Updates self to S = L + U - I.
725    /// Diagonal entries are those of U;
726    /// L has diagonal entries equal to one.
727    fn lu_factorize(&mut self) -> SpResult<()> {
728        assert(self.diag.len()).gt(0);
729        for k in 0..self.axes[ROWS].hdrs.len() {
730            if self.hdr(ROWS, k).is_none() {
731                return Err("Singular Matrix");
732            }
733        }
734        for k in 0..self.axes[COLS].hdrs.len() {
735            if self.hdr(COLS, k).is_none() {
736                return Err("Singular Matrix");
737            }
738        }
739        self.set_state(MatrixState::FACTORING)?;
740
741        for n in 0..self.diag.len() - 1 {
742            let pivot = match self.search_for_pivot(n) {
743                None => return Err("Pivot Search Fail"),
744                Some(p) => p,
745            };
746            self.swap(ROWS, self[pivot].row, n);
747            self.swap(COLS, self[pivot].col, n);
748            self.row_col_elim(pivot, n)?;
749        }
750        self.set_state(MatrixState::FACTORED)?;
751        return Ok(());
752    }
753
754    fn search_for_pivot(&self, n: usize) -> Option<Eindex> {
755        let mut ei = self.markowitz_search_diagonal(n);
756        if let Some(_) = ei {
757            return ei;
758        }
759        ei = self.markowitz_search_submatrix(n);
760        if let Some(_) = ei {
761            return ei;
762        }
763        return self.find_max(n);
764    }
765
766    fn max_after(&self, ax: Axis, after: Eindex) -> Eindex {
767        let mut best = after;
768        let mut best_val = self[after].val.norm();
769        let mut e = self[after].next(ax);
770
771        while let Some(ei) = e {
772            let val = self[ei].val.norm();
773            if val > best_val {
774                best = ei;
775                best_val = val;
776            }
777            e = self[ei].next(ax);
778        }
779        return best;
780    }
781
782    fn markowitz_product(&self, ei: Eindex) -> usize {
783        let e = &self[ei];
784        let mr = self[Axis::ROWS].markowitz[e.row];
785        let mc = self[Axis::COLS].markowitz[e.col];
786        assert(mr).gt(0);
787        assert(mc).gt(0);
788        return (mr - 1) * (mc - 1);
789    }
790
791    #[allow(non_snake_case)]
792    fn markowitz_search_diagonal(&self, n: usize) -> Option<Eindex> {
793        let REL_THRESHOLD = 1e-3;
794        let ABS_THRESHOLD = 0.0;
795        let TIES_MULT = 5;
796
797        let mut best_elem = None;
798        let mut best_mark = MAX; // Actually use usize::MAX!
799        let mut best_ratio = 0.0;
800        let mut num_ties = 0;
801
802        for k in n..self.diag.len() {
803            let d = match self.diag[k] {
804                None => {
805                    continue;
806                }
807                Some(d) => d,
808            };
809
810            // Check whether this element meets our threshold criteria
811            let max_in_col = self.max_after(COLS, d);
812            let threshold = REL_THRESHOLD * self[max_in_col].val.norm() + ABS_THRESHOLD;
813            if self[d].val.norm() < threshold {
814                continue;
815            }
816
817            // If so, compute and compare its Markowitz product to our best
818            let mark = self.markowitz_product(d);
819            if mark < best_mark {
820                num_ties = 0;
821                best_elem = self.diag[k];
822                best_mark = mark;
823                best_ratio = (self[d].val / self[max_in_col].val).norm();
824            } else if mark == best_mark {
825                num_ties += 1;
826                let ratio = (self[d].val / self[max_in_col].val).norm();
827                if ratio > best_ratio {
828                    best_elem = self.diag[k];
829                    best_mark = mark;
830                    best_ratio = ratio;
831                }
832                if num_ties >= best_mark * TIES_MULT {
833                    return best_elem;
834                }
835            }
836        }
837        return best_elem;
838    }
839
840    #[allow(non_snake_case)]
841    fn markowitz_search_submatrix(&self, n: usize) -> Option<Eindex> {
842        let REL_THRESHOLD = 1e-3;
843        let ABS_THRESHOLD = 0.0;
844
845        let mut best_elem = None;
846        let mut best_mark = MAX; // Actually use usize::MAX!
847        let mut best_ratio = 0.0;
848        let mut _num_ties = 0;
849
850        for _ in n..self.axes[COLS].hdrs.len() {
851            let mut e = self.hdr(COLS, n);
852            // Advance to a row ≥ n
853            while let Some(ei) = e {
854                if self[ei].row >= n {
855                    break;
856                }
857                e = self[ei].next_in_col;
858            }
859            let ei = match e {
860                None => {
861                    continue;
862                }
863                Some(d) => d,
864            };
865
866            // Check whether this element meets our threshold criteria
867            let max_in_col = self.max_after(COLS, ei);
868            let _threshold = REL_THRESHOLD * self[max_in_col].val.norm() + ABS_THRESHOLD;
869
870            while let Some(ei) = e {
871                // If so, compute and compare its Markowitz product to our best
872                let mark = self.markowitz_product(ei);
873                if mark < best_mark {
874                    _num_ties = 0;
875                    best_elem = e;
876                    best_mark = mark;
877                    best_ratio = (self[ei].val / self[max_in_col].val).norm();
878                } else if mark == best_mark {
879                    _num_ties += 1;
880                    let ratio = (self[ei].val / self[max_in_col].val).norm();
881                    if ratio > best_ratio {
882                        best_elem = e;
883                        best_mark = mark;
884                        best_ratio = ratio;
885                    }
886                    //                    // FIXME: do we want tie-counting in here?
887                    //                    if _num_ties >= best_mark * TIES_MULT { return best_elem; }
888                }
889                e = self[ei].next_in_col;
890            }
891        }
892        return best_elem;
893    }
894    /// Find the max (abs value) element in sub-matrix of indices ≥ `n`.
895    /// Returns `None` if no elements present.
896    fn find_max(&self, n: usize) -> Option<Eindex> {
897        let mut max_elem = None;
898        let mut max_val = 0.0;
899
900        // Search each column ≥ n
901        for k in n..self.axes[COLS].hdrs.len() {
902            let mut ep = self.hdr(COLS, k);
903
904            // Advance to a row ≥ n
905            while let Some(ei) = ep {
906                if self[ei].row >= n {
907                    break;
908                }
909                ep = self[ei].next_in_col;
910            }
911            // And search over remaining elements
912            while let Some(ei) = ep {
913                let val = self[ei].val.norm();
914                if val > max_val {
915                    max_elem = ep;
916                    max_val = val;
917                }
918                ep = self[ei].next_in_col;
919            }
920        }
921        return max_elem;
922    }
923
924    fn row_col_elim(&mut self, pivot: Eindex, n: usize) -> SpResult<()> {
925        let de = match self.diag[n] {
926            Some(de) => de,
927            None => return Err("Singular Matrix"),
928        };
929        assert(de).eq(pivot);
930        let pivot_val = self[pivot].val;
931        let zero = Complex64{re: 0.0, im: 0.0};
932        assert(pivot_val).ne(zero);
933
934        // Divide elements in the pivot column by the pivot-value
935        let mut plower = self[pivot].next_in_col;
936        while let Some(ple) = plower {
937            self[ple].val /= pivot_val;
938            plower = self[ple].next_in_col;
939        }
940
941        let mut pupper = self[pivot].next_in_row;
942        while let Some(pue) = pupper {
943            let pupper_col = self[pue].col;
944            plower = self[pivot].next_in_col;
945            let mut psub = self[pue].next_in_col;
946            while let Some(ple) = plower {
947                // Walk `psub` down to the lower pointer
948                while let Some(pse) = psub {
949                    if self[pse].row >= self[ple].row {
950                        break;
951                    }
952                    psub = self[pse].next_in_col;
953                }
954                let pse = match psub {
955                    None => self.add_fillin(self[ple].row, pupper_col),
956                    Some(pse) if self[pse].row > self[ple].row => {
957                        self.add_fillin(self[ple].row, pupper_col)
958                    }
959                    Some(pse) => pse,
960                };
961
962                // Update the `psub` element value
963                let result = (self[pue].val) * (self[ple].val);
964                self[pse].val -= result;
965                psub = self[pse].next_in_col;
966                plower = self[ple].next_in_col;
967            }
968            self.axes[COLS].markowitz[pupper_col] -= 1;
969            pupper = self[pue].next_in_row;
970        }
971        // Update remaining Markowitz counts
972        self.axes[ROWS].markowitz[n] -= 1;
973        self.axes[COLS].markowitz[n] -= 1;
974        plower = self[pivot].next_in_col;
975        while let Some(ple) = plower {
976            let plower_row = self[ple].row;
977            self.axes[ROWS].markowitz[plower_row] -= 1;
978            plower = self[ple].next_in_col;
979        }
980        return Ok(());
981    }
982    /// Solve the system `Ax=b`, where:
983    /// * `A` is `self`
984    /// * `b` is argument `rhs`
985    /// * `x` is the return value.
986    ///
987    /// Returns a `Result` containing the `Vec<Complex64>` representing `x` if successful.
988    /// Returns an `Err` if unsuccessful.
989    ///
990    /// Performs LU factorization, forward and backward substitution.
991    pub fn solve(&mut self, rhs: &[Complex64]) -> SpResult<Vec<Complex64>> {
992        if self.state == MatrixState::CREATED {
993            self.lu_factorize()?;
994        }
995        assert(self.state).eq(MatrixState::FACTORED);
996
997        // Unwind any row-swaps
998        let zero = Complex64{re: 0.0, im: 0.0};
999        let mut c: Vec<Complex64> = vec![zero; rhs.len()];
1000        let row_mapping = self.axes[ROWS].mapping.as_ref().unwrap();
1001        for k in 0..c.len() {
1002            c[row_mapping.e2i[k]] = rhs[k];
1003        }
1004
1005        // Forward substitution: Lc=b
1006        for k in 0..self.diag.len() {
1007            // Walk down each column, update c
1008            if c[k] == zero {
1009                continue;
1010            } // No updates to make on this iteration
1011
1012            // c[d.row] /= d.val
1013
1014            let di = match self.diag[k] {
1015                Some(di) => di,
1016                None => return Err("Singular Matrix"),
1017            };
1018            let mut e = self[di].next_in_col;
1019            while let Some(ei) = e {
1020                let result = c[k] * self[ei].val;
1021                c[self[ei].row] -= result;
1022                e = self[ei].next_in_col;
1023            }
1024        }
1025
1026        // Backward substitution: Ux=c
1027        for k in (0..self.diag.len()).rev() {
1028            // Walk each row, update c
1029            let di = match self.diag[k] {
1030                Some(di) => di,
1031                None => return Err("Singular Matrix"),
1032            };
1033            let mut ep = self[di].next_in_row;
1034            while let Some(ei) = ep {
1035                let result = c[self[ei].col] * self[ei].val;
1036                c[k] -= result;
1037                ep = self[ei].next_in_row;
1038            }
1039            c[k] /= self[di].val;
1040        }
1041
1042        // Unwind any column-swaps
1043        let zero = Complex64{re: 0.0, im: 0.0};
1044        let mut soln: Vec<Complex64> = vec![zero; c.len()];
1045        let col_mapping = self.axes[COLS].mapping.as_ref().unwrap();
1046        for k in 0..c.len() {
1047            soln[k] = c[col_mapping.e2i[k]];
1048        }
1049        return Ok(soln);
1050    }
1051    fn hdr(&self, ax: Axis, loc: usize) -> Option<Eindex> {
1052        self.axes[ax].hdrs[loc]
1053    }
1054    fn set_hdr(&mut self, ax: Axis, loc: usize, ei: Option<Eindex>) {
1055        self.axes[ax].hdrs[loc] = ei;
1056    }
1057    fn _swap_rows(&mut self, x: usize, y: usize) {
1058        self.swap(ROWS, x, y)
1059    }
1060    fn _swap_cols(&mut self, x: usize, y: usize) {
1061        self.swap(COLS, x, y)
1062    }
1063    fn num_rows(&self) -> usize {
1064        self.axes[ROWS].hdrs.len()
1065    }
1066    fn num_cols(&self) -> usize {
1067        self.axes[COLS].hdrs.len()
1068    }
1069    fn _size(&self) -> (usize, usize) {
1070        (self.num_rows(), self.num_cols())
1071    }
1072}
1073
1074impl Index<Eindex> for Matrix {
1075    type Output = Element;
1076    fn index(&self, index: Eindex) -> &Self::Output {
1077        &self.elements[index.0]
1078    }
1079}
1080
1081impl IndexMut<Eindex> for Matrix {
1082    fn index_mut(&mut self, index: Eindex) -> &mut Self::Output {
1083        &mut self.elements[index.0]
1084    }
1085}
1086
1087impl Index<Axis> for Matrix {
1088    type Output = AxisData;
1089    fn index(&self, ax: Axis) -> &Self::Output {
1090        &self.axes[ax]
1091    }
1092}
1093
1094impl IndexMut<Axis> for Matrix {
1095    fn index_mut(&mut self, ax: Axis) -> &mut Self::Output {
1096        &mut self.axes[ax]
1097    }
1098}
1099
1100impl fmt::Debug for Matrix {
1101    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1102        write!(
1103            f,
1104            "csparse21::Matrix (rows={}, cols={}, elems={}\n",
1105            self.num_rows(),
1106            self.num_cols(),
1107            self.elements.len()
1108        )?;
1109        for e in self.elements.iter() {
1110            write!(f, "({}, {}, {}) \n", e.row, e.col, e.val)?;
1111        }
1112        write!(f, "\n")
1113    }
1114}
1115
1116#[derive(Debug, Clone)]
1117struct NonRealNumError;
1118
1119impl fmt::Display for NonRealNumError {
1120    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1121        write!(f, "invalid first item to double")
1122    }
1123}
1124
1125impl Error for NonRealNumError {
1126    fn description(&self) -> &str {
1127        "invalid first item to double"
1128    }
1129
1130    fn cause(&self) -> Option<&dyn Error> {
1131        // Generic error, underlying cause isn't tracked.
1132        None
1133    }
1134}
1135
1136/// Sparse Matrix System
1137///
1138/// Represents a linear system of the form `Ax=b`
1139///
1140#[allow(dead_code)]
1141pub struct System {
1142    mat: Matrix,
1143    rhs: Vec<Complex64>,
1144    title: Option<String>,
1145    size: usize,
1146}
1147
1148impl System {
1149    /// Splits a `System` into a two-tuple of `self.matrix` and `self.rhs`.
1150    /// Nothing is copied; `self` is consumed in the process.
1151    pub fn split(self) -> (Matrix, Vec<Complex64>) {
1152        (self.mat, self.rhs)
1153    }
1154
1155    /// Solve the system `Ax=b`, where:
1156    /// * `A` is `self.matrix`
1157    /// * `b` is `self.rhs`
1158    /// * `x` is the return value.
1159    ///
1160    /// Returns a `Result` containing the `Vec<Complex64>` representing `x` if successful.
1161    /// Returns an `Err` if unsuccessful.
1162    ///
1163    /// Performs LU factorization, forward and backward substitution.
1164    pub fn solve(mut self) -> SpResult<Vec<Complex64>> {
1165        self.mat.solve(&self.rhs)
1166    }
1167}
1168
1169struct Assert<T> {
1170    val: T,
1171}
1172
1173fn assert<T>(val: T) -> Assert<T> {
1174    Assert { val }
1175}
1176
1177impl<T> Assert<T> {
1178    fn raise(&self) {
1179        // Breakpoint here
1180        panic!("Assertion Failed");
1181    }
1182}
1183
1184impl<T: PartialEq> Assert<T> {
1185    fn eq(&self, other: T) {
1186        if self.val != other {
1187            self.raise();
1188        }
1189    }
1190    fn ne(&self, other: T) {
1191        if self.val == other {
1192            self.raise();
1193        }
1194    }
1195}
1196
1197#[allow(dead_code)]
1198impl<T: PartialOrd> Assert<T> {
1199    fn gt(&self, other: T) {
1200        if self.val <= other {
1201            self.raise();
1202        }
1203    }
1204    fn lt(&self, other: T) {
1205        if self.val >= other {
1206            self.raise();
1207        }
1208    }
1209    fn ge(&self, other: T) {
1210        if self.val < other {
1211            self.raise();
1212        }
1213    }
1214    fn le(&self, other: T) {
1215        if self.val > other {
1216            self.raise();
1217        }
1218    }
1219}