use hisab::Vec2;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Obstacle {
pub center: Vec2,
pub radius: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SteerBehavior {
Seek { target: Vec2 },
Flee { target: Vec2 },
Arrive {
target: Vec2,
slow_radius: f32,
},
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct SteerOutput {
pub velocity: Vec2,
}
impl SteerOutput {
#[cfg_attr(feature = "logging", tracing::instrument)]
#[must_use]
pub fn new(vx: f32, vy: f32) -> Self {
Self {
velocity: Vec2::new(vx, vy),
}
}
#[cfg_attr(feature = "logging", tracing::instrument)]
#[must_use]
pub fn from_vec2(velocity: Vec2) -> Self {
Self { velocity }
}
#[cfg_attr(feature = "logging", tracing::instrument(skip(self)))]
#[inline]
#[must_use]
pub fn speed(&self) -> f32 {
self.velocity.length()
}
}
#[cfg_attr(feature = "logging", tracing::instrument)]
#[inline]
#[must_use]
pub fn compute_steer(behavior: &SteerBehavior, position: Vec2, max_speed: f32) -> SteerOutput {
match behavior {
SteerBehavior::Seek { target } => {
let desired = *target - position;
let len = desired.length();
if len < f32::EPSILON {
return SteerOutput::default();
}
SteerOutput::from_vec2(desired / len * max_speed)
}
SteerBehavior::Flee { target } => {
let desired = position - *target;
let len = desired.length();
if len < f32::EPSILON {
return SteerOutput::default();
}
SteerOutput::from_vec2(desired / len * max_speed)
}
SteerBehavior::Arrive {
target,
slow_radius,
} => {
let desired = *target - position;
let dist = desired.length();
if dist < f32::EPSILON {
return SteerOutput::default();
}
let speed = if dist < *slow_radius {
max_speed * (dist / slow_radius)
} else {
max_speed
};
SteerOutput::from_vec2(desired / dist * speed)
}
}
}
#[inline]
#[must_use]
pub fn pursuit(position: Vec2, target_pos: Vec2, target_vel: Vec2, max_speed: f32) -> SteerOutput {
let to_target = target_pos - position;
let dist = to_target.length();
if dist < f32::EPSILON {
return SteerOutput::default();
}
let prediction_time = dist / max_speed;
let predicted = target_pos + target_vel * prediction_time;
compute_steer(
&SteerBehavior::Seek { target: predicted },
position,
max_speed,
)
}
#[inline]
#[must_use]
pub fn evade(position: Vec2, target_pos: Vec2, target_vel: Vec2, max_speed: f32) -> SteerOutput {
let to_target = target_pos - position;
let dist = to_target.length();
if dist < f32::EPSILON {
return SteerOutput::default();
}
let prediction_time = dist / max_speed;
let predicted = target_pos + target_vel * prediction_time;
compute_steer(
&SteerBehavior::Flee { target: predicted },
position,
max_speed,
)
}
#[inline]
#[must_use]
pub fn wander(
position: Vec2,
velocity: Vec2,
max_speed: f32,
wander_distance: f32,
wander_radius: f32,
wander_angle: f32,
) -> SteerOutput {
let speed = velocity.length();
let forward = if speed > f32::EPSILON {
velocity / speed
} else {
Vec2::new(1.0, 0.0)
};
let circle_center = position + forward * wander_distance;
let displacement = Vec2::new(wander_angle.cos(), wander_angle.sin()) * wander_radius;
let target = circle_center + displacement;
compute_steer(&SteerBehavior::Seek { target }, position, max_speed)
}
#[inline]
#[must_use]
pub fn separation(position: Vec2, neighbors: &[Vec2], radius: f32, max_force: f32) -> SteerOutput {
let mut force = Vec2::ZERO;
let mut count = 0;
for &neighbor in neighbors {
let diff = position - neighbor;
let dist = diff.length();
if dist > f32::EPSILON && dist < radius {
force += diff / (dist * dist);
count += 1;
}
}
if count > 0 {
force /= count as f32;
let len = force.length();
if len > f32::EPSILON {
force = force / len * max_force;
}
}
SteerOutput::from_vec2(force)
}
#[inline]
#[must_use]
pub fn alignment(velocity: Vec2, neighbor_velocities: &[Vec2], max_force: f32) -> SteerOutput {
if neighbor_velocities.is_empty() {
return SteerOutput::default();
}
let avg: Vec2 =
neighbor_velocities.iter().copied().sum::<Vec2>() / neighbor_velocities.len() as f32;
let desired = avg - velocity;
let len = desired.length();
if len < f32::EPSILON {
return SteerOutput::default();
}
SteerOutput::from_vec2(desired / len * max_force)
}
#[inline]
#[must_use]
pub fn cohesion(position: Vec2, neighbors: &[Vec2], max_speed: f32) -> SteerOutput {
if neighbors.is_empty() {
return SteerOutput::default();
}
let center: Vec2 = neighbors.iter().copied().sum::<Vec2>() / neighbors.len() as f32;
compute_steer(&SteerBehavior::Seek { target: center }, position, max_speed)
}
#[inline]
#[must_use]
pub fn avoid_obstacles(
position: Vec2,
velocity: Vec2,
obstacles: &[Obstacle],
look_ahead: f32,
max_force: f32,
) -> SteerOutput {
let speed = velocity.length();
if speed < f32::EPSILON {
return SteerOutput::default();
}
let forward = velocity / speed;
let lateral = Vec2::new(-forward.y, forward.x);
let mut nearest_t = f32::INFINITY;
let mut nearest_lateral_offset = 0.0f32;
let mut nearest_dist_sq = f32::INFINITY;
for obs in obstacles {
let to_obs = obs.center - position;
let forward_dot = to_obs.dot(forward);
if forward_dot < -obs.radius || forward_dot > look_ahead + obs.radius {
continue;
}
let lateral_dot = to_obs.dot(lateral);
let overlap = obs.radius - lateral_dot.abs();
if overlap < 0.0 {
continue;
}
let dist_sq = to_obs.length_squared();
if forward_dot < nearest_t || (forward_dot == nearest_t && dist_sq < nearest_dist_sq) {
nearest_t = forward_dot;
nearest_dist_sq = dist_sq;
nearest_lateral_offset = if lateral_dot >= 0.0 { -1.0 } else { 1.0 };
}
}
if nearest_t == f32::INFINITY {
return SteerOutput::default();
}
let urgency = 1.0 - (nearest_t / (look_ahead + f32::EPSILON)).clamp(0.0, 1.0);
let force = lateral * nearest_lateral_offset * max_force * urgency;
SteerOutput::from_vec2(force)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn seek_toward_target() {
let out = compute_steer(
&SteerBehavior::Seek {
target: Vec2::new(10.0, 0.0),
},
Vec2::ZERO,
5.0,
);
assert!((out.velocity.x - 5.0).abs() < 0.01);
assert!(out.velocity.y.abs() < 0.01);
}
#[test]
fn flee_from_target() {
let out = compute_steer(
&SteerBehavior::Flee {
target: Vec2::new(10.0, 0.0),
},
Vec2::ZERO,
5.0,
);
assert!((out.velocity.x - (-5.0)).abs() < 0.01);
}
#[test]
fn arrive_full_speed() {
let out = compute_steer(
&SteerBehavior::Arrive {
target: Vec2::new(100.0, 0.0),
slow_radius: 10.0,
},
Vec2::ZERO,
5.0,
);
assert!((out.speed() - 5.0).abs() < 0.01);
}
#[test]
fn arrive_slow_down() {
let out = compute_steer(
&SteerBehavior::Arrive {
target: Vec2::new(5.0, 0.0),
slow_radius: 10.0,
},
Vec2::ZERO,
10.0,
);
assert!((out.speed() - 5.0).abs() < 0.01);
}
#[test]
fn arrive_at_target() {
let out = compute_steer(
&SteerBehavior::Arrive {
target: Vec2::new(5.0, 5.0),
slow_radius: 10.0,
},
Vec2::new(5.0, 5.0),
10.0,
);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn seek_at_target() {
let out = compute_steer(&SteerBehavior::Seek { target: Vec2::ZERO }, Vec2::ZERO, 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn flee_at_target() {
let out = compute_steer(&SteerBehavior::Flee { target: Vec2::ZERO }, Vec2::ZERO, 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn steer_output_speed() {
let out = SteerOutput::new(3.0, 4.0);
assert!((out.speed() - 5.0).abs() < 0.01);
}
#[test]
fn seek_diagonal() {
let out = compute_steer(
&SteerBehavior::Seek {
target: Vec2::new(10.0, 10.0),
},
Vec2::ZERO,
1.0,
);
assert!((out.speed() - 1.0).abs() < 0.01);
assert!((out.velocity.x - out.velocity.y).abs() < 0.01);
}
#[test]
fn flee_negative_coords() {
let out = compute_steer(
&SteerBehavior::Flee {
target: Vec2::new(-10.0, -10.0),
},
Vec2::ZERO,
5.0,
);
assert!(out.velocity.x > 0.0);
assert!(out.velocity.y > 0.0);
assert!((out.speed() - 5.0).abs() < 0.01);
}
#[test]
fn arrive_at_slow_radius_boundary() {
let out = compute_steer(
&SteerBehavior::Arrive {
target: Vec2::new(10.0, 0.0),
slow_radius: 10.0,
},
Vec2::ZERO,
10.0,
);
assert!((out.speed() - 10.0).abs() < 0.01);
}
#[test]
fn steer_output_zero_speed() {
let out = SteerOutput::default();
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn steer_serde_roundtrip() {
let b = SteerBehavior::Arrive {
target: Vec2::new(1.0, 2.0),
slow_radius: 5.0,
};
let json = serde_json::to_string(&b).unwrap();
let deserialized: SteerBehavior = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, b);
}
#[test]
fn avoid_no_obstacles() {
let out = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[], 10.0, 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn avoid_obstacle_ahead() {
let obs = Obstacle {
center: Vec2::new(5.0, 0.5),
radius: 1.0,
};
let out = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[obs], 10.0, 5.0);
assert!(out.speed() > 0.0);
assert!(out.velocity.y < 0.0);
}
#[test]
fn avoid_obstacle_behind() {
let obs = Obstacle {
center: Vec2::new(-5.0, 0.0),
radius: 1.0,
};
let out = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[obs], 10.0, 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn avoid_obstacle_far_lateral() {
let obs = Obstacle {
center: Vec2::new(5.0, 10.0),
radius: 1.0,
};
let out = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[obs], 10.0, 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn avoid_nearer_obstacle_preferred() {
let near = Obstacle {
center: Vec2::new(3.0, 0.5),
radius: 1.0,
};
let far = Obstacle {
center: Vec2::new(8.0, 0.5),
radius: 1.0,
};
let out_both = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[far, near], 10.0, 5.0);
let out_near = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[near], 10.0, 5.0);
assert!((out_both.velocity.y - out_near.velocity.y).abs() < 0.01);
}
#[test]
fn avoid_zero_velocity() {
let obs = Obstacle {
center: Vec2::new(5.0, 0.0),
radius: 1.0,
};
let out = avoid_obstacles(Vec2::ZERO, Vec2::ZERO, &[obs], 10.0, 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn avoid_urgency_scales_with_distance() {
let close = Obstacle {
center: Vec2::new(2.0, 0.5),
radius: 1.0,
};
let far = Obstacle {
center: Vec2::new(9.0, 0.5),
radius: 1.0,
};
let out_close = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[close], 10.0, 5.0);
let out_far = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[far], 10.0, 5.0);
assert!(out_close.speed() > out_far.speed());
}
#[test]
fn avoid_beyond_look_ahead() {
let obs = Obstacle {
center: Vec2::new(15.0, 0.0),
radius: 1.0,
};
let out = avoid_obstacles(Vec2::ZERO, Vec2::new(1.0, 0.0), &[obs], 10.0, 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn obstacle_serde_roundtrip() {
let obs = Obstacle {
center: Vec2::new(1.0, 2.0),
radius: 3.0,
};
let json = serde_json::to_string(&obs).unwrap();
let deserialized: Obstacle = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, obs);
}
#[test]
fn pursuit_intercepts() {
let out = pursuit(
Vec2::new(0.0, -5.0),
Vec2::new(5.0, 0.0),
Vec2::new(1.0, 0.0),
5.0,
);
assert!(out.velocity.x > 0.0);
assert!(out.speed() > 0.0);
}
#[test]
fn evade_flees_predicted() {
let out = evade(
Vec2::ZERO,
Vec2::new(5.0, 0.0),
Vec2::new(-1.0, 0.0), 5.0,
);
assert!(out.velocity.x < 0.0);
}
#[test]
fn wander_produces_movement() {
let out = wander(
Vec2::ZERO,
Vec2::new(1.0, 0.0),
5.0,
2.0, 1.0, 0.5, );
assert!(out.speed() > 0.0);
}
#[test]
fn wander_stationary_agent() {
let out = wander(Vec2::ZERO, Vec2::ZERO, 5.0, 2.0, 1.0, 0.0);
assert!(out.speed() > 0.0);
}
#[test]
fn separation_pushes_apart() {
let neighbors = vec![Vec2::new(1.0, 0.0), Vec2::new(-1.0, 0.0)];
let out = separation(Vec2::ZERO, &neighbors, 5.0, 10.0);
assert!(out.speed() < 1.0);
}
#[test]
fn separation_one_neighbor() {
let neighbors = vec![Vec2::new(1.0, 0.0)];
let out = separation(Vec2::ZERO, &neighbors, 5.0, 10.0);
assert!(out.velocity.x < 0.0);
}
#[test]
fn separation_no_neighbors() {
let out = separation(Vec2::ZERO, &[], 5.0, 10.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn separation_out_of_range() {
let neighbors = vec![Vec2::new(100.0, 0.0)];
let out = separation(Vec2::ZERO, &neighbors, 5.0, 10.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn alignment_matches_heading() {
let neighbor_vels = vec![Vec2::new(3.0, 0.0), Vec2::new(3.0, 0.0)];
let out = alignment(Vec2::new(1.0, 0.0), &neighbor_vels, 5.0);
assert!(out.velocity.x > 0.0);
}
#[test]
fn alignment_no_neighbors() {
let out = alignment(Vec2::new(1.0, 0.0), &[], 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn cohesion_toward_center() {
let neighbors = vec![Vec2::new(10.0, 0.0), Vec2::new(10.0, 10.0)];
let out = cohesion(Vec2::ZERO, &neighbors, 5.0);
assert!(out.velocity.x > 0.0);
assert!(out.velocity.y > 0.0);
}
#[test]
fn cohesion_no_neighbors() {
let out = cohesion(Vec2::ZERO, &[], 5.0);
assert!(out.speed() < f32::EPSILON);
}
#[test]
fn steer_output_serde_roundtrip() {
let out = SteerOutput::new(3.0, 4.0);
let json = serde_json::to_string(&out).unwrap();
let deserialized: SteerOutput = serde_json::from_str(&json).unwrap();
assert!((deserialized.speed() - 5.0).abs() < 0.01);
}
}