use super::{DistanceFunction, Generator, GeneratorWrapper, Hybrid};
use crate::{safe::SafeNode, Node};
#[derive(Clone, Debug)]
pub struct Constant {
pub value: f32,
}
#[derive(Clone, Debug)]
pub struct White {
pub seed_offset: i32,
pub output_min: f32,
pub output_max: f32,
}
#[derive(Clone, Debug)]
pub struct Checkerboard {
pub feature_scale: f32,
pub output_min: f32,
pub output_max: f32,
}
#[derive(Clone, Debug)]
pub struct SineWave {
pub feature_scale: f32,
pub output_min: f32,
pub output_max: f32,
}
#[derive(Clone, Debug)]
pub struct Gradient {
pub multiplier_x: f32,
pub multiplier_y: f32,
pub multiplier_z: f32,
pub multiplier_w: f32,
pub offset_x: f32,
pub offset_y: f32,
pub offset_z: f32,
pub offset_w: f32,
}
#[derive(Clone, Debug)]
pub struct DistanceToPoint<X, Y, Z, W, M>
where
X: Hybrid,
Y: Hybrid,
Z: Hybrid,
W: Hybrid,
M: Hybrid,
{
pub distance_function: DistanceFunction,
pub point_x: X,
pub point_y: Y,
pub point_z: Z,
pub point_w: W,
pub minkowski_p: M,
}
impl Default for Gradient {
fn default() -> Self {
Self {
multiplier_x: 0.0,
multiplier_y: 0.0,
multiplier_z: 0.0,
multiplier_w: 0.0,
offset_x: 0.0,
offset_y: 0.0,
offset_z: 0.0,
offset_w: 0.0,
}
}
}
impl Default for White {
fn default() -> Self {
Self {
seed_offset: 0,
output_min: -1.0,
output_max: 1.0,
}
}
}
impl Default for DistanceToPoint<f32, f32, f32, f32, f32> {
fn default() -> Self {
Self {
distance_function: DistanceFunction::EuclideanSquared,
point_x: 0.0,
point_y: 0.0,
point_z: 0.0,
point_w: 0.0,
minkowski_p: 1.5,
}
}
}
impl Default for Checkerboard {
fn default() -> Self {
Self {
feature_scale: 100.0,
output_min: -1.0,
output_max: 1.0,
}
}
}
impl Default for SineWave {
fn default() -> Self {
Self {
feature_scale: 100.0,
output_min: -1.0,
output_max: 1.0,
}
}
}
impl Generator for Constant {
#[cfg_attr(feature = "trace", tracing::instrument(level = "trace"))]
fn build(&self) -> GeneratorWrapper<SafeNode> {
let mut node = Node::from_name("Constant").unwrap();
node.set("Value", self.value).unwrap();
SafeNode(node.into()).into()
}
}
impl Generator for White {
#[cfg_attr(feature = "trace", tracing::instrument(level = "trace"))]
fn build(&self) -> GeneratorWrapper<SafeNode> {
let mut node = Node::from_name("White").unwrap();
node.set("SeedOffset", self.seed_offset).unwrap();
node.set("OutputMin", self.output_min).unwrap();
node.set("OutputMax", self.output_max).unwrap();
SafeNode(node.into()).into()
}
}
impl Generator for Checkerboard {
#[cfg_attr(feature = "trace", tracing::instrument(level = "trace"))]
fn build(&self) -> GeneratorWrapper<SafeNode> {
let mut node = Node::from_name("Checkerboard").unwrap();
node.set("FeatureScale", self.feature_scale).unwrap();
node.set("OutputMin", self.output_min).unwrap();
node.set("OutputMax", self.output_max).unwrap();
SafeNode(node.into()).into()
}
}
impl Generator for SineWave {
#[cfg_attr(feature = "trace", tracing::instrument(level = "trace"))]
fn build(&self) -> GeneratorWrapper<SafeNode> {
let mut node = Node::from_name("SineWave").unwrap();
node.set("FeatureScale", self.feature_scale).unwrap();
node.set("OutputMin", self.output_min).unwrap();
node.set("OutputMax", self.output_max).unwrap();
SafeNode(node.into()).into()
}
}
impl Generator for Gradient {
#[cfg_attr(feature = "trace", tracing::instrument(level = "trace"))]
fn build(&self) -> GeneratorWrapper<SafeNode> {
let mut node = Node::from_name("Gradient").unwrap();
node.set("MultiplierX", self.multiplier_x).unwrap();
node.set("MultiplierY", self.multiplier_y).unwrap();
node.set("MultiplierZ", self.multiplier_z).unwrap();
node.set("MultiplierW", self.multiplier_w).unwrap();
node.set("OffsetX", self.offset_x).unwrap();
node.set("OffsetY", self.offset_y).unwrap();
node.set("OffsetZ", self.offset_z).unwrap();
node.set("OffsetW", self.offset_w).unwrap();
SafeNode(node.into()).into()
}
}
impl<X, Y, Z, W, M> Generator for DistanceToPoint<X, Y, Z, W, M>
where
X: Hybrid,
Y: Hybrid,
Z: Hybrid,
W: Hybrid,
M: Hybrid,
{
#[cfg_attr(feature = "trace", tracing::instrument(level = "trace"))]
fn build(&self) -> GeneratorWrapper<SafeNode> {
let mut node = Node::from_name("DistanceToPoint").unwrap();
node.set("DistanceFunction", &*self.distance_function.to_string())
.unwrap();
node.set("PointX", self.point_x.clone()).unwrap();
node.set("PointY", self.point_y.clone()).unwrap();
node.set("PointZ", self.point_z.clone()).unwrap();
node.set("PointW", self.point_w.clone()).unwrap();
node.set("MinkowskiP", self.minkowski_p.clone()).unwrap();
SafeNode(node.into()).into()
}
}
pub fn constant(value: f32) -> GeneratorWrapper<Constant> {
Constant { value }.into()
}
pub fn white() -> GeneratorWrapper<White> {
White::default().into()
}
pub fn checkerboard(feature_scale: f32) -> GeneratorWrapper<Checkerboard> {
Checkerboard {
feature_scale,
..Default::default()
}
.into()
}
pub fn sinewave(feature_scale: f32) -> GeneratorWrapper<SineWave> {
SineWave {
feature_scale,
..Default::default()
}
.into()
}
pub fn gradient() -> GeneratorWrapper<Gradient> {
Gradient::default().into()
}
pub fn distance_to_point() -> GeneratorWrapper<DistanceToPoint<f32, f32, f32, f32, f32>> {
DistanceToPoint::default().into()
}
impl GeneratorWrapper<White> {
pub fn with_seed_offset(mut self, offset: i32) -> Self {
self.0.seed_offset = offset;
self
}
pub fn with_output_range(mut self, min: f32, max: f32) -> Self {
self.0.output_min = min;
self.0.output_max = max;
self
}
}
impl GeneratorWrapper<Checkerboard> {
pub fn with_feature_scale(mut self, scale: f32) -> Self {
self.0.feature_scale = scale;
self
}
pub fn with_output_range(mut self, min: f32, max: f32) -> Self {
self.0.output_min = min;
self.0.output_max = max;
self
}
}
impl GeneratorWrapper<SineWave> {
pub fn with_feature_scale(mut self, scale: f32) -> Self {
self.0.feature_scale = scale;
self
}
pub fn with_output_range(mut self, min: f32, max: f32) -> Self {
self.0.output_min = min;
self.0.output_max = max;
self
}
}
impl GeneratorWrapper<Gradient> {
pub fn with_multiplier_x(mut self, multiplier: f32) -> Self {
self.0.multiplier_x = multiplier;
self
}
pub fn with_multiplier_y(mut self, multiplier: f32) -> Self {
self.0.multiplier_y = multiplier;
self
}
pub fn with_multiplier_z(mut self, multiplier: f32) -> Self {
self.0.multiplier_z = multiplier;
self
}
pub fn with_multiplier_w(mut self, multiplier: f32) -> Self {
self.0.multiplier_w = multiplier;
self
}
pub fn with_multipliers(mut self, multipliers: [f32; 4]) -> Self {
let [mx, my, mz, mw] = multipliers;
self.0.multiplier_x = mx;
self.0.multiplier_y = my;
self.0.multiplier_z = mz;
self.0.multiplier_w = mw;
self
}
pub fn with_offset_x(mut self, offset: f32) -> Self {
self.0.offset_x = offset;
self
}
pub fn with_offset_y(mut self, offset: f32) -> Self {
self.0.offset_y = offset;
self
}
pub fn with_offset_z(mut self, offset: f32) -> Self {
self.0.offset_z = offset;
self
}
pub fn with_offset_w(mut self, offset: f32) -> Self {
self.0.offset_w = offset;
self
}
pub fn with_offsets(mut self, offsets: [f32; 4]) -> Self {
let [ox, oy, oz, ow] = offsets;
self.0.offset_x = ox;
self.0.offset_y = oy;
self.0.offset_z = oz;
self.0.offset_w = ow;
self
}
}
impl GeneratorWrapper<DistanceToPoint<f32, f32, f32, f32, f32>> {
pub fn with_distance_function(mut self, distance_function: DistanceFunction) -> Self {
self.0.distance_function = distance_function;
self
}
pub fn with_minkowski_p<M: Hybrid>(
self,
minkowski_p: M,
) -> GeneratorWrapper<DistanceToPoint<f32, f32, f32, f32, M>> {
DistanceToPoint {
distance_function: self.0.distance_function,
point_x: self.0.point_x,
point_y: self.0.point_y,
point_z: self.0.point_z,
point_w: self.0.point_w,
minkowski_p,
}
.into()
}
pub fn with_point_x<X: Hybrid>(
self,
point_x: X,
) -> GeneratorWrapper<DistanceToPoint<X, f32, f32, f32, f32>> {
DistanceToPoint {
distance_function: self.0.distance_function,
point_x,
point_y: self.0.point_y,
point_z: self.0.point_z,
point_w: self.0.point_w,
minkowski_p: self.0.minkowski_p,
}
.into()
}
pub fn with_point_y<Y: Hybrid>(
self,
point_y: Y,
) -> GeneratorWrapper<DistanceToPoint<f32, Y, f32, f32, f32>> {
DistanceToPoint {
distance_function: self.0.distance_function,
point_x: self.0.point_x,
point_y,
point_z: self.0.point_z,
point_w: self.0.point_w,
minkowski_p: self.0.minkowski_p,
}
.into()
}
pub fn with_point_z<Z: Hybrid>(
self,
point_z: Z,
) -> GeneratorWrapper<DistanceToPoint<f32, f32, Z, f32, f32>> {
DistanceToPoint {
distance_function: self.0.distance_function,
point_x: self.0.point_x,
point_y: self.0.point_y,
point_z,
point_w: self.0.point_w,
minkowski_p: self.0.minkowski_p,
}
.into()
}
pub fn with_point_w<W: Hybrid>(
self,
point_w: W,
) -> GeneratorWrapper<DistanceToPoint<f32, f32, f32, W, f32>> {
DistanceToPoint {
distance_function: self.0.distance_function,
point_x: self.0.point_x,
point_y: self.0.point_y,
point_z: self.0.point_z,
point_w,
minkowski_p: self.0.minkowski_p,
}
.into()
}
pub fn with_point(mut self, point: [f32; 4]) -> Self {
let [px, py, pz, pw] = point;
self.0.point_x = px;
self.0.point_y = py;
self.0.point_z = pz;
self.0.point_w = pw;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{generator::simplex::simplex, test_utils::*};
#[test]
fn test_constant() {
let node = constant(0.5).build();
test_generator_produces_output(node.0);
}
#[test]
fn test_white_noise() {
let node = white().build();
test_generator_produces_output(node.0);
}
#[test]
fn test_checkerboard() {
let node = checkerboard(10.0).build();
test_generator_produces_output(node.0);
}
#[test]
fn test_sinewave() {
let node = sinewave(10.0).build();
test_generator_produces_output(node.0);
}
#[test]
fn test_gradient() {
let node = gradient()
.with_multipliers([0.01, 0.01, 0.0, 0.0])
.with_offsets([0.0, 0.0, 0.0, 0.0])
.build();
test_generator_produces_output(node.0);
}
#[test]
fn test_distance_to_point() {
let node = distance_to_point()
.with_distance_function(DistanceFunction::Euclidean)
.with_point([0.0, 0.0, 0.0, 0.0])
.build();
test_generator_produces_output(node.0);
}
#[test]
fn test_gradient_builder_patterns() {
{
let gradient = gradient().with_multiplier_x(0.5).with_offset_y(1.0).build();
test_generator_produces_output(gradient.0);
}
let gradient2 = gradient()
.with_multipliers([0.1, 0.2, 0.3, 0.4])
.with_offsets([1.0, 2.0, 3.0, 4.0])
.build();
test_generator_produces_output(gradient2.0);
}
#[test]
fn test_distance_to_point_builder_patterns() {
let distance = distance_to_point()
.with_distance_function(DistanceFunction::Euclidean)
.with_point([1.0, 2.0, 3.0, 4.0])
.build();
test_generator_produces_output(distance.0);
let distance2 = distance_to_point()
.with_point_x(5.0)
.with_point_y(10.0)
.with_distance_function(DistanceFunction::Manhattan)
.build();
test_generator_produces_output(distance2.0);
}
#[test]
fn test_builder_patterns_produce_different_outputs() {
let gradient1 = gradient().build();
let gradient2 = gradient().with_multiplier_x(1.0).build();
let output1 = generate_output(&gradient1.0);
let output2 = generate_output(&gradient2.0);
assert_outputs_differ(&output1, &output2, "Gradient builder patterns");
let distance1 = distance_to_point().build();
let distance2 = distance_to_point().with_point([1.0, 0.0, 0.0, 0.0]).build();
let output3 = generate_output(&distance1.0);
let output4 = generate_output(&distance2.0);
assert_outputs_differ(&output3, &output4, "DistanceToPoint builder patterns");
}
#[test]
fn test_checkerboard_default_feature_scale() {
let checkerboard_node = checkerboard(100.0).build();
let checkerboard_100 = checkerboard(100.0).with_feature_scale(100.0).build();
let output1 = generate_output(&checkerboard_node.0);
let output2 = generate_output(&checkerboard_100.0);
let diff: f32 = output1
.iter()
.zip(output2.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff < 0.01,
"Checkerboard default feature scale test failed: outputs differ by {}",
diff
);
}
#[test]
fn test_sinewave_default_feature_scale() {
let sinewave_node = sinewave(100.0).build();
let sinewave_100 = sinewave(100.0).with_feature_scale(100.0).build();
let output1 = generate_output(&sinewave_node.0);
let output2 = generate_output(&sinewave_100.0);
let diff: f32 = output1
.iter()
.zip(output2.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff < 0.01,
"SineWave default feature scale test failed: outputs differ by {}",
diff
);
}
#[test]
fn test_param_constant_value() {
let node1 = constant(0.5).build();
let node2 = constant(0.8).build();
let output1 = generate_output(&node1.0);
let output2 = generate_output(&node2.0);
assert_outputs_differ(&output1, &output2, "Constant.Value");
}
#[test]
fn test_param_checkerboard_feature_scale() {
let node1 = checkerboard(0.5).build();
let node2 = checkerboard(2.0).build();
let output1 = generate_output(&node1.0);
let output2 = generate_output(&node2.0);
assert_outputs_differ(&output1, &output2, "Checkerboard.Feature Scale");
}
#[test]
fn test_param_sinewave_feature_scale() {
let node1 = sinewave(10.0).build();
let node2 = sinewave(20.0).build();
let output1 = generate_output(&node1.0);
let output2 = generate_output(&node2.0);
assert_outputs_differ(&output1, &output2, "SineWave.Feature Scale");
}
#[test]
fn test_param_gradient_multipliers() {
let node1 = gradient()
.with_multipliers([0.01, 0.01, 0.0, 0.0])
.with_offsets([0.0, 0.0, 0.0, 0.0])
.build();
let node2 = gradient()
.with_multipliers([0.05, 0.02, 0.0, 0.0])
.with_offsets([0.0, 0.0, 0.0, 0.0])
.build();
let output1 = generate_output(&node1.0);
let output2 = generate_output(&node2.0);
assert_outputs_differ(&output1, &output2, "Gradient.MultiplierX/Y");
}
#[test]
fn test_param_gradient_offsets() {
let node1 = gradient()
.with_multipliers([0.01, 0.01, 0.0, 0.0])
.with_offsets([0.0, 0.0, 0.0, 0.0])
.build();
let node2 = gradient()
.with_multipliers([0.01, 0.01, 0.0, 0.0])
.with_offsets([1.0, 1.0, 0.0, 0.0])
.build();
let output1 = generate_output(&node1.0);
let output2 = generate_output(&node2.0);
assert_outputs_differ(&output1, &output2, "Gradient.OffsetX/Y");
}
#[test]
fn test_param_distance_to_point_point() {
let node1 = distance_to_point()
.with_distance_function(DistanceFunction::Euclidean)
.with_point([0.0, 0.0, 0.0, 0.0])
.build();
let node2 = distance_to_point()
.with_distance_function(DistanceFunction::Euclidean)
.with_point([5.0, 5.0, 0.0, 0.0])
.build();
let output1 = generate_output(&node1.0);
let output2 = generate_output(&node2.0);
assert_outputs_differ(&output1, &output2, "DistanceToPoint.PointX/Y");
}
#[test]
fn test_param_distance_to_point_distance_function() {
let node1 = distance_to_point()
.with_distance_function(DistanceFunction::Euclidean)
.with_point([0.0, 0.0, 0.0, 0.0])
.build();
let node2 = distance_to_point()
.with_distance_function(DistanceFunction::Manhattan)
.with_point([0.0, 0.0, 0.0, 0.0])
.build();
let output1 = generate_output(&node1.0);
let output2 = generate_output(&node2.0);
assert_outputs_differ(&output1, &output2, "DistanceToPoint.Distance Function");
}
#[test]
fn test_white_builder_methods() {
let node = white()
.with_seed_offset(42)
.with_output_range(0.0, 1.0)
.build();
test_generator_produces_output(node.0);
}
#[test]
fn test_checkerboard_builder_methods() {
let node = checkerboard(10.0)
.with_feature_scale(5.0)
.with_output_range(0.0, 1.0)
.build();
test_generator_produces_output(node.0);
}
#[test]
fn test_sinewave_builder_methods() {
let node = sinewave(10.0)
.with_feature_scale(5.0)
.with_output_range(0.0, 1.0)
.build();
test_generator_produces_output(node.0);
}
#[test]
fn test_distance_to_point_minkowski() {
let node = distance_to_point()
.with_distance_function(DistanceFunction::Minkowski)
.with_point([0.0, 0.0, 0.0, 0.0])
.with_minkowski_p(2.0)
.build();
test_generator_produces_output(node.0);
}
#[test]
fn test_distance_to_point_hybrid_minkowski() {
let p_gen = simplex();
let node = distance_to_point()
.with_distance_function(DistanceFunction::Minkowski)
.with_point([0.0, 0.0, 0.0, 0.0])
.with_minkowski_p(p_gen)
.build();
test_generator_produces_output(node.0);
}
#[test]
fn test_distance_to_point_hybrid_coords() {
let x_gen = simplex();
let node = distance_to_point()
.with_distance_function(DistanceFunction::Euclidean)
.with_point([0.0, 0.0, 0.0, 0.0])
.with_point_x(x_gen)
.build();
test_generator_produces_output(node.0);
}
}