1#![allow(unused_variables)]
4#![allow(unused_assignments)]
5#![allow(unused_mut)]
6
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::numeric::{Float, NumAssign};
10use scirs2_core::SparseElement;
11use std::fmt::Debug;
12use std::iter::Sum;
13use std::marker::PhantomData;
14
15type MatVecFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
17
18type SolverFn<F> = Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>;
20
21pub trait LinearOperator<F: Float> {
26 fn shape(&self) -> (usize, usize);
28
29 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>>;
31
32 fn matmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
35 let mut result = Vec::new();
36 for col in x {
37 result.push(self.matvec(col)?);
38 }
39 Ok(result)
40 }
41
42 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
45 Err(crate::error::SparseError::OperationNotSupported(
46 "adjoint not implemented for this operator".to_string(),
47 ))
48 }
49
50 fn rmatmat(&self, x: &[Vec<F>]) -> SparseResult<Vec<Vec<F>>> {
53 let mut result = Vec::new();
54 for col in x {
55 result.push(self.rmatvec(col)?);
56 }
57 Ok(result)
58 }
59
60 fn has_adjoint(&self) -> bool {
62 false
63 }
64}
65
66#[derive(Clone)]
68pub struct IdentityOperator<F> {
69 size: usize,
70 phantom: PhantomData<F>,
71}
72
73impl<F> IdentityOperator<F> {
74 pub fn new(size: usize) -> Self {
76 Self {
77 size,
78 phantom: PhantomData,
79 }
80 }
81}
82
83impl<F: Float> LinearOperator<F> for IdentityOperator<F> {
84 fn shape(&self) -> (usize, usize) {
85 (self.size, self.size)
86 }
87
88 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
89 if x.len() != self.size {
90 return Err(crate::error::SparseError::DimensionMismatch {
91 expected: self.size,
92 found: x.len(),
93 });
94 }
95 Ok(x.to_vec())
96 }
97
98 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
99 self.matvec(x)
100 }
101
102 fn has_adjoint(&self) -> bool {
103 true
104 }
105}
106
107#[derive(Clone)]
109pub struct ScaledIdentityOperator<F> {
110 size: usize,
111 scale: F,
112}
113
114impl<F: Float> ScaledIdentityOperator<F> {
115 pub fn new(size: usize, scale: F) -> Self {
117 Self { size, scale }
118 }
119}
120
121impl<F: Float + NumAssign> LinearOperator<F> for ScaledIdentityOperator<F> {
122 fn shape(&self) -> (usize, usize) {
123 (self.size, self.size)
124 }
125
126 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
127 if x.len() != self.size {
128 return Err(crate::error::SparseError::DimensionMismatch {
129 expected: self.size,
130 found: x.len(),
131 });
132 }
133 Ok(x.iter().map(|&xi| xi * self.scale).collect())
134 }
135
136 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
137 self.matvec(x)
139 }
140
141 fn has_adjoint(&self) -> bool {
142 true
143 }
144}
145
146#[derive(Clone)]
148pub struct DiagonalOperator<F> {
149 diagonal: Vec<F>,
150}
151
152impl<F: Float> DiagonalOperator<F> {
153 pub fn new(diagonal: Vec<F>) -> Self {
155 Self { diagonal }
156 }
157
158 pub fn diagonal(&self) -> &[F] {
160 &self.diagonal
161 }
162}
163
164impl<F: Float + NumAssign> LinearOperator<F> for DiagonalOperator<F> {
165 fn shape(&self) -> (usize, usize) {
166 let n = self.diagonal.len();
167 (n, n)
168 }
169
170 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
171 if x.len() != self.diagonal.len() {
172 return Err(crate::error::SparseError::DimensionMismatch {
173 expected: self.diagonal.len(),
174 found: x.len(),
175 });
176 }
177 Ok(x.iter()
178 .zip(&self.diagonal)
179 .map(|(&xi, &di)| xi * di)
180 .collect())
181 }
182
183 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
184 self.matvec(x)
186 }
187
188 fn has_adjoint(&self) -> bool {
189 true
190 }
191}
192
193#[derive(Clone)]
195pub struct ZeroOperator<F> {
196 shape: (usize, usize),
197 _phantom: PhantomData<F>,
198}
199
200impl<F> ZeroOperator<F> {
201 #[allow(dead_code)]
203 pub fn new(rows: usize, cols: usize) -> Self {
204 Self {
205 shape: (rows, cols),
206 _phantom: PhantomData,
207 }
208 }
209}
210
211impl<F: Float + SparseElement> LinearOperator<F> for ZeroOperator<F> {
212 fn shape(&self) -> (usize, usize) {
213 self.shape
214 }
215
216 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
217 if x.len() != self.shape.1 {
218 return Err(crate::error::SparseError::DimensionMismatch {
219 expected: self.shape.1,
220 found: x.len(),
221 });
222 }
223 Ok(vec![F::sparse_zero(); self.shape.0])
224 }
225
226 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
227 if x.len() != self.shape.0 {
228 return Err(crate::error::SparseError::DimensionMismatch {
229 expected: self.shape.0,
230 found: x.len(),
231 });
232 }
233 Ok(vec![F::sparse_zero(); self.shape.1])
234 }
235
236 fn has_adjoint(&self) -> bool {
237 true
238 }
239}
240
241pub trait AsLinearOperator<F: Float> {
243 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>>;
245}
246
247pub struct MatrixLinearOperator<F, M> {
249 matrix: M,
250 phantom: PhantomData<F>,
251}
252
253impl<F, M> MatrixLinearOperator<F, M> {
254 pub fn new(matrix: M) -> Self {
256 Self {
257 matrix,
258 phantom: PhantomData,
259 }
260 }
261}
262
263use crate::csr::CsrMatrix;
265
266impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> LinearOperator<F>
267 for MatrixLinearOperator<F, CsrMatrix<F>>
268{
269 fn shape(&self) -> (usize, usize) {
270 (self.matrix.rows(), self.matrix.cols())
271 }
272
273 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
274 if x.len() != self.matrix.cols() {
275 return Err(SparseError::DimensionMismatch {
276 expected: self.matrix.cols(),
277 found: x.len(),
278 });
279 }
280
281 let mut result = vec![F::sparse_zero(); self.matrix.rows()];
283 for (row, result_elem) in result.iter_mut().enumerate().take(self.matrix.rows()) {
284 let row_range = self.matrix.row_range(row);
285 let row_indices = &self.matrix.colindices()[row_range.clone()];
286 let row_data = &self.matrix.data[row_range];
287
288 let mut sum = F::sparse_zero();
289 for (col_idx, &col) in row_indices.iter().enumerate() {
290 sum += row_data[col_idx] * x[col];
291 }
292 *result_elem = sum;
293 }
294 Ok(result)
295 }
296
297 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
298 let transposed = self.matrix.transpose();
300 MatrixLinearOperator::new(transposed).matvec(x)
301 }
302
303 fn has_adjoint(&self) -> bool {
304 true
305 }
306}
307
308use crate::csr_array::CsrArray;
310
311impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> LinearOperator<F>
312 for MatrixLinearOperator<F, CsrArray<F>>
313{
314 fn shape(&self) -> (usize, usize) {
315 self.matrix.shape()
316 }
317
318 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
319 if x.len() != self.matrix.shape().1 {
320 return Err(SparseError::DimensionMismatch {
321 expected: self.matrix.shape().1,
322 found: x.len(),
323 });
324 }
325
326 use scirs2_core::ndarray::Array1;
327 let x_array = Array1::from_vec(x.to_vec());
328 let result = self.matrix.dot_vector(&x_array.view())?;
329 Ok(result.to_vec())
330 }
331
332 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
333 if x.len() != self.matrix.shape().0 {
335 return Err(SparseError::DimensionMismatch {
336 expected: self.matrix.shape().0,
337 found: x.len(),
338 });
339 }
340
341 let mut result = vec![F::sparse_zero(); self.matrix.shape().1];
342
343 for (row_idx, &x_val) in x.iter().enumerate() {
345 if x_val != F::sparse_zero() {
346 let row_start = self.matrix.get_indptr()[row_idx];
348 let row_end = self.matrix.get_indptr()[row_idx + 1];
349
350 for idx in row_start..row_end {
351 let col_idx = self.matrix.get_indices()[idx];
352 let data_val = self.matrix.get_data()[idx];
353 result[col_idx] += data_val * x_val;
354 }
355 }
356 }
357
358 Ok(result)
359 }
360
361 fn has_adjoint(&self) -> bool {
362 true
363 }
364}
365
366impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> AsLinearOperator<F>
367 for CsrMatrix<F>
368{
369 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
370 Box::new(MatrixLinearOperator::new(self.clone()))
371 }
372}
373
374impl<F: Float + SparseElement + NumAssign + Sum + 'static + Debug> AsLinearOperator<F>
375 for crate::csr_array::CsrArray<F>
376{
377 fn as_linear_operator(&self) -> Box<dyn LinearOperator<F>> {
378 Box::new(MatrixLinearOperator::new(self.clone()))
379 }
380}
381
382pub struct SumOperator<F> {
385 a: Box<dyn LinearOperator<F>>,
386 b: Box<dyn LinearOperator<F>>,
387}
388
389impl<F: Float + NumAssign> SumOperator<F> {
390 #[allow(dead_code)]
392 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
393 if a.shape() != b.shape() {
394 return Err(crate::error::SparseError::ShapeMismatch {
395 expected: a.shape(),
396 found: b.shape(),
397 });
398 }
399 Ok(Self { a, b })
400 }
401}
402
403impl<F: Float + NumAssign> LinearOperator<F> for SumOperator<F> {
404 fn shape(&self) -> (usize, usize) {
405 self.a.shape()
406 }
407
408 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
409 let a_result = self.a.matvec(x)?;
410 let b_result = self.b.matvec(x)?;
411 Ok(a_result
412 .iter()
413 .zip(&b_result)
414 .map(|(&a, &b)| a + b)
415 .collect())
416 }
417
418 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
419 if !self.a.has_adjoint() || !self.b.has_adjoint() {
420 return Err(crate::error::SparseError::OperationNotSupported(
421 "adjoint not supported for one or both operators".to_string(),
422 ));
423 }
424 let a_result = self.a.rmatvec(x)?;
425 let b_result = self.b.rmatvec(x)?;
426 Ok(a_result
427 .iter()
428 .zip(&b_result)
429 .map(|(&a, &b)| a + b)
430 .collect())
431 }
432
433 fn has_adjoint(&self) -> bool {
434 self.a.has_adjoint() && self.b.has_adjoint()
435 }
436}
437
438pub struct ProductOperator<F> {
440 a: Box<dyn LinearOperator<F>>,
441 b: Box<dyn LinearOperator<F>>,
442}
443
444impl<F: Float + NumAssign> ProductOperator<F> {
445 #[allow(dead_code)]
447 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
448 let (_a_rows, a_cols) = a.shape();
449 let (b_rows, b_cols) = b.shape();
450 if a_cols != b_rows {
451 return Err(crate::error::SparseError::DimensionMismatch {
452 expected: a_cols,
453 found: b_rows,
454 });
455 }
456 Ok(Self { a, b })
457 }
458}
459
460impl<F: Float + NumAssign> LinearOperator<F> for ProductOperator<F> {
461 fn shape(&self) -> (usize, usize) {
462 let (a_rows, _) = self.a.shape();
463 let (_, b_cols) = self.b.shape();
464 (a_rows, b_cols)
465 }
466
467 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
468 let b_result = self.b.matvec(x)?;
469 self.a.matvec(&b_result)
470 }
471
472 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
473 if !self.a.has_adjoint() || !self.b.has_adjoint() {
474 return Err(crate::error::SparseError::OperationNotSupported(
475 "adjoint not supported for one or both operators".to_string(),
476 ));
477 }
478 let a_result = self.a.rmatvec(x)?;
480 self.b.rmatvec(&a_result)
481 }
482
483 fn has_adjoint(&self) -> bool {
484 self.a.has_adjoint() && self.b.has_adjoint()
485 }
486}
487
488pub struct FunctionOperator<F> {
490 shape: (usize, usize),
491 matvec_fn: MatVecFn<F>,
492 rmatvec_fn: Option<MatVecFn<F>>,
493}
494
495impl<F: Float + 'static> FunctionOperator<F> {
496 #[allow(dead_code)]
498 pub fn new<MV, RMV>(shape: (usize, usize), matvec_fn: MV, rmatvec_fn: Option<RMV>) -> Self
499 where
500 MV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
501 RMV: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
502 {
503 Self {
504 shape,
505 matvec_fn: Box::new(matvec_fn),
506 rmatvec_fn: rmatvec_fn
507 .map(|f| Box::new(f) as Box<dyn Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync>),
508 }
509 }
510
511 #[allow(dead_code)]
513 pub fn from_function<FMv>(shape: (usize, usize), matvec_fn: FMv) -> Self
514 where
515 FMv: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
516 {
517 Self::new(shape, matvec_fn, None::<fn(&[F]) -> SparseResult<Vec<F>>>)
518 }
519}
520
521impl<F: Float> LinearOperator<F> for FunctionOperator<F> {
522 fn shape(&self) -> (usize, usize) {
523 self.shape
524 }
525
526 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
527 (self.matvec_fn)(x)
528 }
529
530 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
531 match &self.rmatvec_fn {
532 Some(f) => f(x),
533 None => Err(SparseError::OperationNotSupported(
534 "adjoint not implemented for this function operator".to_string(),
535 )),
536 }
537 }
538
539 fn has_adjoint(&self) -> bool {
540 self.rmatvec_fn.is_some()
541 }
542}
543
544pub struct InverseOperator<F> {
547 original: Box<dyn LinearOperator<F>>,
548 solver_fn: SolverFn<F>,
549}
550
551impl<F: Float> InverseOperator<F> {
552 #[allow(dead_code)]
554 pub fn new<S>(original: Box<dyn LinearOperator<F>>, solver_fn: S) -> SparseResult<Self>
555 where
556 S: Fn(&[F]) -> SparseResult<Vec<F>> + Send + Sync + 'static,
557 {
558 let (rows, cols) = original.shape();
559 if rows != cols {
560 return Err(SparseError::ValueError(
561 "Cannot invert non-square operator".to_string(),
562 ));
563 }
564
565 Ok(Self {
566 original,
567 solver_fn: Box::new(solver_fn),
568 })
569 }
570}
571
572impl<F: Float> LinearOperator<F> for InverseOperator<F> {
573 fn shape(&self) -> (usize, usize) {
574 self.original.shape()
575 }
576
577 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
578 (self.solver_fn)(x)
580 }
581
582 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
583 if !self.original.has_adjoint() {
586 return Err(SparseError::OperationNotSupported(
587 "adjoint not supported for original operator".to_string(),
588 ));
589 }
590
591 Err(SparseError::OperationNotSupported(
594 "adjoint of inverse operator not yet implemented".to_string(),
595 ))
596 }
597
598 fn has_adjoint(&self) -> bool {
599 false }
601}
602
603pub struct TransposeOperator<F> {
605 original: Box<dyn LinearOperator<F>>,
606}
607
608impl<F: Float + NumAssign> TransposeOperator<F> {
609 pub fn new(original: Box<dyn LinearOperator<F>>) -> Self {
611 Self { original }
612 }
613}
614
615impl<F: Float + NumAssign> LinearOperator<F> for TransposeOperator<F> {
616 fn shape(&self) -> (usize, usize) {
617 let (rows, cols) = self.original.shape();
618 (cols, rows) }
620
621 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
622 self.original.rmatvec(x)
624 }
625
626 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
627 self.original.matvec(x)
629 }
630
631 fn has_adjoint(&self) -> bool {
632 true }
634}
635
636pub struct AdjointOperator<F> {
638 original: Box<dyn LinearOperator<F>>,
639}
640
641impl<F: Float + NumAssign> AdjointOperator<F> {
642 pub fn new(original: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
644 if !original.has_adjoint() {
645 return Err(SparseError::OperationNotSupported(
646 "Original operator does not support adjoint operations".to_string(),
647 ));
648 }
649 Ok(Self { original })
650 }
651}
652
653impl<F: Float + NumAssign> LinearOperator<F> for AdjointOperator<F> {
654 fn shape(&self) -> (usize, usize) {
655 let (rows, cols) = self.original.shape();
656 (cols, rows) }
658
659 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
660 self.original.rmatvec(x)
661 }
662
663 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
664 self.original.matvec(x)
665 }
666
667 fn has_adjoint(&self) -> bool {
668 true
669 }
670}
671
672pub struct DifferenceOperator<F> {
674 a: Box<dyn LinearOperator<F>>,
675 b: Box<dyn LinearOperator<F>>,
676}
677
678impl<F: Float + NumAssign> DifferenceOperator<F> {
679 pub fn new(a: Box<dyn LinearOperator<F>>, b: Box<dyn LinearOperator<F>>) -> SparseResult<Self> {
681 if a.shape() != b.shape() {
682 return Err(SparseError::ShapeMismatch {
683 expected: a.shape(),
684 found: b.shape(),
685 });
686 }
687 Ok(Self { a, b })
688 }
689}
690
691impl<F: Float + NumAssign> LinearOperator<F> for DifferenceOperator<F> {
692 fn shape(&self) -> (usize, usize) {
693 self.a.shape()
694 }
695
696 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
697 let a_result = self.a.matvec(x)?;
698 let b_result = self.b.matvec(x)?;
699 Ok(a_result
700 .iter()
701 .zip(&b_result)
702 .map(|(&a, &b)| a - b)
703 .collect())
704 }
705
706 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
707 if !self.a.has_adjoint() || !self.b.has_adjoint() {
708 return Err(SparseError::OperationNotSupported(
709 "adjoint not supported for one or both operators".to_string(),
710 ));
711 }
712 let a_result = self.a.rmatvec(x)?;
713 let b_result = self.b.rmatvec(x)?;
714 Ok(a_result
715 .iter()
716 .zip(&b_result)
717 .map(|(&a, &b)| a - b)
718 .collect())
719 }
720
721 fn has_adjoint(&self) -> bool {
722 self.a.has_adjoint() && self.b.has_adjoint()
723 }
724}
725
726pub struct ScaledOperator<F> {
728 alpha: F,
729 operator: Box<dyn LinearOperator<F>>,
730}
731
732impl<F: Float + NumAssign> ScaledOperator<F> {
733 pub fn new(alpha: F, operator: Box<dyn LinearOperator<F>>) -> Self {
735 Self { alpha, operator }
736 }
737}
738
739impl<F: Float + NumAssign> LinearOperator<F> for ScaledOperator<F> {
740 fn shape(&self) -> (usize, usize) {
741 self.operator.shape()
742 }
743
744 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
745 let result = self.operator.matvec(x)?;
746 Ok(result.iter().map(|&val| self.alpha * val).collect())
747 }
748
749 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
750 if !self.operator.has_adjoint() {
751 return Err(SparseError::OperationNotSupported(
752 "adjoint not supported for underlying operator".to_string(),
753 ));
754 }
755 let result = self.operator.rmatvec(x)?;
756 Ok(result.iter().map(|&val| self.alpha * val).collect())
757 }
758
759 fn has_adjoint(&self) -> bool {
760 self.operator.has_adjoint()
761 }
762}
763
764pub struct ChainOperator<F> {
766 operators: Vec<Box<dyn LinearOperator<F>>>,
767 totalshape: (usize, usize),
768}
769
770impl<F: Float + NumAssign> ChainOperator<F> {
771 #[allow(dead_code)]
774 pub fn new(operators: Vec<Box<dyn LinearOperator<F>>>) -> SparseResult<Self> {
775 if operators.is_empty() {
776 return Err(SparseError::ValueError(
777 "Cannot create chain with no operators".to_string(),
778 ));
779 }
780
781 #[allow(clippy::needless_range_loop)]
783 for i in 0..operators.len() - 1 {
784 let (_, a_cols) = operators[i].shape();
785 let (b_rows, _) = operators[i + 1].shape();
786 if a_cols != b_rows {
787 return Err(SparseError::DimensionMismatch {
788 expected: a_cols,
789 found: b_rows,
790 });
791 }
792 }
793
794 let (first_rows, _) = operators[0].shape();
795 let (_, last_cols) = operators.last().unwrap().shape();
796 let totalshape = (first_rows, last_cols);
797
798 Ok(Self {
799 operators,
800 totalshape,
801 })
802 }
803}
804
805impl<F: Float + NumAssign> LinearOperator<F> for ChainOperator<F> {
806 fn shape(&self) -> (usize, usize) {
807 self.totalshape
808 }
809
810 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
811 let mut result = x.to_vec();
812 for op in self.operators.iter().rev() {
814 result = op.matvec(&result)?;
815 }
816 Ok(result)
817 }
818
819 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
820 for op in &self.operators {
822 if !op.has_adjoint() {
823 return Err(SparseError::OperationNotSupported(
824 "adjoint not supported for all operators in chain".to_string(),
825 ));
826 }
827 }
828
829 let mut result = x.to_vec();
830 for op in &self.operators {
832 result = op.rmatvec(&result)?;
833 }
834 Ok(result)
835 }
836
837 fn has_adjoint(&self) -> bool {
838 self.operators.iter().all(|op| op.has_adjoint())
839 }
840}
841
842pub struct PowerOperator<F> {
844 operator: Box<dyn LinearOperator<F>>,
845 power: usize,
846}
847
848impl<F: Float + NumAssign> PowerOperator<F> {
849 pub fn new(operator: Box<dyn LinearOperator<F>>, power: usize) -> SparseResult<Self> {
851 let (rows, cols) = operator.shape();
852 if rows != cols {
853 return Err(SparseError::ValueError(
854 "Can only compute powers of square operators".to_string(),
855 ));
856 }
857 if power == 0 {
858 return Err(SparseError::ValueError(
859 "Power must be positive".to_string(),
860 ));
861 }
862 Ok(Self { operator, power })
863 }
864}
865
866impl<F: Float + NumAssign> LinearOperator<F> for PowerOperator<F> {
867 fn shape(&self) -> (usize, usize) {
868 self.operator.shape()
869 }
870
871 fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
872 let mut result = x.to_vec();
873 for _ in 0..self.power {
874 result = self.operator.matvec(&result)?;
875 }
876 Ok(result)
877 }
878
879 fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
880 if !self.operator.has_adjoint() {
881 return Err(SparseError::OperationNotSupported(
882 "adjoint not supported for underlying operator".to_string(),
883 ));
884 }
885 let mut result = x.to_vec();
886 for _ in 0..self.power {
887 result = self.operator.rmatvec(&result)?;
888 }
889 Ok(result)
890 }
891
892 fn has_adjoint(&self) -> bool {
893 self.operator.has_adjoint()
894 }
895}
896
897#[allow(dead_code)]
899pub trait LinearOperatorExt<F: Float + NumAssign>: LinearOperator<F> {
900 fn add(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
902
903 fn sub(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
905
906 fn mul(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>>;
908
909 fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>>;
911
912 fn transpose(&self) -> Box<dyn LinearOperator<F>>;
914
915 fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>>;
917
918 fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>>;
920}
921
922macro_rules! impl_linear_operator_ext {
924 ($typ:ty) => {
925 impl<F: Float + NumAssign + Copy + 'static> LinearOperatorExt<F> for $typ {
926 fn add(
927 &self,
928 other: Box<dyn LinearOperator<F>>,
929 ) -> SparseResult<Box<dyn LinearOperator<F>>> {
930 let self_box = Box::new(self.clone());
931 Ok(Box::new(SumOperator::new(self_box, other)?))
932 }
933
934 fn sub(
935 &self,
936 other: Box<dyn LinearOperator<F>>,
937 ) -> SparseResult<Box<dyn LinearOperator<F>>> {
938 let self_box = Box::new(self.clone());
939 Ok(Box::new(DifferenceOperator::new(self_box, other)?))
940 }
941
942 fn mul(
943 &self,
944 other: Box<dyn LinearOperator<F>>,
945 ) -> SparseResult<Box<dyn LinearOperator<F>>> {
946 let self_box = Box::new(self.clone());
947 Ok(Box::new(ProductOperator::new(self_box, other)?))
948 }
949
950 fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>> {
951 let self_box = Box::new(self.clone());
952 Box::new(ScaledOperator::new(alpha, self_box))
953 }
954
955 fn transpose(&self) -> Box<dyn LinearOperator<F>> {
956 let self_box = Box::new(self.clone());
957 Box::new(TransposeOperator::new(self_box))
958 }
959
960 fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>> {
961 let self_box = Box::new(self.clone());
962 Ok(Box::new(AdjointOperator::new(self_box)?))
963 }
964
965 fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>> {
966 let self_box = Box::new(self.clone());
967 Ok(Box::new(PowerOperator::new(self_box, n)?))
968 }
969 }
970 };
971}
972
973impl_linear_operator_ext!(IdentityOperator<F>);
975impl_linear_operator_ext!(ScaledIdentityOperator<F>);
976impl_linear_operator_ext!(DiagonalOperator<F>);
977
978impl<F: Float + NumAssign + Copy + SparseElement + 'static> LinearOperatorExt<F>
980 for ZeroOperator<F>
981{
982 fn add(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>> {
983 let self_box = Box::new(self.clone());
984 Ok(Box::new(SumOperator::new(self_box, other)?))
985 }
986
987 fn sub(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>> {
988 let self_box = Box::new(self.clone());
989 Ok(Box::new(DifferenceOperator::new(self_box, other)?))
990 }
991
992 fn mul(&self, other: Box<dyn LinearOperator<F>>) -> SparseResult<Box<dyn LinearOperator<F>>> {
993 let self_box = Box::new(self.clone());
994 Ok(Box::new(ProductOperator::new(self_box, other)?))
995 }
996
997 fn scale(&self, alpha: F) -> Box<dyn LinearOperator<F>> {
998 let self_box = Box::new(self.clone());
999 Box::new(ScaledOperator::new(alpha, self_box))
1000 }
1001
1002 fn transpose(&self) -> Box<dyn LinearOperator<F>> {
1003 let self_box = Box::new(self.clone());
1004 Box::new(TransposeOperator::new(self_box))
1005 }
1006
1007 fn adjoint(&self) -> SparseResult<Box<dyn LinearOperator<F>>> {
1008 let self_box = Box::new(self.clone());
1009 Ok(Box::new(AdjointOperator::new(self_box)?))
1010 }
1011
1012 fn pow(&self, n: usize) -> SparseResult<Box<dyn LinearOperator<F>>> {
1013 let self_box = Box::new(self.clone());
1014 Ok(Box::new(PowerOperator::new(self_box, n)?))
1015 }
1016}
1017
1018#[allow(dead_code)]
1021pub fn add_operators<F: Float + NumAssign + 'static>(
1022 left: Box<dyn LinearOperator<F>>,
1023 right: Box<dyn LinearOperator<F>>,
1024) -> SparseResult<Box<dyn LinearOperator<F>>> {
1025 Ok(Box::new(SumOperator::new(left, right)?))
1026}
1027
1028#[allow(dead_code)]
1030pub fn subtract_operators<F: Float + NumAssign + 'static>(
1031 left: Box<dyn LinearOperator<F>>,
1032 right: Box<dyn LinearOperator<F>>,
1033) -> SparseResult<Box<dyn LinearOperator<F>>> {
1034 Ok(Box::new(DifferenceOperator::new(left, right)?))
1035}
1036
1037#[allow(dead_code)]
1039pub fn multiply_operators<F: Float + NumAssign + 'static>(
1040 left: Box<dyn LinearOperator<F>>,
1041 right: Box<dyn LinearOperator<F>>,
1042) -> SparseResult<Box<dyn LinearOperator<F>>> {
1043 Ok(Box::new(ProductOperator::new(left, right)?))
1044}
1045
1046#[allow(dead_code)]
1048pub fn scale_operator<F: Float + NumAssign + 'static>(
1049 alpha: F,
1050 operator: Box<dyn LinearOperator<F>>,
1051) -> Box<dyn LinearOperator<F>> {
1052 Box::new(ScaledOperator::new(alpha, operator))
1053}
1054
1055#[allow(dead_code)]
1057pub fn transpose_operator<F: Float + NumAssign + 'static>(
1058 operator: Box<dyn LinearOperator<F>>,
1059) -> Box<dyn LinearOperator<F>> {
1060 Box::new(TransposeOperator::new(operator))
1061}
1062
1063#[allow(dead_code)]
1065pub fn adjoint_operator<F: Float + NumAssign + 'static>(
1066 operator: Box<dyn LinearOperator<F>>,
1067) -> SparseResult<Box<dyn LinearOperator<F>>> {
1068 Ok(Box::new(AdjointOperator::new(operator)?))
1069}
1070
1071#[allow(dead_code)]
1073pub fn compose_operators<F: Float + NumAssign + 'static>(
1074 operators: Vec<Box<dyn LinearOperator<F>>>,
1075) -> SparseResult<Box<dyn LinearOperator<F>>> {
1076 Ok(Box::new(ChainOperator::new(operators)?))
1077}
1078
1079#[allow(dead_code)]
1081pub fn power_operator<F: Float + NumAssign + 'static>(
1082 operator: Box<dyn LinearOperator<F>>,
1083 n: usize,
1084) -> SparseResult<Box<dyn LinearOperator<F>>> {
1085 Ok(Box::new(PowerOperator::new(operator, n)?))
1086}
1087
1088#[cfg(test)]
1089mod tests {
1090 use super::*;
1091
1092 #[test]
1093 fn test_identity_operator() {
1094 let op = IdentityOperator::<f64>::new(3);
1095 let x = vec![1.0, 2.0, 3.0];
1096 let y = op.matvec(&x).unwrap();
1097 assert_eq!(x, y);
1098 }
1099
1100 #[test]
1101 fn test_scaled_identity_operator() {
1102 let op = ScaledIdentityOperator::new(3, 2.0);
1103 let x = vec![1.0, 2.0, 3.0];
1104 let y = op.matvec(&x).unwrap();
1105 assert_eq!(y, vec![2.0, 4.0, 6.0]);
1106 }
1107
1108 #[test]
1109 fn test_diagonal_operator() {
1110 let op = DiagonalOperator::new(vec![2.0, 3.0, 4.0]);
1111 let x = vec![1.0, 2.0, 3.0];
1112 let y = op.matvec(&x).unwrap();
1113 assert_eq!(y, vec![2.0, 6.0, 12.0]);
1114 }
1115
1116 #[test]
1117 fn test_zero_operator() {
1118 let op = ZeroOperator::<f64>::new(3, 3);
1119 let x = vec![1.0, 2.0, 3.0];
1120 let y = op.matvec(&x).unwrap();
1121 assert_eq!(y, vec![0.0, 0.0, 0.0]);
1122 }
1123
1124 #[test]
1125 fn test_sum_operator() {
1126 let id = Box::new(IdentityOperator::<f64>::new(3));
1127 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1128 let sum = SumOperator::new(id, scaled).unwrap();
1129 let x = vec![1.0, 2.0, 3.0];
1130 let y = sum.matvec(&x).unwrap();
1131 assert_eq!(y, vec![3.0, 6.0, 9.0]); }
1133
1134 #[test]
1135 fn test_product_operator() {
1136 let id = Box::new(IdentityOperator::<f64>::new(3));
1137 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1138 let product = ProductOperator::new(scaled, id).unwrap();
1139 let x = vec![1.0, 2.0, 3.0];
1140 let y = product.matvec(&x).unwrap();
1141 assert_eq!(y, vec![2.0, 4.0, 6.0]); }
1143
1144 #[test]
1145 fn test_difference_operator() {
1146 let scaled_3 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1147 let scaled_2 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1148 let diff = DifferenceOperator::new(scaled_3, scaled_2).unwrap();
1149 let x = vec![1.0, 2.0, 3.0];
1150 let y = diff.matvec(&x).unwrap();
1151 assert_eq!(y, vec![1.0, 2.0, 3.0]); }
1153
1154 #[test]
1155 fn test_scaled_operator() {
1156 let id = Box::new(IdentityOperator::<f64>::new(3));
1157 let scaled = ScaledOperator::new(5.0, id);
1158 let x = vec![1.0, 2.0, 3.0];
1159 let y = scaled.matvec(&x).unwrap();
1160 assert_eq!(y, vec![5.0, 10.0, 15.0]); }
1162
1163 #[test]
1164 fn test_transpose_operator() {
1165 let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1166 let transpose = TransposeOperator::new(diag);
1167 let x = vec![1.0, 2.0, 3.0];
1168 let y = transpose.matvec(&x).unwrap();
1169 assert_eq!(y, vec![2.0, 6.0, 12.0]);
1171 }
1172
1173 #[test]
1174 fn test_adjoint_operator() {
1175 let diag = Box::new(DiagonalOperator::new(vec![2.0, 3.0, 4.0]));
1176 let adjoint = AdjointOperator::new(diag).unwrap();
1177 let x = vec![1.0, 2.0, 3.0];
1178 let y = adjoint.matvec(&x).unwrap();
1179 assert_eq!(y, vec![2.0, 6.0, 12.0]);
1181 }
1182
1183 #[test]
1184 fn test_chain_operator() {
1185 let op1 = Box::new(ScaledIdentityOperator::new(3, 2.0));
1186 let op2 = Box::new(ScaledIdentityOperator::new(3, 3.0));
1187 let chain = ChainOperator::new(vec![op1, op2]).unwrap();
1188 let x = vec![1.0, 2.0, 3.0];
1189 let y = chain.matvec(&x).unwrap();
1190 assert_eq!(y, vec![6.0, 12.0, 18.0]);
1192 }
1193
1194 #[test]
1195 fn test_power_operator() {
1196 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1197 let power = PowerOperator::new(scaled, 3).unwrap();
1198 let x = vec![1.0, 2.0, 3.0];
1199 let y = power.matvec(&x).unwrap();
1200 assert_eq!(y, vec![8.0, 16.0, 24.0]);
1202 }
1203
1204 #[test]
1205 fn test_composition_utility_functions() {
1206 let id = Box::new(IdentityOperator::<f64>::new(3));
1207 let scaled = Box::new(ScaledIdentityOperator::new(3, 2.0));
1208
1209 let sum = add_operators(id.clone(), scaled.clone()).unwrap();
1211 let x = vec![1.0, 2.0, 3.0];
1212 let y = sum.matvec(&x).unwrap();
1213 assert_eq!(y, vec![3.0, 6.0, 9.0]); let diff = subtract_operators(scaled.clone(), id.clone()).unwrap();
1217 let y2 = diff.matvec(&x).unwrap();
1218 assert_eq!(y2, vec![1.0, 2.0, 3.0]); let product = multiply_operators(scaled.clone(), id.clone()).unwrap();
1222 let y3 = product.matvec(&x).unwrap();
1223 assert_eq!(y3, vec![2.0, 4.0, 6.0]); let scaled_op = scale_operator(3.0, id.clone());
1227 let y4 = scaled_op.matvec(&x).unwrap();
1228 assert_eq!(y4, vec![3.0, 6.0, 9.0]); let transpose = transpose_operator(scaled.clone());
1232 let y5 = transpose.matvec(&x).unwrap();
1233 assert_eq!(y5, vec![2.0, 4.0, 6.0]); let ops: Vec<Box<dyn LinearOperator<f64>>> = vec![scaled.clone(), id.clone()];
1237 let composed = compose_operators(ops).unwrap();
1238 let y6 = composed.matvec(&x).unwrap();
1239 assert_eq!(y6, vec![2.0, 4.0, 6.0]); let power = power_operator(scaled.clone(), 2).unwrap();
1243 let y7 = power.matvec(&x).unwrap();
1244 assert_eq!(y7, vec![4.0, 8.0, 12.0]); }
1246
1247 #[test]
1248 fn test_dimension_mismatch_errors() {
1249 let op1 = Box::new(IdentityOperator::<f64>::new(3));
1250 let op2 = Box::new(IdentityOperator::<f64>::new(4));
1251
1252 assert!(SumOperator::new(op1.clone(), op2.clone()).is_err());
1254
1255 assert!(DifferenceOperator::new(op1.clone(), op2.clone()).is_err());
1257
1258 let rect1 = Box::new(ZeroOperator::<f64>::new(3, 4));
1260 let rect2 = Box::new(ZeroOperator::<f64>::new(5, 3));
1261 assert!(ProductOperator::new(rect1, rect2).is_err());
1262 }
1263
1264 #[test]
1265 fn test_adjoint_not_supported_error() {
1266 let func_op = Box::new(FunctionOperator::from_function((3, 3), |x: &[f64]| {
1268 Ok(x.to_vec())
1269 }));
1270
1271 assert!(AdjointOperator::new(func_op).is_err());
1273 }
1274
1275 #[test]
1276 fn test_power_operator_errors() {
1277 let rect_op = Box::new(ZeroOperator::<f64>::new(3, 4));
1278
1279 assert!(PowerOperator::new(rect_op, 2).is_err());
1281
1282 let square_op = Box::new(IdentityOperator::<f64>::new(3));
1283
1284 assert!(PowerOperator::new(square_op, 0).is_err());
1286 }
1287}