use core::f32;
use derive_builder::Builder;
use ndarray::{ArrayD, Axis, IxDyn, stack};
use std::convert::TryFrom;
use tracing::{debug, trace};
use crate::dataset::LazySubset;
use crate::error::Result;
type CostArray = ndarray::Array<f32, ndarray::Dim<ndarray::IxDynImpl>>;
const HIGH_FRICTION_INVALID_COST: f32 = 1e10;
fn true_option() -> bool {
true
}
#[derive(Clone, Debug, serde::Deserialize)]
pub(crate) struct CostFunction {
cost_layers: Vec<CostLayer>,
friction_layers: Option<Vec<FrictionLayer>>,
#[serde(default = "true_option")]
pub(crate) ignore_invalid_costs: bool,
}
#[derive(Builder, Clone, Debug, serde::Deserialize)]
struct CostLayer {
layer_name: String,
#[builder(setter(strip_option), default)]
multiplier_scalar: Option<f32>,
#[builder(setter(strip_option, into), default)]
multiplier_layer: Option<String>,
#[builder(setter(strip_option), default)]
is_invariant: Option<bool>,
}
#[derive(Builder, Clone, Debug, serde::Deserialize)]
struct FrictionLayer {
multiplier_layer: String,
#[builder(setter(strip_option), default)]
multiplier_scalar: Option<f32>,
}
impl CostFunction {
pub(super) fn from_json(json: &str) -> Result<Self> {
trace!("Parsing cost definition from json: {}", json);
let cost = serde_json::from_str(json).unwrap();
Ok(cost)
}
pub(crate) fn compute(&self, features: &mut LazySubset<f32>, is_invariant: bool) -> CostArray {
debug!(
"Calculating (is_invariant={}) cost for ({})",
is_invariant,
features.subset()
);
let cost_layers: Vec<&CostLayer> = self
.cost_layers
.iter()
.filter(|layer| layer.is_invariant.unwrap_or(false) == is_invariant)
.collect();
if cost_layers.is_empty() {
return empty_cost_array(features);
}
let cost_data = cost_layers
.into_iter()
.map(|layer| build_single_cost_layer(layer, features))
.collect::<Vec<_>>();
let mut final_cost_layer = reduce_layers(cost_data);
final_cost_layer.mapv_inplace(|v| {
if v <= 0_f32 {
if self.ignore_invalid_costs {
f32::NAN
} else {
HIGH_FRICTION_INVALID_COST
}
} else {
v
}
});
let friction_data = match &self.friction_layers {
None => vec![],
Some(layers) => layers
.iter()
.map(|layer| build_single_friction_layer(layer, features))
.collect::<Vec<_>>(),
};
let mut final_friction_layer = match friction_data.is_empty() {
true => ArrayD::<f32>::zeros(IxDyn(final_cost_layer.shape())),
false => reduce_layers(friction_data),
};
if final_friction_layer.iter().any(|v| *v <= -1.0) {
tracing::warn!("Friction layer contains values <= -1; clamping to -1");
final_friction_layer.mapv_inplace(|v| if v <= -1.0 { -1.0 + 1e-7 } else { v });
}
final_cost_layer
* (ArrayD::<f32>::ones(IxDyn(final_friction_layer.shape())) + final_friction_layer)
}
}
fn empty_cost_array(features: &LazySubset<f32>) -> CostArray {
let shape: Vec<usize> = features
.subset()
.shape()
.iter()
.map(|&dim| usize::try_from(dim).expect("subset dimension exceeds usize range"))
.collect();
ArrayD::<f32>::zeros(IxDyn(&shape))
}
fn build_single_cost_layer(layer: &CostLayer, features: &mut LazySubset<f32>) -> CostArray {
let layer_name = &layer.layer_name;
trace!("Layer name: {}", layer_name);
let mut cost = features
.get(layer_name)
.expect("Layer not found in features");
if let Some(multiplier_scalar) = layer.multiplier_scalar {
trace!(
"Layer {} has multiplier scalar {}",
layer_name, multiplier_scalar
);
cost *= multiplier_scalar;
}
if let Some(multiplier_layer) = &layer.multiplier_layer {
trace!(
"Layer {} has multiplier layer {}",
layer_name, multiplier_layer
);
let multiplier_value = features
.get(multiplier_layer)
.expect("Multiplier layer not found in features");
cost = cost * multiplier_value;
}
cost
}
fn build_single_friction_layer(layer: &FrictionLayer, features: &mut LazySubset<f32>) -> CostArray {
trace!("Building friction layer: {:?}", layer);
let multiplier_layer_name = &layer.multiplier_layer;
let mut friction = features
.get(multiplier_layer_name)
.expect("Multiplier layer not found in features");
if let Some(multiplier_scalar) = layer.multiplier_scalar {
trace!("\t- Layer has multiplier scalar {}", multiplier_scalar);
friction *= multiplier_scalar;
}
friction
}
fn reduce_layers(data: Vec<CostArray>) -> CostArray {
let views: Vec<_> = data.iter().map(|a| a.view()).collect();
let stack = stack(Axis(0), &views).unwrap();
trace!("Stack shape: {:?}", stack.shape());
let final_layer = stack.sum_axis(Axis(0));
trace!("Stack shape: {:?}", stack.shape());
final_layer
}
#[cfg(test)]
pub(crate) mod sample {
use super::*;
pub(crate) fn as_text_v1() -> String {
r#"
{
"cost_layers": [
{"layer_name": "A"},
{"layer_name": "B", "multiplier_scalar": 100},
{"layer_name": "A",
"multiplier_layer": "B"},
{"layer_name": "C", "multiplier_scalar": 2,
"multiplier_layer": "A"},
{"layer_name": "C", "multiplier_scalar": 100,
"is_invariant": true}
]
}
"#
.to_string()
}
pub(crate) fn cost_function() -> CostFunction {
let json = as_text_v1();
CostFunction::from_json(&json).unwrap()
}
}
#[cfg(test)]
mod test_builder {
use super::*;
#[test]
fn costlayer() {
let layer = CostLayerBuilder::default()
.layer_name("A".to_string())
.multiplier_scalar(2.0)
.multiplier_layer("B")
.is_invariant(false)
.build()
.unwrap();
assert_eq!(layer.layer_name, "A".to_string());
assert_eq!(layer.multiplier_scalar, Some(2.0));
assert_eq!(layer.multiplier_layer, Some("B".to_string()));
assert_eq!(layer.is_invariant, Some(false));
}
#[test]
fn defaults() {
let layer = CostLayerBuilder::default()
.layer_name("A".to_string())
.build()
.unwrap();
assert_eq!(layer.layer_name, "A".to_string());
assert_eq!(layer.multiplier_scalar, None);
assert_eq!(layer.multiplier_layer, None);
assert_eq!(layer.is_invariant, None);
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::dataset::{make_lazy_subset_for_tests, samples};
use std::sync::Arc;
use zarrs::array_subset::ArraySubset;
use zarrs::filesystem::FilesystemStore;
use zarrs::storage::ReadableListableStorage;
fn make_features_for_costs_tests() -> (tempfile::TempDir, LazySubset<f32>) {
let tmp = samples::ZarrTestBuilder::new()
.dimensions(1, 8, 8)
.chunks(1, 4, 4)
.layer(samples::LayerConfig::random("A", 0.0, 1.0))
.layer(samples::LayerConfig::random("B", 0.0, 1.0))
.layer(samples::LayerConfig::random("C", 0.0, 1.0))
.layer(samples::LayerConfig::random("cost", 0.0, 1.0))
.build()
.expect("Failed to create multi-variable zarr");
let store: ReadableListableStorage = Arc::new(FilesystemStore::new(tmp.path()).unwrap());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 2, 2]).unwrap();
(tmp, make_lazy_subset_for_tests(store, subset))
}
#[test]
fn test_cost() {
let json = sample::as_text_v1();
let cost = CostFunction::from_json(&json).unwrap();
assert_eq!(cost.cost_layers.len(), 5);
assert_eq!(cost.cost_layers[0].layer_name, "A".to_string());
assert_eq!(cost.cost_layers[0].is_invariant, None);
assert_eq!(cost.cost_layers[1].layer_name, "B".to_string());
assert_eq!(cost.cost_layers[1].multiplier_scalar, Some(100.0));
assert_eq!(cost.cost_layers[1].is_invariant, None);
assert_eq!(cost.cost_layers[2].layer_name, "A".to_string());
assert_eq!(cost.cost_layers[2].multiplier_layer, Some("B".to_string()));
assert_eq!(cost.cost_layers[2].is_invariant, None);
assert_eq!(cost.cost_layers[3].layer_name, "C".to_string());
assert_eq!(cost.cost_layers[3].multiplier_layer, Some("A".to_string()));
assert_eq!(cost.cost_layers[3].multiplier_scalar, Some(2.0));
assert_eq!(cost.cost_layers[3].is_invariant, None);
assert_eq!(cost.cost_layers[4].layer_name, "C".to_string());
assert_eq!(cost.cost_layers[4].multiplier_layer, None);
assert_eq!(cost.cost_layers[4].multiplier_scalar, Some(100.0));
assert_eq!(cost.cost_layers[4].is_invariant, Some(true));
}
#[test]
fn test_friction_only_returns_zeros() {
let (_tmp, mut features) = make_features_for_costs_tests();
let json = r#"
{
"cost_layers": [],
"friction_layers": [
{"multiplier_layer": "B", "multiplier_scalar": -3.0}
]
}
"#;
let cost_fn = CostFunction::from_json(json).unwrap();
let result = cost_fn.compute(&mut features, false);
assert_eq!(result.shape(), &[1, 2, 2]);
for v in result.iter() {
assert_eq!(*v, 0.0_f32);
}
}
#[test]
fn test_friction_clamp_with_cost_layer() {
use ndarray::Zip;
let (_tmp, mut features) = make_features_for_costs_tests();
let json = r#"
{
"cost_layers": [
{"layer_name": "A"}
],
"friction_layers": [
{"multiplier_layer": "B", "multiplier_scalar": -3.0}
]
}
"#;
let cost_fn = CostFunction::from_json(json).unwrap();
let result = cost_fn.compute(&mut features, false);
let a = features.get("A").unwrap();
let b = features.get("B").unwrap();
Zip::from(&result)
.and(&a)
.and(&b)
.for_each(|r, a_item, b_item| {
let mut friction = b_item * -3.0;
if friction < -1.0 {
friction = -1.0 + 1e-7;
}
let truth = a_item * (1.0 + friction);
if *a_item > 0.0_f32 {
dbg!(r, a_item, b_item);
assert!(*r > 0.0_f32);
}
let diff = (*r - truth).abs();
assert!(diff < 1e-6, "mismatch {} vs {} (diff={})", r, truth, diff);
});
}
}