image_recovery/
solvers.rs

1// Copyright (C) 2022  Lílian Ferreira de Freitas & Emilia L. K. Blåsten
2//
3// This program is free software: you can redistribute it and/or modify
4// it under the terms of the GNU Affero General Public License as published
5// by the Free Software Foundation, either version 3 of the License, or
6// (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11// GNU Affero General Public License for more details.
12//
13// You should have received a copy of the GNU Affero General Public License
14// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15
16//! Implementation of algorithms for image recovery.
17use std::ops::Deref;
18
19use ndarray::{
20    Array3,
21    ShapeError,
22};
23
24use crate::{
25    image_array::ImageArray,
26    ops::{
27        Average,
28        Gradient,
29        Norm,
30        VectorLen,
31    },
32};
33
34impl ImageArray<Array3<f64>> {
35    /// Image denoising algorithm for 2 dimentional shapes with 1 dimention of
36    /// information (pixels) as an arbitrarily sized vector. Assumes axes 0
37    /// and 1 and the x and y coordinates of the image, and axis 2 is the
38    /// pixel vector coordinate of the image.
39    ///
40    /// # inputs
41    /// `lambda` is the target value of the dual objective function,
42    /// i.e. how close you want the output to be to the input:
43    /// approaching 0, the output should be completely smooth (flat),
44    /// approaching "infinifty", the output should be the same as
45    /// the original input.
46    ///
47    /// `tau` and `sigma` affect how fast the algorithm converges,
48    /// according to Chambolle, A. and Pock, T. (2011) these should
49    /// be chosen such that `tau * lambda * L2 norm^2 <= 1` where
50    /// `L2 norm^2 <= 8`.
51    ///
52    /// `gamma` updates the algorithm's internal variables,
53    /// for the accelerated algorithm of Chambolle, A. and Pock, T. (2011)
54    /// the chosen value is `0.35 * lambda`.
55    ///
56    /// `max_iter` and `convergence_threshold` bound the runtime of the
57    /// algorithm, i.e. it runs until `convergence_threshold < norm(current -
58    /// previous) / norm(previous)` or `max_iter` is hit.
59    pub fn denoise(
60        &self,
61        lambda: f64,
62        mut tau: f64,
63        mut sigma: f64,
64        gamma: f64,
65        max_iter: u32,
66        convergence_threshold: f64,
67    ) -> Result<Self, ShapeError> {
68        // primal variable (two copies, for storing value of iteration n-1)
69        let mut current: Array3<f64> = self.deref().clone();
70        let mut previous: Array3<f64>;
71        // primal variable "bar"
72        let mut current_bar = current.clone();
73        // dual variables
74        let mut dual_a = current.positive_gradient_on_axis(0)?;
75        let mut dual_b = current.positive_gradient_on_axis(1)?;
76        // theta will be set upon first iteration
77        let mut theta: f64;
78
79        let mut iter: u32 = 1;
80        loop {
81            // update the dual variable
82            dual_a =
83                &dual_a + (sigma * current_bar.positive_gradient_on_axis(0)?);
84            dual_b =
85                &dual_b + (sigma * current_bar.positive_gradient_on_axis(1)?);
86            // project dual variables color axis into L2 ball (-1, 1).
87            // assumes axis 2 is color axis of image.
88            let max = dual_a
89                .vector_len_on_axis(&dual_b, 2)?
90                .map(|&x| 1_f64.max(x));
91            dual_a /= &max;
92            dual_b /= &max;
93
94            // update the primal variable
95            previous = current.clone();
96            current = &current
97                - (tau
98                    * (dual_a.negative_gradient_on_axis(0)?
99                        + dual_b.negative_gradient_on_axis(1)?));
100            current = self.weighted_average(&current, tau, lambda);
101
102            // update theta
103            theta = 1_f64 / (1_f64 + (2_f64 * gamma * tau));
104            // update tau
105            tau *= theta;
106            // update sigma
107            sigma /= theta;
108
109            // update the primal variable bar
110            current_bar = &current + &(theta * (&current - &previous));
111
112            // check for convergence or max_iter iterations
113            let c = (&current - &previous).norm() / previous.norm();
114            if c < convergence_threshold || iter >= max_iter {
115                log::debug!(
116                    "returned at iteration = {}; where max = {}",
117                    iter,
118                    max_iter
119                );
120                log::debug!(
121                    "convergence = {}; where threshold = {}",
122                    c,
123                    convergence_threshold
124                );
125                break;
126            }
127            iter += 1;
128        }
129
130        Ok(ImageArray::from(&current))
131    }
132}