ellalgo_rs/oracles/
lowpass_oracle.rs

1use std::f64::consts::PI;
2// use ndarray::{stack, Axis, Array, Array1, Array2};
3use crate::cutting_plane::{OracleFeas, OracleOptim};
4use ndarray::{Array, Array1};
5
6type Arr = Array1<f64>;
7pub type Cut = (Arr, (f64, Option<f64>));
8
9/// The `LowpassOracle` struct in Rust represents a lowpass filter with various configuration
10/// parameters.
11///
12/// Properties:
13///
14/// * `more_alt`: The `more_alt` property is a boolean flag indicating whether there are more
15///             alternative options available.
16/// * `idx1`: The `idx1` property in the `LowpassOracle` struct is of type `i32`.
17/// * `spectrum`: The `spectrum` property is a vector of type `Arr`.
18/// * `nwpass`: The `nwpass` property in the `LowpassOracle` struct represents the number of points in
19///             the passband of a lowpass filter.
20/// * `nwstop`: The `nwstop` property in the `LowpassOracle` struct represents the number of points in
21///             the stopband of a lowpass filter. It is used to determine the characteristics of the filter,
22///             specifically the stopband width.
23/// * `lp_sq`: The `lp_sq` property in the `LowpassOracle` struct appears to be a floating-point number
24///            (f64). It likely represents a squared value used in low-pass filtering calculations or operations.
25/// * `up_sq`: The `up_sq` property in the `LowpassOracle` struct appears to be a floating-point number
26///            of type `f64`.
27/// * `sp_sq`: The `sp_sq` property in the `LowpassOracle` struct represents a floating-point value of
28///            type `f64`.
29/// * `idx2`: The `idx2` property in the `LowpassOracle` struct appears to be a `i32` type. It is a
30///           field that holds an unsigned integer value representing an index or position within the context of
31///           the struct.
32/// * `idx3`: The `idx3` property in the `LowpassOracle` struct represents an unsigned integer value.
33/// * `fmax`: The `fmax` property in the `LowpassOracle` struct represents the maximum frequency value.
34/// * `kmax`: The `kmax` property in the `LowpassOracle` struct represents the maximum value for a
35///           specific type `i32`. It is used to store the maximum value for a certain index or count within the
36///           context of the `LowpassOracle` struct.
37pub struct LowpassOracle {
38    pub more_alt: bool,
39    pub idx1: i32,
40    pub spectrum: Vec<Arr>,
41    pub nwpass: i32,
42    pub nwstop: i32,
43    pub lp_sq: f64,
44    pub up_sq: f64,
45    pub sp_sq: f64,
46    pub idx2: i32,
47    pub idx3: i32,
48    pub fmax: f64,
49    pub kmax: i32,
50}
51
52impl LowpassOracle {
53    /// The `new` function in Rust initializes a struct with specified parameters for spectral analysis.
54    ///
55    /// Arguments:
56    ///
57    /// * `ndim`: `ndim` represents the number of dimensions for the filter design.
58    /// * `wpass`: The `wpass` parameter represents the passband edge frequency in the provided function.
59    /// * `wstop`: The `wstop` parameter represents the stopband edge frequency in the given function.
60    /// * `lp_sq`: The `lp_sq` parameter in the code represents the lower passband squared value. It is
61    ///            used in the initialization of the struct and is a floating-point number (`f64`) passed as an
62    ///            argument to the `new` function.
63    /// * `up_sq`: The `up_sq` parameter in the function represents the upper bound squared value for
64    ///            the filter design. It is used in the calculation and initialization of the struct fields in the
65    ///            function.
66    /// * `sp_sq`: The `sp_sq` parameter in the `new` function represents the square of the stopband
67    ///            ripple level in the spectral domain. It is used in digital signal processing to define the
68    ///            desired characteristics of a filter, specifically in this context for designing a filter with
69    ///            given passband and stopband specifications.
70    ///
71    /// Returns:
72    ///
73    /// The `new` function is returning an instance of the struct that it belongs to. The struct
74    /// contains several fields such as `more_alt`, `idx1`, `spectrum`, `nwpass`, `nwstop`, `lp_sq`,
75    /// `up_sq`, `sp_sq`, `idx2`, `idx3`, `fmax`, and `kmax`. The function initializes these fields with
76    /// the
77    pub fn new(ndim: usize, wpass: f64, wstop: f64, lp_sq: f64, up_sq: f64, sp_sq: f64) -> Self {
78        let mdim = 15 * ndim;
79        let w: Array1<f64> = Array::linspace(0.0, std::f64::consts::PI, mdim);
80        // let tmp: Array2<f64> = Array::from_shape_fn((mdim, ndim - 1), |(i, j)| 2.0 * (w[i] * (j + 1) as f64).cos());
81        // let spectrum: Array2<f64> = stack![Axis(1), Array::ones(mdim).insert_axis(Axis(1)), tmp];
82
83        let mut spectrum = vec![Arr::zeros(ndim); mdim];
84        for i in 0..mdim {
85            spectrum[i][0] = 1.0;
86            for j in 1..ndim {
87                spectrum[i][j] = 2.0 * (w[i] * j as f64).cos();
88            }
89        }
90        // spectrum.iter_mut().for_each(|row| row.insert(0, 1.0));
91
92        let nwpass = (wpass * (mdim - 1) as f64).floor() as i32 + 1;
93        let nwstop = (wstop * (mdim - 1) as f64).floor() as i32 + 1;
94
95        Self {
96            more_alt: true,
97            idx1: -1,
98            spectrum,
99            nwpass,
100            nwstop,
101            lp_sq,
102            up_sq,
103            sp_sq,
104            idx2: nwpass - 1,
105            idx3: nwstop - 1,
106            fmax: f64::NEG_INFINITY,
107            kmax: -1,
108        }
109    }
110}
111
112impl OracleFeas<Arr> for LowpassOracle {
113    type CutChoice = (f64, Option<f64>); // parallel cut
114
115    /// The `assess_feas` function in Rust assesses the feasibility of a given array `x` based on
116    /// certain conditions and returns a corresponding `Cut` option.
117    ///
118    /// Arguments:
119    ///
120    /// * `x`: The `x` parameter in the `assess_feas` function is an array (`Arr`) that is passed by
121    ///        reference (`&`). It is used to perform calculations and comparisons with the elements of the
122    ///        `spectrum` array in the function.
123    ///
124    /// Returns:
125    ///
126    /// The function `assess_feas` returns an `Option` containing a tuple of type `Cut`. The `Cut` tuple
127    /// consists of two elements: a vector of coefficients (`Arr`) and a tuple of two optional values.
128    /// The first optional value represents the violation amount if the constraint is violated, and the
129    /// second optional value represents the amount to reach feasibility if the constraint is
130    /// infeasible.
131    fn assess_feas(&mut self, x: &Arr) -> Option<Cut> {
132        self.more_alt = true;
133
134        let mdim = self.spectrum.len();
135        let ndim = self.spectrum[0].len();
136        for _ in 0..self.nwpass {
137            self.idx1 += 1;
138            if self.idx1 == self.nwpass {
139                self.idx1 = 0;
140            }
141            let col_k = &self.spectrum[self.idx1 as usize];
142            // let v = col_k.iter().zip(x.iter()).map(|(&a, &b)| a * b).sum();
143            let v = col_k.dot(x);
144            if v > self.up_sq {
145                let f = (v - self.up_sq, Some(v - self.lp_sq));
146                return Some((col_k.clone(), f));
147            }
148            if v < self.lp_sq {
149                let f = (-v + self.lp_sq, Some(-v + self.up_sq));
150                return Some((col_k.iter().map(|&a| -a).collect(), f));
151            }
152        }
153
154        self.fmax = f64::NEG_INFINITY;
155        self.kmax = -1;
156        for _ in self.nwstop..mdim as i32 {
157            self.idx3 += 1;
158            if self.idx3 == mdim as i32 {
159                self.idx3 = self.nwstop;
160            }
161            let col_k = &self.spectrum[self.idx3 as usize];
162            // let v = col_k.iter().zip(x.iter()).map(|(&a, &b)| a * b).sum();
163            let v = col_k.dot(x);
164            if v > self.sp_sq {
165                return Some((col_k.clone(), (v - self.sp_sq, Some(v))));
166            }
167            if v < 0.0 {
168                return Some((
169                    col_k.iter().map(|&a| -a).collect(),
170                    (-v, Some(-v + self.sp_sq)),
171                ));
172            }
173            if v > self.fmax {
174                self.fmax = v;
175                self.kmax = self.idx3;
176            }
177        }
178
179        for _ in self.nwpass..self.nwstop {
180            self.idx2 += 1;
181            if self.idx2 == self.nwstop {
182                self.idx2 = self.nwpass;
183            }
184            let col_k = &self.spectrum[self.idx2 as usize];
185            // let v = col_k.iter().zip(x.iter()).map(|(&a, &b)| a * b).sum();
186            let v = col_k.dot(x);
187            if v < 0.0 {
188                // single cut
189                return Some((col_k.iter().map(|&a| -a).collect(), (-v, None)));
190            }
191        }
192
193        self.more_alt = false;
194
195        if x[0] < 0.0 {
196            let mut grad = Arr::zeros(ndim);
197            grad[0] = -1.0;
198            return Some((grad, (-x[0], None)));
199        }
200
201        None
202    }
203}
204
205impl OracleOptim<Arr> for LowpassOracle {
206    type CutChoice = (f64, Option<f64>); // parallel cut
207
208    /// The function assess_optim takes in parameters x and sp_sq, updates the value of sp_sq, assesses
209    /// feasibility of x, and returns a tuple containing a cut and a boolean value.
210    ///
211    /// Arguments:
212    ///
213    /// * `x`: The `x` parameter is of type `Arr`, which is likely an array or a slice of some kind. It
214    ///         is passed by reference to the `assess_optim` function.
215    /// * `sp_sq`: The `sp_sq` parameter in the `assess_optim` function is a mutable reference to a
216    ///            `f64` value. This parameter is updated within the function and its value is used to determine
217    ///            the return values of the function.
218    fn assess_optim(&mut self, x: &Arr, sp_sq: &mut f64) -> (Cut, bool) {
219        self.sp_sq = *sp_sq;
220
221        if let Some(cut) = self.assess_feas(x) {
222            return (cut, false);
223        }
224
225        let cut = (
226            self.spectrum[self.kmax as usize].clone(),
227            (0.0, Some(self.fmax)),
228        );
229        *sp_sq = self.fmax;
230        (cut, true)
231    }
232}
233
234/// The function `create_lowpass_case` in Rust calculates parameters for a lowpass filter based on given
235/// values.
236///
237/// Arguments:
238///
239/// * `ndim`: The `ndim` parameter represents the number of dimensions for the lowpass filter. It is
240///           used to create a `LowpassOracle` struct with specific parameters for the lowpass filter.
241///
242/// Returns:
243///
244/// A `LowpassOracle` struct is being returned with parameters `ndim`, `0.12`, `0.20`, `lp_sq`, `up_sq`,
245/// and `sp_sq`.
246pub fn create_lowpass_case(ndim: usize) -> LowpassOracle {
247    let delta0_wpass = 0.025;
248    let delta0_wstop = 0.125;
249    let delta1 = 20.0 * (delta0_wpass * PI).log10();
250    let delta2 = 20.0 * (delta0_wstop * PI).log10();
251
252    let low_pass = 10.0f64.powf(-delta1 / 20.0);
253    let up_pass = 10.0f64.powf(delta1 / 20.0);
254    let stop_pass = 10.0f64.powf(delta2 / 20.0);
255
256    let lp_sq = low_pass * low_pass;
257    let up_sq = up_pass * up_pass;
258    let sp_sq = stop_pass * stop_pass;
259
260    LowpassOracle::new(ndim, 0.12, 0.20, lp_sq, up_sq, sp_sq)
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    // use super::{ProfitOracle, ProfitOracleQ, ProfitRbOracle};
267    use crate::cutting_plane::{cutting_plane_optim, Options};
268    use crate::ell::Ell;
269
270    fn run_lowpass() -> (bool, usize) {
271        let ndim = 32;
272        let r0 = Arr::zeros(ndim);
273        let mut ellip = Ell::new_with_scalar(40.0, r0);
274        // ellip.helper.use_parallel_cut = use_parallel_cut;
275        let mut omega = create_lowpass_case(ndim);
276        let mut sp_sq = omega.sp_sq;
277        let options = Options {
278            max_iters: 50000,
279            tolerance: 1e-14,
280        };
281        let (h, num_iters) = cutting_plane_optim(&mut omega, &mut ellip, &mut sp_sq, &options);
282        (h.is_some(), num_iters)
283    }
284
285    #[test]
286    fn test_lowpass() {
287        let (_feasible, _num_iters) = run_lowpass();
288        // assert!(feasible);
289        // assert!(num_iters >= 23000);
290        // assert!(num_iters <= 24000);
291    }
292}