#![allow(clippy::module_name_repetitions)]
use crate::bounds::Bounds1;
use crate::deeppoly::deep_poly;
use crate::dnn::dnn::DNN;
use crate::dnn::dnn_iter::DNNIndex;
use crate::dnn::dnn_iter::DNNIterator;
use crate::gaussian::GaussianDistribution;
use crate::num::Float;
use crate::polytope::Polytope;
use crate::star::Star;
use crate::NNVFloat;
use log::trace;
use ndarray::Array1;
use ndarray::ArrayView1;
use ndarray::ArrayView2;
use ndarray::Dimension;
use ndarray::Ix2;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use truncnorm::tilting::TiltingSolution;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum StarNodeType {
Interpolate {
child_idx: usize,
},
Leaf {
safe_idx: Option<usize>,
unsafe_idx: Option<usize>,
},
Affine {
child_idx: usize,
},
Conv {
child_idx: usize,
},
StepRelu {
dim: usize,
fst_child_idx: usize,
snd_child_idx: Option<usize>,
},
StepReluDropOut {
dim: usize,
dropout_prob: NNVFloat,
fst_child_idx: usize,
snd_child_idx: Option<usize>,
trd_child_idx: Option<usize>,
},
}
impl StarNodeType {
pub fn get_child_ids(&self) -> Vec<usize> {
match self {
StarNodeType::Leaf {
safe_idx,
unsafe_idx,
} => IntoIterator::into_iter(vec![safe_idx, unsafe_idx])
.filter_map(|&x| x)
.collect(),
StarNodeType::Affine { child_idx }
| StarNodeType::Conv { child_idx }
| StarNodeType::Interpolate { child_idx } => {
vec![*child_idx]
}
StarNodeType::StepRelu {
dim: _,
fst_child_idx,
snd_child_idx,
} => {
let mut child_ids: Vec<usize> = vec![*fst_child_idx];
if let Some(idx) = snd_child_idx {
child_ids.push(*idx);
}
child_ids
}
StarNodeType::StepReluDropOut {
fst_child_idx,
snd_child_idx,
trd_child_idx,
..
} => {
let mut child_ids: Vec<usize> = vec![*fst_child_idx];
if let Some(idx) = snd_child_idx {
child_ids.push(*idx);
}
if let Some(idx) = trd_child_idx {
child_ids.push(*idx);
}
child_ids
}
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StarNode<D: Dimension> {
star: Star<D>,
dnn_index: DNNIndex,
star_cdf: Option<NNVFloat>,
cdf_delta: NNVFloat,
axis_aligned_input_bounds: Option<Bounds1>,
output_bounds: Option<(NNVFloat, NNVFloat)>,
gaussian_distribution: Option<GaussianDistribution>,
}
impl<D: Dimension> StarNode<D> {
pub fn default(
star: Star<D>,
axis_aligned_input_bounds: Option<Bounds1>,
initial_idx: DNNIndex,
) -> Self {
Self {
star,
dnn_index: initial_idx,
star_cdf: None,
cdf_delta: 0.,
axis_aligned_input_bounds,
output_bounds: None,
gaussian_distribution: None,
}
}
pub fn get_star(&self) -> &Star<D> {
&self.star
}
pub fn get_dnn_index(&self) -> DNNIndex {
self.dnn_index
}
pub fn try_get_cdf(&self) -> Option<NNVFloat> {
self.star_cdf
}
pub fn set_cdf(&mut self, val: NNVFloat) {
self.star_cdf = Some(val);
}
pub fn reset_cdf(&mut self) {
self.star_cdf = None;
self.cdf_delta = 0.;
}
pub fn add_cdf(&mut self, add: NNVFloat) {
self.cdf_delta += add;
}
pub fn try_get_output_bounds(&self) -> Option<(NNVFloat, NNVFloat)> {
self.output_bounds
}
pub fn set_output_bounds(&mut self, val: (NNVFloat, NNVFloat)) {
self.output_bounds = Some(val);
}
}
impl StarNode<Ix2> {
pub fn is_input_member(&self, point: &ArrayView1<NNVFloat>) -> bool {
match self.star.input_space_polytope() {
Some(poly) => poly.is_member(point),
None => true,
}
}
pub fn get_reduced_input_polytope(&self, bounds: &Option<Bounds1>) -> Option<Polytope> {
self.star
.input_space_polytope()
.and_then(|x| x.reduce_fixed_inputs(bounds))
}
pub const fn try_get_gaussian_distribution(&self) -> Option<&GaussianDistribution> {
self.gaussian_distribution.as_ref()
}
pub fn set_gaussian_distribution(&mut self, val: GaussianDistribution) {
self.gaussian_distribution = Some(val);
}
pub fn get_gaussian_distribution(
&mut self,
loc: ArrayView1<NNVFloat>,
scale: ArrayView2<NNVFloat>,
max_accept_reject_iters: usize,
stability_eps: NNVFloat,
input_bounds_opt: &Option<Bounds1>,
) -> &mut GaussianDistribution {
if self.gaussian_distribution.is_none() {
self.gaussian_distribution = self.star.get_input_trunc_gaussian(
loc,
scale,
max_accept_reject_iters,
stability_eps,
input_bounds_opt,
);
if self.gaussian_distribution.is_none() {
self.gaussian_distribution = Some(GaussianDistribution::Gaussian {
loc: loc.to_owned(),
scale: scale.diag().to_owned(),
});
}
}
self.gaussian_distribution.as_mut().unwrap()
}
pub fn forward(&self, x: &Array1<NNVFloat>) -> Array1<NNVFloat> {
self.star.get_representation().apply(&x.view())
}
#[must_use]
pub fn get_unsafe_star(&self, safe_value: NNVFloat) -> Self {
let safe_star = self.star.get_safe_subset(safe_value);
Self {
star: safe_star,
dnn_index: self.dnn_index,
star_cdf: None,
cdf_delta: 0.,
axis_aligned_input_bounds: None,
output_bounds: None,
gaussian_distribution: None,
}
}
#[must_use]
pub fn get_safe_star(&self, safe_value: NNVFloat) -> Self {
let safe_star = self.star.get_safe_subset(safe_value);
Self {
star: safe_star,
dnn_index: self.dnn_index,
star_cdf: None,
cdf_delta: 0.,
axis_aligned_input_bounds: None,
output_bounds: None,
gaussian_distribution: None,
}
}
pub fn gaussian_cdf<R: Rng>(
&mut self,
mu: ArrayView1<NNVFloat>,
sigma: ArrayView2<NNVFloat>,
n: usize,
max_iters: usize,
rng: &mut R,
stability_eps: NNVFloat,
input_bounds_opt: &Option<Bounds1>,
) -> NNVFloat {
let cdf = self.star_cdf.unwrap_or_else(|| {
let cdf: NNVFloat = self
.get_gaussian_distribution(mu, sigma, max_iters, stability_eps, input_bounds_opt)
.cdf(n, rng);
debug_assert!(cdf.is_sign_positive());
self.star_cdf = Some(cdf);
cdf
});
let cdf_sum = cdf + self.cdf_delta;
if cdf_sum.is_sign_negative() {
NNVFloat::epsilon()
} else {
cdf_sum
}
}
pub fn gaussian_sample<R: Rng>(
&mut self,
rng: &mut R,
mu: ArrayView1<NNVFloat>,
sigma: ArrayView2<NNVFloat>,
n: usize,
max_iters: usize,
tilting_initialization: Option<&TiltingSolution>,
stability_eps: NNVFloat,
input_bounds_opt: &Option<Bounds1>,
) -> Vec<Array1<NNVFloat>> {
let distribution =
self.get_gaussian_distribution(mu, sigma, max_iters, stability_eps, input_bounds_opt);
distribution.populate_tilting_solution(tilting_initialization);
distribution.sample_n(n, rng)
}
pub const fn try_get_axis_aligned_input_bounds(&self) -> &Option<Bounds1> {
&self.axis_aligned_input_bounds
}
pub fn get_axis_aligned_input_bounds(&mut self, outer_bounds: &Bounds1) -> &Bounds1 {
if self.axis_aligned_input_bounds.is_none() {
self.axis_aligned_input_bounds = Some(
self.star
.calculate_output_axis_aligned_bounding_box(outer_bounds),
);
}
self.axis_aligned_input_bounds.as_ref().unwrap()
}
pub fn get_output_bounds(
&mut self,
dnn: &DNN,
output_fn: &dyn Fn(Bounds1) -> (NNVFloat, NNVFloat),
outer_input_bounds: &Bounds1,
) -> (NNVFloat, NNVFloat) {
if self.output_bounds.is_none() {
trace!("get_output_bounds on star {:?}", self.star);
let dnn_iter = DNNIterator::new(dnn, self.dnn_index);
self.output_bounds = Some(output_fn(deep_poly(
self.get_axis_aligned_input_bounds(outer_input_bounds),
dnn,
dnn_iter,
)));
}
self.output_bounds.unwrap()
}
}