use crate::{
linalg::{temp_mat_req, temp_mat_uninit},
ComplexField, MatMut, MatRef, Parallelism,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use reborrow::*;
pub mod bicgstab;
pub mod conjugate_gradient;
pub mod lsmr;
mod linop_impl;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub enum InitialGuessStatus {
Zero,
#[default]
MaybeNonZero,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct IdentityPrecond {
pub dim: usize,
}
impl<E: ComplexField> LinOp<E> for IdentityPrecond {
#[inline]
#[track_caller]
fn apply_req(
&self,
_rhs_ncols: usize,
_parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
Ok(StackReq::empty())
}
#[inline]
fn nrows(&self) -> usize {
self.dim
}
#[inline]
fn ncols(&self) -> usize {
self.dim
}
#[inline]
#[track_caller]
fn apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
_parallelism: Parallelism,
_stack: PodStack<'_>,
) {
{ out }.copy_from(rhs);
}
#[inline]
#[track_caller]
fn conj_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
_parallelism: Parallelism,
_stack: PodStack<'_>,
) {
{ out }.copy_from(rhs);
}
}
impl<E: ComplexField> BiLinOp<E> for IdentityPrecond {
#[inline]
fn transpose_apply_req(
&self,
_rhs_ncols: usize,
_parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
Ok(StackReq::empty())
}
#[inline]
#[track_caller]
fn transpose_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
_parallelism: Parallelism,
_stack: PodStack<'_>,
) {
{ out }.copy_from(rhs);
}
#[inline]
#[track_caller]
fn adjoint_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
_parallelism: Parallelism,
_stack: PodStack<'_>,
) {
{ out }.copy_from(rhs);
}
}
impl<E: ComplexField> Precond<E> for IdentityPrecond {
fn apply_in_place_req(
&self,
_rhs_ncols: usize,
_parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
Ok(StackReq::empty())
}
fn apply_in_place(&self, _rhs: MatMut<'_, E>, _parallelism: Parallelism, _stack: PodStack<'_>) {
}
fn conj_apply_in_place(
&self,
_rhs: MatMut<'_, E>,
_parallelism: Parallelism,
_stack: PodStack<'_>,
) {
}
}
impl<E: ComplexField> BiPrecond<E> for IdentityPrecond {
fn transpose_apply_in_place_req(
&self,
_rhs_ncols: usize,
_parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
Ok(StackReq::empty())
}
fn transpose_apply_in_place(
&self,
_rhs: MatMut<'_, E>,
_parallelism: Parallelism,
_stack: PodStack<'_>,
) {
}
fn adjoint_apply_in_place(
&self,
_rhs: MatMut<'_, E>,
_parallelism: Parallelism,
_stack: PodStack<'_>,
) {
}
}
pub trait LinOp<E: ComplexField>: Sync + core::fmt::Debug {
fn apply_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow>;
fn nrows(&self) -> usize;
fn ncols(&self) -> usize;
fn apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
);
fn conj_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
);
}
pub trait BiLinOp<E: ComplexField>: LinOp<E> {
fn transpose_apply_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow>;
fn transpose_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
);
fn adjoint_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
);
}
pub trait Precond<E: ComplexField>: LinOp<E> {
fn apply_in_place_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
temp_mat_req::<E>(self.nrows(), rhs_ncols)?.try_and(self.apply_req(rhs_ncols, parallelism)?)
}
#[track_caller]
fn apply_in_place(&self, rhs: MatMut<'_, E>, parallelism: Parallelism, stack: PodStack<'_>) {
let (mut tmp, stack) = temp_mat_uninit::<E>(self.nrows(), rhs.ncols(), stack);
self.apply(tmp.rb_mut(), rhs.rb(), parallelism, stack);
{ rhs }.copy_from(&tmp);
}
#[track_caller]
fn conj_apply_in_place(
&self,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
let (mut tmp, stack) = temp_mat_uninit::<E>(self.nrows(), rhs.ncols(), stack);
self.conj_apply(tmp.rb_mut(), rhs.rb(), parallelism, stack);
{ rhs }.copy_from(&tmp);
}
}
pub trait BiPrecond<E: ComplexField>: Precond<E> + BiLinOp<E> {
fn transpose_apply_in_place_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
temp_mat_req::<E>(self.nrows(), rhs_ncols)?
.try_and(self.transpose_apply_req(rhs_ncols, parallelism)?)
}
#[track_caller]
fn transpose_apply_in_place(
&self,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
let (mut tmp, stack) = temp_mat_uninit::<E>(self.nrows(), rhs.ncols(), stack);
self.transpose_apply(tmp.rb_mut(), rhs.rb(), parallelism, stack);
{ rhs }.copy_from(&tmp);
}
#[track_caller]
fn adjoint_apply_in_place(
&self,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
let (mut tmp, stack) = temp_mat_uninit::<E>(self.nrows(), rhs.ncols(), stack);
self.adjoint_apply(tmp.rb_mut(), rhs.rb(), parallelism, stack);
{ rhs }.copy_from(&tmp);
}
}
impl<E: ComplexField, T: ?Sized + LinOp<E>> LinOp<E> for &T {
#[inline]
#[track_caller]
fn apply_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
(**self).apply_req(rhs_ncols, parallelism)
}
#[inline]
fn nrows(&self) -> usize {
(**self).nrows()
}
#[inline]
fn ncols(&self) -> usize {
(**self).ncols()
}
#[inline]
#[track_caller]
fn apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
(**self).apply(out, rhs, parallelism, stack)
}
#[inline]
#[track_caller]
fn conj_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
(**self).conj_apply(out, rhs, parallelism, stack)
}
}
impl<E: ComplexField, T: ?Sized + BiLinOp<E>> BiLinOp<E> for &T {
#[inline]
#[track_caller]
fn transpose_apply_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
(**self).transpose_apply_req(rhs_ncols, parallelism)
}
#[inline]
#[track_caller]
fn transpose_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
(**self).transpose_apply(out, rhs, parallelism, stack)
}
#[inline]
#[track_caller]
fn adjoint_apply(
&self,
out: MatMut<'_, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
(**self).adjoint_apply(out, rhs, parallelism, stack)
}
}
impl<E: ComplexField, T: ?Sized + Precond<E>> Precond<E> for &T {
fn apply_in_place_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
(**self).apply_in_place_req(rhs_ncols, parallelism)
}
fn apply_in_place(&self, rhs: MatMut<'_, E>, parallelism: Parallelism, stack: PodStack<'_>) {
(**self).apply_in_place(rhs, parallelism, stack);
}
fn conj_apply_in_place(
&self,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
(**self).conj_apply_in_place(rhs, parallelism, stack);
}
}
impl<E: ComplexField, T: ?Sized + BiPrecond<E>> BiPrecond<E> for &T {
fn transpose_apply_in_place_req(
&self,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
(**self).transpose_apply_in_place_req(rhs_ncols, parallelism)
}
fn transpose_apply_in_place(
&self,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
(**self).transpose_apply_in_place(rhs, parallelism, stack);
}
fn adjoint_apply_in_place(
&self,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
(**self).adjoint_apply_in_place(rhs, parallelism, stack);
}
}