1use std::f32::consts::PI;
2
3use avian3d::prelude::*;
4use bevy::{ecs::query::QueryData, prelude::*};
5use derivative::Derivative;
6#[cfg(feature = "serialize")]
7use serde::{Deserialize, Serialize};
8
9use crate::prelude::NearbyObstacles;
10
11const NUM_SLOTS: usize = 16;
12
13#[derive(Component, Debug, Copy, Clone, Reflect, Default, Deref)]
17#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
18#[cfg_attr(feature = "serialize", serde(default))]
19pub(crate) struct SpeedMask([f32; NUM_SLOTS]);
20
21impl SpeedMask {
22 pub(crate) fn slot_to_dir(slot: usize) -> Vec3 {
23 const ANGLE_PER_SLOT: f32 = 2.0 * PI / NUM_SLOTS as f32;
24 let angle = slot as f32 * ANGLE_PER_SLOT;
25 Vec3::new(angle.cos(), 0.0, angle.sin())
26 }
27
28 pub(crate) fn subtract_speed(&mut self, dir: Vec3, amount: f32) {
29 for i in 0..NUM_SLOTS {
30 let slot_dir = Self::slot_to_dir(i);
31 let factor = slot_dir.dot(dir).max(0.0);
32 self.0[i] = (self.0[i] - amount * factor).max(0.0);
33 }
34 }
35
36 pub(crate) fn get_speed(&self, dir: Vec3) -> f32 {
37 const ANGLE_PER_SLOT: f32 = 2.0 * PI / NUM_SLOTS as f32;
38
39 let angle = dir.z.atan2(dir.x);
41 let angle = if angle < 0.0 { angle + 2.0 * PI } else { angle };
43
44 let slot_f = angle / ANGLE_PER_SLOT;
46 let slot_low = slot_f.floor() as usize % NUM_SLOTS;
47 let slot_high = (slot_low + 1) % NUM_SLOTS;
48 let t = slot_f.fract();
49
50 self.0[slot_low] * (1.0 - t) + self.0[slot_high] * t
52 }
53}
54
55#[derive(Component, Debug, Copy, Clone, Reflect, Derivative, Deref)]
59#[derivative(Default)]
60#[reflect(Component)]
61#[require(SpeedMask)]
62#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
63#[cfg_attr(feature = "serialize", serde(default))]
64pub struct SpeedOverride(#[derivative(Default(value = "1.0"))] f32);
65
66impl SpeedOverride {
67 pub fn set(&mut self, value: f32) {
69 self.0 = self.0.min(value);
70 }
71
72 pub fn reset(&mut self) {
75 self.0 = 1.0;
76 }
77}
78
79#[derive(Component, Debug, Copy, Clone, Reflect, Derivative)]
83#[derivative(Default)]
84#[reflect(Component)]
85#[require(SpeedMask, SpeedOverride)]
86#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
87#[cfg_attr(feature = "serialize", serde(default))]
88pub struct SpeedController {
89 #[derivative(Default(value = "1.0"))]
94 pub obstacle_weight: f32,
95
96 #[derivative(Default(value = "1.0"))]
99 pub stopping_distance: f32,
100}
101
102impl SpeedController {
103 pub fn with_obstacle_weight(mut self, weight: f32) -> Self {
104 self.obstacle_weight = weight;
105 self
106 }
107
108 pub fn with_stopping_distance(mut self, distance: f32) -> Self {
109 self.stopping_distance = distance;
110 self
111 }
112}
113
114#[derive(QueryData)]
115pub(crate) struct SpeedControllerQuery {
116 entity: Entity,
117 obstacles: Option<&'static NearbyObstacles>,
118 speed_controller: &'static SpeedController,
119 speed_override: &'static SpeedOverride,
120 transform: &'static GlobalTransform,
121 velocity: &'static LinearVelocity,
122}
123
124pub(crate) fn speed_control(query: Query<SpeedControllerQuery>, mut commands: Commands) {
125 for item in query.iter() {
126 let speed_override = item.speed_override;
127 let mut speed_mask = SpeedMask([speed_override.0; NUM_SLOTS]);
128
129 let Some(obstacles) = item.obstacles else {
130 commands.entity(item.entity).insert(speed_mask);
131 continue;
132 };
133
134 let reference_distance = item.speed_controller.stopping_distance;
135 for obstacle in obstacles.values() {
136 let Some((impact_point, agent_point)) = obstacle.impact_points else {
138 continue;
139 };
140 let to_obstacle_dir = (impact_point - agent_point).normalize_or_zero();
141
142 let dist = obstacle.distance;
143 let normalized_dist = (dist / reference_distance).clamp(0.0, 1.0);
144 let threat = 1.0 - normalized_dist;
145 let slowdown = item.speed_controller.obstacle_weight * threat;
146 speed_mask.subtract_speed(to_obstacle_dir, slowdown);
147 }
148
149 commands.entity(item.entity).insert(speed_mask);
150 }
151}
152
153pub(crate) fn reset_speed_override(mut query: Query<&mut SpeedOverride>) {
154 for mut speed_override in query.iter_mut() {
155 speed_override.reset();
156 }
157}
158
159pub(crate) fn debug_speed(mut gizmos: Gizmos, query: Query<(&GlobalTransform, &SpeedMask)>) {
160 const BASE_LINE_LENGTH: f32 = 8.0;
161
162 for (transform, speed_mask) in query.iter() {
163 let agent_position = transform.translation();
164
165 for i in 0..NUM_SLOTS {
167 let direction = SpeedMask::slot_to_dir(i);
168
169 let mask_value = speed_mask[i];
170 if mask_value > 0.01 {
171 let danger_length = BASE_LINE_LENGTH * mask_value;
172 let end_point = agent_position + direction * danger_length;
173 let offset = Vec3::new(-0.1, 0.0, -0.1);
174 gizmos.line(
175 agent_position + offset,
176 end_point + offset,
177 Color::srgb(1.0, 1.0, 0.0),
178 );
179 }
180 }
181 }
182}