image_recovery/solvers.rs
1// Copyright (C) 2022 Lílian Ferreira de Freitas & Emilia L. K. Blåsten
2//
3// This program is free software: you can redistribute it and/or modify
4// it under the terms of the GNU Affero General Public License as published
5// by the Free Software Foundation, either version 3 of the License, or
6// (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU Affero General Public License for more details.
12//
13// You should have received a copy of the GNU Affero General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Implementation of algorithms for image recovery.
17use std::ops::Deref;
18
19use ndarray::{
20 Array3,
21 ShapeError,
22};
23
24use crate::{
25 image_array::ImageArray,
26 ops::{
27 Average,
28 Gradient,
29 Norm,
30 VectorLen,
31 },
32};
33
34impl ImageArray<Array3<f64>> {
35 /// Image denoising algorithm for 2 dimentional shapes with 1 dimention of
36 /// information (pixels) as an arbitrarily sized vector. Assumes axes 0
37 /// and 1 and the x and y coordinates of the image, and axis 2 is the
38 /// pixel vector coordinate of the image.
39 ///
40 /// # inputs
41 /// `lambda` is the target value of the dual objective function,
42 /// i.e. how close you want the output to be to the input:
43 /// approaching 0, the output should be completely smooth (flat),
44 /// approaching "infinifty", the output should be the same as
45 /// the original input.
46 ///
47 /// `tau` and `sigma` affect how fast the algorithm converges,
48 /// according to Chambolle, A. and Pock, T. (2011) these should
49 /// be chosen such that `tau * lambda * L2 norm^2 <= 1` where
50 /// `L2 norm^2 <= 8`.
51 ///
52 /// `gamma` updates the algorithm's internal variables,
53 /// for the accelerated algorithm of Chambolle, A. and Pock, T. (2011)
54 /// the chosen value is `0.35 * lambda`.
55 ///
56 /// `max_iter` and `convergence_threshold` bound the runtime of the
57 /// algorithm, i.e. it runs until `convergence_threshold < norm(current -
58 /// previous) / norm(previous)` or `max_iter` is hit.
59 pub fn denoise(
60 &self,
61 lambda: f64,
62 mut tau: f64,
63 mut sigma: f64,
64 gamma: f64,
65 max_iter: u32,
66 convergence_threshold: f64,
67 ) -> Result<Self, ShapeError> {
68 // primal variable (two copies, for storing value of iteration n-1)
69 let mut current: Array3<f64> = self.deref().clone();
70 let mut previous: Array3<f64>;
71 // primal variable "bar"
72 let mut current_bar = current.clone();
73 // dual variables
74 let mut dual_a = current.positive_gradient_on_axis(0)?;
75 let mut dual_b = current.positive_gradient_on_axis(1)?;
76 // theta will be set upon first iteration
77 let mut theta: f64;
78
79 let mut iter: u32 = 1;
80 loop {
81 // update the dual variable
82 dual_a =
83 &dual_a + (sigma * current_bar.positive_gradient_on_axis(0)?);
84 dual_b =
85 &dual_b + (sigma * current_bar.positive_gradient_on_axis(1)?);
86 // project dual variables color axis into L2 ball (-1, 1).
87 // assumes axis 2 is color axis of image.
88 let max = dual_a
89 .vector_len_on_axis(&dual_b, 2)?
90 .map(|&x| 1_f64.max(x));
91 dual_a /= &max;
92 dual_b /= &max;
93
94 // update the primal variable
95 previous = current.clone();
96 current = ¤t
97 - (tau
98 * (dual_a.negative_gradient_on_axis(0)?
99 + dual_b.negative_gradient_on_axis(1)?));
100 current = self.weighted_average(¤t, tau, lambda);
101
102 // update theta
103 theta = 1_f64 / (1_f64 + (2_f64 * gamma * tau));
104 // update tau
105 tau *= theta;
106 // update sigma
107 sigma /= theta;
108
109 // update the primal variable bar
110 current_bar = ¤t + &(theta * (¤t - &previous));
111
112 // check for convergence or max_iter iterations
113 let c = (¤t - &previous).norm() / previous.norm();
114 if c < convergence_threshold || iter >= max_iter {
115 log::debug!(
116 "returned at iteration = {}; where max = {}",
117 iter,
118 max_iter
119 );
120 log::debug!(
121 "convergence = {}; where threshold = {}",
122 c,
123 convergence_threshold
124 );
125 break;
126 }
127 iter += 1;
128 }
129
130 Ok(ImageArray::from(¤t))
131 }
132}