use crate::as_entity::AsEntity;
use crate::dimension::Dimension;
use crate::entity::Entity;
use crate::utilities::{find_median, max_min_xyz, xyz_distances};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone)]
pub(crate) struct Node<T: AsEntity + Clone> {
split_dimension: Option<Dimension>, split_value: f64, pub(crate) left: Option<Box<Node<T>>>, pub(crate) right: Option<Box<Node<T>>>, pub(crate) points: Option<Vec<T>>, pub(crate) center_of_mass: (f64, f64, f64),
total_mass: f64, r_max: f64, x_min: f64,
x_max: f64,
y_min: f64,
y_max: f64,
z_min: f64,
z_max: f64,
}
impl<T: AsEntity + Clone> Node<T> {
pub(crate) fn new() -> Node<T> {
Node {
split_dimension: None,
split_value: 0.0,
left: None,
right: None,
points: None,
center_of_mass: (0.0, 0.0, 0.0),
total_mass: 0.0,
r_max: 0.0,
x_min: 0.0,
x_max: 0.0,
y_min: 0.0,
y_max: 0.0,
z_min: 0.0,
z_max: 0.0,
}
}
pub(crate) fn set_max_mins(&mut self) {
let xmin = f64::min(
self.left.as_ref().unwrap().x_min,
self.right.as_ref().unwrap().x_min,
);
let xmax = f64::max(
self.left.as_ref().unwrap().x_max,
self.right.as_ref().unwrap().x_max,
);
let ymin = f64::min(
self.left.as_ref().unwrap().y_min,
self.right.as_ref().unwrap().y_min,
);
let ymax = f64::max(
self.left.as_ref().unwrap().y_max,
self.right.as_ref().unwrap().y_max,
);
let zmin = f64::min(
self.left.as_ref().unwrap().z_min,
self.right.as_ref().unwrap().z_min,
);
let zmax = f64::max(
self.left.as_ref().unwrap().z_max,
self.right.as_ref().unwrap().z_max,
);
let left_r_max = self.left.as_ref().expect("unexpected null node #7").r_max;
let right_r_max = self.right.as_ref().expect("unexpected null node #8").r_max;
self.r_max = f64::max(left_r_max, right_r_max);
self.x_min = xmin;
self.x_max = xmax;
self.y_min = ymin;
self.y_max = ymax;
self.z_min = zmin;
self.z_max = zmax;
}
pub(crate) fn as_entity(&self) -> Entity {
let (range_x, range_y, range_z) = (
self.x_max - self.x_min,
self.y_max - self.y_min,
self.z_max - self.z_min,
);
let max_dimension_range = f64::max(range_x, f64::max(range_y, range_z));
let super_radius = max_dimension_range / 2f64 + self.r_max;
Entity {
x: self.center_of_mass.0,
y: self.center_of_mass.1,
z: self.center_of_mass.2,
vx: 0.0,
vy: 0.0,
vz: 0.0,
mass: self.total_mass,
radius: super_radius,
}
}
pub(crate) fn max_distance(&self) -> f64 {
let x_distance = self.x_max - self.x_min;
let y_distance = self.y_max - self.y_min;
let z_distance = self.z_max - self.z_min;
f64::max(x_distance, f64::max(y_distance, z_distance))
}
pub(crate) fn traverse_tree_helper(&self) -> Vec<T> {
let mut to_return: Vec<T> = Vec::new();
if let Some(node) = &self.left {
to_return.append(&mut node.traverse_tree_helper());
}
if let Some(node) = &self.right {
to_return.append(&mut node.traverse_tree_helper());
} else {
to_return.append(
&mut (self
.points
.as_ref()
.expect("unexpected null node #10")
.clone()),
);
}
to_return
}
pub(crate) fn new_root_node(pts: &[T], max_entities: i32) -> Node<T> {
let length_of_points = pts.len() as i32;
let mut entities = pts.iter().map(|x| x.as_entity()).collect::<Vec<Entity>>();
let (xdistance, ydistance, zdistance) = xyz_distances(entities.as_slice());
if length_of_points <= max_entities {
let (x_total, y_total, z_total, max_radius, total_mass) =
pts.iter().fold((0.0, 0.0, 0.0, 0.0, 0.0), |acc, pt| {
let pt = pt.as_entity();
(
acc.0 + (pt.x * pt.mass),
acc.1 + (pt.y * pt.mass),
acc.2 + (pt.z * pt.mass),
if acc.3 > pt.radius { acc.3 } else { pt.radius },
acc.4 + pt.mass,
)
});
let (x_max, x_min, y_max, y_min, z_max, z_min) = max_min_xyz(&entities);
Node {
center_of_mass: (
x_total / total_mass as f64,
y_total / total_mass as f64,
z_total / total_mass as f64,
),
total_mass,
r_max: max_radius,
points: Some(pts.to_vec()),
left: None,
right: None,
split_dimension: None,
split_value: 0.0,
x_max: *x_max,
x_min: *x_min,
y_max: *y_max,
y_min: *y_min,
z_max: *z_max,
z_min: *z_min,
}
} else {
let mut root_node = Node::new();
let split_index;
let (split_dimension, split_value) = if zdistance > ydistance && zdistance > xdistance {
let (split_value, tmp) = find_median(Dimension::Z, &mut entities);
split_index = tmp;
(Dimension::Z, split_value)
} else if ydistance > xdistance && ydistance > zdistance {
let (split_value, tmp) = find_median(Dimension::Y, &mut entities);
split_index = tmp;
(Dimension::Y, split_value)
} else {
let (split_value, tmp) = find_median(Dimension::X, &mut entities);
split_index = tmp;
(Dimension::X, split_value)
};
root_node.split_dimension = Some(split_dimension);
root_node.split_value = *split_value;
let (below_split, above_split) = pts.split_at(split_index);
let left = Node::new_root_node(&below_split, max_entities);
let right = Node::new_root_node(&above_split, max_entities);
let left_mass = left.total_mass;
let right_mass = right.total_mass;
let (left_x, left_y, left_z) = left.center_of_mass;
let (right_x, right_y, right_z) = right.center_of_mass;
let total_mass = left_mass + right_mass;
assert!(total_mass != 0., "invalid mass of 0");
let (center_x, center_y, center_z) = (
((left_mass * left_x) + (right_mass * right_x)) / total_mass,
((left_mass * left_y) + (right_mass * right_y)) / total_mass,
((left_mass * left_z) + (right_mass * right_z)) / total_mass,
);
root_node.left = Some(Box::new(left));
root_node.right = Some(Box::new(right));
root_node.center_of_mass = (center_x, center_y, center_z);
root_node.set_max_mins();
root_node.total_mass = total_mass;
root_node
}
}
}
#[test]
fn test() {
use crate::{collisions::soft_body, Responsive, SimulationResult};
impl Responsive for Entity {
fn respond(&self, simulation_result: SimulationResult<Self>, time_step: f64) -> Self {
let mut vx = self.vx;
let mut vy = self.vy;
let mut vz = self.vz;
let (mut ax, mut ay, mut az) = simulation_result.gravitational_acceleration;
for other in simulation_result.collisions {
let (collision_ax, collision_ay, collision_az) = soft_body(self, other, 50f64);
ax += collision_ax;
ay += collision_ay;
az += collision_az;
}
vx += ax * time_step;
vy += ay * time_step;
vz += az * time_step;
Entity {
vx,
vy,
vz,
x: self.x + (vx * time_step),
y: self.y + (vy * time_step),
z: self.z + (vz * time_step),
radius: self.radius,
mass: self.mass,
}
}
}
let mut test_vec: Vec<Entity> = Vec::new();
for i in 0..10 {
test_vec.push(Entity {
x: i as f64,
y: (10 - i) as f64,
z: i as f64,
vx: i as f64,
vy: i as f64,
vz: i as f64,
mass: i as f64,
radius: i as f64,
});
}
let check_vec = test_vec.clone();
let tree = crate::GravTree::new(&test_vec, 0.2, 3, 0.2);
let root_node = tree.root.clone();
let mut nodes: Vec<Node<Entity>> = Vec::new();
let mut traversal_stack: Vec<Option<Box<Node<Entity>>>> = Vec::new();
let mut rover = Some(Box::new(root_node));
while !traversal_stack.is_empty() || rover.is_some() {
if rover.is_some() {
traversal_stack.push(rover.clone());
nodes.push(*rover.clone().unwrap());
rover = rover.unwrap().left;
} else {
rover = traversal_stack.pop().unwrap();
rover = rover.unwrap().right;
}
}
let post_tree_vec = tree.as_vec();
for i in check_vec.iter() {
assert!(post_tree_vec.contains(i));
}
assert_eq!(8, nodes.len());
for node in nodes.iter().skip(1) {
assert!(node.total_mass > 0.);
}
let total_mass = check_vec.iter().fold(0., |acc, x| acc + x.mass);
assert_eq!(total_mass, tree.root.left.unwrap().total_mass);
}