1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
use std::collections::HashMap;

mod node;
mod physics;
pub mod point2d;
pub mod integration;

use node::*;
use physics::*;
use point2d::*;
use integration::*;

pub use integration::ExplicitEuler;
pub use node::Node;
pub use physics::FnThatComputesForceFromNode;

pub struct ParSys {
  nodes: HashMap<NodeId, Node>,
  id_counter: NodeId,
  pub timestep: f32,
  forces: Vec<Force>, // Cannot be a HashSet, because Eq is not implementable.
  integration_method: IntegrationMethod
}

impl ParSys {
  pub fn new(timestep: f32, integration_method: IntegrationMethod) -> ParSys {
    ParSys {
      nodes: HashMap::new(),
      id_counter: 0,
      timestep,
      forces: vec![],
      integration_method
    }
  }

  fn compute_acc(&self, xs: &HashMap<NodeId, Point2D>, vs: &HashMap<NodeId, Point2D>) -> HashMap<NodeId, Point2D> {
    let mut accs = HashMap::with_capacity(xs.len());
    assert_eq!(xs.len(), self.nodes.len());
    assert_eq!(vs.len(), self.nodes.len());
    for (n1_id, n1) in self.nodes.iter() {
      let x1 = xs.get(n1_id).unwrap();
      let v1 = vs.get(n1_id).unwrap();

      let mut f = Point2D(0.0, 0.0);
      for force in self.forces.iter() {
        if (force.selector)(&n1, x1, v1) {
          // For now we iterate over all pairs of nodes for each force, but this
          // can be made more efficient by iterating over a subset here. (E.g.,
          // for spring forces the nodes could store from which other nodes
          // they receive it.)
          for (n2_id, n2) in self.nodes.iter() {
            // Compute the force based on the two nodes' during-integration
            // positions and velocities, and the nodes themselves.
            let x2 = xs.get(n2_id).unwrap();
            let v2 = vs.get(n2_id).unwrap();
            f += (force.force_from_node)((x1, v1), (x2, v2), n1, n2);
          }
        }
      }
      accs.insert(*n1_id, f / n1.physics.mass);
    }
    accs
  }

  // Adds a node to the system and returns the new node's id.
  pub fn add_node(&mut self, center: Point2D, radius: f32) -> NodeId {
    self.nodes.insert(
      self.id_counter,
      Node::new(self.id_counter, center, radius)
    );
    self.id_counter += 1;
    self.id_counter - 1
  }

  // Retrieve a node from the system.
  pub fn get_node(&self, id: NodeId) -> Result<&Node, &str> {
    let node = self.nodes.get(&id);
    match node {
      Some(node) => {
        Ok(node)
      }
      None => Err("Node not found.")
    }
  }

  // Simulate the system for the specified length of time.
  pub fn simulate(&mut self, seconds: f32) {
    let mut simulated = 0.0;
    while simulated < seconds {
      self.step();
      simulated += self.timestep;
    }
  }

  pub fn step(&mut self) {
    match self.integration_method {
      IntegrationMethod::ExplicitEuler => self.step_explicit_euler(),
      IntegrationMethod::RungeKutta4 => self.step_rk4(),
    }
  }

  pub fn set_velocity(&mut self, id: NodeId, velocity: Point2D) -> Result<(), &str> {
    let node = self.nodes.get_mut(&id);
    match node {
      Some(node) => {
        node.physics.velocity = velocity;
        Ok(())
      }
      None => Err("Node not found.")
    }
  }

  // Add a custom force to the system.
  pub fn add_force(&mut self,
    selector: fn(n: &Node, pos: &Point2D, vel: &Point2D) -> bool,
    force_from_node: FnThatComputesForceFromNode 
  ) {
    self.forces.push(Force {
      selector, force_from_node
    });
  }

  // Get the distance between the two nodes' center.
  pub fn get_dist(&self, id1: NodeId, id2: NodeId) -> Result<f32, &str> {
    let n1 = self.nodes.get(&id1);
    let n2 = self.nodes.get(&id2);
    match (n1, n2) {
      (Some(n1), Some(n2)) => {
        Ok((n1.physics.geometry.center - n2.physics.geometry.center).norm())
      }
      _ => Err("Invalid NodeId.")
    }
  }
}

impl Default for ParSys {
  fn default() -> ParSys {
    ParSys::new(0.01, IntegrationMethod::ExplicitEuler)
  }
}

#[cfg(test)]
mod tests {
  use super::*;

  #[test]
  fn ids_are_different() {
    let mut s = ParSys::default();
    let n1_id = s.add_node(Point2D(0.0, 0.0), 1.0);
    let n2_id = s.add_node(Point2D(0.0, 0.0), 1.0);
    assert_ne!(n1_id, n2_id);
  }

  #[test]
  fn moving_a_node() {
    let mut s = ParSys::default();
    let n_id = s.add_node(Point2D(0.0, 0.0), 1.0);
    s.set_velocity(n_id, Point2D(1.0, 0.0)).unwrap();
    
    let node_before = s.get_node(n_id).unwrap();
    assert!(node_before.physics.geometry.center.0 < 1.0);

    s.simulate(1.01);

    let node_after = s.get_node(n_id).unwrap();
    assert!(node_after.physics.geometry.center.0 > 1.0);
    assert!(node_after.physics.geometry.center.0 < 1.1);
  }

  #[test]
  fn attractive_force() {
    let mut s = ParSys::new(0.01, IntegrationMethod::ExplicitEuler);

    let n1_id = s.add_node(Point2D(1.0, 0.0), 1.0);
    let n2_id = s.add_node(Point2D(0.0, 0.0), 1.0);

    s.add_force(|_, _, _| true, |(pos1, _v1), (pos2, _v2), n, m| {
      // Don't try to apply this force to oneself.
      if n.id == m.id {
        return Point2D(0.0, 0.0);
      }

      // Vector from n to m
      let delta = *pos2 - *pos1;
      // Return a vector in the direction of m, with magnitude 1/distance
      // (Inverse distance)
      delta.normalize() / delta.norm()
    });

    let dist_before = s.get_dist(n1_id, n2_id).unwrap();
    assert!(dist_before < 1.01 && dist_before > 0.99);

    // Perform two steps (one to pick up a velocity and the second to)
    s.step();
    s.step();

    let dist_after = s.get_dist(n1_id, n2_id).unwrap();
    println!("{}", dist_after);
    assert!(dist_after < 1.0);
  }
}