bevy_steering/
speed.rs

1use std::f32::consts::PI;
2
3use avian3d::prelude::*;
4use bevy::{ecs::query::QueryData, prelude::*};
5use derivative::Derivative;
6#[cfg(feature = "serialize")]
7use serde::{Deserialize, Serialize};
8
9use crate::prelude::NearbyObstacles;
10
11const NUM_SLOTS: usize = 16;
12
13/// Maximum speed from 0.0-1.0 for each direction around the agent.
14/// This is computed by the [SpeedController] to help avoid running
15/// into obstacles.
16#[derive(Component, Debug, Copy, Clone, Reflect, Default, Deref)]
17#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
18#[cfg_attr(feature = "serialize", serde(default))]
19pub(crate) struct SpeedMask([f32; NUM_SLOTS]);
20
21impl SpeedMask {
22    pub(crate) fn slot_to_dir(slot: usize) -> Vec3 {
23        const ANGLE_PER_SLOT: f32 = 2.0 * PI / NUM_SLOTS as f32;
24        let angle = slot as f32 * ANGLE_PER_SLOT;
25        Vec3::new(angle.cos(), 0.0, angle.sin())
26    }
27
28    pub(crate) fn subtract_speed(&mut self, dir: Vec3, amount: f32) {
29        for i in 0..NUM_SLOTS {
30            let slot_dir = Self::slot_to_dir(i);
31            let factor = slot_dir.dot(dir).max(0.0);
32            self.0[i] = (self.0[i] - amount * factor).max(0.0);
33        }
34    }
35
36    pub(crate) fn get_speed(&self, dir: Vec3) -> f32 {
37        const ANGLE_PER_SLOT: f32 = 2.0 * PI / NUM_SLOTS as f32;
38
39        // Get the angle from the direction vector
40        let angle = dir.z.atan2(dir.x);
41        // Normalize to [0, 2π)
42        let angle = if angle < 0.0 { angle + 2.0 * PI } else { angle };
43
44        // Find fractional slot position
45        let slot_f = angle / ANGLE_PER_SLOT;
46        let slot_low = slot_f.floor() as usize % NUM_SLOTS;
47        let slot_high = (slot_low + 1) % NUM_SLOTS;
48        let t = slot_f.fract();
49
50        // Linearly interpolate between the two adjacent slots
51        self.0[slot_low] * (1.0 - t) + self.0[slot_high] * t
52    }
53}
54
55/// Set the desired maximum speed as a fraction from [0.0-1.0]. This
56/// will act as a global speed limit by anyone that sets it. If multiple
57/// systems try to set this, only the minimum is kept.
58#[derive(Component, Debug, Copy, Clone, Reflect, Derivative, Deref)]
59#[derivative(Default)]
60#[reflect(Component)]
61#[require(SpeedMask)]
62#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
63#[cfg_attr(feature = "serialize", serde(default))]
64pub struct SpeedOverride(#[derivative(Default(value = "1.0"))] f32);
65
66impl SpeedOverride {
67    /// Set a new speed override value.
68    pub fn set(&mut self, value: f32) {
69        self.0 = self.0.min(value);
70    }
71
72    /// Reset the speed override to 1.0. This happens
73    /// automatically during FixedPreUpdate.
74    pub fn reset(&mut self) {
75        self.0 = 1.0;
76    }
77}
78
79/// The speed controller component. This component
80/// manages the [TargetSpeed] of an entity, to avoid
81/// crashing into obstacles.
82#[derive(Component, Debug, Copy, Clone, Reflect, Derivative)]
83#[derivative(Default)]
84#[reflect(Component)]
85#[require(SpeedMask, SpeedOverride)]
86#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
87#[cfg_attr(feature = "serialize", serde(default))]
88pub struct SpeedController {
89    /// The weight of the obstacle avoidance behavior. If
90    /// it's greater than 0.0, the agent will slow down
91    /// near obstacles. The default value of 1.0 will cause an agent to
92    /// completely stop to avoid a collision.
93    #[derivative(Default(value = "1.0"))]
94    pub obstacle_weight: f32,
95
96    /// The distance at which the agent will start slowing down
97    /// if it's in the direction of an obstacle.
98    #[derivative(Default(value = "1.0"))]
99    pub stopping_distance: f32,
100}
101
102impl SpeedController {
103    pub fn with_obstacle_weight(mut self, weight: f32) -> Self {
104        self.obstacle_weight = weight;
105        self
106    }
107
108    pub fn with_stopping_distance(mut self, distance: f32) -> Self {
109        self.stopping_distance = distance;
110        self
111    }
112}
113
114#[derive(QueryData)]
115pub(crate) struct SpeedControllerQuery {
116    entity: Entity,
117    obstacles: Option<&'static NearbyObstacles>,
118    speed_controller: &'static SpeedController,
119    speed_override: &'static SpeedOverride,
120    transform: &'static GlobalTransform,
121    velocity: &'static LinearVelocity,
122}
123
124pub(crate) fn speed_control(query: Query<SpeedControllerQuery>, mut commands: Commands) {
125    for item in query.iter() {
126        let speed_override = item.speed_override;
127        let mut speed_mask = SpeedMask([speed_override.0; NUM_SLOTS]);
128
129        let Some(obstacles) = item.obstacles else {
130            commands.entity(item.entity).insert(speed_mask);
131            continue;
132        };
133
134        let reference_distance = item.speed_controller.stopping_distance;
135        for obstacle in obstacles.values() {
136            // Slow down if the direction is towards the obstacle
137            let Some((impact_point, agent_point)) = obstacle.impact_points else {
138                continue;
139            };
140            let to_obstacle_dir = (impact_point - agent_point).normalize_or_zero();
141
142            let dist = obstacle.distance;
143            let normalized_dist = (dist / reference_distance).clamp(0.0, 1.0);
144            let threat = 1.0 - normalized_dist;
145            let slowdown = item.speed_controller.obstacle_weight * threat;
146            speed_mask.subtract_speed(to_obstacle_dir, slowdown);
147        }
148
149        commands.entity(item.entity).insert(speed_mask);
150    }
151}
152
153pub(crate) fn reset_speed_override(mut query: Query<&mut SpeedOverride>) {
154    for mut speed_override in query.iter_mut() {
155        speed_override.reset();
156    }
157}
158
159pub(crate) fn debug_speed(mut gizmos: Gizmos, query: Query<(&GlobalTransform, &SpeedMask)>) {
160    const BASE_LINE_LENGTH: f32 = 8.0;
161
162    for (transform, speed_mask) in query.iter() {
163        let agent_position = transform.translation();
164
165        // Draw speed mask (yellow, offset slightly)
166        for i in 0..NUM_SLOTS {
167            let direction = SpeedMask::slot_to_dir(i);
168
169            let mask_value = speed_mask[i];
170            if mask_value > 0.01 {
171                let danger_length = BASE_LINE_LENGTH * mask_value;
172                let end_point = agent_position + direction * danger_length;
173                let offset = Vec3::new(-0.1, 0.0, -0.1);
174                gizmos.line(
175                    agent_position + offset,
176                    end_point + offset,
177                    Color::srgb(1.0, 1.0, 0.0),
178                );
179            }
180        }
181    }
182}