#[derive(Debug, Clone, Copy)]
pub struct Weights {
pub w_mn: f32,
pub w_neighbors: f32,
pub w_fp: f32,
}
#[allow(clippy::cast_precision_loss)]
pub fn find_weights(
w_mn_init: f32,
itr: usize,
phase_1_iters: usize,
phase_2_iters: usize,
) -> Weights {
if itr < phase_1_iters {
let progress = itr as f32 / phase_1_iters as f32;
Weights {
w_mn: (1.0 - progress) * w_mn_init + progress * 3.0,
w_neighbors: 2.0,
w_fp: 1.0,
}
} else if itr < phase_1_iters + phase_2_iters {
Weights {
w_mn: 3.0,
w_neighbors: 3.0,
w_fp: 1.0,
}
} else {
Weights {
w_mn: 0.0,
w_neighbors: 1.0,
w_fp: 1.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_find_weight() {
let w_mn_init = 1000.0;
let w0 = find_weights(w_mn_init, 0, 100, 100);
assert_abs_diff_eq!(w0.w_mn, 1000.0);
assert_abs_diff_eq!(w0.w_neighbors, 2.0);
assert_abs_diff_eq!(w0.w_fp, 1.0);
let w50 = find_weights(w_mn_init, 50, 100, 100);
assert_abs_diff_eq!(w50.w_mn, 501.5);
assert_abs_diff_eq!(w50.w_neighbors, 2.0);
assert_abs_diff_eq!(w50.w_fp, 1.0);
let w150 = find_weights(w_mn_init, 150, 100, 100);
assert_abs_diff_eq!(w150.w_mn, 3.0);
assert_abs_diff_eq!(w150.w_neighbors, 3.0);
assert_abs_diff_eq!(w150.w_fp, 1.0);
let w300 = find_weights(w_mn_init, 300, 100, 100);
assert_abs_diff_eq!(w300.w_mn, 0.0);
assert_abs_diff_eq!(w300.w_neighbors, 1.0);
assert_abs_diff_eq!(w300.w_fp, 1.0);
}
}