use hisab::Vec2;
use serde::{Deserialize, Serialize};
use crate::follow::PathFollower;
use crate::steer::{Obstacle, SteerOutput, avoid_obstacles};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub position: Vec2,
pub velocity: Vec2,
pub max_speed: f32,
pub max_force: f32,
follower: Option<PathFollower>,
}
impl Agent {
#[must_use]
#[cfg_attr(feature = "logging", tracing::instrument)]
pub fn new(position: Vec2, max_speed: f32, max_force: f32) -> Self {
Self {
position,
velocity: Vec2::ZERO,
max_speed,
max_force,
follower: None,
}
}
#[cfg_attr(feature = "logging", tracing::instrument(skip(self)))]
pub fn set_path(&mut self, follower: PathFollower) {
self.follower = Some(follower);
}
#[cfg_attr(feature = "logging", tracing::instrument(skip(self)))]
pub fn clear_path(&mut self) {
self.follower = None;
}
#[must_use]
#[cfg_attr(feature = "logging", tracing::instrument(skip(self)))]
pub fn has_path(&self) -> bool {
self.follower.as_ref().is_some_and(|f| !f.is_finished())
}
#[must_use]
pub fn follower(&self) -> Option<&PathFollower> {
self.follower.as_ref()
}
pub fn update(&mut self, dt: f32, obstacles: &[Obstacle]) -> SteerOutput {
let path_steer = match self.follower.as_mut() {
Some(f) if !f.is_finished() => f.steer(self.position, self.max_speed),
_ => SteerOutput::default(),
};
let avoid_steer = if !obstacles.is_empty() && self.velocity.length_squared() > f32::EPSILON
{
let look_ahead = self.max_speed * 2.0;
avoid_obstacles(
self.position,
self.velocity,
obstacles,
look_ahead,
self.max_force,
)
} else {
SteerOutput::default()
};
let combined = path_steer.velocity + avoid_steer.velocity;
let combined_len = combined.length();
let force = if combined_len > self.max_force {
combined * (self.max_force / combined_len)
} else {
combined
};
if force.length() > f32::EPSILON {
self.velocity = force;
} else {
self.velocity = Vec2::ZERO;
}
let speed = self.velocity.length();
if speed > self.max_speed {
self.velocity = self.velocity / speed * self.max_speed;
}
self.position += self.velocity * dt;
SteerOutput::from_vec2(force)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn agent_creation() {
let a = Agent::new(Vec2::ZERO, 5.0, 10.0);
assert_eq!(a.position, Vec2::ZERO);
assert_eq!(a.velocity, Vec2::ZERO);
assert!(!a.has_path());
}
#[test]
fn agent_set_path() {
let mut a = Agent::new(Vec2::ZERO, 5.0, 10.0);
let f = PathFollower::new(vec![Vec2::new(10.0, 0.0)], 0.5, 2.0);
a.set_path(f);
assert!(a.has_path());
}
#[test]
fn agent_clear_path() {
let mut a = Agent::new(Vec2::ZERO, 5.0, 10.0);
a.set_path(PathFollower::new(vec![Vec2::new(10.0, 0.0)], 0.5, 2.0));
a.clear_path();
assert!(!a.has_path());
}
#[test]
fn agent_moves_toward_target() {
let mut a = Agent::new(Vec2::ZERO, 5.0, 10.0);
a.set_path(PathFollower::new(vec![Vec2::new(10.0, 0.0)], 0.5, 2.0));
for _ in 0..100 {
a.update(0.1, &[]);
}
assert!(a.position.x > 5.0);
}
#[test]
fn agent_reaches_target() {
let mut a = Agent::new(Vec2::ZERO, 10.0, 20.0);
a.set_path(PathFollower::new(vec![Vec2::new(5.0, 0.0)], 0.5, 2.0));
for _ in 0..200 {
a.update(0.05, &[]);
if !a.has_path() {
break;
}
}
assert!(a.position.distance(Vec2::new(5.0, 0.0)) < 1.0);
}
#[test]
fn agent_no_path_stays_still() {
let mut a = Agent::new(Vec2::new(3.0, 3.0), 5.0, 10.0);
a.update(0.1, &[]);
assert!(a.position.distance(Vec2::new(3.0, 3.0)) < f32::EPSILON);
}
#[test]
fn agent_avoids_obstacle() {
let mut a = Agent::new(Vec2::ZERO, 5.0, 10.0);
a.set_path(PathFollower::new(vec![Vec2::new(20.0, 0.0)], 0.5, 2.0));
let obstacle = Obstacle {
center: Vec2::new(5.0, 0.0),
radius: 2.0,
};
for _ in 0..200 {
a.update(0.05, &[obstacle]);
}
assert!(a.position.x > 5.0);
}
#[test]
fn agent_speed_clamped() {
let mut a = Agent::new(Vec2::ZERO, 5.0, 100.0);
a.set_path(PathFollower::new(vec![Vec2::new(100.0, 0.0)], 0.5, 2.0));
for _ in 0..10 {
a.update(0.1, &[]);
}
assert!(a.velocity.length() <= 5.0 + f32::EPSILON);
}
#[test]
fn agent_serde_roundtrip() {
let mut a = Agent::new(Vec2::new(1.0, 2.0), 5.0, 10.0);
a.set_path(PathFollower::new(vec![Vec2::new(10.0, 0.0)], 0.5, 2.0));
let json = serde_json::to_string(&a).unwrap();
let deserialized: Agent = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.position, Vec2::new(1.0, 2.0));
assert!(deserialized.has_path());
}
#[test]
fn agent_multi_waypoint() {
let mut a = Agent::new(Vec2::ZERO, 10.0, 20.0);
a.set_path(PathFollower::new(
vec![
Vec2::new(5.0, 0.0),
Vec2::new(5.0, 5.0),
Vec2::new(10.0, 5.0),
],
1.0,
2.0,
));
for _ in 0..1000 {
a.update(0.02, &[]);
if !a.has_path() {
break;
}
}
assert!(a.position.distance(Vec2::new(10.0, 5.0)) < 3.0);
}
}