1use fixedbitset as fb;
4use nalgebra as na;
5use nalgebra_sparse as nas;
6
7use crate::{cochain::CochainImpl, mesh::SubsetImpl, Dual, DualCellView, Primal, SimplexView};
8use itertools::izip;
9
10pub trait Operator {
16 type Input: Operand;
18 type Output: Operand;
20
21 fn apply(&self, input: &Self::Input) -> Self::Output;
23 fn into_csr(self) -> nas::CsrMatrix<f64>;
25}
26
27pub trait Operand {
30 type Dimension;
32 type Primality;
34 fn values(&self) -> &na::DVector<f64>;
36 fn from_values(values: na::DVector<f64>) -> Self;
38}
39
40#[derive(Clone, Debug)]
48pub struct DiagonalOperator<Input, Output> {
49 pub diagonal: na::DVector<f64>,
74 _marker: std::marker::PhantomData<(Input, Output)>,
75}
76
77impl<Input, Output> Operator for DiagonalOperator<Input, Output>
78where
79 Input: Operand,
80 Output: Operand,
81{
82 type Input = Input;
83 type Output = Output;
84
85 fn apply(&self, input: &Self::Input) -> Self::Output {
86 let input = input.values();
87 let ret = na::DVector::from_iterator(
88 input.len(),
89 izip!(self.diagonal.iter(), input.iter()).map(|(&diag_val, &in_val)| diag_val * in_val),
90 );
91 Self::Output::from_values(ret)
92 }
93
94 fn into_csr(self) -> nas::CsrMatrix<f64> {
95 let mut csr = nas::CsrMatrix::identity(self.diagonal.len());
99 for (&diag, mat_diag) in self.diagonal.iter().zip(csr.values_mut()) {
100 *mat_diag = diag;
101 }
102 csr
103 }
104}
105
106impl<Input, Output> From<na::DVector<f64>> for DiagonalOperator<Input, Output> {
107 fn from(diagonal: na::DVector<f64>) -> Self {
108 Self {
109 diagonal,
110 _marker: std::marker::PhantomData,
111 }
112 }
113}
114
115impl<Input, Output> DiagonalOperator<Input, Output>
116where
117 Input: Operand,
118 Output: Operand,
119{
120 pub fn exclude_subset(
125 mut self,
126 set: &SubsetImpl<
127 <<Self as Operator>::Output as Operand>::Dimension,
128 <<Self as Operator>::Output as Operand>::Primality,
129 >,
130 ) -> Self {
131 for row_idx in set.indices.ones() {
132 self.diagonal[row_idx] = 0.0;
133 }
134 self
135 }
136
137 pub fn restrict_to_subset(
139 mut self,
140 set: &SubsetImpl<
141 <<Self as Operator>::Output as Operand>::Dimension,
142 <<Self as Operator>::Output as Operand>::Primality,
143 >,
144 ) -> Self {
145 for row_idx in (0..self.diagonal.len()).filter(|i| !set.indices.contains(*i)) {
146 self.diagonal[row_idx] = 0.0;
147 }
148 self
149 }
150
151 pub fn inverse(&self) -> DiagonalOperator<Output, Input> {
160 self.diagonal.map(|el| 1. / el).into()
161 }
162
163 pub fn recip(&self) -> Self {
172 self.diagonal.map(|el| 1. / el).into()
173 }
174
175 pub fn inner_product(&self, lhs: &Input, rhs: &Input) -> f64 {
182 let lhs = self.apply(lhs);
183 lhs.values().dot(rhs.values())
184 }
185
186 pub fn norm(&self, c: &Input) -> f64 {
193 self.inner_product(c, c)
194 }
195}
196
197impl<Input, Output> PartialEq for DiagonalOperator<Input, Output> {
198 fn eq(&self, other: &Self) -> bool {
199 self.diagonal == other.diagonal
200 }
201}
202
203impl<Input, Output> std::ops::Neg for DiagonalOperator<Input, Output> {
204 type Output = Self;
205
206 fn neg(self) -> Self::Output {
207 Self::from(-self.diagonal)
208 }
209}
210
211impl<'a, In, OutDim, const MESH_DIM: usize> std::ops::Index<SimplexView<'a, OutDim, MESH_DIM>>
214 for DiagonalOperator<In, CochainImpl<OutDim, Primal>>
215where
216 OutDim: na::DimName,
217 na::Const<MESH_DIM>: na::DimNameSub<OutDim>,
218{
219 type Output = f64;
220
221 fn index(&self, index: SimplexView<'a, OutDim, MESH_DIM>) -> &Self::Output {
222 &self.diagonal[index.index()]
223 }
224}
225
226impl<'a, In, OutDim, const MESH_DIM: usize> std::ops::IndexMut<SimplexView<'a, OutDim, MESH_DIM>>
227 for DiagonalOperator<In, CochainImpl<OutDim, Primal>>
228where
229 OutDim: na::DimName,
230 na::Const<MESH_DIM>: na::DimNameSub<OutDim>,
231{
232 fn index_mut(&mut self, index: SimplexView<'a, OutDim, MESH_DIM>) -> &mut Self::Output {
233 &mut self.diagonal[index.index()]
234 }
235}
236
237impl<'a, In, OutDim, const MESH_DIM: usize> std::ops::Index<DualCellView<'a, OutDim, MESH_DIM>>
238 for DiagonalOperator<In, CochainImpl<OutDim, Dual>>
239where
240 OutDim: na::DimName,
241 na::Const<MESH_DIM>: na::DimNameSub<OutDim>,
242{
243 type Output = f64;
244
245 fn index(&self, index: DualCellView<'a, OutDim, MESH_DIM>) -> &Self::Output {
246 &self.diagonal[index.index()]
247 }
248}
249
250impl<'a, In, OutDim, const MESH_DIM: usize> std::ops::IndexMut<DualCellView<'a, OutDim, MESH_DIM>>
251 for DiagonalOperator<In, CochainImpl<OutDim, Dual>>
252where
253 OutDim: na::DimName,
254 na::Const<MESH_DIM>: na::DimNameSub<OutDim>,
255{
256 fn index_mut(&mut self, index: DualCellView<'a, OutDim, MESH_DIM>) -> &mut Self::Output {
257 &mut self.diagonal[index.index()]
258 }
259}
260
261#[derive(Clone, Debug)]
282pub struct MatrixOperator<Input, Output> {
283 pub mat: nas::CsrMatrix<f64>,
288 _marker: std::marker::PhantomData<(Input, Output)>,
289}
290
291pub type Op<Input, Output> = MatrixOperator<Input, Output>;
294
295impl<Input, Output> Operator for MatrixOperator<Input, Output>
296where
297 Input: Operand,
298 Output: Operand,
299{
300 type Input = Input;
301 type Output = Output;
302
303 fn apply(&self, input: &Self::Input) -> Self::Output {
304 Self::Output::from_values(&self.mat * input.values())
305 }
306
307 fn into_csr(self) -> nas::CsrMatrix<f64> {
308 self.mat
309 }
310}
311
312impl<Input, Output> MatrixOperator<Input, Output>
313where
314 Input: Operand,
315 Output: Operand,
316{
317 pub fn exclude_subset(
321 mut self,
322 set: &SubsetImpl<<Output as Operand>::Dimension, <Output as Operand>::Primality>,
323 ) -> Self {
324 self.mat = drop_csr_rows(self.mat, &set.indices);
325 self
326 }
327
328 pub fn restrict_to_subset(
330 mut self,
331 set: &SubsetImpl<
332 <<Self as Operator>::Output as Operand>::Dimension,
333 <<Self as Operator>::Output as Operand>::Primality,
334 >,
335 ) -> Self {
336 let mut inv_set = set.indices.clone();
337 inv_set.grow(self.mat.nrows());
339 inv_set.toggle_range(..);
340
341 self.mat = drop_csr_rows(self.mat, &inv_set);
342 self
343 }
344}
345
346impl<L, R> PartialEq for MatrixOperator<L, R> {
347 fn eq(&self, other: &Self) -> bool {
348 self.mat == other.mat
349 }
350}
351
352impl<Input, Output> std::ops::Neg for MatrixOperator<Input, Output> {
353 type Output = Self;
354
355 fn neg(self) -> Self::Output {
356 Self::from(-self.mat)
357 }
358}
359
360impl<Input, Output> From<nas::CsrMatrix<f64>> for MatrixOperator<Input, Output> {
363 fn from(mat: nas::CsrMatrix<f64>) -> Self {
364 Self {
365 mat,
366 _marker: std::marker::PhantomData,
367 }
368 }
369}
370
371impl<Input, Output> From<DiagonalOperator<Input, Output>> for MatrixOperator<Input, Output>
372where
373 DiagonalOperator<Input, Output>: Operator,
374{
375 fn from(s: DiagonalOperator<Input, Output>) -> Self {
376 Self {
377 mat: s.into_csr(),
378 _marker: std::marker::PhantomData,
379 }
380 }
381}
382
383pub fn compose<Left, Right>(l: Left, r: Right) -> MatrixOperator<Right::Input, Left::Output>
399where
400 Left: Operator<Input = Right::Output>,
401 Right: Operator,
402{
403 MatrixOperator {
404 mat: l.into_csr() * r.into_csr(),
405 _marker: std::marker::PhantomData,
406 }
407}
408
409pub fn compose_diagonal<LeftOut, LeftIn, RightIn>(
434 l: DiagonalOperator<LeftIn, LeftOut>,
435 r: DiagonalOperator<RightIn, LeftIn>,
436) -> DiagonalOperator<RightIn, LeftOut> {
437 DiagonalOperator {
438 diagonal: l.diagonal.component_mul(&r.diagonal),
439 _marker: std::marker::PhantomData,
440 }
441}
442
443fn drop_csr_rows(mat: nas::CsrMatrix<f64>, set_to_drop: &fb::FixedBitSet) -> nas::CsrMatrix<f64> {
450 let num_rows = mat.nrows();
451 let num_cols = mat.ncols();
452 let (mut row_offsets, mut col_indices, mut values) = mat.disassemble();
454
455 let mut retained_value_idx = 0;
459 let mut prev_row_offset = 0;
462 for row_idx in 0..num_rows {
463 let old_row_range = prev_row_offset..row_offsets[row_idx + 1];
464 prev_row_offset = row_offsets[row_idx + 1];
465
466 if set_to_drop.contains(row_idx) {
467 row_offsets[row_idx + 1] = row_offsets[row_idx];
468 } else {
469 for old_val_idx in old_row_range {
470 col_indices[retained_value_idx] = col_indices[old_val_idx];
471 values[retained_value_idx] = values[old_val_idx];
472 retained_value_idx += 1;
473 }
474 row_offsets[row_idx + 1] = retained_value_idx;
475 }
476 }
477
478 col_indices.truncate(retained_value_idx);
480 values.truncate(retained_value_idx);
481
482 nas::CsrMatrix::try_from_csr_data(num_rows, num_cols, row_offsets, col_indices, values).unwrap()
483}
484
485impl<In, Out, OtherIn> std::ops::Mul<DiagonalOperator<OtherIn, In>> for DiagonalOperator<In, Out>
496where
497 In: Operand,
498 Out: Operand,
499 OtherIn: Operand,
500{
501 type Output = DiagonalOperator<OtherIn, Out>;
502
503 fn mul(self, rhs: DiagonalOperator<OtherIn, In>) -> Self::Output {
504 compose_diagonal(self, rhs)
505 }
506}
507
508impl<In, Out, OtherIn> std::ops::Mul<MatrixOperator<OtherIn, In>> for DiagonalOperator<In, Out>
509where
510 In: Operand,
511 Out: Operand,
512 OtherIn: Operand,
513{
514 type Output = MatrixOperator<OtherIn, Out>;
515
516 fn mul(self, rhs: MatrixOperator<OtherIn, In>) -> Self::Output {
517 compose(self, rhs)
518 }
519}
520
521impl<In, Out, Op> std::ops::Mul<Op> for MatrixOperator<In, Out>
522where
523 In: Operand,
524 Out: Operand,
525 Op: Operator<Output = <Self as Operator>::Input>,
526{
527 type Output = MatrixOperator<Op::Input, <Self as Operator>::Output>;
528
529 fn mul(self, rhs: Op) -> Self::Output {
530 compose(self, rhs)
531 }
532}
533
534impl<Input, Output> std::ops::Mul<DiagonalOperator<Input, Output>> for f64 {
537 type Output = DiagonalOperator<Input, Output>;
538
539 fn mul(self, mut rhs: DiagonalOperator<Input, Output>) -> Self::Output {
540 rhs.diagonal *= self;
541 rhs
542 }
543}
544
545impl<L, R> std::ops::Mul<MatrixOperator<L, R>> for f64 {
546 type Output = MatrixOperator<L, R>;
547
548 fn mul(self, mut rhs: MatrixOperator<L, R>) -> Self::Output {
549 rhs.mat *= self;
550 rhs
551 }
552}
553
554impl<Out, D, P> std::ops::Mul<&CochainImpl<D, P>> for DiagonalOperator<CochainImpl<D, P>, Out>
559where
560 Out: Operand,
561{
562 type Output = Out;
563
564 fn mul(self, rhs: &CochainImpl<D, P>) -> Self::Output {
565 self.apply(rhs)
566 }
567}
568
569impl<Out, D, P> std::ops::Mul<&CochainImpl<D, P>> for &DiagonalOperator<CochainImpl<D, P>, Out>
570where
571 Out: Operand,
572{
573 type Output = Out;
574
575 fn mul(self, rhs: &CochainImpl<D, P>) -> Self::Output {
576 self.apply(rhs)
577 }
578}
579
580impl<O, D, P> std::ops::Mul<&CochainImpl<D, P>> for MatrixOperator<CochainImpl<D, P>, O>
581where
582 O: Operand,
583{
584 type Output = <Self as Operator>::Output;
585
586 fn mul(self, rhs: &CochainImpl<D, P>) -> Self::Output {
587 self.apply(rhs)
588 }
589}
590
591impl<O, D, P> std::ops::Mul<&CochainImpl<D, P>> for &MatrixOperator<CochainImpl<D, P>, O>
592where
593 O: Operand,
594{
595 type Output = O;
596
597 fn mul(self, rhs: &CochainImpl<D, P>) -> Self::Output {
598 self.apply(rhs)
599 }
600}
601
602impl<D, P, O> std::ops::Mul<&CochainImpl<D, P>>
605 for &dyn Operator<Input = CochainImpl<D, P>, Output = O>
606where
607 O: Operand,
608{
609 type Output = O;
610
611 fn mul(self, rhs: &CochainImpl<D, P>) -> Self::Output {
612 self.apply(rhs)
613 }
614}
615
616impl<D, P, O> std::ops::Mul<&CochainImpl<D, P>>
617 for Box<dyn Operator<Input = CochainImpl<D, P>, Output = O>>
618where
619 O: Operand,
620{
621 type Output = O;
622
623 fn mul(self, rhs: &CochainImpl<D, P>) -> Self::Output {
624 self.apply(rhs)
625 }
626}
627
628impl<D, P, O> std::ops::Mul<&CochainImpl<D, P>>
629 for &Box<dyn Operator<Input = CochainImpl<D, P>, Output = O>>
630where
631 O: Operand,
632{
633 type Output = O;
634
635 fn mul(self, rhs: &CochainImpl<D, P>) -> Self::Output {
636 self.apply(rhs)
637 }
638}
639
640#[cfg(test)]
647mod tests {
648 use super::*;
649 use crate::mesh::{tiny_mesh_2d, Dual, Primal};
650
651 #[test]
652 fn exterior_derivative_works_in_2d() {
653 let mesh = tiny_mesh_2d();
654 let mut c0 = mesh.new_zero_cochain::<0, Primal>();
657 for (i, val) in c0.values.iter_mut().enumerate() {
658 *val = i as f64;
659 }
660 let c0 = c0;
661
662 let d0 = mesh.d::<0, Primal>();
663 let c1 = d0.apply(&c0);
664 #[rustfmt::skip]
670 let expected_c1 = na::DVector::from_vec(vec![
671 1.-0., 2.-0., 3.-0.,
672 3.-1., 4.-1.,
673 3.-2., 5.-2.,
674 4.-3., 5.-3., 6.-3.,
675 6.-4., 6.-5.,
676 ]);
677 assert_eq!(c1.values, expected_c1, "d_0 gave unexpected results");
678
679 let c2 = mesh.d() * &c1;
681 assert!(
682 c2.values.iter().all(|v| *v == 0.0),
683 "d twice should always be zero"
684 );
685
686 assert_eq!(
690 mesh.d::<0, Dual>().mat,
691 mesh.d::<1, Primal>().mat.transpose(),
692 "dual d_0 should be the transpose of primal d_1",
693 );
694 assert_eq!(
695 mesh.d::<1, Dual>().mat,
696 mesh.d::<0, Primal>().mat.transpose(),
697 "dual d_1 should be the transpose of primal d_0",
698 );
699 }
700
701 #[test]
702 fn exclude_subsets() {
703 let mesh = tiny_mesh_2d();
704
705 let d0_full = mesh.d::<0, Primal>();
708 let boundary = mesh.boundary::<1>();
709 let d0_excluded = d0_full.clone().exclude_subset(&boundary);
710 for (row_idx, (full_row, excluded_row)) in
714 izip!(d0_full.mat.row_iter(), d0_excluded.mat.row_iter()).enumerate()
715 {
716 if boundary.indices.contains(row_idx) {
717 assert!(excluded_row.nnz() == 0);
718 } else {
719 assert_eq!(full_row, excluded_row);
720 }
721 }
722
723 let star_full = mesh.star::<2, Dual>();
726 let boundary = mesh.boundary::<0>();
727 let star_excluded = star_full.clone().exclude_subset(&boundary);
728 for (row_idx, (full_diag, excluded_diag)) in
729 izip!(star_full.diagonal.iter(), star_excluded.diagonal.iter()).enumerate()
730 {
731 if boundary.indices.contains(row_idx) {
732 assert!(*excluded_diag == 0.0);
733 } else {
734 assert_eq!(full_diag, excluded_diag);
735 }
736 }
737
738 let comp_full: MatrixOperator<crate::Cochain<0, Primal>, crate::Cochain<0, Primal>> =
741 mesh.star() * mesh.d() * mesh.star() * mesh.d();
742 let boundary = mesh.boundary::<0>();
743 let comp_excluded = comp_full.clone().exclude_subset(&boundary);
744 for (row_idx, (full_row, excluded_row)) in
745 izip!(comp_full.mat.row_iter(), comp_excluded.mat.row_iter()).enumerate()
746 {
747 if boundary.indices.contains(row_idx) {
748 assert!(excluded_row.nnz() == 0);
749 } else {
750 assert_eq!(full_row, excluded_row);
751 }
752 }
753 }
754}